├── .devcontainer
└── Dockerfile
├── .gitignore
├── LICENSE
├── README.md
├── imgs
├── deisam_logo.png
├── deisam_logo_eye.png
└── deisam_task.png
├── neumann
├── LICENSE
├── README.md
├── neumann
│ ├── __init__.py
│ ├── clause_generator.py
│ ├── data_behind_the_scenes.py
│ ├── data_clevr.py
│ ├── data_kandinsky.py
│ ├── data_logic.py
│ ├── data_vilp.py
│ ├── explain_clevr.py
│ ├── explanation_utils.py
│ ├── facts_converter.py
│ ├── fol
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── data_utils.py
│ │ ├── exp.lark
│ │ ├── exp_parser.py
│ │ ├── language.py
│ │ ├── logic.py
│ │ └── logic_ops.py
│ ├── img2facts.py
│ ├── lark
│ │ └── exp.lark
│ ├── logic_utils.py
│ ├── message_passing.py
│ ├── mode_declaration.py
│ ├── neumann.py
│ ├── neumann_utils.py
│ ├── neural_utils.py
│ ├── percept.py
│ ├── predict.py
│ ├── reasoning_graph.py
│ ├── refinement.py
│ ├── scatter.py
│ ├── soft_logic.py
│ ├── solve_behind_the_scenes.py
│ ├── solve_kandinsky.py
│ ├── torch_utils.py
│ ├── train_neumann.py
│ ├── valuation.py
│ ├── valuation_func.py
│ └── visualize.py
├── requirements.txt
└── setup.py
├── prompt
├── gen_constants.txt
├── gen_predicates.txt
└── gen_rules.txt
└── src
├── __init__.py
├── data_vg.py
├── deisam.py
├── deisam_utils.py
├── lark
└── exp.lark
├── learning_demo.py
├── learning_utils.py
├── llm_logic_generator.py
├── sam_utils.py
├── semantic_unifier.py
├── solve_deivg.py
├── test_sgg.py
├── visual_genome_utils.py
└── visualization_utils.py
/.devcontainer/Dockerfile:
--------------------------------------------------------------------------------
1 | # Select the base image
2 |
3 | ### To run with GPUs, use the following:
4 | FROM nvcr.io/nvidia/pytorch:21.10-py3
5 | ###
6 |
7 |
8 | ### To run without GPUs, use the following
9 | # FROM ubuntu:18.04
10 | # RUN apt-get update
11 | # RUN apt-get install -y software-properties-common
12 | # RUN add-apt-repository ppa:deadsnakes/ppa
13 | # RUN apt-get install -y python3.8-dev python3-pip
14 | # RUN rm /usr/bin/python3 && ln -s /usr/bin/python3.8 /usr/bin/python3
15 | # RUN python3 --version
16 | # RUN pip3 --version
17 | # RUN update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1
18 | ###
19 |
20 |
21 | ### other setup
22 | ENV TZ=Europe/Berlin
23 | ARG DEBIAN_FRONTEND=noninteractive
24 |
25 | # Select the working directory
26 | WORKDIR /Workspace
27 |
28 | # Install system libraries required by OpenCV.
29 | RUN apt-get update \
30 | && apt-get install -y libgl1-mesa-glx libgtk2.0-0 libsm6 libxext6 \
31 | && rm -rf /var/lib/apt/lists/*
32 |
33 | RUN pip install --upgrade pip
34 | # Install Python requirements
35 | # RUN pip install opencv-python==4.5.5.64
36 | RUN pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
37 | RUN pip install networkx==3.0 lark-parser joblib scikit-learn torchsummary setuptools tensorboard numpy>=1.18.5 tqdm>=4.41.0 matplotlib>=3.2.2 opencv-python>=4.1.2 Pillow PyYAML>=5.3.1 scipy>=1.4.1 seaborn pandas rtpt
38 |
39 | RUN pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-1.12.0+cu116.html
40 | RUN pip install torch-geometric
41 | RUN pip install wandb anytree
42 |
43 | # RUN python -m pip install -e Grounded-Segment-Anything/GroundingDINO
44 | #RUN cd GroundingDINO
45 | #RUN pip install -e .
46 | #RUN cd ..
47 | RUN pip install "opencv-python-headless<4.3"
48 | RUN pip install --upgrade diffusers[torch]
49 | RUN pip install openai==0.28.1
50 | RUN pip install pydantic==1.9.0
51 | RUN pip install visual_genome
52 |
53 | # fix opencv
54 | #RUN pip uninstall opencv-python
55 | #RUN pip uninstall opencv-contrib-python
56 | #RUN pip uninstall opencv-contrib-python-headless
57 | #RUN pip3 install opencv-contrib-python==4.5.5.62
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Hikaru Shindo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### *Refactoring is undegoing.
2 |
5 |
6 | # DeiSAM: Segment Anything with Deictic Prompting (NeurIPS 2024)
7 | [Hikaru Shindo](https://www.hikarushindo.com/), Manuel Brack, Gopika Sudhakaran, Devendra Singh Dhami, Patrick Schramowski, Kristian Kersting
8 |
9 | [AI/ML Lab @ TU Darmstadt](https://ml-research.github.io/index.html)
10 |
11 |
12 |
13 |
14 | We propose DeiSAM, which integrates large pre-trained neural networks with differentiable logic reasoners. Given a complex, textual segmentation description, DeiSAM leverages Large Language Models (LLMs) to generate first-order logic rules and performs differentiable forward reasoning on generated scene graphs.
15 |
16 |
17 | # Install
18 | [Dockerfile](.devcontainer/Dockerfile) is avaialbe in the [.devcontainer](.devcontainer) folder.
19 |
20 | To install further dependencies, clone [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) and then:
21 |
22 |
23 | ```
24 | cd neumann/
25 | pip install -e .
26 | cd ../Grounded-Segment-Anything/
27 | cd segment_anything
28 | pip install -e .
29 | cd ../GroundingDINO
30 | pip install -e .
31 | ```
32 |
33 | If an error appears regarding OpenCV (circular import), try:
34 | ```
35 | pip uninstall opencv-python
36 | pip uninstall opencv-contrib-python
37 | pip uninstall opencv-contrib-python-headless
38 | pip3 install opencv-contrib-python==4.5.5.62
39 | ```
40 |
41 | Download vit model
42 | ```
43 | wget https://huggingface.co/spaces/abhishek/StableSAM/resolve/main/sam_vit_h_4b8939.pth
44 | ```
45 |
46 | # Dataset
47 | **DeiVG datasets can be downloaded here
48 | [link](https://hessenbox.tu-darmstadt.de/getlink/fiJwsDNjdY9HDrUMf3btjoHG/).** Please locate downloaded files to `data/` as follows (make sure you are in the home folder of this project):
49 | ```
50 | mkdir data/
51 | cd data
52 | wget https://hessenbox.tu-darmstadt.de/dl/fiJwsDNjdY9HDrUMf3btjoHG/.dir -O deivg.zip
53 | unzip deivg.zip
54 | cd visual_genome
55 | unzip by-id.zip
56 | ```
57 |
58 |
59 | Please download Visual Genome images [link](https://homes.cs.washington.edu/~ranjay/visualgenome/api.html), and locate downloaded files to `data/visual_genome/` as follows:
60 | ```
61 | cd data/visual_genome
62 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip
63 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip
64 | unzip images.zip
65 | unzip iamges2.zip
66 | mv VG_100K_2/* VG_100K/
67 | ```
68 |
69 |
70 | # Experiments
71 | To solve DeiVG using DeiSAM:
72 | ```
73 | python src/solve_deivg.py --api-key YOUR_OPENAI_API_KEY -c 1
74 | python src/solve_deivg.py --api-key YOUR_OPENAI_API_KEY -c 2
75 | python src/solve_deivg.py --api-key YOUR_OPENAI_API_KEY -c 3
76 | ```
77 |
78 |
79 | The demonstration of learning can be performed by:
80 | ```
81 | python src/learning_demo.py --api-key YOUR_OPENAI_API_KEY -c 1 -sgg VETO -su
82 | python src/learning_demo.py --api-key YOUR_OPENAI_API_KEY -c 2 -sgg VETO -su
83 | ```
84 | *Note that DeiSAM is esseitially a training-free model.* Learning here is a demonstration of the learning capability by gradients. The best performance will be always achieved by using the model with ground-truth scene graphs, which corresponds to `solve_deivg.py`.
85 | In other words, DeiSAM doesn't need to be trained when the scene graphs are availale. A future plan is to mitigate the case where scene graphs are not available.
86 |
87 |
88 | # Bibtex
89 | ```
90 | @inproceedings{shindo24deisam,
91 | author = {Hikaru Shindo and
92 | Manuel Brack and
93 | Gopika Sudhakaran and
94 | Devendra Singh Dhami and
95 | Patrick Schramowski and
96 | Kristian Kersting},
97 | title = {DeiSAM: Segment Anything with Deictic Prompting},
98 | booktitle = {Proceedings of the Conference on Advances in Neural Information Processing Systems (NeurIPS)},
99 | year = {2024},
100 | }
101 |
102 | ```
103 |
104 |
105 |
106 | # LICENSE
107 | See [LICENSE](./LICENSE).
108 |
109 |
--------------------------------------------------------------------------------
/imgs/deisam_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/imgs/deisam_logo.png
--------------------------------------------------------------------------------
/imgs/deisam_logo_eye.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/imgs/deisam_logo_eye.png
--------------------------------------------------------------------------------
/imgs/deisam_task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/imgs/deisam_task.png
--------------------------------------------------------------------------------
/neumann/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 anonymous
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/neumann/README.md:
--------------------------------------------------------------------------------
1 | # Learning Differentiable Logic Programs for Abstract Visual Reasoning
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | # Abstract
10 | Visual reasoning is essential for building intelligent agents that understand the world and perform problem-solving beyond perception. Differentiable forward reasoning has been developed to integrate reasoning with gradient-based machine learning paradigms.
11 | However, due to the memory intensity, most existing approaches do not bring the best of the expressivity of first-order logic, excluding a crucial ability to solve *abstract visual reasoning*, where agents need to perform reasoning by using analogies on abstract concepts in different scenarios.
12 | To overcome this problem, we propose *NEUro-symbolic Message-pAssiNg reasoNer (NEUMANN)*, which is a graph-based differentiable forward reasoner, passing messages in a memory-efficient manner and handling structured programs with functors.
13 | Moreover, we propose a computationally-efficient structure learning algorithm to perform explanatory program induction on complex visual scenes.
14 | To evaluate, in addition to conventional visual reasoning tasks, we propose a new task, *visual reasoning behind-the-scenes*, where agents need to learn abstract programs and then answer queries by imagining scenes that are not observed.
15 | We empirically demonstrate that NEUMANN solves visual reasoning tasks efficiently, outperforming neural, symbolic, and neuro-symbolic baselines.
16 |
17 |
18 | 
19 |
20 | **NEUMANN solves Behind-the-Scenes task.**
21 | Reasoning behind the scenes: The goal of this task is to compute the answer of a query, e.g., *``What is the color of the second left-most object after deleting a gray object?''* given a visual scene. To answer this query, the agent needs to reason behind the scenes and understand abstract operations on objects. In the first task, the agent needs to induce an explicit program given visual examples, where each example consists of several visual scenes that describe the input and the output of the operation to be learned. The abstract operations can be described and computed by first-order logic with functors.
22 | In the second task, the agent needs to apply the learned programs to new situations to solve queries reasoning about non-observational scenes.
23 |
24 | ## How does it work?
25 | NEUMANN compiles *first-order logic* programs into a *graph neural network*. Logical entailment is compted using probabilistic atoms and weighted rules using fuzzy logic operations.
26 | 
27 |
28 | # Relevant Repositories
29 | [Visual ILP: A repository of the dataset generation of CLEVR images for abstract operations.](https://github.com/ml-research/visual-ilp)
30 |
31 | [Behind-the-Scenes: A repository for the generation of visual scenes and queries for the behind-the-scenes task.](https://github.com/ml-research/behind-the-scenes)
32 |
33 | # Experiments
34 |
35 | ## Prerequisites
36 | Docker container is available in folder [.devcontainer](./.devcontainer/Dockerfile),
37 | which is compatible with [packages](./pip_requirements.txt) (produced by pip freeze).
38 | The main dependent packages are:
39 | ```
40 | pytorch
41 | torch-geometric
42 | networkx
43 | ```
44 | We used Python 3.8 for the experiments.
45 | See [Dockerfile](.devcontainer/Dockerfile) for more details.
46 |
47 | ## Build a Docker container
48 | Simply use VSCode to open the container, or build the container manually:
49 | To run on machines without GPUs
50 | ```
51 | cp .devcontainer/Dockerfile_nogpu ./Dockerfile
52 | docker build -t neumann .
53 | docker run -it -v :/neumann --name neumann neumann
54 | ```
55 | For example, the local path could be: `/Users/username/Workspace/github/neumann`. The path is where this repository has been cloned.
56 |
57 | For the GPU-equipped machines, use:
58 | ```
59 | cp .devcontainer/Dockerfile ./Dockerfile
60 | docker build -t neumann .
61 | docker run -it -v :/neumann --name neumann neumann
62 | ```
63 | To open the container on machines without GPUs using VSCode, run
64 | ```
65 | cp .devcontainer/Dockerfile_nogpu .devcontainer/Dockerfile
66 | ```
67 | and use the VSCode remotehost extension (recommended).
68 |
69 |
70 |
71 | ## Perform learning
72 | For example, in the container, learning Kandinsky patterns on red triangle using the demo dataset can be performed:
73 | ```
74 | cd /neumann
75 | python3 src/train_neumann.py --dataset-type kandinsky --dataset red-triangle --num-objects 6 --batch-size 12 --no-cuda --epochs 30 --infer-step 4 --trial 5 --n-sample 10 --program-size 1 --max-var 6 --min-body-len 6 --pos-ratio 1.0 --neg-ratio 1.0
76 | ```
77 | An exenplary log can be found [redtrianlge_log.txt](./logs/redtriangle_log.txt).
78 |
79 | More scripts are available:
80 |
81 | [Learning kandinsky/clevr-hans patterns](./scripts/solve_kandinsky_clevr.sh)
82 |
83 | [Solving Behind-the-Scenes](./scripts/solve_behind-the-scenes.sh)
84 |
85 | # LICENSE
86 | See [LICENSE](./LICENSE). The [src/yolov5](./src/yolov5) folder is following [GPL3](./src/yolov5/LICENSE) license.
--------------------------------------------------------------------------------
/neumann/neumann/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/neumann/neumann/__init__.py
--------------------------------------------------------------------------------
/neumann/neumann/clause_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from anytree import Node, PreOrderIter, RenderTree
5 | from anytree.search import find_by_attr, findall
6 |
7 | from .fol.logic import Clause
8 | from .logic_utils import add_true_atom, add_true_atoms, remove_true_atoms, true
9 |
10 |
11 | class ClauseGenerator(object):
12 | """Refinement-based clause generator that holds a tree representation of the generation steps.
13 | """
14 | def __init__(self, refinement_generator, root_clauses, th_depth, n_sample):
15 | self.refinement_generator = refinement_generator
16 | self.th_depth = th_depth
17 | self.root_clauses = add_true_atoms(root_clauses)
18 | self.tree = Node(name="root", clause=Clause(true,[]))
19 | for c in root_clauses:
20 | Node(name=str(c), clause=c, parent=self.tree)
21 | #Node(name=str(nc), clause=nc, =target_node)
22 | self.n_sample = n_sample
23 | self.is_root = True
24 | self.refinement_history = set()
25 | self.refinement_score_history = set()
26 |
27 |
28 | def generate(self, clauses, clause_scores):
29 | clauses_to_refine = self.sample_clauses_by_scores(clauses, clause_scores)
30 | print("=== CLAUSES TO BE REFINED ===")
31 | for i, cr in enumerate(clauses_to_refine):
32 | print(i, ': ', cr)
33 |
34 | self.refinement_history = self.refinement_history.union(set(clauses_to_refine))
35 | # self.refinement_history = list(set(self.clause_history))
36 | new_clauses = add_true_atoms(self.apply_refinement(remove_true_atoms(clauses_to_refine)))
37 | # prune already appeared clauses
38 | new_clauses = [c for c in new_clauses if not c in self.refinement_history]
39 | return list(set(new_clauses))
40 |
41 |
42 | def sample_clauses_by_scores(self, clauses, clause_scores):
43 | clauses_to_refine = []
44 | print("Logits for the sampling: ")
45 | print(np.round(clause_scores.cpu().numpy(), 2))
46 |
47 | n_sampled = 0
48 | while n_sampled < self.n_sample:
49 | if len(clauses) == 0:
50 | # no more clauses to be sampled
51 | break
52 | i_sampled_onehot = F.gumbel_softmax(clause_scores, tau=1.0, hard=True)
53 | i_sampled = int(torch.argmax(i_sampled_onehot, dim=0).item())
54 | # selected_clause_indices = [i for i, j in enumerate(selected_clause_indices)]
55 | # clauses_to_refine_i = [c for i, c in enumerate(clauses) if selected_clause_indices[i] > 0]
56 | sampled_clause = clauses[i_sampled]
57 | score = np.round(clause_scores[i_sampled].cpu().numpy(), 2)
58 |
59 | if score in self.refinement_score_history:
60 | # if a clause with the same score is already sampled, just skip this
61 | # renormalize clause scores
62 | if i_sampled != len(clauses)-1:
63 | clause_scores = torch.cat([clause_scores[:i_sampled], clause_scores[i_sampled + 1:]])
64 | clauses.remove(sampled_clause)
65 | else:
66 | clause_scores = clause_scores[:i_sampled]
67 | clauses.remove(sampled_clause)
68 | else:
69 | # append to the result
70 | clauses_to_refine.append(sampled_clause)
71 | # update history
72 | self.refinement_score_history.add(score)
73 | self.refinement_history.add(sampled_clause)
74 | # renormalize socres
75 | if i_sampled != len(clauses)-1:
76 | clause_scores = torch.cat([clause_scores[:i_sampled], clause_scores[i_sampled + 1:]])
77 | clauses.remove(sampled_clause)
78 | else:
79 | clause_scores = clause_scores[:i_sampled]
80 | clauses.remove(sampled_clause)
81 |
82 | n_sampled += 1
83 |
84 | clauses_to_refine = list(set(clauses_to_refine))
85 | return clauses_to_refine
86 |
87 | def split_by_head_preds(self, clauses, clause_scores):
88 | head_pred_clauses_dic = {}
89 | head_pred_scores_dic = {}
90 | for i, c in enumerate(clauses):
91 | if c.head.pred in head_pred_clauses_dic:
92 | head_pred_clauses_dic[c.head.pred].append(c)
93 | head_pred_scores_dic[c.head.pred].append(clause_scores[i])
94 | else:
95 | head_pred_clauses_dic[c.head.pred] = [c]
96 | head_pred_scores_dic[c.head.pred] = [clause_scores[i]]
97 |
98 | for p in head_pred_scores_dic.keys():
99 | head_pred_scores_dic[p] = torch.tensor(head_pred_scores_dic[p])
100 | return head_pred_clauses_dic, head_pred_scores_dic
101 |
102 |
103 | def apply_refinement(self, clauses):
104 | all_new_clauses = []
105 | for clause in clauses:
106 | new_clauses = self.generate_clauses_by_refinement(clause)
107 | all_new_clauses.extend(new_clauses)
108 | # add to the refinement tree
109 | #self.print_tree()
110 | # the true atom for the clause to be refined has been removed
111 | target_node = find_by_attr(self.tree, name='clause', value=add_true_atom(clause))
112 | #print(target_node)
113 | for nc in new_clauses:
114 | all_nodes =list(PreOrderIter(self.tree))
115 | if not nc in [n.clause for n in all_nodes]:
116 | Node(name=str(nc), clause=nc, parent=target_node)
117 | """
118 | clauses_exist = [n.clause for n in target_node.children]
119 | if not nc in clauses_exist:
120 | print(target_node)
121 | print('clauses_exist: ', clauses_exist)
122 | Node(name=str(nc), clause=nc, parent=target_node)
123 | """
124 | # target_node.children.append(child_node)
125 | return all_new_clauses
126 |
127 |
128 | def generate_clauses_by_refinement(self, clause):
129 | return list(set(add_true_atoms(self.refinement_generator.refine_clause(clause))))
130 |
131 | def print_tree(self):
132 | print("-- rule generation tree --")
133 | print(RenderTree(self.tree).by_attr('clause'))
134 |
135 | def get_clauses_by_th_depth(self, th_depth):
136 | """Get all clauses that are located deeper nodes than given threashold."""
137 | nodes = findall(self.tree, filter_=lambda node: node.depth >= th_depth)
138 | return [node.clause for node in nodes]
139 | """
140 | def __sample_clauses_by_scores(self, clauses, clause_scores):
141 | clauses_to_refine = []
142 | print("Logits for the sampling: ")
143 | print(np.round(clause_scores.cpu().numpy(), 2))
144 | for i in range(clause_scores.size(0)):
145 | clauses_dic, scores_dic = self.split_by_head_preds(clauses, clause_scores[i])
146 | for p, clauses_p in clauses_dic.items():
147 | selected_clause_indices = torch.stack([F.gumbel_softmax(scores_dic[p] * 100, tau=1.0, hard=True) for j in range(int(self.n_sample / len(clauses_dic.keys())))])
148 | selected_clause_indices, _ = torch.max(selected_clause_indices, dim=0)
149 | # selected_clause_indices = [i for i, j in enumerate(selected_clause_indices)]
150 | clauses_to_refine_i = [c for i, c in enumerate(clauses_p) if selected_clause_indices[i] > 0]
151 | clauses_to_refine.extend(clauses_to_refine_i)
152 | clauses_to_refine = list(set(clauses_to_refine))
153 | return clauses_to_refine
154 |
155 |
156 | def sample_clauses_by_scores(self, clauses, clause_scores):
157 | clauses_to_refine = []
158 | print("Logits for the sampling: ")
159 | print(np.round(clause_scores.cpu().numpy(), 2))
160 | for i in range(clause_scores.size(0)):
161 | selected_clause_indices = torch.stack([F.gumbel_softmax(clause_scores[i], tau=1.0, hard=True) for j in range(int(self.n_sample))])
162 | selected_clause_indices, _ = torch.max(selected_clause_indices, dim=0)
163 | # selected_clause_indices = [i for i, j in enumerate(selected_clause_indices)]
164 | clauses_to_refine_i = [c for i, c in enumerate(clauses) if selected_clause_indices[i] > 0]
165 | clauses_to_refine.extend(clauses_to_refine_i)
166 | clauses_to_refine = list(set(clauses_to_refine))
167 | return clauses_to_refine
168 | """
--------------------------------------------------------------------------------
/neumann/neumann/data_behind_the_scenes.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.utils.data
8 | import torchvision.transforms as transforms
9 | from PIL import Image
10 |
11 |
12 | def load_image_clevr(path):
13 | """Load an image using given path.
14 | """
15 | img = cv2.imread(path) # BGR
16 | assert img is not None, 'Image Not Found ' + path
17 | img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW
18 | img = np.ascontiguousarray(img)
19 | return img
20 |
21 | def load_question_json(question_json_path):
22 | with open(question_json_path) as f:
23 | question = json.load(f)
24 | #questions = question["questions"]
25 | #return questions
26 | return question
27 |
28 | class BehindTheScenes(torch.utils.data.Dataset):
29 | def __init__(self, question_json_path, lang, n_data, device, img_size=128, base='data/behind-the-scenes/'):
30 | super().__init__()
31 | #self.colors = ["cyan", "blue", "yellow",\
32 | # "purple", "red", "green", "gray", "brown"]
33 | self.colors = ["cyan", "gray", "red", "yellow"]
34 | self.query_types = ["delete", "append", "reverse", "sort"]
35 | self.positions = ["1st", "2nd", "3rd"]
36 | self.base = base
37 | self.lang = lang
38 | self.device = device
39 | self.questions = random.sample(load_question_json(question_json_path), int(n_data))
40 | self.img_size = img_size
41 | self.transform = transforms.Compose(
42 | [transforms.Resize((img_size, img_size))]
43 | )
44 | #self.image_paths, self.answer_paths = load_images_and_labels(
45 | # dataset=dataset, split=split, base=base)
46 | # {"program": "query2(sort,2nd)", "split": "train", "image_index": 45, "answer": "red", \
47 | # "image": "BehindTheScenes_train_000045", "question_index": 10290, "question": ["sort", "2nd"], \
48 | # "image_filename": "BehindTheScenes_train_000045.png"},
49 |
50 | def __getitem__(self, item):
51 | question = self.questions[item]
52 | # print('question: ', question)
53 | image_path = self.base + 'images/' + question["image_filename"]
54 | image = Image.open(image_path).convert("RGB")
55 | image = self.image_preprocess(image)
56 | answer = self.to_onehot(self.colors.index(question["answer"]), len(self.colors))
57 | query_tuple = question["question"]
58 | query = self.to_query_vector(query_tuple)
59 | # TODO: concate and return??
60 | # answer = load_answer(self.answer_paths[item])
61 | return image, query, answer
62 |
63 | def __len__(self):
64 | return len(self.questions)
65 |
66 | def to_onehot(self, index, size):
67 | onehot = torch.zeros(size, ).to(self.device)
68 | onehot[index] = 1.0
69 | return onehot
70 |
71 | def to_query_vector(self, query_tuple):
72 | if len(query_tuple) == 3:
73 | query_type, color, position = query_tuple
74 | # ("delete", "red", "1st")
75 | q_1 = self.to_onehot(self.query_types.index(query_type), len(self.query_types))
76 | q_2 = self.to_onehot(self.colors.index(color), len(self.colors))
77 | q_3 = self.to_onehot(self.positions.index(position), len(self.positions))
78 | return torch.cat([q_1, q_2, q_3])
79 | elif len(query_tuple) == 2:
80 | query_type, position = query_tuple
81 | # ("sort", "1st")
82 | q_1 = self.to_onehot(self.query_types.index(query_type), len(self.query_types))
83 | q_2 = torch.zeros(len(self.colors, )).to(self.device)
84 | q_3 = self.to_onehot(self.positions.index(position), len(self.positions))
85 | return torch.cat([q_1, q_2, q_3])
86 |
87 |
88 | def image_preprocess(self, image):
89 | image = transforms.ToTensor()(image)[:3, :, :]
90 | image = self.transform(image)
91 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1].
92 | return image.unsqueeze(0)
93 |
94 |
--------------------------------------------------------------------------------
/neumann/neumann/data_clevr.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data
6 | import torchvision.transforms as transforms
7 | from PIL import Image
8 |
9 |
10 | def load_images_and_labels(dataset='clevr-hans3', split='train', pos_ratio=1.0, neg_ratio=1.0, base=None):
11 | """Load image paths and labels for clevr-hans dataset.
12 | """
13 | image_paths = []
14 | labels = []
15 | folder = 'data/clevr-hans/' + dataset + '/' + split + '/'
16 | true_folder = folder + 'true/'
17 | false_folder = folder + 'false/'
18 |
19 | filenames = sorted(os.listdir(true_folder))
20 | n = int(pos_ratio * len(filenames))
21 | for filename in filenames[:n]:
22 | if filename != '.DS_Store':
23 | image_paths.append(os.path.join(true_folder, filename))
24 | labels.append(1)
25 |
26 | filenames = sorted(os.listdir(false_folder))
27 | n = int(neg_ratio * len(filenames))
28 | for filename in filenames[:n]:
29 | if filename != '.DS_Store':
30 | image_paths.append(os.path.join(false_folder, filename))
31 | labels.append(0)
32 | return image_paths, labels
33 |
34 |
35 | def load_image_clevr(path):
36 | """Load an image using given path.
37 | """
38 | img = cv2.imread(path) # BGR
39 | assert img is not None, 'Image Not Found ' + path
40 | img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW
41 | img = np.ascontiguousarray(img)
42 | return img
43 |
44 |
45 |
46 | class CLEVRHans(torch.utils.data.Dataset):
47 | """CLEVRHans dataset.
48 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py.
49 | """
50 |
51 | def __init__(self, dataset, split, pos_ratio=1.0, neg_ratio=1.0, img_size=128, base=None):
52 | super().__init__()
53 | self.img_size = img_size
54 | self.dataset = dataset
55 | assert split in {
56 | "train",
57 | "val",
58 | "test",
59 | } # note: test isn't very useful since it doesn't have ground-truth scene information
60 | self.split = split
61 | self.transform = transforms.Compose(
62 | [transforms.Resize((img_size, img_size))]
63 | )
64 | self.image_paths, self.labels = load_images_and_labels(
65 | dataset=dataset, split=split, pos_ratio=pos_ratio, neg_ratio=neg_ratio)
66 |
67 | def __getitem__(self, item):
68 | path = self.image_paths[item]
69 | image = Image.open(path).convert("RGB")
70 | image = transforms.ToTensor()(image)[:3, :, :]
71 | image = self.transform(image)
72 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1].
73 | label = torch.tensor(self.labels[item], dtype=torch.float32)
74 | return image.unsqueeze(0), label
75 |
76 |
77 | def __old__getitem__(self, item):
78 | path = self.image_paths[item]
79 | image = Image.open(path).convert("RGB")
80 | image = transforms.ToTensor()(image)[:3, :, :]
81 | image = self.transform(image)
82 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1].
83 | if self.dataset == 'clevr-hans3':
84 | labels = torch.zeros((3, ), dtype=torch.float32)
85 | elif self.dataset == 'clevr-hans7':
86 | labels = torch.zeros((7, ), dtype=torch.float32)
87 | labels[self.labels[item]] = 1.0
88 | return image, labels
89 |
90 | def __len__(self):
91 | return len(self.labels)
92 |
93 | class CLEVRConcept(torch.utils.data.Dataset):
94 | """The Concept-learning dataset for CLEVR-Hans.
95 | """
96 |
97 | def __init__(self, dataset, split):
98 | self.dataset = dataset
99 | self.split = split
100 | self.data, self.labels = self.load_csv()
101 | print('concept data: ', self.data.shape, 'labels: ', len(self.labels))
102 |
103 | def load_csv(self):
104 | data = []
105 | labels = []
106 | pos_csv_data = pd.read_csv(
107 | 'data/clevr/concept_data/' + self.split + '/' + self.dataset + '_pos' + '.csv', delimiter=' ')
108 | pos_data = pos_csv_data.values
109 | #pos_labels = np.ones((len(pos_data, )))
110 | pos_labels = np.zeros((len(pos_data, )))
111 | neg_csv_data = pd.read_csv(
112 | 'data/clevr/concept_data/' + self.split + '/' + self.dataset + '_neg' + '.csv', delimiter=' ')
113 | neg_data = neg_csv_data.values
114 | #neg_labels = np.zeros((len(neg_data, )))
115 | neg_labels = np.ones((len(neg_data, )))
116 | data = torch.tensor(np.concatenate(
117 | [pos_data, neg_data], axis=0), dtype=torch.float32)
118 | labels = torch.tensor(np.concatenate(
119 | [pos_labels, neg_labels], axis=0), dtype=torch.float32)
120 | return data, labels
121 |
122 | def __getitem__(self, item):
123 | return self.data[item], self.labels[item]
124 |
125 | def __len__(self):
126 | return len(self.data)
127 |
--------------------------------------------------------------------------------
/neumann/neumann/data_kandinsky.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import torch.utils.data
7 | import torchvision.transforms as transforms
8 |
9 |
10 | class KANDINSKY(torch.utils.data.Dataset):
11 | """Kandinsky Patterns dataset.
12 | """
13 |
14 | def __init__(self, dataset, split, pos_ratio=1.0, neg_ratio=1.0, img_size=128):
15 | self.img_size = img_size
16 | assert split in {
17 | "train",
18 | "val",
19 | "test",
20 | }
21 | self.image_paths, self.labels = load_images_and_labels(
22 | dataset=dataset, split=split, pos_ratio=pos_ratio, neg_ratio=neg_ratio)
23 | print("{} {} images loaded!!".format(len(self.image_paths), split))
24 |
25 | def __getitem__(self, item):
26 | image = load_image_yolo(
27 | self.image_paths[item], img_size=self.img_size)
28 | image = torch.from_numpy(image).type(torch.float32) / 255.
29 |
30 | label = torch.tensor(self.labels[item], dtype=torch.float32)
31 |
32 | # return image as one image, not two or more
33 | return image.unsqueeze(0), label
34 |
35 | def __len__(self):
36 | return len(self.labels)
37 |
38 |
39 | def load_images_and_labels(dataset='twopairs', split='train', pos_ratio=1.0, neg_ratio=1.0, img_size=128):
40 | """Load image paths and labels for kandinsky dataset.
41 | """
42 | image_paths = []
43 | labels = []
44 | folder = 'data/kandinsky/' + dataset + '/' + split + '/'
45 | true_folder = folder + 'true/'
46 | false_folder = folder + 'false/'
47 |
48 | filenames = sorted(os.listdir(true_folder))
49 | #if split == 'train':
50 | # n = int(len(filenames)/10)
51 | #else:
52 | # n = len(filenames)
53 | n = int(pos_ratio * len(filenames))
54 | for filename in filenames[:n]:
55 | if filename != '.DS_Store':
56 | image_paths.append(os.path.join(true_folder, filename))
57 | labels.append(1)
58 |
59 | filenames = sorted(os.listdir(false_folder))
60 | n = int(neg_ratio * len(filenames))
61 | for filename in filenames[:n]:
62 | if filename != '.DS_Store':
63 | image_paths.append(os.path.join(false_folder, filename))
64 | labels.append(0)
65 | return image_paths, labels
66 |
67 |
68 | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
69 | """A utilitiy function for yolov5 model to make predictions. The implementation is from the yolov5 repository.
70 | """
71 | import cv2
72 |
73 | # Resize and pad image while meeting stride-multiple constraints
74 | shape = img.shape[:2] # current shape [height, width]
75 | if isinstance(new_shape, int):
76 | new_shape = (new_shape, new_shape)
77 |
78 | # Scale ratio (new / old)
79 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
80 | if not scaleup: # only scale down, do not scale up (for better test mAP)
81 | r = min(r, 1.0)
82 |
83 | # Compute padding
84 | ratio = r, r # width, height ratios
85 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
86 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - \
87 | new_unpad[1] # wh padding
88 | if auto: # minimum rectangle
89 | dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
90 | elif scaleFill: # stretch
91 | dw, dh = 0.0, 0.0
92 | new_unpad = (new_shape[1], new_shape[0])
93 | ratio = new_shape[1] / shape[1], new_shape[0] / \
94 | shape[0] # width, height ratios
95 |
96 | dw /= 2 # divide padding into 2 sides
97 | dh /= 2
98 |
99 | if shape[::-1] != new_unpad: # resize
100 | img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
101 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
102 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
103 | img = cv2.copyMakeBorder(
104 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
105 | return img, ratio, (dw, dh)
106 |
107 |
108 | def load_image_yolo(path, img_size, stride=32):
109 | """Load an image using given path.
110 | """
111 | img0 = cv2.imread(path) # BGR
112 | assert img0 is not None, 'Image Not Found ' + path
113 | img = cv2.resize(img0, (img_size, img_size))
114 |
115 | # Convert
116 | img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW
117 | img = np.ascontiguousarray(img)
118 | return img
119 |
--------------------------------------------------------------------------------
/neumann/neumann/data_logic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 | from PIL import Image
7 |
8 |
9 | class RandomValuation(torch.utils.data.Dataset):
10 | """CLEVRHans dataset.
11 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py.
12 | """
13 |
14 | def __init__(self, dataset, split, atoms, n_data=2000):
15 | super().__init__()
16 | self.atoms = atoms
17 | self.n_data = n_data
18 | self.dataset = dataset
19 | assert split in {
20 | "train",
21 | "val",
22 | "test",
23 | } # note: test isn't very useful since it doesn't have ground-truth scene information
24 | self.split = split
25 |
26 |
27 | def __getitem__(self, item):
28 | return torch.rand((len(self.atoms), ))
29 |
30 | def __len__(self):
31 | return self.n_data
32 |
33 | class ZeroValuation(torch.utils.data.Dataset):
34 | """CLEVRHans dataset.
35 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py.
36 | """
37 |
38 | def __init__(self, dataset, split, atoms, n_data=2000):
39 | super().__init__()
40 | self.atoms = atoms
41 | self.n_data = n_data
42 | self.dataset = dataset
43 | assert split in {
44 | "train",
45 | "val",
46 | "test",
47 | } # note: test isn't very useful since it doesn't have ground-truth scene information
48 | self.split = split
49 |
50 |
51 | def __getitem__(self, item):
52 | return torch.zeros((len(self.atoms), ))
53 |
54 | def __len__(self):
55 | return self.n_data
56 |
--------------------------------------------------------------------------------
/neumann/neumann/data_vilp.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data
6 | import torchvision.transforms as transforms
7 | from PIL import Image
8 |
9 |
10 | def load_images_and_labels(dataset='member', split='train', base=None, pos_ratio=1.0, neg_ratio=1.0):
11 | """Load image paths and labels for clevr-hans dataset.
12 | """
13 | image_paths_list = []
14 | labels = []
15 | base_folder = 'data/vilp/' + dataset + '/' + split + '/true/'
16 | folder_names = sorted(os.listdir(base_folder))
17 | if '.DS_Store' in folder_names:
18 | folder_names.remove('.DS_Store')
19 |
20 | if split == 'train':
21 | n = int(len(folder_names) * pos_ratio)
22 | else:
23 | n = len(folder_names)
24 | for folder_name in folder_names[:n]:
25 | folder = base_folder + folder_name + '/'
26 | filenames = sorted(os.listdir(folder))
27 | image_paths = []
28 | for filename in filenames:
29 | if filename != '.DS_Store':
30 | image_paths.append(os.path.join(folder, filename))
31 | image_paths_list.append(image_paths)
32 | labels.append(1.0)
33 | base_folder = 'data/vilp/' + dataset + '/' + split + '/false/'
34 | if split == 'train':
35 | n = int(len(folder_names) * neg_ratio)
36 | else:
37 | n = len(folder_names)
38 | folder_names = sorted(os.listdir(base_folder))
39 | if '.DS_Store' in folder_names:
40 | folder_names.remove('.DS_Store')
41 | for folder_name in folder_names[:n]:
42 | folder = base_folder + folder_name + '/'
43 | filenames = sorted(os.listdir(folder))
44 | image_paths = []
45 | for filename in filenames:
46 | if filename != '.DS_Store':
47 | image_paths.append(os.path.join(folder, filename))
48 | image_paths_list.append(image_paths)
49 | labels.append(0.0)
50 | return image_paths_list, labels
51 |
52 |
53 | def load_images_and_labels_positive(dataset='member', split='train', base=None):
54 | """Load image paths and labels for clevr-hans dataset.
55 | """
56 | image_paths_list = []
57 | labels = []
58 | base_folder = 'data/vilp/' + dataset + '/' + split + '/true/'
59 | folder_names = sorted(os.listdir(base_folder))
60 | if '.DS_Store' in folder_names:
61 | folder_names.remove('.DS_Store')
62 |
63 | for folder_name in folder_names:
64 | folder = base_folder + folder_name + '/'
65 | filenames = sorted(os.listdir(folder))
66 | image_paths = []
67 | for filename in filenames:
68 | if filename != '.DS_Store':
69 | image_paths.append(os.path.join(folder, filename))
70 | image_paths_list.append(image_paths)
71 | labels.append(1.0)
72 | return image_paths_list, labels
73 |
74 | def load_image_clevr(path):
75 | """Load an image using given path.
76 | """
77 | img = cv2.imread(path) # BGR
78 | assert img is not None, 'Image Not Found ' + path
79 | img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW
80 | img = np.ascontiguousarray(img)
81 | return img
82 |
83 |
84 | class VisualILP(torch.utils.data.Dataset):
85 | """CLEVRHans dataset.
86 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py.
87 | """
88 |
89 | def __init__(self, dataset, split, img_size=128, base=None, pos_ratio=1.0, neg_ratio=1.0):
90 | super().__init__()
91 | self.img_size = img_size
92 | self.dataset = dataset
93 | assert split in {
94 | "train",
95 | "val",
96 | "test",
97 | } # note: test isn't very useful since it doesn't have ground-truth scene information
98 | self.split = split
99 | self.transform = transforms.Compose(
100 | [transforms.Resize((img_size, img_size))]
101 | )
102 | self.pos_ratio = pos_ratio
103 | self.neg_ratio = neg_ratio
104 | self.image_paths, self.labels = load_images_and_labels(
105 | dataset=dataset, split=split, base=base, pos_ratio=pos_ratio, neg_ratio=neg_ratio)
106 |
107 | def __getitem__(self, item):
108 | paths = self.image_paths[item]
109 | images = []
110 | for path in paths:
111 | image = Image.open(path).convert("RGB")
112 | image = transforms.ToTensor()(image)[:3, :, :]
113 | image = self.transform(image)
114 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1].
115 | images.append(image)
116 | # TODO: concate and return??
117 | image = torch.stack(images, dim=0)
118 | return image, self.labels[item]
119 |
120 | def __len__(self):
121 | return len(self.labels)
122 |
123 | class VisualILP_POSITIVE(torch.utils.data.Dataset):
124 | """CLEVRHans dataset.
125 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py.
126 | """
127 |
128 | def __init__(self, dataset, split, img_size=128, base=None):
129 | super().__init__()
130 | self.img_size = img_size
131 | self.dataset = dataset
132 | assert split in {
133 | "train",
134 | "val",
135 | "test",
136 | } # note: test isn't very useful since it doesn't have ground-truth scene information
137 | self.split = split
138 | self.transform = transforms.Compose(
139 | [transforms.Resize((img_size, img_size))]
140 | )
141 | self.image_paths, self.labels = load_images_and_labels_positive(
142 | dataset=dataset, split=split, base=base)
143 |
144 | def __getitem__(self, item):
145 | paths = self.image_paths[item]
146 | images = []
147 | for path in paths:
148 | image = Image.open(path).convert("RGB")
149 | image = transforms.ToTensor()(image)[:3, :, :]
150 | image = self.transform(image)
151 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1].
152 | images.append(image)
153 | # TODO: concate and return??
154 | image = torch.stack(images, dim=0)
155 | return image, self.labels[item]
156 |
157 | def __len__(self):
158 | return len(self.labels)
159 |
--------------------------------------------------------------------------------
/neumann/neumann/explain_clevr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | import time
5 |
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | from rtpt import RTPT
10 | from torch.utils.tensorboard import SummaryWriter
11 | from tqdm import tqdm
12 |
13 | from explanation_utils import *
14 | from logic_utils import get_lang
15 | from neumann_utils import get_data_loader, get_model, get_prob
16 |
17 | torch.set_num_threads(10)
18 |
19 | def get_args():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--batch-size", type=int, default=10,
22 | help="Batch size to infer with")
23 | parser.add_argument("--batch-size-bs", type=int,
24 | default=1, help="Batch size in beam search")
25 | parser.add_argument("--num-objects", type=int, default=6,
26 | help="The maximum number of objects in one image")
27 | parser.add_argument("--dataset", default="delete") # , choices=["member"])
28 | parser.add_argument("--dataset-type", default="behind-the-scenes")
29 | parser.add_argument('--device', default='cpu',
30 | help='cuda device, i.e. 0 or cpu')
31 | parser.add_argument("--no-cuda", action="store_true",
32 | help="Run on CPU instead of GPU (not recommended)")
33 | parser.add_argument("--no-train", action="store_true",
34 | help="Perform prediction without training model")
35 | parser.add_argument("--small-data", action="store_true",
36 | help="Use small training data.")
37 | parser.add_argument("--num-workers", type=int, default=0,
38 | help="Number of threads for data loader")
39 | parser.add_argument('--gamma', default=0.01, type=float,
40 | help='Smooth parameter in the softor function')
41 | parser.add_argument("--plot", action="store_true",
42 | help="Plot images with captions.")
43 | parser.add_argument("--t-beam", type=int, default=4,
44 | help="Number of rule expantion of clause generation.")
45 | parser.add_argument("--n-beam", type=int, default=5,
46 | help="The size of the beam.")
47 | parser.add_argument("--n-max", type=int, default=50,
48 | help="The maximum number of clauses.")
49 | parser.add_argument("--program-size", type=int, default=1,
50 | help="The size of the logic program.")
51 | #parser.add_argument("--n-obj", type=int, default=2, help="The number of objects to be focused.")
52 | parser.add_argument("--epochs", type=int, default=20,
53 | help="The number of epochs.")
54 | parser.add_argument("--lr", type=float, default=1e-2,
55 | help="The learning rate.")
56 | parser.add_argument("--n-ratio", type=float, default=1.0,
57 | help="The ratio of data to be used.")
58 | parser.add_argument("--pre-searched", action="store_true",
59 | help="Using pre searched clauses.")
60 | parser.add_argument("--infer-step", type=int, default=6,
61 | help="The number of steps of forward reasoning.")
62 | parser.add_argument("--term-depth", type=int, default=3,
63 | help="The number of steps of forward reasoning.")
64 | parser.add_argument("--question-json-path", default="data/behind-the-scenes/BehindTheScenes_questions.json")
65 | args = parser.parse_args()
66 | return args
67 |
68 | # def get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False):
69 |
70 |
71 | def discretise_NEUMANN(NEUMANN, args, device):
72 | lark_path = 'src/lark/exp.lark'
73 | lang_base_path = 'data/lang/'
74 | lang, clauses_, bk, terms, atoms = get_lang(
75 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth)
76 | # Discretise NEUMANN rules
77 | clauses = NEUMANN.get_clauses()
78 | return get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False)
79 |
80 | def predict(NEUMANN, I2F, loader, args, device, th=None, split='train'):
81 | predicted_list = []
82 | target_list = []
83 | count = 0
84 |
85 | start = time.time()
86 | for epoch in tqdm(range(args.epochs)):
87 | for i, sample in enumerate(tqdm(loader), start=0):
88 | imgs, target_set = map(lambda x: x.to(device), sample)
89 | # to cuda
90 | target_set = target_set.float()
91 |
92 | #imgs: torch.Size([1, 1, 3, 128, 128])
93 | V_0 = I2F(imgs)
94 | V_T = NEUMANN(V_0)
95 | #a NEUMANN.print_valuation_batch(V_T)
96 | predicted = get_prob(V_T, NEUMANN, args)
97 |
98 | # compute explanation for each input image
99 | for pred in predicted:
100 | if pred > 0.9:
101 | pred.backward(retain_graph=True)
102 | atom_grads = NEUMANN.mpm.dummy_zeros.grad.squeeze(-1).unsqueeze(0)
103 | attention_maps = I2F.pm.model.slot_attention.attention_maps.squeeze(0)
104 | target_attention_maps = get_target_maps(NEUMANN.atoms, atom_grads, attention_maps)
105 | #print(atom_grads, torch.max(atom_grads), atom_grads.shape)
106 | NEUMANN.print_valuation_batch(atom_grads)
107 |
108 |
109 |
110 | imgs_to_plot = to_plot_images_clevr(imgs.squeeze(0).detach().cpu())
111 | captions = generate_captions(atom_grads, NEUMANN.atoms, args.num_objects, th=0.33)
112 | # + args.dataset + '/' + split + '/', \
113 | save_images_with_captions_and_attention_maps(imgs_to_plot, target_attention_maps, captions, folder='explanation/clevr/', \
114 | img_id=count, dataset=args.dataset)
115 | NEUMANN.mpm.dummy_zeros.grad.detach_()
116 | NEUMANN.mpm.dummy_zeros.grad.zero_()
117 | count += 1
118 | reasoning_time = time.time() - start
119 | print('Reasoning Time: ', reasoning_time)
120 | return 0, 0, 0, reasoning_time
121 |
122 |
123 | def to_one_label(ys, labels, th=0.7):
124 | ys_new = []
125 | for i in range(len(ys)):
126 | y = ys[i]
127 | label = labels[i]
128 | # check in case answers are computed
129 | num_class = 0
130 | for p_j in y:
131 | if p_j > th:
132 | num_class += 1
133 | if num_class >= 2:
134 | # drop the value using label (the label is one-hot)
135 | drop_index = torch.argmin(label - y)
136 | y[drop_index] = y.min()
137 | ys_new.append(y)
138 | return torch.stack(ys_new)
139 |
140 |
141 | def main(n):
142 | seed_everything(n)
143 | args = get_args()
144 | assert args.batch_size == 1, "Set batch_size=1."
145 | #name = 'VILP'
146 | print('args ', args)
147 | if args.no_cuda:
148 | device = torch.device('cpu')
149 | elif len(args.device.split(',')) > 1:
150 | # multi gpu
151 | device = torch.device('cuda')
152 | else:
153 | device = torch.device('cuda:' + args.device)
154 |
155 | print('device: ', device)
156 | name = 'neumann/behind-the-scenes/' + str(n)
157 | writer = SummaryWriter(f"runs/{name}", purge_step=0)
158 |
159 | # Create RTPT object
160 | rtpt = RTPT(name_initials='HS', experiment_name=name,
161 | max_iterations=args.epochs)
162 | # Start the RTPT tracking
163 | rtpt.start()
164 |
165 |
166 | ## train_pos_loader, val_pos_loader, test_pos_loader = get_vilp_pos_loader(args)
167 | #####train_pos_loader, val_pos_loader, test_pos_loader = get_data_loader(args)
168 |
169 | # load logical representations
170 | lark_path = 'src/lark/exp.lark'
171 | lang_base_path = 'data/lang/'
172 | lang, clauses, bk, bk_clauses, terms, atoms = get_lang(
173 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth, use_learned_clauses=True)
174 |
175 | print("{} Atoms:".format(len(atoms)))
176 |
177 | # get torch data loader
178 | #question_json_path = 'data/behind-the-scenes/BehindTheScenes_questions_{}.json'.format(args.dataset)
179 | # test_loader = get_behind_the_scenes_loader(question_json_path, args.batch_size, lang, args.n_data, device)
180 | train_loader, val_loader, test_loader = get_data_loader(args, device)
181 |
182 | NEUMANN, I2F = get_model(lang=lang, clauses=clauses, atoms=atoms, terms=terms, bk=bk, bk_clauses=bk_clauses,
183 | program_size=args.program_size, device=device, dataset=args.dataset, dataset_type=args.dataset_type,
184 | num_objects=args.num_objects, infer_step=args.infer_step, train=False, explain=True)#train=not(args.no_train))
185 |
186 | writer.add_scalar("graph/num_atom_nodes", len(NEUMANN.rgm.atom_node_idxs))
187 | writer.add_scalar("graph/num_conj_nodes", len(NEUMANN.rgm.conj_node_idxs))
188 | num_nodes = len(NEUMANN.rgm.atom_node_idxs) + len(NEUMANN.rgm.conj_node_idxs)
189 | writer.add_scalar("graph/num_nodes", num_nodes)
190 |
191 | num_edges = NEUMANN.rgm.edge_index.size(1)
192 | writer.add_scalar("graph/num_edges", num_edges)
193 |
194 | writer.add_scalar("graph/memory_total", num_nodes + num_edges)
195 |
196 | print("=====================")
197 | print("NUM NODES: ", num_nodes)
198 | print("NUM EDGES: ", num_edges)
199 | print("MEMORY TOTAL: ", num_nodes + num_edges)
200 | print("=====================")
201 |
202 | params = list(NEUMANN.parameters())
203 | print('parameters: ', list(params))
204 |
205 | print("Predicting on train data set...")
206 | times = []
207 | # train split
208 | for j in range(n):
209 | acc_test, rec_test, th_test, time = predict(
210 | NEUMANN, I2F, test_loader, args, device, th=0.5, split='test')
211 | times.append(time)
212 |
213 | with open('out/inference_time/time_{}_ratio_{}.txt'.format(args.dataset, args.n_ratio), 'w') as f:
214 | f.write("\n".join(str(item) for item in times))
215 |
216 | print("train acc: ", acc_test, "threashold: ", th_test, "recall: ", rec_test)
217 |
218 |
219 | def seed_everything(seed: int):
220 | import os
221 | import random
222 |
223 | import numpy as np
224 | import torch
225 |
226 | random.seed(seed)
227 | os.environ['PYTHONHASHSEED'] = str(seed)
228 | np.random.seed(seed)
229 | torch.manual_seed(seed)
230 | torch.cuda.manual_seed(seed)
231 | torch.backends.cudnn.deterministic = True
232 | torch.backends.cudnn.benchmark = True
233 |
234 | if __name__ == "__main__":
235 | main(n=1)
236 |
237 |
--------------------------------------------------------------------------------
/neumann/neumann/explanation_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import torch
7 |
8 | attrs = ['color', 'shape', 'material', 'size']
9 |
10 | def get_target_maps(atoms, atom_grads, attention_maps):
11 | if atom_grads.size(0) == 1:
12 | atom_grads = atom_grads.squeeze(0)
13 |
14 | obj_ids = []
15 | scores = []
16 | for i, g in enumerate(atom_grads):
17 | atom = atoms[i]
18 | if g > 0.8 and 'obj' in str(atom):
19 | obj_id = str(atom).split(',')[0][-1]
20 | obj_id = int(obj_id)
21 | if not obj_id in obj_ids:
22 | obj_ids.append(obj_id)
23 | scores.append(g)
24 | return [attention_maps[id] * scores[i] for i, id in enumerate(obj_ids)]
25 |
26 |
27 | def valuation_to_attr_string(v, atoms, e, th=0.5):
28 | """Generate string explanations of the scene.
29 | """
30 |
31 | st = ''
32 | for i in range(e):
33 | st_i = ''
34 | for j, atom in enumerate(atoms):
35 | #print(atom, [str(term) for term in atom.terms])
36 | if 'obj' + str(i) in [str(term) for term in atom.terms] and atom.pred.name in attrs:
37 | if v[j] > th:
38 | prob = np.round(v[j].detach().cpu().numpy(), 2)
39 | st_i += str(prob) + ':' + str(atom) + ','
40 | if st_i != '':
41 | st_i = st_i[:-1]
42 | st += st_i + '\n'
43 | return st
44 |
45 |
46 | def valuation_to_rel_string(v, atoms, th=0.5):
47 | l = 15
48 | st = ''
49 | n = 0
50 | for j, atom in enumerate(atoms):
51 | if v[j] > th and not (atom.pred.name in attrs+['in', '.']):
52 | prob = np.round(v[j].detach().cpu().numpy(), 2)
53 | st += str(prob) + ':' + str(atom) + ','
54 | n += len(str(prob) + ':' + str(atom) + ',')
55 | if n > l:
56 | st += '\n'
57 | n = 0
58 | return st[:-1] + '\n'
59 |
60 |
61 | def valuation_to_string(v, atoms, e, th=0.5):
62 | return valuation_to_attr_string(v, atoms, e, th) + valuation_to_rel_string(v, atoms, th)
63 |
64 |
65 | def valuations_to_string(V, atoms, e, th=0.5):
66 | """Generate string explanation of the scenes.
67 | """
68 | st = ''
69 | for i in range(V.size(0)):
70 | st += 'image ' + str(i) + '\n'
71 | # for each data in the batch
72 | st += valuation_to_string(V[i], atoms, e, th)
73 | return st
74 |
75 |
76 | def generate_captions(V, atoms, e, th):
77 | captions = []
78 | for v in V:
79 | # for each data in the batch
80 | captions.append(valuation_to_string(v, atoms, e, th))
81 | return captions
82 |
83 |
84 | def save_images_with_captions_and_attention_maps(imgs, attention_maps, captions, folder, img_id, dataset):
85 | if not os.path.exists(folder):
86 | os.makedirs(folder)
87 |
88 | figsize = (12, 6)
89 | # imgs should be denormalized.
90 |
91 | img_size = imgs[0].shape[0]
92 | attention_maps = np.array([m.cpu().detach().numpy() for m in attention_maps])
93 | attention_map = np.zeros_like(attention_maps[0])
94 | for am in attention_maps:
95 | attention_map += am
96 | attention_map = attention_map.reshape(32, 32)
97 | attention_map = cv2.resize(attention_map, (img_size, img_size))
98 | #attention_map = torch.tensor(attention_map).extend()
99 |
100 | # apply attention maps to filter
101 | for i, img in enumerate(imgs):
102 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(4, 2)) #, sharex=True, sharey=True)
103 | ax1.axis('off')
104 | ax2.axis('off')
105 |
106 | #e=figsize, dpi=80)
107 | ax1.imshow(img)
108 | #ax1.xlabel(captions[i])
109 | #ax1.tight_layout()
110 | #ax1.savefig(folder+str(img_id)+'_original.png')
111 |
112 | # ax2.figure(figsize=figsize, dpi=80)
113 | ax2.imshow(attention_map, cmap='cividis')
114 | #ax2.set_xlabel(captions[i], fontsize=12)
115 | #plt.axis('off')
116 | fig.tight_layout()
117 | fig.savefig(folder + dataset + '_img' + str(img_id) + '_explanation.svg')
118 |
119 | plt.close()
120 |
121 | def save_images_with_captions_and_attention_maps_indivisual(imgs, attention_maps, captions, folder, img_id, dataset):
122 | if not os.path.exists(folder):
123 | os.makedirs(folder)
124 |
125 | figsize = (12, 6)
126 | # imgs should be denormalized.
127 |
128 | img_size = imgs[0].shape[0]
129 | attention_maps = np.array([m.cpu().detach().numpy() for m in attention_maps])
130 | attention_map = np.zeros_like(attention_maps[0])
131 | for am in attention_maps:
132 | attention_map += am
133 | attention_map = attention_map.reshape(32,32)
134 | attention_map = cv2.resize(attention_map, (img_size, img_size))
135 | #attention_map = torch.tensor(attention_map).extend()
136 |
137 | # apply attention maps to filter
138 | for i, img in enumerate(imgs):
139 | plt.figure(figsize=figsize, dpi=80)
140 | plt.imshow(img)
141 | plt.xlabel(captions[i])
142 | plt.tight_layout()
143 | plt.savefig(folder+str(img_id)+'_original.png')
144 |
145 | plt.figure(figsize=figsize, dpi=80)
146 | plt.imshow(attention_map, cmap='cividis')
147 | plt.xlabel(captions[i])
148 | plt.tight_layout()
149 | plt.savefig(folder+str(img_id)+'_attention_map.png')
150 |
151 | plt.close()
152 |
153 | def denormalize_clevr(imgs):
154 | """denormalize clevr images
155 | """
156 | # normalizing: image = (image - 0.5) * 2.0 # Rescale to [-1, 1].
157 | return (0.5 * imgs) + 0.5
158 |
159 |
160 | def denormalize_kandinsky(imgs):
161 | """denormalize kandinsky images
162 | """
163 | return imgs
164 |
165 |
166 | def to_plot_images_clevr(imgs):
167 | return [img.permute(1, 2, 0).detach().numpy() for img in denormalize_clevr(imgs)]
168 |
169 |
170 | def to_plot_images_kandinsky(imgs):
171 | return [img.permute(1, 2, 0).detach().numpy() for img in denormalize_kandinsky(imgs)]
172 |
--------------------------------------------------------------------------------
/neumann/neumann/facts_converter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from tqdm import tqdm
4 |
5 | from .fol.logic import NeuralPredicate
6 |
7 |
8 | class FactsConverter(nn.Module):
9 | """FactsConverter converts the output fromt the perception module to the valuation vector.
10 | """
11 |
12 | def __init__(self, lang, atoms, bk, device=None):
13 | super(FactsConverter, self).__init__()
14 | # self.e = perception_module.e
15 | # self.d = perception_module.d
16 | self.lang = lang
17 | # self.vm = valuation_module # valuation functions
18 | self.device = device
19 | self.atoms = atoms
20 | self.bk = bk
21 | # init indices
22 | self.np_indices = self._get_np_atom_indices()
23 | self.bk_indices = self._get_bk_atom_indices()
24 |
25 | def __str__(self):
26 | return "FactsConverter(entities={}, dimension={})".format(self.e, self.d)
27 |
28 | def __repr__(self):
29 | return "FactsConverter(entities={}, dimension={})".format(self.e, self.d)
30 |
31 | def _get_np_atom_indices(self):
32 | """Pre compute the indices of atoms with neural predicats."""
33 | indices = []
34 | for i, atom in enumerate(self.atoms):
35 | if type(atom.pred) == NeuralPredicate:
36 | indices.append(i)
37 | return indices
38 |
39 | def _get_bk_atom_indices(self):
40 | """Pre compute the indices of atoms in background knowledge."""
41 | indices = []
42 | for i, atom in enumerate(self.atoms):
43 | if atom in self.bk:
44 | indices.append(i)
45 | return indices
46 |
47 | def forward(self, Z):
48 | return self.convert(Z)
49 |
50 | def get_params(self):
51 | return self.vm.get_params()
52 |
53 | def init_valuation(self, n, batch_size):
54 | v = torch.zeros((batch_size, n)).to(self.device)
55 | v[:, 1] = 1.0
56 | return v
57 |
58 | def filter_by_datatype():
59 | pass
60 |
61 | def to_vec(self, term, zs):
62 | pass
63 |
64 |
65 | def convert(self, bk_atoms=None):
66 | V = torch.zeros(1, len(self.atoms)).to(self.device)
67 |
68 | # add background knowledge
69 | for i, atom in enumerate(self.atoms):
70 | if atom in bk_atoms:
71 | V[0, i] += 1.0
72 |
73 | V[0, 0] += 1.0
74 | return V
75 |
76 |
77 | class FactsConverterWithQuery(nn.Module):
78 | """FactsConverter converts the output fromt the perception module to the valuation vector.
79 | """
80 |
81 | # def __init__(self, lang, perception_module, valuation_module, device=None):
82 | def __init__(self, lang, atoms, bk, perception_module, valuation_module, device=None):
83 | super(FactsConverterWithQuery, self).__init__()
84 | self.e = perception_module.e
85 | self.d = perception_module.d
86 | self.lang = lang
87 | self.vm = valuation_module # valuation functions
88 | self.device = device
89 | self.atoms = atoms
90 | self.bk = bk
91 | # init indices
92 | self.np_indices = self._get_np_atom_indices()
93 | self.bk_indices = self._get_bk_atom_indices()
94 |
95 | def __str__(self):
96 | return "FactsConverter(entities={}, dimension={})".format(self.e, self.d)
97 |
98 | def __repr__(self):
99 | return "FactsConverter(entities={}, dimension={})".format(self.e, self.d)
100 |
101 | def _get_np_atom_indices(self):
102 | """Pre compute the indices of atoms with neural predicats."""
103 | indices = []
104 | for i, atom in enumerate(self.atoms):
105 | if type(atom.pred) == NeuralPredicate:
106 | indices.append(i)
107 | return indices
108 |
109 | def _get_bk_atom_indices(self):
110 | """Pre compute the indices of atoms in background knowledge."""
111 | indices = []
112 | for i, atom in enumerate(self.atoms):
113 | if atom in self.bk:
114 | indices.append(i)
115 | return indices
116 |
117 | def forward(self, Z, Q):
118 | return self.convert(Z, Q)
119 |
120 | def get_params(self):
121 | return self.vm.get_params()
122 |
123 | def init_valuation(self, n, batch_size):
124 | v = torch.zeros((batch_size, n)).to(self.device)
125 | v[:, 1] = 1.0
126 | return v
127 |
128 | def filter_by_datatype():
129 | pass
130 |
131 | def to_vec(self, term, zs):
132 | pass
133 |
134 | def convert(self, Z, Q):
135 | batch_size = Z.size(0)
136 |
137 | # V = self.init_valuation(len(G), Z.size(0))
138 | #V = torch.zeros((batch_size, len(G))).to(
139 | # torch.float32).to(self.device)
140 | V = torch.zeros((batch_size, len(self.atoms))).to(
141 | torch.float32).to(self.device)
142 |
143 | # T to be 1.0
144 | V[:, 0] = 1.0
145 |
146 | for i in self.np_indices:
147 | V[:, i] = self.vm(Z, Q, self.atoms[i])
148 |
149 | for i in self.bk_indices:
150 | V[:, i] += torch.ones((batch_size, )).to(
151 | torch.float32).to(self.device)
152 | """
153 | for i, atom in enumerate(G):
154 | if type(atom.pred) == NeuralPredicate:
155 | V[:, i] = self.vm(Z, Q, atom)
156 | elif atom in B:
157 | # V[:, i] += 1.0
158 | V[:, i] += torch.ones((batch_size, )).to(
159 | torch.float32).to(self.device)
160 | """
161 | return V
--------------------------------------------------------------------------------
/neumann/neumann/fol/README.md:
--------------------------------------------------------------------------------
1 | # First-Order Logic
2 | The implementation of the first-order logic based on the [Differentiable Inductive Logic Programming](https://arxiv.org/abs/2103.01719) project.
--------------------------------------------------------------------------------
/neumann/neumann/fol/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/neumann/neumann/fol/__init__.py
--------------------------------------------------------------------------------
/neumann/neumann/fol/data_utils.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | from lark import Lark
4 |
5 | from .exp_parser import ExpTree
6 | from .language import DataType, Language
7 | from .logic import Const, FuncSymbol, NeuralPredicate, Predicate
8 |
9 |
10 | class DataUtils(object):
11 | """Utilities about I/O of logic.
12 | """
13 |
14 | def __init__(self, lark_path, lang_base_path, dataset_type='kandinsky', dataset=None):
15 | #if dataset == 'behind-the-scenes':
16 | # for behind the scenes
17 | # self.base_path = lang_base_path + dataset_type + '/'
18 | #if:
19 | self.base_path = lang_base_path + dataset_type + '/' + dataset + '/'
20 | with open(lark_path, encoding="utf-8") as grammar:
21 | self.lp_atom = Lark(grammar.read(), start="atom")
22 | with open(lark_path, encoding="utf-8") as grammar:
23 | self.lp_clause = Lark(grammar.read(), start="clause")
24 |
25 | def load_clauses(self, path, lang):
26 | """Read lines and parse to Atom objects.
27 | """
28 | clauses = []
29 | if os.path.isfile(path):
30 | with open(path) as f:
31 | for line in f:
32 | if line[-1] == '\n':
33 | line = line[:-1]
34 | if len(line) == 0 or line[0] == '#':
35 | continue
36 | tree = self.lp_clause.parse(line)
37 | clause = ExpTree(lang).transform(tree)
38 | clauses.append(clause)
39 | return clauses
40 |
41 | def load_atoms(self, path, lang):
42 | """Read lines and parse to Atom objects.
43 | """
44 | atoms = []
45 |
46 | if os.path.isfile(path):
47 | with open(path) as f:
48 | for line in f:
49 | if line[-1] == '\n':
50 | line = line[:-2]
51 | else:
52 | line = line[:-1]
53 | tree = self.lp_atom.parse(line)
54 | atom = ExpTree(lang).transform(tree)
55 | atoms.append(atom)
56 | return atoms
57 |
58 | def load_preds(self, path):
59 | f = open(path)
60 | lines = f.readlines()
61 | preds = [self.parse_pred(line) for line in lines]
62 | return preds
63 |
64 | def load_neural_preds(self, path):
65 | f = open(path)
66 | lines = f.readlines()
67 | preds = [self.parse_neural_pred(line) for line in lines]
68 | return preds
69 |
70 | def load_consts(self, path):
71 | f = open(path)
72 | lines = f.readlines()
73 | consts = []
74 | for line in lines:
75 | consts.extend(self.parse_const(line))
76 | return consts
77 |
78 | def load_funcs(self, path):
79 | funcs = []
80 | if os.path.isfile(path):
81 | with open(path) as f:
82 | lines = f.readlines()
83 | for line in lines:
84 | funcs.append(self.parse_func(line))
85 | return funcs
86 |
87 | def load_terms(self, path, lang):
88 | f = open(path)
89 | lines = f.readlines()
90 | terms = []
91 | for line in lines:
92 | terms.extend(self.parse_term(line, lang))
93 | return
94 | def parse_pred(self, line):
95 | """Parse string to predicates.
96 | """
97 | line = line.replace('\n', '')
98 | pred, arity, dtype_names_str = line.split(':')
99 | dtype_names = dtype_names_str.split(',')
100 | dtypes = [DataType(dt) for dt in dtype_names]
101 | assert int(arity) == len(
102 | dtypes), 'Invalid arity and dtypes in ' + pred + '.'
103 | return Predicate(pred, int(arity), dtypes)
104 |
105 | def parse_neural_pred(self, line):
106 | """Parse string to predicates.
107 | """
108 | line = line.replace('\n', '')
109 | pred, arity, dtype_names_str = line.split(':')
110 | dtype_names = dtype_names_str.split(',')
111 | dtypes = [DataType(dt) for dt in dtype_names]
112 | assert int(arity) == len(
113 | dtypes), 'Invalid arity and dtypes in ' + pred + '.'
114 | return NeuralPredicate(pred, int(arity), dtypes)
115 |
116 | def parse_func(self, line):
117 | """Parse string to function symbols.
118 | (Format) name:arity:input_type:output_type
119 | """
120 | name, arity, in_dtypes, out_dtype = line.replace("\n", "").split(':')
121 | in_dtypes = in_dtypes.split(',')
122 | in_dtypes = [DataType(in_dtype) for in_dtype in in_dtypes]
123 | out_dtype = DataType(out_dtype)
124 | return FuncSymbol(name, int(arity), in_dtypes, out_dtype)
125 |
126 |
127 |
128 | def parse_const(self, line):
129 | """Parse string to constants.
130 | """
131 | line = line.replace('\n', '')
132 | dtype_name, const_names_str = line.split(':')
133 | dtype = DataType(dtype_name)
134 | const_names = const_names_str.split(',')
135 | return [Const(const_name, dtype) for const_name in const_names]
136 |
137 | def parse_term(self, line, lang):
138 | """Parse string to func_terms.
139 | """
140 | line = line.replace('\n', '')
141 | dtype_name, term_names_str = line.split(':')
142 | dtype = DataType(dtype_name)
143 | term_strs = term_names_str.split(',')
144 | terms = []
145 | for term_str in term_strs:
146 | if not term_str == '':
147 | print("term_str: ", term_str)
148 | tree = self.lp_term.parse(term_str)
149 | terms.append(ExpTree(lang).transform(tree))
150 | return terms
151 | #return [Const(const_name, dtype) for const_name in const_names]
152 |
153 | def parse_clause(self, clause_str, lang):
154 | tree = self.lp_clause.parse(clause_str)
155 | return ExpTree(lang).transform(tree)
156 |
157 | def get_clauses(self, lang):
158 | return self.load_clauses(self.base_path + 'clauses.txt', lang)
159 |
160 | def get_bk(self, lang):
161 | return self.load_atoms(self.base_path + 'bk.txt', lang)
162 |
163 | def get_facts(self, lang):
164 | return self.load_atoms(self.base_path + 'facts.txt', lang)
165 |
166 | def load_language(self):
167 | """Load language, background knowledge, and clauses from files.
168 | """
169 | preds = self.load_preds(self.base_path + 'preds.txt') + \
170 | self.load_neural_preds(self.base_path + 'neural_preds.txt')
171 | consts = self.load_consts(self.base_path + 'consts.txt')
172 | funcs = self.load_funcs(self.base_path + 'funcs.txt')
173 | #terms = self.load_terms(self.base_path + 'terms.txt')
174 | lang = Language(preds, funcs, consts)
175 | return lang
176 |
--------------------------------------------------------------------------------
/neumann/neumann/fol/exp.lark:
--------------------------------------------------------------------------------
1 | clause : atom ":-" body
2 |
3 | body : atom "," body
4 | | atom "."
5 | | "."
6 |
7 | atom : predicate "(" args ")"
8 |
9 | args : term "," args
10 | | term
11 |
12 | term : functor "(" args ")"
13 | | const
14 | | variable
15 |
16 | const : /[a-z0-9\*]+/
17 |
18 | variable : /[A-Z0-9]+/
19 |
20 | functor : /[a-z0-9]+/
21 |
22 | predicate : /[a-z0-9\_]+/
23 |
24 | var_name : /[A-Z]/
25 | small_chars : /[a-z0-9]+/
26 | chars : /[^\+\|\s\(\)']+/[/\n+/]
27 | allchars : /[^']+/[/\n+/]
28 |
--------------------------------------------------------------------------------
/neumann/neumann/fol/exp_parser.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from lark import Lark
3 | from lark import Transformer
4 | from .logic import *
5 |
6 |
7 | def flatten(x): return [z for y in x for z in (
8 | flatten(y) if hasattr(y, '__iter__') and not isinstance(y, str) else (y,))]
9 |
10 |
11 | class ExpTree(Transformer):
12 | '''Functions to parse strings into logical objects using Lark
13 |
14 | Attrs:
15 | lang (language): the language of first-order logic.
16 | '''
17 |
18 | def __init__(self, lang):
19 | self.lang = lang
20 |
21 | def clause(self, trees):
22 | head = trees[0]
23 | body = flatten([trees[1]])
24 | return Clause(head, body)
25 |
26 | def body(self, trees):
27 | if len(trees) == 0:
28 | return []
29 | elif len(trees) == 1:
30 | return trees[0]
31 | else:
32 | return [trees[0]] + trees[1:]
33 |
34 | def atom(self, trees):
35 | pred = trees[0]
36 | args = flatten([trees[1]])
37 | return Atom(pred, args)
38 |
39 | def args(self, content):
40 | if len(content) == 1:
41 | return content[0]
42 | else:
43 | return [content[0]] + content[1:]
44 |
45 | def const(self, name):
46 | dtype = self.lang.get_by_dtype_name(name[0])
47 | return Const(name[0], dtype)
48 |
49 | def variable(self, name):
50 | return Var(name[0])
51 |
52 | def functor(self, name):
53 | func = [f for f in self.lang.funcs if f.name == name[0]][0]
54 | return func
55 |
56 | def predicate(self, alphas):
57 | pred = [p for p in self.lang.preds if p.name == alphas[0]][0]
58 | return pred
59 |
60 | def term(self, content):
61 | if type(content[0]) == FuncSymbol:
62 | func = content[0]
63 | args = flatten([content[1]])
64 | return FuncTerm(func, args)
65 | else:
66 | return content[0]
67 |
68 | def small_chars(self, content):
69 | return content[0]
70 |
--------------------------------------------------------------------------------
/neumann/neumann/fol/language.py:
--------------------------------------------------------------------------------
1 | from .logic import Var
2 | import itertools
3 |
4 |
5 | class Language(object):
6 | """Language of first-order logic.
7 |
8 | A class of languages in first-order logic.
9 |
10 | Args:
11 | preds (List[Predicate]): A set of predicate symbols.
12 | funcs (List[FunctionSymbol]): A set of function symbols.
13 | consts (List[Const]): A set of constants.
14 |
15 | Attrs:
16 | preds (List[Predicate]): A set of predicate symbols.
17 | funcs (List[FunctionSymbol]): A set of function symbols.
18 | consts (List[Const]): A set of constants.
19 | """
20 |
21 | def __init__(self, preds, funcs, consts):
22 | self.preds = preds
23 | self.funcs = funcs
24 | self.consts = consts
25 | self.var_gen = VariableGenerator()
26 |
27 |
28 | def __str__(self):
29 | s = "===Predicates===\n"
30 | for pred in self.preds:
31 | s += pred.__str__() + '\n'
32 | s += "===Function Symbols===\n"
33 | for func in self.funcs:
34 | s += func.__str__() + '\n'
35 | s += "===Constants===\n"
36 | for const in self.consts:
37 | s += const.__str__() + '\n'
38 | return s
39 |
40 | def __repr__(self):
41 | return self.__str__()
42 |
43 | def get_var_and_dtype(self, atom):
44 | """Get all variables in an input atom with its dtypes by enumerating variables in the input atom.
45 |
46 | Note:
47 | with the assumption with function free atoms.
48 |
49 | Args:
50 | atom (Atom): The atom.
51 |
52 | Returns:
53 | List of tuples (var, dtype)
54 | """
55 | var_dtype_list = []
56 | for i, arg in enumerate(atom.terms):
57 | if arg.is_var():
58 | dtype = atom.pred.dtypes[i]
59 | var_dtype_list.append((arg, dtype))
60 | return var_dtype_list
61 |
62 | def get_by_dtype(self, dtype):
63 | """Get constants that match given dtypes.
64 |
65 | Args:
66 | dtype (DataType): The data type.
67 |
68 | Returns:
69 | List of constants whose data type is the given data type.
70 | """
71 | return [c for c in self.consts if c.dtype == dtype]
72 |
73 | def get_by_dtype_name(self, dtype_name):
74 | """Get constants that match given dtype name.
75 |
76 | Args:
77 | dtype_name (str): The name of the data type to be used.
78 |
79 | Returns:
80 | List of constants whose datatype has the given name.
81 | """
82 | return [c for c in self.consts if c.dtype.name == dtype_name]
83 |
84 | def term_index(self, term):
85 | """Get the index of a term in the language.
86 |
87 | Args:
88 | term (Term): The term to be used.
89 |
90 | Returns:
91 | int: The index of the term.
92 | """
93 | terms = self.get_by_dtype(term.dtype)
94 | return terms.index(term)
95 |
96 | def get_const_by_name(self, const_name):
97 | """Get the constant by its name.
98 |
99 | Args:
100 | const_name (str): The name of the constant.
101 |
102 | Returns:
103 | Const: The matched constant with the given name.
104 |
105 | """
106 | const = [c for c in self.consts if const_name == c.name]
107 | assert len(const) == 1, 'Too many match in ' + const_name
108 | return const[0]
109 |
110 | def get_pred_by_name(self, pred_name):
111 | """Get the predicate by its name.
112 |
113 | Args:
114 | pred_name (str): The name of the predicate.
115 |
116 | Returns:
117 | Predicate: The matched preicate with the given name.
118 | """
119 | pred = [pred for pred in self.preds if pred.name == pred_name]
120 | assert len(pred) == 1, 'Too many or less match in ' + pred_name
121 | return pred[0]
122 |
123 |
124 | class DataType(object):
125 | """Data type in first-order logic.
126 |
127 | A class of data types in first-order logic.
128 |
129 | Args:
130 | name (str): The name of the data type.
131 |
132 | Attrs:
133 | name (str): The name of the data type.
134 | """
135 |
136 | def __init__(self, name):
137 | self.name = name
138 |
139 | def __eq__(self, other):
140 | if type(other) == str:
141 | return self.name == other
142 | else:
143 | return self.name == other.name
144 |
145 | def __str__(self):
146 | return self.name
147 |
148 | def __repr__(self):
149 | return self.__str__()
150 |
151 | def __hash__(self):
152 | return hash(self.__str__())
153 |
154 |
155 |
156 |
157 | class VariableGenerator():
158 | """
159 | generator of variables
160 | Parameters
161 | __________
162 | base_name : str
163 | base name of variables
164 | """
165 |
166 | def __init__(self, base_name='x'):
167 | self.counter = 0
168 | self.base_name = base_name
169 |
170 | def generate(self):
171 | """
172 | generate variable with new name
173 | Returns
174 | -------
175 | generated_var : .logic.Var
176 | generated variable
177 | """
178 | generated_var = Var(self.base_name + str(self.counter))
179 | self.counter += 1
180 | return generated_var
181 |
--------------------------------------------------------------------------------
/neumann/neumann/fol/logic_ops.py:
--------------------------------------------------------------------------------
1 | from .logic import Clause, Atom, FuncTerm, Const, Var
2 |
3 |
4 | def subs(exp, target_var, const):
5 | """
6 | Substitute var = const
7 |
8 | Inputs
9 | ------
10 | exp : .logic.CLause .logic.Atom .logic.FuncTerm .logic.Const .logic.Var
11 | logical expression
12 | atom, clause, or term
13 | target_var : .logic.Var
14 | target variable of the substitution
15 | const : .logic.Const
16 | constant to be substituted
17 |
18 | Returns
19 | -------
20 | exp : .logic.CLause .logic.Atom .logic.FuncTerm .logic.Const .logic.Var
21 | result of the substitution
22 | logical expression
23 | atom, clause, or term
24 | """
25 | if type(exp) == Clause:
26 | head = subs(exp.head, target_var, const)
27 | body = [subs(bi, target_var, const) for bi in exp.body]
28 | return Clause(head, body)
29 | elif type(exp) == Atom:
30 | terms = [subs(term, target_var, const) for term in exp.terms]
31 | return Atom(exp.pred, terms)
32 | elif type(exp) == FuncTerm:
33 | args = [subs(arg, target_var, const) for arg in exp.args]
34 | return FuncTerm(exp.func_symbol, args)
35 | elif type(exp) == Var:
36 | if exp.name == target_var.name:
37 | return const
38 | else:
39 | return exp
40 | elif type(exp) == Const:
41 | return exp
42 | else:
43 | assert 1 == 0, 'Unknown type in substitution: ' + str(exp)
44 |
45 |
46 | def subs_list(exp, theta_list):
47 | if type(exp) == Clause:
48 | head = exp.head
49 | body = exp.body
50 | for target_var, const in theta_list:
51 | head = subs(head, target_var, const)
52 | body = [subs(bi, target_var, const) for bi in body]
53 | return Clause(head, body)
54 | elif type(exp) == Atom:
55 | terms = exp.terms
56 | for target_var, const in theta_list:
57 | terms = [subs(term, target_var, const) for term in terms]
58 | return Atom(exp.pred, terms)
59 | #elif type(exp) == FuncTerm:
60 | # for target_var, const in theta_list:
61 | # args = [subs(arg, target_var, const) for arg in exp.args]
62 | # return FuncTerm(exp.func_symbol, args)
63 | #elif type(exp) == Var:
64 | # if exp.name == target_var.name:
65 | # return const
66 | # else:
67 | # return exp
68 | #elif type(exp) == Const:
69 | # return exp
70 | else:
71 | assert 1 == 0, 'Unknown type in substitution: ' + str(exp)
72 |
73 |
74 |
75 | def __subs_list(clause, theta_list):
76 | """
77 | perform list of substitutions
78 |
79 | Inputs
80 | ------
81 | clause : .logic.Clause
82 | target clause
83 | theta_list : List[(.logic.Var, .logic.Const)]
84 | list of substitute operations to be performed
85 | """
86 | result = clause
87 | for theta in theta_list:
88 | result = subs(result, theta[0], theta[1])
89 | return result
90 |
91 |
92 | def unify(atoms):
93 | """
94 | Unification of first-order logic expressions
95 | details in [Foundations of Inductive Logic Programming. Nienhuys-Cheng, S.-H. et.al. 1997.]
96 |
97 | Inputs
98 | ------
99 | atoms : List[.logic.Atom]
100 | Returns
101 | -------
102 | flag : bool
103 | unifiable or not
104 | unifier : List[(.logic.Var, .logic.Const)]
105 | unifiable - unifier (list of substitutions)
106 | not unifiable - empty list
107 | """
108 | # empty set
109 | if len(atoms) == 0:
110 | return (1, [])
111 | # check predicates
112 | for i in range(len(atoms)-1):
113 | if atoms[i].pred != atoms[i+1].pred:
114 | return (0, [])
115 |
116 | # check all the same
117 | all_same_flag = True
118 | for i in range(len(atoms)-1):
119 | all_same_flag = all_same_flag and (atoms[i] == atoms[i+1])
120 | if all_same_flag:
121 | return (1, [])
122 |
123 | k = 0
124 | theta_list = []
125 |
126 | atoms_ = atoms
127 | while(True):
128 | # check terms from left
129 | for i in range(atoms_[0].pred.arity):
130 | # atom_1(term_1, ..., term_i, ...), ..., atom_j(term_1, ..., term_i, ...), ...
131 | terms_i = [atoms_[j].terms[i] for j in range(len(atoms_))]
132 | disagree_flag, disagree_set = get_disagreements(terms_i)
133 | if not disagree_flag:
134 | continue
135 | var_list = [x for x in disagree_set if type(x) == Var]
136 | if len(var_list) == 0:
137 | return (0, [])
138 | else:
139 | # substitute
140 | subs_var = var_list[0]
141 | # find term where the var does not occur
142 | subs_flag, subs_term = find_subs_term(
143 | subs_var, disagree_set)
144 | if subs_flag:
145 | k += 1
146 | theta_list.append((subs_var, subs_term))
147 | subs_flag = True
148 | # UNIFICATION SUCCESS
149 | atoms_ = [subs(atom, subs_var, subs_term)
150 | for atom in atoms_]
151 | if is_singleton(atoms_):
152 | return (1, theta_list)
153 | else:
154 | # UNIFICATION FAILED
155 | return (0, [])
156 |
157 |
158 | def get_disagreements(terms):
159 | """
160 | get desagreements in the unification algorithm
161 | details in [Foundations of Inductive Logic Programming. Nienhuys-Cheng, S.-H. et.al. 1997.]
162 |
163 | Inputs
164 | ------
165 | temrs : List[Term]
166 | Term : .logic.FuncTerm .logic.Const .logic.Var
167 | list of terms
168 |
169 | Returns
170 | -------
171 | disagree_flag : bool
172 | flag of disagreement
173 | disagree_terms : List[Term]
174 | Term : .logic.FuncTerm .logic.Const .logic.Var
175 | terms of disagreement
176 | """
177 | disagree_flag, disagree_index = get_disagree_index(terms)
178 | if disagree_flag:
179 | disagree_terms = [term.get_ith_term(
180 | disagree_index) for term in terms]
181 | return disagree_flag, disagree_terms
182 | else:
183 | return disagree_flag, []
184 |
185 |
186 | def get_disagree_index(terms):
187 | """
188 | get the desagreement index in the unification algorithm
189 | details in [Foundations of Inductive Logic Programming. Nienhuys-Cheng, S.-H. et.al. 1997.]
190 |
191 | Inputs
192 | ------
193 | terms : List[Term]
194 | Term : .logic.FuncTerm .logic.Const .logic.Var
195 | list of terms
196 |
197 | Returns
198 | -------
199 | disagree_flag : bool
200 | flag of disagreement
201 | disagree_index : int
202 | index of the disagreement term in the args of predicates
203 | """
204 | symbols_list = [term.to_list() for term in terms]
205 | n = min([len(symbols) for symbols in symbols_list])
206 | for i in range(n):
207 | ith_symbols = [symbols[i] for symbols in symbols_list]
208 | for j in range(len(ith_symbols)-1):
209 | if ith_symbols[j] != ith_symbols[j+1]:
210 | return (True, i)
211 | # all the same terms
212 | return (False, 0)
213 |
214 |
215 | def occur_check(variable, term):
216 | """
217 | occur check function
218 | details in [Foundations of Inductive Logic Programming. Nienhuys-Cheng, S.-H. et.al. 1997.]
219 |
220 | Inputs
221 | ------
222 | variable : .logic.Var
223 | term : Term
224 | Term : .logic.FuncTerm .logic.Const .logic.Var
225 |
226 | Returns
227 | -------
228 | occur_flag : bool
229 | flag ofthe occurance of the variable
230 | """
231 | if type(term) == Const:
232 | return False
233 | elif type(term) == Var:
234 | return variable.name == term.name
235 | else:
236 | # func term case
237 | for arg in term.args:
238 | if occur_check(variable, arg):
239 | return True
240 | return False
241 |
242 |
243 | def find_subs_term(subs_var, disagree_set):
244 | """
245 | Find term where the var does not occur
246 |
247 | Inputs
248 | ------
249 | subs_var : .logic.Var
250 | disagree_set : List[.logic.Term]
251 |
252 | Returns
253 | -------
254 | flag : bool
255 | term : .logic.Term
256 | """
257 | for term in disagree_set:
258 | if not occur_check(subs_var, term):
259 | return True, term
260 | return False, Term()
261 |
262 |
263 | def is_singleton(atoms):
264 | """
265 | returns whether all the input atoms are the same or not
266 |
267 | Inputs
268 | ------
269 | atoms: List[.logic.Atom]
270 | [a_1, a_2, ..., a_n]
271 |
272 | Returns
273 | -------
274 | flag : bool
275 | a_1 == a_2 == ... == a_n
276 | """
277 | result = True
278 | for i in range(len(atoms)-1):
279 | result = result and (atoms[i] == atoms[i+1])
280 | return result
281 |
282 |
283 | def is_entailed(e, clause, facts, n):
284 | """
285 | decision function of ground atom is entailed by a clause and facts by n-step inference
286 |
287 | Inputs
288 | ------
289 | e : .logic.Atom
290 | ground atom
291 | clause : .logic.Clause
292 | clause
293 | facts : List[.logic.Atom]
294 | set of facts
295 | n : int
296 | infer step
297 |
298 | Returns
299 | -------
300 | flag : bool
301 | ${clause} \cup facts \models e$
302 | """
303 | if len(clause.body) == 0:
304 | flag, thetas = unify([e, clause.head])
305 | return flag
306 | if len(clause.body) == 1:
307 | return e in t_p_n(clause, facts, n)
308 |
309 |
310 | def t_p_n(clause, facts, n):
311 | """
312 | applying the T_p operator n-times taking union of results
313 |
314 | Inputs
315 | ------
316 | clause : .logic.Clause
317 | clause
318 | facts : List[.logic.Atom]
319 | set of facts
320 | n : int
321 | infer step
322 |
323 | Returns
324 | -------
325 | G : Set[.logic.Atom]
326 | set of ground atoms entailed by ${clause} \cup facts$
327 | """
328 | G = set(facts)
329 | for i in range(n):
330 | G = G.union(t_p(clause, G))
331 | return G
332 |
333 |
334 | def t_p(clause, facts):
335 | """
336 | T_p operator
337 | limited to clauses with one body atom
338 |
339 | Inputs
340 | ------
341 | clause : .logic.Clause
342 | clause
343 | facts : List[.logic.Atom]
344 | set of facts
345 |
346 | Returns
347 | -------
348 | S : List[.logic.Atom]
349 | set of ground atoms entailed by one step forward-chaining inference
350 | """
351 | # |body| == 1
352 | S = []
353 | unify_dic = {}
354 | for fact in facts:
355 | flag, thetas = unify([clause.body[0], fact])
356 | if flag:
357 | head_fact = subs_list(clause.head, thetas)
358 | S = S + [head_fact]
359 | return list(set(S))
360 |
--------------------------------------------------------------------------------
/neumann/neumann/img2facts.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Img2Facts(nn.Module):
6 | """Img2Facts module converts raw images into a form of probabilistic facts. Each image is fed to the perception module, and the result is concatenated and fed to facts converter.
7 | """
8 |
9 | def __init__(self, perception_module, facts_converter, atoms, bk, device):
10 | super().__init__()
11 | self.pm = perception_module
12 | self.fc = facts_converter
13 | self.atoms = atoms
14 | self.bk = bk
15 | self.device = device
16 |
17 | def forward(self, x):
18 | # x: batch_size * num_iamges * C * W * H
19 | num_images = x.size(1)
20 | # feed each input image
21 | zs_list = []
22 | # TODO: concat image ids
23 | for i in range(num_images):
24 | # B * E * D
25 | zs = self.pm(x[:, i, :, :])
26 | # image_ids = torch.tensor([i for i in range(x.size(0))]).unsqueeze(0).unsqueeze(0).to(self.device)
27 | image_ids = torch.tensor(i).expand(
28 | (zs.size(0), zs.size(1), 1)).to(self.device)
29 | zs = torch.cat((zs, image_ids), dim=2)
30 | zs_list.append(zs)
31 | zs = torch.cat(zs_list, dim=1)
32 | # zs: batch_size * num_images * num_objects * num_attributes
33 | return self.fc(zs)
34 |
35 | class Img2FactsWithQuery(nn.Module):
36 | """Img2Facts module converts raw images into a form of probabilistic facts. Each image is fed to the perception module, and the result is concatenated and fed to facts converter.
37 | """
38 |
39 | def __init__(self, perception_module, facts_converter, atoms, bk, device):
40 | super().__init__()
41 | self.pm = perception_module
42 | self.fc = facts_converter
43 | self.atoms = atoms
44 | self.bk = bk
45 | self.device = device
46 |
47 | def forward(self, x, query):
48 | # x: batch_size * num_iamges * C * W * H
49 | num_images = x.size(1)
50 | # feed each input image
51 | zs_list = []
52 | # TODO: concat image ids
53 | for i in range(num_images):
54 | # B * E * D
55 | zs = self.pm(x[:, i, :, :])
56 | # image_ids = torch.tensor([i for i in range(x.size(0))]).unsqueeze(0).unsqueeze(0).to(self.device)
57 | image_ids = torch.tensor(i).expand(
58 | (zs.size(0), zs.size(1), 1)).to(self.device)
59 | zs = torch.cat((zs, image_ids), dim=2)
60 | zs_list.append(zs)
61 | zs = torch.cat(zs_list, dim=1)
62 | # zs: batch_size * num_images * num_objects * num_attributes
63 | return self.fc(zs, query)
--------------------------------------------------------------------------------
/neumann/neumann/lark/exp.lark:
--------------------------------------------------------------------------------
1 | clause : atom ":-" body
2 |
3 | body : atom "," body
4 | | atom "."
5 | | "."
6 |
7 | atom : predicate "(" args ")"
8 |
9 | args : term "," args
10 | | term
11 |
12 | term : functor "(" args ")"
13 | | const
14 | | variable
15 |
16 | const : /[a-z0-9\*\_]+/
17 |
18 | variable : /[A-Z]+[A-Za-z0-9]*/
19 |
20 | functor : /[a-z0-9]+/
21 |
22 | predicate : /[a-z0-9\_]+/
23 |
24 | var_name : /[A-Z]/
25 | small_chars : /[a-z0-9]+/
26 | chars : /[^\+\|\s\(\)']+/[/\n+/]
27 | allchars : /[^']+/[/\n+/]
28 |
--------------------------------------------------------------------------------
/neumann/neumann/message_passing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch_geometric.nn
4 | from torch import Tensor
5 | from torch.nn import Linear, ReLU, Sequential
6 | from torch_geometric.nn import MessagePassing
7 | from torch_scatter import gather_csr, scatter, segment_csr
8 |
9 | # from neural_utils import MLP
10 | from .scatter import *
11 |
12 |
13 | class Atom2ConjConv(MessagePassing):
14 | """Message Passing class for messages from atom nodes to conjunction nodes.
15 |
16 | Args:
17 | soft_logic (softlogic): An implementation of the soft-logic operations.
18 | device (device): A device.
19 | """
20 |
21 | def __init__(self, soft_logic, device):
22 | super().__init__(aggr='add')
23 | self.soft_logic = soft_logic
24 | self.device = device
25 |
26 | def forward(self, x, edge_index, conj_node_idxs, batch_size):
27 | """Perform atom2conj message-passing.
28 | Args:
29 | x (tensor): A data (node features).
30 | edge_index (tensor): An edge index.
31 | conj_node_idxs (tensor): A list of indicies of conjunction nodes extended for the batch size.
32 | batch_size (int): A batch size.
33 | """
34 | return self.propagate(edge_index, x=x, conj_node_idxs=conj_node_idxs, batch_size=batch_size).view((batch_size, -1))
35 |
36 | def message(self, x_j):
37 | """Compute the message.
38 | """
39 | return x_j
40 |
41 | def update(self, message, x, conj_node_idxs):
42 | """Update the node features.
43 | Args:
44 | message (tensor, [node_num, node_dim]): Messages aggregated by `aggregate`.
45 | x (tensor, [node_num, node_dim]): Node features on the previous step.
46 | """
47 | return self.soft_logic._or(torch.stack([x[conj_node_idxs], message[conj_node_idxs]]))
48 |
49 | def aggregate(self, inputs, index):
50 | """Aggregate the messages.
51 | Args:
52 | inputs (tensor, [num_edges, num_features]): The values come from each edge.
53 | index (tensor, [2, num_of_edges]): The indices of the terminal nodes (conjunction nodes).
54 | """
55 | return scatter_mul(inputs, index, dim=0)
56 |
57 |
58 | class Conj2AtomConv(MessagePassing):
59 | """Message Passing class for messages from atom nodes to conjunction nodes.
60 |
61 | Args:
62 | soft_logic (softlogic): An implementation of the soft-logic operations.
63 | device (device): A device.
64 | """
65 |
66 | def __init__(self, soft_logic, device):
67 | super().__init__(aggr='add')
68 | self.soft_logic = soft_logic
69 | self.device = device
70 | self.eps = 1e-4
71 | # self.linear = nn.Linear()
72 |
73 | def forward(self, x, edge_index, edge_weight, edge_clause_index, atom_node_idxs, n_nodes, batch_size):
74 | """Perform conj2atom message-passing.
75 | Args:
76 | edge_index (tensor): An edge index.
77 | x (tensor): A data (node features).
78 | edge_weight (tensor): The edge weights.
79 | edge_clause_index (tensor): A list of indices of clauses representing which clause produced each edge in the reasoning graph.
80 | atom_node_idxs (tensor): A list of indicies of atom nodes extended for the batch size.
81 | n_nodes (int): The number of nodes in the reasoning graph.
82 | batch_size (int): A batch size.
83 | """
84 | return self.propagate(edge_index, x=x, edge_weight=edge_weight, edge_clause_index=edge_clause_index,
85 | atom_node_idxs=atom_node_idxs, n_nodes=n_nodes, batch_size=batch_size).view(batch_size, -1)
86 |
87 | def message(self, x_j, edge_weight, edge_clause_index):
88 | """Compute the message.
89 | """
90 | return edge_weight.view(-1, 1) * x_j
91 |
92 | def update(self, message, x, atom_node_idxs):
93 | """Update the node features.
94 | Args:
95 | message (tensor, [node_num, node_dim]): Messages aggregated by `aggregate`.
96 | x (tensor, [node_num, node_dim]): Node features on the previous step.
97 | """
98 | return self.soft_logic._or(torch.stack([x[atom_node_idxs], message[atom_node_idxs]]))
99 |
100 | def aggregate(self, inputs, index, n_nodes):
101 | """Aggregate the messages.
102 | Args:
103 | inputs (tensor, [num_edges, num_features]): The values come from each edge.
104 | index (tensor, [2, num_of_edges]): The indices of the terminal nodes (conjunction nodes).
105 | n_nodes (int): The number of nodes in the reasoning graph.
106 | """
107 | # softor
108 | # gamma = 0.05
109 | gamma = 0.015
110 | log_sum_exp = gamma * \
111 | self._logsumexp((inputs) * (1/gamma), index, n_nodes)
112 | if log_sum_exp.max() > 1.0:
113 | return log_sum_exp / log_sum_exp.max()
114 | else:
115 | return log_sum_exp
116 |
117 | def _logsumexp(self, inputs, index, n_nodes):
118 | return torch.log(scatter(src=torch.exp(inputs), index=index, dim=0, dim_size=n_nodes, reduce='sum') + self.eps)
119 |
120 |
121 | class MessagePassingModule(torch.nn.Module):
122 | """The bi-directional message-passing module.
123 |
124 | Args:
125 | soft_logic (softlogic): An implementation of the soft-logic operations.
126 | device (device): A device.
127 | T (int): The number of steps for reasoning.
128 | """
129 |
130 | def __init__(self, soft_logic, device, T):
131 | super().__init__()
132 | self.soft_logic = soft_logic
133 | self.device = device
134 | self.T = T
135 | self.atom2conj = Atom2ConjConv(soft_logic, device)
136 | self.conj2atom = Conj2AtomConv(soft_logic, device)
137 |
138 | def forward(self, data, clause_weights, edge_clause_index, edge_type, atom_node_idxs, conj_node_idxs, batch_size, explain=False):
139 | """
140 | Args:
141 | data (torch_geometric.Data): A logic progam and probabilistic facts as a graph data.
142 | clause_weights (torch.Tensor): Weights for clauses.
143 | edge_clause_index (torch.Tensor): A clause indices for each edge, representing which clause produces the edge.
144 | edge_type (torch.Tensor): Edge types (atom2conj:0 or conj2atom:1)
145 | atom_node_idxs (torch.Tensor): The indices of atom nodes.
146 | conj_node_idxs (torch.Tensor): The indices of conjunction nodes.
147 | batch_size (int): The batch size of input.
148 | """
149 | x_atom = data.x[atom_node_idxs]
150 | x_conj = data.x[conj_node_idxs]
151 |
152 |
153 | # filter the edge index using the edge type obtaining the set of atom2conj edges and the set of conj2atom edges
154 | atom2conj_edge_index = self._filter_edge_index(
155 | data.edge_index, edge_type, 'atom2conj', batch_size).to(self.device)
156 | conj2atom_edge_index = self._filter_edge_index(
157 | data.edge_index, edge_type, 'conj2atom', batch_size).to(self.device)
158 |
159 | # filter the edge-clause index using the edge type obtaining the set of atom2conj edge-clause indices and the set of conj2atom edge-clause indices
160 | conj2atom_edge_clause_index = self._filter_edge_clause_index(
161 | edge_clause_index, edge_type, batch_size).to(self.device)
162 | edge_weight = torch.gather(
163 | input=clause_weights, dim=0, index=conj2atom_edge_clause_index)
164 |
165 | n_nodes = data.x.size(0)
166 |
167 | self.x_atom_list = [data.x[atom_node_idxs].view((batch_size, -1)).detach().cpu().numpy()[:,1:]]
168 | x = data.x
169 |
170 |
171 | # dummy variable to compute inpute gradients
172 | if explain:
173 | #print(x[atom_node_idxs], x[atom_node_idxs].shape)
174 | self.dummy_zeros = torch.zeros_like(x[atom_node_idxs], requires_grad=True).to(torch.float32).to(self.device)
175 | self.dummy_zeros.requires_grad_()
176 | self.dummy_zeros.retain_grad()
177 | #print(self.dummy_zeros)
178 | # add dummy zeros to get input gradients
179 | x[atom_node_idxs] = x[atom_node_idxs] + self.dummy_zeros
180 |
181 | # iterate message passing T times
182 | for t in range(self.T):
183 |
184 | # step 1: Atom -> Conj
185 | x_conj_new = self.atom2conj(
186 | x, atom2conj_edge_index, conj_node_idxs, batch_size)
187 |
188 | # create new tensor (node features) by updating conjunction embeddings
189 | x = self._cat_x_atom_x_conj(
190 | x[atom_node_idxs].view((batch_size, -1)), x_conj_new)
191 |
192 | # step 2: Conj -> Atom
193 | x_atom_new = self.conj2atom(x=x, edge_weight=edge_weight, edge_index=conj2atom_edge_index, edge_clause_index=edge_clause_index, atom_node_idxs=atom_node_idxs,
194 | n_nodes=n_nodes, batch_size=batch_size)
195 | self.x_atom_list.append(x_atom_new.detach().cpu().numpy()[:,1:])
196 |
197 | x = self._cat_x_atom_x_conj(x_atom_new, x_conj_new)
198 | self.x_atom_final = x_atom_new
199 | return x_atom_new
200 |
201 | def _cat_x_atom_x_conj(self, x_atom, x_conj):
202 | """Concatenate the features of atom ndoes and those of conj nodes.
203 | Args:
204 | x_atom : batch_size * n_atom
205 | x_conj : batch_size * n_conj
206 |
207 | Returns:
208 | [x_atom_1, x_conj_1, x_atom_2, x_conj_2, ...]
209 | """
210 | xs = []
211 | for i in range(x_atom.size(0)):
212 | x_i = torch.cat([x_atom[i], x_conj[i]])
213 | xs.append(x_i)
214 | return torch.cat(xs).unsqueeze(-1)
215 |
216 | def _filter_edge_clause_index(self, edge_clause_index, edge_type, batch_size):
217 | """Filter the edge index by the edge type.
218 | """
219 | edge_clause_index = torch.stack(
220 | [edge_clause_index for i in range(batch_size)]).view((-1))
221 | edge_type = torch.stack([edge_type for i in range(batch_size)])
222 | mask = (edge_type == 1).view((-1))
223 | return edge_clause_index[mask]
224 |
225 | def _filter_edge_index(self, edge_index, edge_type, mode, batch_size):
226 | """Filter the edge index by the edge type.
227 | """
228 | edge_type = torch.stack([edge_type for i in range(batch_size)])
229 | if mode == 'atom2conj':
230 | mask = (edge_type == 0).view((-1))
231 | return edge_index[:, mask]
232 | elif mode == 'conj2atom':
233 | mask = (edge_type == 1).view((-1))
234 | return edge_index[:, mask]
235 | else:
236 | assert 0, "Invalid mode in _filter_edge_index"
237 |
--------------------------------------------------------------------------------
/neumann/neumann/mode_declaration.py:
--------------------------------------------------------------------------------
1 | from fol.language import DataType
2 |
3 |
4 | class ModeDeclaration(object):
5 | """from https://www.cs.ox.ac.uk/activities/programinduction/Aleph/aleph.html
6 | p(ModeType, ModeType,...)
7 | Here are some examples of how they appear in a file:
8 | :- mode(1,mem(+number,+list)).
9 | :- mode(1,dec(+integer,-integer)).
10 | :- mode(1,mult(+integer,+integer,-integer)).
11 | :- mode(1,plus(+integer,+integer,-integer)).
12 | :- mode(1,(+integer)=(#integer)).
13 | :- mode(*,has_car(+train,-car)).
14 | Each ModeType is either (a) simple; or (b) structured.
15 | A simple ModeType is one of:
16 | (a) +T specifying that when a literal with predicate symbol p appears in a
17 | hypothesised clause, the corresponding argument should be an "input" variable of type T;
18 | (b) -T specifying that the argument is an "output" variable of type T; or
19 | (c) #T specifying that it should be a constant of type T.
20 | All the examples above have simple modetypes.
21 | A structured ModeType is of the form f(..) where f is a function symbol,
22 | each argument of which is either a simple or structured ModeType.
23 | Here is an example containing a structured ModeType:
24 | To make this more clear, here is an example for the mode declarations for
25 | the grandfather task from
26 | above::- modeh(1, grandfather(+human, +human)).:-
27 | modeb(*, parent(-human, +human)).:-
28 | modeb(*, male(+human)).
29 | The first mode states that the head of the rule
30 | (and therefore the targetpredicate) will be the atomgrandfather.
31 | Its parameters have to be of the typehuman.
32 | The + annotation says that the rule head needs two variables.
33 | Thesecond mode declaration states theparentatom and declares again
34 | that theparameters have to be of type human.
35 | Here, the + at the second parametertells, that the system is only allowed to
36 | introduce the atomparentin the clauseif it already contains a variable of type human.
37 | The first attribute introduces a new variable into the clause.
38 | The modes consist of a recall n that states how many versions of the
39 | literal are allowed in a rule and an atom with place-markers that state the literal to-gether
40 | with annotations on input- and output-variables as well as constants (see[Mug95]).
41 |
42 | Args:
43 | recall (int): The recall number i.e. how many times the declaration can be instanciated
44 | pred (Predicate): The predicate.
45 | mode_terms (ModeTerm): Terms for mode declarations.
46 | """
47 |
48 | def __init__(self, mode_type, recall, pred, mode_terms, ordered=True):
49 | self.mode_type = mode_type # head or body
50 | self.recall = recall
51 | self.pred = pred
52 | self.mode_terms = mode_terms
53 | self.ordered = ordered
54 |
55 | def __str__(self):
56 | s = 'mode_' + self.mode_type + '('
57 | for mt in self.mode_terms:
58 | s += str(mt)
59 | s += ','
60 | s = s[0:-1]
61 | s += ')'
62 | return s
63 |
64 | def __repr__(self):
65 | return self.__str__()
66 |
67 | def __hash__(self):
68 | return hash(self.__str__())
69 |
70 |
71 | class ModeTerm(object):
72 | """Terms for mode declarations. It has mode (+, -, #) and data types.
73 | """
74 |
75 | def __init__(self, mode, dtype):
76 | self.mode = mode
77 | assert mode in ['+', '-', '#'], "Invalid mode declaration."
78 | self.dtype = dtype
79 |
80 | def __str__(self):
81 | return self.mode + self.dtype.name
82 |
83 | def __repr__(self):
84 | return self.__str__()
85 |
86 |
87 | def get_mode_declarations_clevr(lang, obj_num):
88 | p_image = ModeTerm('+', DataType('image'))
89 | m_object = ModeTerm('-', DataType('object'))
90 | p_object = ModeTerm('+', DataType('object'))
91 | s_color = ModeTerm('#', DataType('color'))
92 | s_shape = ModeTerm('#', DataType('shape'))
93 | s_material = ModeTerm('#', DataType('material'))
94 | s_size = ModeTerm('#', DataType('size'))
95 |
96 | # modeh_1 = ModeDeclaration('head', 'kp', p_image)
97 |
98 | """
99 | kp1(X):-in(O1,X),in(O2,X),size(O1,large),shape(O1,cube),size(O2,large),shape(O2,cylinder).
100 | kp2(X):-in(O1,X),in(O2,X),size(O1,small),material(O1,metal),shape(O1,cube),size(O2,small),shape(O2,sphere).
101 | kp3(X):-in(O1,X),in(O2,X),size(O1,large),color(O1,blue),shape(O1,sphere),size(O2,small),color(O2,yellow),shape(O2,sphere)."""
102 |
103 | modeb_list = [
104 | #ModeDeclaration('body', obj_num, lang.get_pred_by_name(
105 | # 'in'), [m_object, p_image]),
106 | ModeDeclaration('body', 2, lang.get_pred_by_name(
107 | 'color'), [p_object, s_color]),
108 | ModeDeclaration('body', 2, lang.get_pred_by_name(
109 | 'shape'), [p_object, s_shape]),
110 | #ModeDeclaration('body', 1, lang.get_pred_by_name(
111 | # 'material'), [p_object, s_material]),
112 | ModeDeclaration('body', 2, lang.get_pred_by_name(
113 | 'size'), [p_object, s_size]),
114 | ]
115 | return modeb_list
116 |
117 |
118 | def get_mode_declarations_kandinsky(lang, obj_num):
119 | p_image = ModeTerm('+', DataType('image'))
120 | m_object = ModeTerm('-', DataType('object'))
121 | p_object = ModeTerm('+', DataType('object'))
122 | s_color = ModeTerm('#', DataType('color'))
123 | s_shape = ModeTerm('#', DataType('shape'))
124 |
125 | # modeh_1 = ModeDeclaration('head', 'kp', p_image)
126 |
127 | modeb_list = [
128 | #ModeDeclaration('body', obj_num, lang.get_pred_by_name(
129 | # 'in'), [m_object, p_image]),
130 | ModeDeclaration('body', 1, lang.get_pred_by_name(
131 | 'color'), [p_object, s_color]),
132 | ModeDeclaration('body', 1, lang.get_pred_by_name(
133 | 'shape'), [p_object, s_shape]),
134 | ModeDeclaration('body', 1, lang.get_pred_by_name(
135 | 'same_color_pair'), [p_object, p_object], ordered=False),
136 | ModeDeclaration('body', 2, lang.get_pred_by_name(
137 | 'same_shape_pair'), [p_object, p_object], ordered=False),
138 | ModeDeclaration('b1ody', 1, lang.get_pred_by_name(
139 | 'diff_color_pair'), [p_object, p_object], ordered=False),
140 | ModeDeclaration('body', 1, lang.get_pred_by_name(
141 | 'diff_shape_pair'), [p_object, p_object], ordered=False),
142 | ModeDeclaration('body', 1, lang.get_pred_by_name(
143 | 'closeby'), [p_object, p_object], ordered=False),
144 | ModeDeclaration('body', 1, lang.get_pred_by_name('online'), [
145 | p_object, p_object, p_object, p_object, p_object], ordered=False),
146 | # ModeDeclaration('body', 2, lang.get_pred_by_name('diff_shape_pair'), [p_object, p_object]),
147 | ]
148 | return modeb_list
149 |
150 | def get_mode_declarations_vilp(lang, dataset):
151 | p_colors = ModeTerm('+', DataType('colors'))
152 | p_color = ModeTerm('+', DataType('color'))
153 | # modeh_1 = ModeDeclaration('head', 'kp', p_image)
154 | if dataset=='member':
155 | modeb_list = [
156 | ModeDeclaration('body', 1, lang.get_pred_by_name(
157 | 'member'), [p_color, p_colors])]
158 | elif dataset == 'delete':
159 | modeb_list = [
160 | ModeDeclaration('body', 1, lang.get_pred_by_name(
161 | 'delete'), [p_color, p_colors, p_colors])]
162 | elif dataset == 'append':
163 | modeb_list = [
164 | ModeDeclaration('body', 1, lang.get_pred_by_name(
165 | 'append'), [p_colors, p_colors, p_colors])]
166 | elif dataset == 'reverse':
167 | modeb_list = [
168 | ModeDeclaration('body', 1, lang.get_pred_by_name(
169 | 'reverse'), [p_colors, p_colors, p_colors])
170 | #ModeDeclaration('body', 1, lang.get_pred_by_name(
171 | # 'append'), [p_colors, p_colors, p_colors]),
172 | ]
173 | elif dataset == 'sort':
174 | modeb_list = [
175 | ModeDeclaration('body', 1, lang.get_pred_by_name('perm'), [p_colors, p_colors]),
176 | ModeDeclaration('body', 1, lang.get_pred_by_name('is_sorted'), [p_colors]),
177 | ModeDeclaration('body', 1, lang.get_pred_by_name('smaller'), [p_color, p_color]),
178 | ]
179 | #ModeDeclaration('body', 1, lang.get_pred_by_name(
180 | # 'append'), [p_colors, p_colors, p_colors]),
181 | #ModeDeclaration('body', 1, lang.get_pred_by_name(
182 | # 'reverse'), [p_colors, p_colors]),
183 | #ModeDeclaration('body', 1, lang.get_pred_by_name(
184 | # 'sort'), [p_colors, p_colors])
185 | return modeb_list
186 |
187 | def get_mode_declarations(args, lang):
188 | if args.dataset_type == 'kandinsky':
189 | return get_mode_declarations_kandinsky(lang, args.num_objects)
190 | elif args.dataset_type == 'clevr-hans':
191 | return get_mode_declarations_clevr(lang, 10)
192 | elif args.dataset_type == 'vilp':
193 | return get_mode_declarations_vilp(lang, args.dataset)
194 | else:
195 | assert False, "Invalid data type."
--------------------------------------------------------------------------------
/neumann/neumann/neural_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class LogisticRegression(torch.nn.Module):
7 | def __init__(self, input_dim, output_dim=1):
8 | super(LogisticRegression, self).__init__()
9 | self.linear = torch.nn.Linear(input_dim, output_dim)
10 |
11 | def forward(self, x):
12 | y_pred = torch.sigmoid(self.linear(x))
13 | return y_pred
14 |
15 | class MLP(nn.Module):
16 | def __init__(self, in_channels, out_channels, hidden_dim=256):
17 | super(MLP, self).__init__()
18 | # Number of input features is input_dim.
19 | self.layer_1 = nn.Linear(in_channels, hidden_dim)
20 | self.layer_2 = nn.Linear(hidden_dim, hidden_dim)
21 | self.layer_out = nn.Linear(hidden_dim, out_channels)
22 | self.relu = nn.ReLU()
23 |
24 | def forward(self, inputs):
25 | x = self.relu(self.layer_1(inputs))
26 | x = self.relu(self.layer_2(x))
27 | x = self.layer_out(x)
28 | #x = torch.sigmoid(x)
29 | return x
--------------------------------------------------------------------------------
/neumann/neumann/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import random
4 | import time
5 |
6 | import matplotlib.pyplot as plt
7 | import networkx as nx
8 | import numpy as np
9 | import torch
10 | from rtpt import RTPT
11 | from sklearn.metrics import accuracy_score, recall_score, roc_curve
12 | from torch.utils.tensorboard import SummaryWriter
13 | from tqdm import tqdm
14 |
15 | from .logic_utils import get_lang
16 | from .neumann_utils import (
17 | generate_captions,
18 | get_data_loader,
19 | get_model,
20 | get_prob,
21 | save_images_with_captions,
22 | to_plot_images_clevr,
23 | to_plot_images_kandinsky,
24 | )
25 | from .tensor_encoder import TensorEncoder
26 | from .tensor_utils import build_infer_module
27 | from .visualize import plot_proof_history
28 |
29 | random.seed(0)
30 |
31 |
32 | def get_args():
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument(
35 | "--batch-size", type=int, default=1, help="Batch size to infer with"
36 | )
37 | parser.add_argument(
38 | "--num-objects",
39 | type=int,
40 | default=3,
41 | help="The maximum number of objects in one image",
42 | )
43 | parser.add_argument("--dataset") # , choices=["member"])
44 | parser.add_argument("--dataset_type") # , choices=["member"])
45 | parser.add_argument("--rtpt-name", default="") # , choices=["member"])
46 | parser.add_argument(
47 | "--dataset-type",
48 | choices=["vilp", "clevr-hans", "kandinsky"],
49 | help="vilp or kandinsky or clevr",
50 | )
51 | parser.add_argument("--device", default="cpu", help="cuda device, i.e. 0 or cpu")
52 | parser.add_argument(
53 | "--no-cuda",
54 | action="store_true",
55 | help="Run on CPU instead of GPU (not recommended)",
56 | )
57 | parser.add_argument(
58 | "--no-train",
59 | action="store_true",
60 | help="Perform prediction without training model",
61 | )
62 | parser.add_argument(
63 | "--small-data", action="store_true", help="Use small training data."
64 | )
65 | parser.add_argument(
66 | "--num-workers", type=int, default=4, help="Number of threads for data loader"
67 | )
68 | parser.add_argument(
69 | "--gamma",
70 | default=0.01,
71 | type=float,
72 | help="Smooth parameter in the softor function",
73 | )
74 | parser.add_argument(
75 | "--plot", action="store_true", help="Plot images with captions."
76 | )
77 | parser.add_argument(
78 | "--t-beam",
79 | type=int,
80 | default=4,
81 | help="Number of rule expantion of clause generation.",
82 | )
83 | parser.add_argument("--n-beam", type=int, default=5, help="The size of the beam.")
84 | parser.add_argument(
85 | "--n-max", type=int, default=50, help="The maximum number of clauses."
86 | )
87 | parser.add_argument(
88 | "--program-size",
89 | "-m",
90 | type=int,
91 | default=1,
92 | help="The size of the logic program.",
93 | )
94 | # parser.add_argument("--n-obj", type=int, default=2, help="The number of objects to be focused.")
95 | parser.add_argument("--epochs", type=int, default=20, help="The number of epochs.")
96 | parser.add_argument("--lr", type=float, default=1e-2, help="The learning rate.")
97 | parser.add_argument(
98 | "--n-data", type=float, default=200, help="The number of data to be used."
99 | )
100 | parser.add_argument(
101 | "--pre-searched", action="store_true", help="Using pre searched clauses."
102 | )
103 | parser.add_argument(
104 | "-T",
105 | "--infer-step",
106 | type=int,
107 | default=10,
108 | help="The number of steps of forward reasoning.",
109 | )
110 | parser.add_argument(
111 | "--term-depth",
112 | type=int,
113 | default=3,
114 | help="The number of steps of forward reasoning.",
115 | )
116 | args = parser.parse_args()
117 | return args
118 |
119 |
120 | def main():
121 | args = get_args()
122 | print("args ", args)
123 | if args.no_cuda:
124 | device = torch.device("cpu")
125 | elif len(args.device.split(",")) > 1:
126 | # multi gpu
127 | device = torch.device("cuda")
128 | else:
129 | device = torch.device("cuda:" + args.device)
130 |
131 | print("device: ", device)
132 |
133 | # Create RTPT object
134 | rtpt = RTPT(
135 | name_initials="",
136 | experiment_name="NEUMANN_{}".format(args.dataset),
137 | max_iterations=args.epochs,
138 | )
139 | # Start the RTPT tracking
140 | rtpt.start()
141 |
142 | # Get torch data loader
143 | # train_loader, val_loader, test_loader = get_data_loader(args, device)
144 |
145 | # Load logical representations
146 | lark_path = "src/lark/exp.lark"
147 | lang_base_path = "data/lang/"
148 | lang, clauses, bk, bk_clauses, terms, atoms = get_lang(
149 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth
150 | )
151 | print(terms)
152 | print("{} Atoms:".format(len(atoms)))
153 | print(atoms)
154 |
155 | # Load the NEUMANN model
156 | NEUMANN = get_model(
157 | lang=lang,
158 | clauses=clauses,
159 | atoms=atoms,
160 | terms=terms,
161 | bk=bk,
162 | bk_clauses=bk_clauses,
163 | program_size=args.program_size,
164 | device=device,
165 | dataset=args.dataset,
166 | dataset_type=args.dataset_type,
167 | num_objects=args.num_objects,
168 | term_depth=args.term_depth,
169 | infer_step=args.infer_step,
170 | train=not (args.no_train),
171 | )
172 |
173 | x = torch.zeros((1, len(atoms))).to(device)
174 | x[:, 0] = 1.0
175 | ## x[:,1] = 0.8
176 | print("x: ", x)
177 | print(np.round(NEUMANN(x).detach().cpu().numpy(), 2))
178 | print(NEUMANN.rgm.edge_index)
179 | print("graph: ")
180 | print(NEUMANN.rgm.networkx_graph)
181 | NEUMANN.plot_reasoning_graph(name=args.dataset)
182 | plot_proof_history(
183 | NEUMANN.mpm.x_atom_list, atoms[1:], args.infer_step, args.dataset, mode="graph"
184 | )
185 | # nx.draw(NEUMANN.rgm.networkx_graph)
186 | # plt.savefig('imgs/{}.png'.format(args.dataset))
187 |
188 | print("==== tensor based reasoner")
189 | from logic_utils import false
190 |
191 | atoms = [false] + atoms
192 | rgm = NEUMANN.rgm
193 | rgm.facts = atoms
194 | x = torch.zeros((1, len(atoms))).to(device)
195 | x[:, 1] = 1.0
196 | for i, atom in enumerate(atoms):
197 | if atom in bk:
198 | x[:, i] += 1.0
199 | ## x[:,2] = 0.8
200 | IM = build_infer_module(
201 | clauses,
202 | atoms,
203 | lang,
204 | rgm,
205 | device,
206 | m=args.program_size,
207 | infer_step=args.infer_step,
208 | train=True,
209 | )
210 | IM(x)
211 | IM.W = NEUMANN.clause_weights
212 | print(IM.get_weights())
213 | plot_proof_history(
214 | IM.V_list, atoms[2:], args.infer_step, args.dataset, mode="tensor"
215 | )
216 |
217 |
218 | if __name__ == "__main__":
219 | main()
220 |
--------------------------------------------------------------------------------
/neumann/neumann/scatter.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 |
6 |
7 |
8 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
9 | if dim < 0:
10 | dim = other.dim() + dim
11 | if src.dim() == 1:
12 | for _ in range(0, dim):
13 | src = src.unsqueeze(0)
14 | for _ in range(src.dim(), other.dim()):
15 | src = src.unsqueeze(-1)
16 | src = src.expand(other.size())
17 | return src
18 |
19 | def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
20 | out: Optional[torch.Tensor] = None,
21 | dim_size: Optional[int] = None) -> torch.Tensor:
22 | index = broadcast(index, src, dim)
23 | if out is None:
24 | size = list(src.size())
25 | if dim_size is not None:
26 | size[dim] = dim_size
27 | elif index.numel() == 0:
28 | size[dim] = 0
29 | else:
30 | size[dim] = int(index.max()) + 1
31 | out = torch.zeros(size, dtype=src.dtype, device=src.device)
32 | return out.scatter_add_(dim, index, src)
33 | else:
34 | return out.scatter_add_(dim, index, src)
35 |
36 |
37 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
38 | out: Optional[torch.Tensor] = None,
39 | dim_size: Optional[int] = None) -> torch.Tensor:
40 | return scatter_sum(src, index, dim, out, dim_size)
41 |
42 |
43 | def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
44 | out: Optional[torch.Tensor] = None,
45 | dim_size: Optional[int] = None) -> torch.Tensor:
46 | return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
47 |
48 |
49 | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
50 | out: Optional[torch.Tensor] = None,
51 | dim_size: Optional[int] = None) -> torch.Tensor:
52 | out = scatter_sum(src, index, dim, out, dim_size)
53 | dim_size = out.size(dim)
54 |
55 | index_dim = dim
56 | if index_dim < 0:
57 | index_dim = index_dim + src.dim()
58 | if index.dim() <= index_dim:
59 | index_dim = index.dim() - 1
60 |
61 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
62 | count = scatter_sum(ones, index, index_dim, None, dim_size)
63 | count[count < 1] = 1
64 | count = broadcast(count, out, dim)
65 | if out.is_floating_point():
66 | out.true_divide_(count)
67 | else:
68 | out.div_(count, rounding_mode='floor')
69 | return out
70 |
71 |
72 | def scatter_min(
73 | src: torch.Tensor, index: torch.Tensor, dim: int = -1,
74 | out: Optional[torch.Tensor] = None,
75 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
76 | return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
77 |
78 |
79 | def scatter_max(
80 | src: torch.Tensor, index: torch.Tensor, dim: int = -1,
81 | out: Optional[torch.Tensor] = None,
82 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
83 | return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
84 |
85 |
86 | def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
87 | out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
88 | reduce: str = "sum") -> torch.Tensor:
89 | r"""
90 | |
91 | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
92 | master/docs/source/_figures/add.svg?sanitize=true
93 | :align: center
94 | :width: 400px
95 | |
96 | Reduces all values from the :attr:`src` tensor into :attr:`out` at the
97 | indices specified in the :attr:`index` tensor along a given axis
98 | :attr:`dim`.
99 | For each value in :attr:`src`, its output index is specified by its index
100 | in :attr:`src` for dimensions outside of :attr:`dim` and by the
101 | corresponding value in :attr:`index` for dimension :attr:`dim`.
102 | The applied reduction is defined via the :attr:`reduce` argument.
103 | Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
104 | tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
105 | and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
106 | tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
107 | Moreover, the values of :attr:`index` must be between :math:`0` and
108 | :math:`y - 1`, although no specific ordering of indices is required.
109 | The :attr:`index` tensor supports broadcasting in case its dimensions do
110 | not match with :attr:`src`.
111 | For one-dimensional tensors with :obj:`reduce="sum"`, the operation
112 | computes
113 | .. math::
114 | \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
115 | where :math:`\sum_j` is over :math:`j` such that
116 | :math:`\mathrm{index}_j = i`.
117 | .. note::
118 | This operation is implemented via atomic operations on the GPU and is
119 | therefore **non-deterministic** since the order of parallel operations
120 | to the same value is undetermined.
121 | For floating-point variables, this results in a source of variance in
122 | the result.
123 | :param src: The source tensor.
124 | :param index: The indices of elements to scatter.
125 | :param dim: The axis along which to index. (default: :obj:`-1`)
126 | :param out: The destination tensor.
127 | :param dim_size: If :attr:`out` is not given, automatically create output
128 | with size :attr:`dim_size` at dimension :attr:`dim`.
129 | If :attr:`dim_size` is not given, a minimal sized output tensor
130 | according to :obj:`index.max() + 1` is returned.
131 | :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
132 | :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
133 | :rtype: :class:`Tensor`
134 | .. code-block:: python
135 | from torch_scatter import scatter
136 | src = torch.randn(10, 6, 64)
137 | index = torch.tensor([0, 1, 0, 1, 2, 1])
138 | # Broadcasting in the first and last dim.
139 | out = scatter(src, index, dim=1, reduce="sum")
140 | print(out.size())
141 | .. code-block::
142 | torch.Size([10, 3, 64])
143 | """
144 | if reduce == 'sum' or reduce == 'add':
145 | return scatter_sum(src, index, dim, out, dim_size)
146 | if reduce == 'mul':
147 | return scatter_mul(src, index, dim, out, dim_size)
148 | elif reduce == 'mean':
149 | return scatter_mean(src, index, dim, out, dim_size)
150 | elif reduce == 'min':
151 | return scatter_min(src, index, dim, out, dim_size)[0]
152 | elif reduce == 'max':
153 | return scatter_max(src, index, dim, out, dim_size)[0]
154 | else:
155 | raise ValueError
--------------------------------------------------------------------------------
/neumann/neumann/soft_logic.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .torch_utils import softor
4 |
5 |
6 | class SoftLogic(object):
7 | """An class of the soft-implementation of logic operations, i.e., logical-or and logical-and.
8 | """
9 |
10 | def __init__(self):
11 | pass
12 |
13 | def _or(self, x):
14 | return softor(x, dim=0, gamma=0.01)
15 |
16 | def _and(self, x):
17 | return torch.prod(x, dim=0)
18 |
19 |
20 | class LukasiewiczSoftLogic(SoftLogic):
21 | pass
22 |
23 |
24 | class LNNSoftLogic(SoftLogic):
25 | pass
26 |
27 |
28 | class aILPSoftLogic(SoftLogic):
29 | pass
30 |
--------------------------------------------------------------------------------
/neumann/neumann/solve_behind_the_scenes.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | from rtpt import RTPT
8 | from sklearn.metrics import accuracy_score, recall_score, roc_curve
9 | from torch.utils.tensorboard import SummaryWriter
10 | from tqdm import tqdm
11 |
12 | from logic_utils import get_lang
13 | from mode_declaration import get_mode_declarations
14 | from neumann_utils import (get_behind_the_scenes_loader, get_clause_evaluator,
15 | get_model, get_prob)
16 | from tensor_encoder import TensorEncoder
17 |
18 | # from nsfr_utils import save_images_with_captions, to_plot_images_clevr, generate_captions
19 |
20 |
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--batch-size", type=int, default=2,
25 | help="Batch size to infer with")
26 | parser.add_argument("--batch-size-bs", type=int,
27 | default=1, help="Batch size in beam search")
28 | parser.add_argument("--num-objects", type=int, default=3,
29 | help="The maximum number of objects in one image")
30 | parser.add_argument("--dataset", default="delete") # , choices=["member"])
31 | parser.add_argument("--dataset-type", default="behind-the-scenes")
32 | parser.add_argument('--device', default='cpu',
33 | help='cuda device, i.e. 0 or cpu')
34 | parser.add_argument("--no-cuda", action="store_true",
35 | help="Run on CPU instead of GPU (not recommended)")
36 | parser.add_argument("--no-train", action="store_true",
37 | help="Perform prediction without training model")
38 | parser.add_argument("--small-data", action="store_true",
39 | help="Use small training data.")
40 | parser.add_argument("--num-workers", type=int, default=0,
41 | help="Number of threads for data loader")
42 | parser.add_argument('--gamma', default=0.01, type=float,
43 | help='Smooth parameter in the softor function')
44 | parser.add_argument("--plot", action="store_true",
45 | help="Plot images with captions.")
46 | parser.add_argument("--t-beam", type=int, default=4,
47 | help="Number of rule expantion of clause generation.")
48 | parser.add_argument("--n-beam", type=int, default=5,
49 | help="The size of the beam.")
50 | parser.add_argument("--n-max", type=int, default=50,
51 | help="The maximum number of clauses.")
52 | parser.add_argument("--program-size", type=int, default=1,
53 | help="The size of the logic program.")
54 | #parser.add_argument("--n-obj", type=int, default=2, help="The number of objects to be focused.")
55 | parser.add_argument("--epochs", type=int, default=20,
56 | help="The number of epochs.")
57 | parser.add_argument("--lr", type=float, default=1e-2,
58 | help="The learning rate.")
59 | parser.add_argument("--n-data", type=float, default=200,
60 | help="The number of data to be used.")
61 | parser.add_argument("--pre-searched", action="store_true",
62 | help="Using pre searched clauses.")
63 | parser.add_argument("--infer-step", type=int, default=6,
64 | help="The number of steps of forward reasoning.")
65 | parser.add_argument("--term-depth", type=int, default=3,
66 | help="The number of steps of forward reasoning.")
67 | parser.add_argument("--question-json-path", default="data/behind-the-scenes/BehindTheScenes_questions.json")
68 | args = parser.parse_args()
69 | return args
70 |
71 | # def get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False):
72 |
73 |
74 | def discretise_NEUMANN(NEUMANN, args, device):
75 | lark_path = 'src/lark/exp.lark'
76 | lang_base_path = 'data/lang/'
77 | lang, clauses_, bk, terms, atoms = get_lang(
78 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth)
79 | # Discretise NEUMANN rules
80 | clauses = NEUMANN.get_clauses()
81 | return get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False)
82 |
83 |
84 | def predict(NEUMANN, I2F, loader, args, device, th=None, split='train'):
85 | predicted_list = []
86 | target_list = []
87 | count = 0
88 |
89 | for i, sample in tqdm(enumerate(loader, start=0)):
90 | imgs, query, target_set = map(lambda x: x.to(device), sample)
91 |
92 | # to cuda
93 | target_set = target_set.float()
94 |
95 | V_0 = I2F(imgs, query)
96 | V_T = NEUMANN(V_0)
97 | predicted = get_prob(V_T, NEUMANN, args)
98 | predicted = to_one_label(predicted, target_set)
99 | predicted = torch.softmax(predicted * 10, dim=1)
100 | predicted_list.append(predicted.detach())
101 | target_list.append(target_set.detach())
102 | """
103 | if args.plot:
104 | imgs = to_plot_images_clevr(imgs.squeeze(1))
105 | captions = generate_captions(
106 | V_T, NEUMANN.atoms, I2F.pm.e, th=0.3)
107 | save_images_with_captions(
108 | imgs, captions, folder='result/kandinsky/' + args.dataset + '/' + split + '/', img_id_start=count, dataset=args.dataset)
109 | """
110 | count += V_T.size(0) # batch size
111 |
112 | predicted = torch.cat(predicted_list, dim=0).detach().cpu().numpy()
113 | target_set = torch.cat(target_list, dim=0).to(
114 | torch.int64).detach().cpu().numpy()
115 |
116 | if th == None:
117 | fpr, tpr, thresholds = roc_curve(target_set, predicted, pos_label=1)
118 | accuracy_scores = []
119 | print('ths', thresholds)
120 | for thresh in thresholds:
121 | accuracy_scores.append(accuracy_score(
122 | target_set, [m > thresh for m in predicted]))
123 |
124 | accuracies = np.array(accuracy_scores)
125 | max_accuracy = accuracies.max()
126 | max_accuracy_threshold = thresholds[accuracies.argmax()]
127 | rec_score = recall_score(
128 | target_set, [m > thresh for m in predicted], average=None)
129 |
130 | print('target_set: ', target_set, target_set.shape)
131 | print('predicted: ', predicted, predicted.shape)
132 | print('accuracy: ', max_accuracy)
133 | print('threshold: ', max_accuracy_threshold)
134 | print('recall: ', rec_score)
135 |
136 | return max_accuracy, rec_score, max_accuracy_threshold
137 | else:
138 | accuracy = accuracy_score(target_set, [m > th for m in predicted])
139 | rec_score = recall_score(
140 | target_set, [m > th for m in predicted], average=None)
141 | return accuracy, rec_score, th
142 |
143 |
144 | def to_one_label(ys, labels, th=0.7):
145 | ys_new = []
146 | for i in range(len(ys)):
147 | y = ys[i]
148 | label = labels[i]
149 | # check in case answers are computed
150 | num_class = 0
151 | for p_j in y:
152 | if p_j > th:
153 | num_class += 1
154 | if num_class >= 2:
155 | # drop the value using label (the label is one-hot)
156 | drop_index = torch.argmin(label - y)
157 | y[drop_index] = y.min()
158 | ys_new.append(y)
159 | return torch.stack(ys_new)
160 |
161 |
162 | def main(n):
163 | args = get_args()
164 | #name = 'VILP'
165 | print('args ', args)
166 | if args.no_cuda:
167 | device = torch.device('cpu')
168 | elif len(args.device.split(',')) > 1:
169 | # multi gpu
170 | device = torch.device('cuda')
171 | else:
172 | device = torch.device('cuda:' + args.device)
173 |
174 | print('device: ', device)
175 | name = 'neumann/behind-the-scenes/' + str(n)
176 | writer = SummaryWriter(f"runs/{name}", purge_step=0)
177 |
178 | # Create RTPT object
179 | rtpt = RTPT(name_initials='HS', experiment_name=name,
180 | max_iterations=args.epochs)
181 | # Start the RTPT tracking
182 | rtpt.start()
183 |
184 |
185 | ## train_pos_loader, val_pos_loader, test_pos_loader = get_vilp_pos_loader(args)
186 | #####train_pos_loader, val_pos_loader, test_pos_loader = get_data_loader(args)
187 |
188 | # load logical representations
189 | lark_path = 'src/lark/exp.lark'
190 | lang_base_path = 'data/lang/'
191 | lang, clauses, bk, bk_clauses, terms, atoms = get_lang(
192 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth)
193 |
194 | print("{} Atoms:".format(len(atoms)))
195 |
196 | NEUMANN, I2F = get_model(lang=lang, clauses=clauses, atoms=atoms, terms=terms, bk=bk, bk_clauses=bk_clauses,
197 | program_size=args.program_size, device=device, dataset=args.dataset, dataset_type=args.dataset_type,
198 | num_objects=args.num_objects, infer_step=args.infer_step, train=False)#train=not(args.no_train))
199 | print(NEUMANN.rgm)
200 | # get torch data loader
201 | question_json_path = 'data/behind-the-scenes/BehindTheScenes_questions_{}.json'.format(args.dataset)
202 | test_loader = get_behind_the_scenes_loader(question_json_path, args.batch_size, lang, device)
203 |
204 |
205 | writer.add_scalar("graph/num_atom_nodes", len(NEUMANN.rgm.atom_node_idxs))
206 | writer.add_scalar("graph/num_conj_nodes", len(NEUMANN.rgm.conj_node_idxs))
207 | num_nodes = len(NEUMANN.rgm.atom_node_idxs) + len(NEUMANN.rgm.conj_node_idxs)
208 | writer.add_scalar("graph/num_nodes", num_nodes)
209 |
210 | num_edges = NEUMANN.rgm.edge_index.size(1)
211 | writer.add_scalar("graph/num_edges", num_edges)
212 |
213 | writer.add_scalar("graph/memory_total", num_nodes + num_edges)
214 |
215 | print("=====================")
216 | print("NUM NODES: ", num_nodes)
217 | print("NUM EDGES: ", num_edges)
218 | print("MEMORY TOTAL: ", num_nodes + num_edges)
219 | print("=====================")
220 |
221 | params = list(NEUMANN.parameters())
222 | print('parameters: ', list(params))
223 |
224 | print("Predicting on test data set...")
225 | # test split
226 | acc_test, rec_test, th_test = predict(
227 | NEUMANN, I2F, test_loader, args, device, th=0.5, split='test')
228 |
229 | print("test acc: ", acc_test, "threashold: ", th_test, "recall: ", rec_test)
230 |
231 |
232 | if __name__ == "__main__":
233 | for i in range(1):
234 | main(n=i)
235 |
--------------------------------------------------------------------------------
/neumann/neumann/solve_kandinsky.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | from rtpt import RTPT
8 | from sklearn.metrics import accuracy_score, recall_score, roc_curve
9 | from torch.utils.tensorboard import SummaryWriter
10 | from tqdm import tqdm
11 |
12 | from logic_utils import get_lang
13 | from mode_declaration import get_mode_declarations
14 | from neumann_utils import (get_clause_evaluator, get_data_loader, get_model,
15 | get_prob)
16 | from tensor_encoder import TensorEncoder
17 |
18 | # from nsfr_utils import save_images_with_captions, to_plot_images_clevr, generate_captions
19 |
20 | torch.set_num_threads(10)
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--batch-size", type=int, default=10,
25 | help="Batch size to infer with")
26 | parser.add_argument("--batch-size-bs", type=int,
27 | default=1, help="Batch size in beam search")
28 | parser.add_argument("--num-objects", type=int, default=6,
29 | help="The maximum number of objects in one image")
30 | parser.add_argument("--dataset", default="delete") # , choices=["member"])
31 | parser.add_argument("--dataset-type", default="behind-the-scenes")
32 | parser.add_argument('--device', default='cpu',
33 | help='cuda device, i.e. 0 or cpu')
34 | parser.add_argument("--no-cuda", action="store_true",
35 | help="Run on CPU instead of GPU (not recommended)")
36 | parser.add_argument("--no-train", action="store_true",
37 | help="Perform prediction without training model")
38 | parser.add_argument("--small-data", action="store_true",
39 | help="Use small training data.")
40 | parser.add_argument("--num-workers", type=int, default=0,
41 | help="Number of threads for data loader")
42 | parser.add_argument('--gamma', default=0.01, type=float,
43 | help='Smooth parameter in the softor function')
44 | parser.add_argument("--plot", action="store_true",
45 | help="Plot images with captions.")
46 | parser.add_argument("--t-beam", type=int, default=4,
47 | help="Number of rule expantion of clause generation.")
48 | parser.add_argument("--n-beam", type=int, default=5,
49 | help="The size of the beam.")
50 | parser.add_argument("--n-max", type=int, default=50,
51 | help="The maximum number of clauses.")
52 | parser.add_argument("--program-size", type=int, default=1,
53 | help="The size of the logic program.")
54 | #parser.add_argument("--n-obj", type=int, default=2, help="The number of objects to be focused.")
55 | parser.add_argument("--epochs", type=int, default=20,
56 | help="The number of epochs.")
57 | parser.add_argument("--lr", type=float, default=1e-2,
58 | help="The learning rate.")
59 | parser.add_argument("--n-ratio", type=float, default=1.0,
60 | help="The ratio of data to be used.")
61 | parser.add_argument("--pre-searched", action="store_true",
62 | help="Using pre searched clauses.")
63 | parser.add_argument("--infer-step", type=int, default=6,
64 | help="The number of steps of forward reasoning.")
65 | parser.add_argument("--term-depth", type=int, default=3,
66 | help="The number of steps of forward reasoning.")
67 | parser.add_argument("--question-json-path", default="data/behind-the-scenes/BehindTheScenes_questions.json")
68 | args = parser.parse_args()
69 | return args
70 |
71 | # def get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False):
72 |
73 |
74 | def discretise_NEUMANN(NEUMANN, args, device):
75 | lark_path = 'src/lark/exp.lark'
76 | lang_base_path = 'data/lang/'
77 | lang, clauses_, bk, terms, atoms = get_lang(
78 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth)
79 | # Discretise NEUMANN rules
80 | clauses = NEUMANN.get_clauses()
81 | return get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False)
82 |
83 |
84 | def predict(NEUMANN, I2F, loader, args, device, th=None, split='train'):
85 | predicted_list = []
86 | target_list = []
87 | count = 0
88 |
89 | start = time.time()
90 | for epoch in tqdm(range(args.epochs)):
91 | for i, sample in enumerate(tqdm(loader), start=0):
92 | imgs, target_set = map(lambda x: x.to(device), sample)
93 |
94 | # to cuda
95 | target_set = target_set.float()
96 |
97 | V_0 = I2F(imgs)
98 | V_T = NEUMANN(V_0)
99 | #a NEUMANN.print_valuation_batch(V_T)
100 | #predicted = get_prob(V_T, NEUMANN, args)
101 | # predicted = to_one_label(predicted, target_set)
102 | # predicted = torch.softmax(predicted * 10, dim=1)
103 | #predicted_list.append(predicted.detach())
104 | #target_list.append(target_set.detach())
105 | """
106 | if args.plot:
107 | imgs = to_plot_images_clevr(imgs.squeeze(1))
108 | captions = generate_captions(
109 | V_T, NEUMANN.atoms, I2F.pm.e, th=0.3)
110 | save_images_with_captions(
111 | imgs, captions, folder='result/kandinsky/' + args.dataset + '/' + split + '/', img_id_start=count, dataset=args.dataset)
112 | """
113 | count += V_T.size(0) # batch size
114 | reasoning_time = time.time() - start
115 | print('Reasoning Time: ', reasoning_time)
116 | return 0, 0, 0, reasoning_time
117 |
118 | predicted = torch.cat(predicted_list, dim=0).detach().cpu().numpy()
119 | target_set = torch.cat(target_list, dim=0).to(
120 | torch.int64).detach().cpu().numpy()
121 |
122 | if th == None:
123 | fpr, tpr, thresholds = roc_curve(target_set, predicted, pos_label=1)
124 | accuracy_scores = []
125 | print('ths', thresholds)
126 | for thresh in thresholds:
127 | accuracy_scores.append(accuracy_score(
128 | target_set, [m > thresh for m in predicted]))
129 |
130 | accuracies = np.array(accuracy_scores)
131 | max_accuracy = accuracies.max()
132 | max_accuracy_threshold = thresholds[accuracies.argmax()]
133 | rec_score = recall_score(
134 | target_set, [m > thresh for m in predicted], average=None)
135 |
136 | print('target_set: ', target_set, target_set.shape)
137 | print('predicted: ', predicted, predicted.shape)
138 | print('accuracy: ', max_accuracy)
139 | print('threshold: ', max_accuracy_threshold)
140 | print('recall: ', rec_score)
141 |
142 | return max_accuracy, rec_score, max_accuracy_threshold, reasoning_time
143 | else:
144 | accuracy = accuracy_score(target_set, [m > th for m in predicted])
145 | rec_score = recall_score(
146 | target_set, [m > th for m in predicted], average=None)
147 | return accuracy, rec_score, th, reasoning_time
148 |
149 |
150 | def to_one_label(ys, labels, th=0.7):
151 | ys_new = []
152 | for i in range(len(ys)):
153 | y = ys[i]
154 | label = labels[i]
155 | # check in case answers are computed
156 | num_class = 0
157 | for p_j in y:
158 | if p_j > th:
159 | num_class += 1
160 | if num_class >= 2:
161 | # drop the value using label (the label is one-hot)
162 | drop_index = torch.argmin(label - y)
163 | y[drop_index] = y.min()
164 | ys_new.append(y)
165 | return torch.stack(ys_new)
166 |
167 |
168 | def main(n):
169 | args = get_args()
170 | #name = 'VILP'
171 | print('args ', args)
172 | if args.no_cuda:
173 | device = torch.device('cpu')
174 | elif len(args.device.split(',')) > 1:
175 | # multi gpu
176 | device = torch.device('cuda')
177 | else:
178 | device = torch.device('cuda:' + args.device)
179 |
180 | print('device: ', device)
181 | name = 'neumann/behind-the-scenes/' + str(n)
182 | writer = SummaryWriter(f"runs/{name}", purge_step=0)
183 |
184 | # Create RTPT object
185 | rtpt = RTPT(name_initials='HS', experiment_name=name,
186 | max_iterations=args.epochs)
187 | # Start the RTPT tracking
188 | rtpt.start()
189 |
190 |
191 | ## train_pos_loader, val_pos_loader, test_pos_loader = get_vilp_pos_loader(args)
192 | #####train_pos_loader, val_pos_loader, test_pos_loader = get_data_loader(args)
193 |
194 | # load logical representations
195 | lark_path = 'src/lark/exp.lark'
196 | lang_base_path = 'data/lang/'
197 | lang, clauses, bk, bk_clauses, terms, atoms = get_lang(
198 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth, use_learned_clauses=True)
199 |
200 | print("{} Atoms:".format(len(atoms)))
201 |
202 | # get torch data loader
203 | #question_json_path = 'data/behind-the-scenes/BehindTheScenes_questions_{}.json'.format(args.dataset)
204 | # test_loader = get_behind_the_scenes_loader(question_json_path, args.batch_size, lang, args.n_data, device)
205 | train_loader, val_loader, test_loader = get_data_loader(args, device)
206 |
207 | NEUMANN, I2F = get_model(lang=lang, clauses=clauses, atoms=atoms, terms=terms, bk=bk, bk_clauses=bk_clauses,
208 | program_size=args.program_size, device=device, dataset=args.dataset, dataset_type=args.dataset_type,
209 | num_objects=args.num_objects, infer_step=args.infer_step, train=False)#train=not(args.no_train))
210 |
211 | writer.add_scalar("graph/num_atom_nodes", len(NEUMANN.rgm.atom_node_idxs))
212 | writer.add_scalar("graph/num_conj_nodes", len(NEUMANN.rgm.conj_node_idxs))
213 | num_nodes = len(NEUMANN.rgm.atom_node_idxs) + len(NEUMANN.rgm.conj_node_idxs)
214 | writer.add_scalar("graph/num_nodes", num_nodes)
215 |
216 | num_edges = NEUMANN.rgm.edge_index.size(1)
217 | writer.add_scalar("graph/num_edges", num_edges)
218 |
219 | writer.add_scalar("graph/memory_total", num_nodes + num_edges)
220 |
221 | print("=====================")
222 | print("NUM NODES: ", num_nodes)
223 | print("NUM EDGES: ", num_edges)
224 | print("MEMORY TOTAL: ", num_nodes + num_edges)
225 | print("=====================")
226 |
227 | params = list(NEUMANN.parameters())
228 | print('parameters: ', list(params))
229 |
230 | print("Predicting on train data set...")
231 | times = []
232 | # train split
233 | for j in range(n):
234 | acc_test, rec_test, th_test, time = predict(
235 | NEUMANN, I2F, train_loader, args, device, th=0.5, split='test')
236 | times.append(time)
237 |
238 | with open('out/inference_time/time_{}_ratio_{}.txt'.format(args.dataset, args.n_ratio), 'w') as f:
239 | f.write("\n".join(str(item) for item in times))
240 |
241 | print("train acc: ", acc_test, "threashold: ", th_test, "recall: ", rec_test)
242 |
243 | if __name__ == "__main__":
244 | main(n=5)
245 |
246 |
--------------------------------------------------------------------------------
/neumann/neumann/torch_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def logsumexp(inputs, dim=None, keepdim=False):
6 | """Numerically stable logsumexp.
7 | from https://github.com/pytorch/pytorch/issues/2591#issuecomment-364474328
8 | Args:
9 | inputs: A Variable with any shape.
10 | dim: An integer.
11 | keepdim: A boolean.
12 |
13 | Returns:
14 | Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
15 | """
16 | # For a 1-D array x (any array along a single dimension),
17 | # log sum exp(x) = s + log sum exp(x - s)
18 | # with s = max(x) being a common choice.
19 | if dim is None:
20 | inputs = inputs.view(-1)
21 | dim = 0
22 | s, _ = torch.max(inputs, dim=dim, keepdim=True)
23 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
24 | if not keepdim:
25 | outputs = outputs.squeeze(dim)
26 | return outputs
27 |
28 |
29 | def weight_sum(W_l, H):
30 | # W: C
31 | # H: C * B * G
32 | W_ex = W_l.unsqueeze(dim=-1).unsqueeze(dim=-1).expand_as(H)
33 | # C * B * G
34 | WH = W_ex * H
35 | # B * G
36 | WH_sum = torch.sum(WH, dim=0)
37 | return WH_sum
38 |
39 |
40 | def softor(xs, dim=0, gamma=0.015):
41 | """The softor function.
42 |
43 | Args:
44 | xs (tensor or list(tensor)): The input tensor.
45 | dim (int): The dimension to be removed.
46 | gamma (float: The smooth parameter for logsumexp.
47 | Returns:
48 | log_sum_exp (tensor): The result of taking or along dim.
49 | """
50 | # xs is List[Tensor] or Tensor
51 | if not torch.is_tensor(xs):
52 | xs = torch.stack(xs, dim)
53 | log_sum_exp = gamma*logsumexp(xs * (1/gamma), dim=dim)
54 | # log_sum_exp = gamma * torch.log(torch.sum(torch.exp(xs/gamma),dim=dim))
55 | if log_sum_exp.max() > 1.0:
56 | return log_sum_exp / log_sum_exp.max()
57 | else:
58 | return log_sum_exp
59 |
60 |
61 | def print_valuation(valuation, atoms, n=40):
62 | """Print the valuation tensor.
63 |
64 | Print the valuation tensor using given atoms.
65 | Args:
66 | valuation (tensor;(B*G)): A valuation tensor.
67 | atoms (list(atom)): The ground atoms.
68 | """
69 | for b in range(valuation.size(0)):
70 | print('===== BATCH: ', b, '=====')
71 | v = valuation[b].detach().cpu().numpy()
72 | idxs = np.argsort(-v)
73 | for i in idxs:
74 | if v[i] > 0.1:
75 | print(i, atoms[i], ': ', round(v[i], 3))
76 |
--------------------------------------------------------------------------------
/neumann/neumann/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import networkx as nx
3 | import numpy as np
4 | from matplotlib.backends.backend_pdf import PdfPages
5 | from sklearn.manifold import TSNE
6 |
7 |
8 | def plot_atoms(x_atom, atoms, path):
9 | x_atom = x_atom.detach().cpu().numpy()
10 | labels = [str(atom) for atom in atoms]
11 | X_reduced = TSNE(n_components=2, random_state=0).fit_transform(x_atom)
12 | fig, ax = plt.subplots(figsize=(30,30))
13 | ax.scatter(X_reduced[:, 0], X_reduced[:, 1])
14 | for i, label in enumerate(labels):
15 | ax.annotate(label, (X_reduced[i,0], X_reduced[i,1]))
16 | plt.savefig(path)
17 |
18 | def plot_infer_embeddings(x_atom_list, atoms):
19 | for i, x_atom in enumerate(x_atom_list):
20 | plot_atoms(x_atom, atoms, 'imgs/x_atom_' + str(i) + '.png')
21 |
22 |
23 | def plot_proof_history(V_list, atoms, infer_step, dataset, mode='graph'):
24 | if dataset == 'graph':
25 | fig, ax = plt.subplots(figsize=(12, 12))
26 | else:
27 | fig, ax = plt.subplots(figsize=(6, 6))
28 | # extract first batch
29 | vs_img = np.round(np.array([vs[0] for vs in V_list]),2)
30 | print(vs_img)
31 | vmax = vs_img.max()
32 | im = ax.imshow(vs_img, cmap="Blues")
33 | ##m = ax.imshow(vs_img, cmap="plasma")
34 | ax.set_xticks(np.arange(len(atoms)))
35 | ax.set_yticks(np.arange(infer_step+1))
36 | plt.yticks(fontname = "monospace", fontsize=12)
37 | plt.xticks(fontname = "monospace", fontsize=11)
38 | ax.set_xticklabels([str(x) for x in atoms])
39 | ax.set_yticklabels(["v_{}".format(i) for i in range(infer_step+1)])
40 | plt.setp(ax.get_xticklabels(), rotation=30, ha="right", rotation_mode="anchor")
41 | # Loop over data dimensions and create text annotations.
42 | plt.rcParams.update({'font.size': 10})
43 | for i in range(infer_step+1):
44 | for j in range(len(atoms)):
45 | if vs_img[i, j] > 0.1:
46 | #print(vs_img[i,j], vs_img[i, j] / vmax )
47 | if vs_img[i, j] / vmax < 0.4:
48 | text = ax.text(j, i, str(vs_img[i, j]).replace('0.', '.'), ha="center", va="center", color="gray")
49 | else:
50 | text = ax.text(j, i, str(vs_img[i, j]).replace('0.', '.'), ha="center", va="center", color="w")
51 | if mode == 'graph':
52 | ax.set_title("Proof history on {} dataset (NEUMANN)".format(dataset), fontsize=18)
53 | elif mode == 'tensor':
54 | ax.set_title("Proof history on {} dataset (Tensor-based Reasoner)".format(dataset), fontsize=18)
55 | fig.tight_layout()
56 | plt.show()
57 | folder_path = "plot"
58 | # plt.savefig(f"{folder_path}/{dataset}_infer_history.svg")
59 | plt.savefig(f"{folder_path}/{dataset}_{infer_step}_{mode}_history.svg")
60 | plt.savefig(f"{folder_path}/{dataset}_{infer_step}_{mode}_history.png")
61 |
62 | def plot_reasoning_graph(path, reasoning_graph_module):
63 | #pp = PdfPages(path)
64 |
65 | G = reasoning_graph_module.networkx_graph
66 | fig = plt.figure(1, figsize=(30, 30))
67 |
68 | first_partition_nodes = list(range(len(reasoning_graph_module.facts)))
69 | edges = G.edges()
70 | colors_rg = [G[u][v]['color'] for u, v in edges]
71 | colors = []
72 | for c in colors_rg:
73 | if c == 'r':
74 | colors.append('indianred')
75 | elif c == 'b':
76 | colors.append('royalblue')
77 | #weights = [G[u][v]['weight'] for u,v in edges]
78 |
79 | nx.draw_networkx(
80 | G,
81 | alpha=0.5,
82 | labels=reasoning_graph_module.node_labels,
83 | node_size=2, node_color='lightgray', edge_color=colors, font_size=10,
84 | pos=nx.drawing.layout.bipartite_layout(G, first_partition_nodes)) # Or whatever other display options you like
85 | plt.tight_layout()
86 |
87 | plt.show()
88 | plt.savefig(path)
89 | plt.close()
90 | #pp.savefig(fig)
91 | #pp.close()
--------------------------------------------------------------------------------
/neumann/requirements.txt:
--------------------------------------------------------------------------------
1 | ### NSFR
2 | torch==1.10.0
3 | lark-parser
4 | sklearn
5 | torchsummary
6 | setuptools
7 | tensorboard
8 | numpy>=1.18.5
9 | tqdm>=4.41.0
10 | matplotlib>=3.2.2
11 | opencv-python>=4.1.2
12 | Pillow
13 | PyYAML>=5.3.1
14 | scipy>=1.4.1
15 | seaborn
16 | pandas
17 | rtpt
18 |
--------------------------------------------------------------------------------
/neumann/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from setuptools import find_packages, setup
4 |
5 | # import ipdb; ipdb.set_trace()
6 |
7 |
8 | def package_files(directory):
9 | paths = []
10 | for path, directories, filenames in os.walk(directory):
11 | for filename in filenames:
12 | paths.append(os.path.join("..", path, filename))
13 | return paths
14 |
15 |
16 | # extra_files = package_files('nsfr/data') + package_files('nsfr/lark')
17 | extra_files = package_files("neumann/src/lark")
18 | # import ipdb; ipdb.set_trace()
19 |
20 | setup(
21 | name="neumann",
22 | version="0.1.0",
23 | author="anonymous",
24 | author_email="anon.gouv",
25 | packages=find_packages(),
26 | package_data={"": extra_files},
27 | include_package_data=True,
28 | # package_dir={'':'src'},
29 | url="tba",
30 | description="GNN-based Differentiable Forward Reasoner",
31 | long_description=open("README.md").read(),
32 | install_requires=[
33 | "matplotlib",
34 | "numpy",
35 | "seaborn",
36 | "setuptools",
37 | # "torch",
38 | "tqdm",
39 | "lark",
40 | ],
41 | )
42 |
--------------------------------------------------------------------------------
/prompt/gen_constants.txt:
--------------------------------------------------------------------------------
1 | Given a deictic representation, generate predicates in the format "const1,const2,...,constn".
2 | Use only small characters,
3 |
4 | An object that is next to the keyboard.
5 | keyboard
6 |
7 | An object that is on the desk.
8 | desk
9 |
10 | An object that has the papers.
11 | papers
12 |
13 | An object which is on the desk and next to the phone.
14 | desk,phone
--------------------------------------------------------------------------------
/prompt/gen_predicates.txt:
--------------------------------------------------------------------------------
1 | Given a deictic representation, generate predicates in the format "pred1,pred2,...,predn".
2 | Use comma "," to separate predicates.
3 | Answer in one line.
4 |
5 | The available predicates are: on, has, wearing, of, in, near, behind, with, holding, above, sitting_on, wears, under, riding, in_front_of, standing_on, at, carrying, attached_to, walking_on, over, for, looking_at, watching, hanging_from, laying_on, eating, and, belonging_to, parked_on, playing_on, using, covering, between, along, covered_in, part_of, lying_on, on_back_of, to, walking_in, mounted_on, across, against, from, growing_on, painted_on, playing, made_of, says, flying_in
6 |
7 |
8 | Examples:
9 |
10 | an object that is next to keyboard.
11 | next_to
12 |
13 | an object that is on a desk.
14 | on
15 |
16 | an object that has a papers.
17 | has
18 |
19 | an object which is on a desk and next to a phone.
20 | on,next_to
21 |
22 | an object that is behind a couch and on floor.
23 | behind,on
24 |
25 | an object that is near a desk and against wall.
26 | near,against
27 |
28 | an object that is parked on a street and has back.
29 | parked_on,has
30 |
31 | an object that has sides, that is on a pole, and that is above a stop sign.
32 | has,on,above
--------------------------------------------------------------------------------
/prompt/gen_rules.txt:
--------------------------------------------------------------------------------
1 | Given a deictic representation and available predicates, generate rules in the format.
2 | The rule's format is
3 | target(X):-cond1(X),...condn(X).
4 | cond1(X):-pred1(X,Y),type(Y,const1).
5 | ...
6 | condn(X):-predn(X,Y),type(Y,const2).
7 | Use predicates and constants that appear in the given sentence.
8 | Capitalize variables: X, Y, Z, W, etc.
9 |
10 | Examples:
11 |
12 | an object that is next to a keyboard.
13 | available predicates: next_to
14 | cond1(X):-next_to(X,Y),type(Y,keyboard).
15 | target(X):-cond1(X).
16 |
17 | an object that is on a desk.
18 | available predicates: on
19 | cond1(X):-on(X,Y),type(Y,desk).
20 | target(X):-cond1(X).
21 |
22 |
23 | an object that has papers.
24 | available predicates: has
25 | cond1(X):-has(X,Y),type(Y,papers).
26 | target(X):-cond1(X).
27 |
28 | an object that is on a white pillow.
29 | available predicates: on
30 | cond1(X):-on(X,Y),type(Y,white_pillow).
31 | target(X):-cond1(X).
32 |
33 | an object that has a fur.
34 | available predicate: has
35 | cond1(X):-has(X,Y),type(Y,fur).
36 | target(X):-cond1(X).
37 |
38 | an object that is on a ground, and that is behind a white line.
39 | available predicates: on,behind
40 | cond1(X):-on(X,Y),type(Y,ground).
41 | cond2(X):-behind(X,Y),type(Y,whiteline).
42 | target(X):-cond1(X),cond2(X)
43 |
44 | an object that is near a desk and against wall.
45 | available predicates: near,against
46 | cond1(X):-near(X,Y),type(Y,desk).
47 | cond2(X):-against(X,Y),type(Y,wall).
48 | target(X):-cond1(X),cond2(X).
49 |
50 | an object that has sides, that is on a pole, and that is above a stop sign.
51 | available predicates: has,on,above
52 | cond1(X):-has(X,Y),type(Y,sides).
53 | cond2(X):-on(X,Y),type(Y,pole).
54 | cond3(X):-above(X,Y),type(Y,stopsign).
55 | target(X):-cond1(X),cond2(X),cond3(X).
56 |
57 | an object that is wearing a shirt, that has a hair, and that is wearing shoes.
58 | available predicates: wearing,has,wearing
59 | cond1(X):-wearing(X,Y),type(Y,shirt).
60 | cond2(X):-has(X,Y),type(Y,hair).
61 | cond3(X):-wearing(X,Y),type(Y,shoes).
62 | target(X):-cond1(X),cond2(X),cond3(X).
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/src/__init__.py
--------------------------------------------------------------------------------
/src/data_vg.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import torch
4 | import random
5 |
6 | import visual_genome.local as vg
7 | from groundingdino.util.inference import load_image
8 | from neumann.fol.language import DataType, Language
9 | from neumann.fol.logic import Atom, Const, NeuralPredicate, Predicate
10 |
11 |
12 | def to_object_type_atoms(atoms, lang):
13 | """Convert VG atoms to an object-type format."""
14 | unique_constants = set(term for atom in atoms for term in atom.terms)
15 | new_lang = Language(preds=lang.preds.copy(), funcs=[], consts=[])
16 | new_const_dic = {
17 | vg_const: Const(f"obj_{vg_const.name.split('_')[-1]}", dtype=DataType("object"))
18 | for vg_const in unique_constants
19 | }
20 |
21 | new_lang.consts.extend(new_const_dic.values())
22 | p_type = NeuralPredicate("type", 2, [DataType("object"), DataType("type")])
23 |
24 | new_atoms = [
25 | Atom(atom.pred, [new_const_dic[term] for term in atom.terms])
26 | for atom in atoms
27 | ]
28 |
29 | for type_const, obj_const in new_const_dic.items():
30 | base_type_const = Const(type_const.name.split("_")[0], dtype=DataType("type"))
31 | if base_type_const not in new_lang.consts:
32 | new_lang.consts.append(base_type_const)
33 | new_atoms.append(Atom(p_type, [obj_const, base_type_const]))
34 |
35 | new_lang = Language(
36 | preds=list(set(new_lang.preds)), funcs=[], consts=list(set(new_lang.consts))
37 | )
38 | return new_atoms, new_lang
39 |
40 |
41 | class BaseVisualGenomeDataset(torch.utils.data.Dataset):
42 | """Base class for Visual Genome datasets."""
43 |
44 | def __init__(self, json_path):
45 | self.json_data = self._load_json(json_path)
46 |
47 | def _load_json(self, path):
48 | with open(path, "r") as file:
49 | return json.load(file)
50 |
51 | def _parse_data(self, item):
52 | data = self.json_data["queries"][item]
53 | return data[0] + ".", data[1], data[2], data[3], data[4]
54 |
55 | def __len__(self):
56 | return len(self.json_data["queries"])
57 |
58 |
59 | class DeicticVisualGenome(BaseVisualGenomeDataset):
60 | """Deictic Visual Genome dataset."""
61 |
62 | def __getitem__(self, item):
63 | deictic_representation, answer, image_id, vg_data_index, id = self._parse_data(item)
64 | image_source, image = load_image(f"data/visual_genome/VG_100K/{image_id}.jpg")
65 | return id, vg_data_index, image_id, image_source, image, deictic_representation, answer
66 |
67 |
68 | class DeicticVisualGenomeSGGTraining(BaseVisualGenomeDataset):
69 | """Deictic Visual Genome dataset for scene graph generator training."""
70 |
71 | def __init__(self, args, mode="train"):
72 | filename = f"deictic_vg_comp{args.complexity}_sgg_{mode}.json"
73 | json_path = f"data/deivg_learning/{filename}"
74 | super().__init__(json_path)
75 |
76 | def __getitem__(self, item):
77 | deictic_representation, answer, image_id, vg_data_index, id = self._parse_data(item)
78 | image_source, image = load_image(f"data/visual_genome/VG_100K/{image_id}.jpg")
79 | return id, vg_data_index, image_id, image_source, image, deictic_representation, answer
80 |
81 |
82 | class VisualGenomeUtils:
83 | """A utility class for the Visual Genome dataset."""
84 |
85 | def __init__(self):
86 | self.all_relationships = self._load_json("data/visual_genome/relationships_deictic.json")
87 | self.all_objects = self._load_json("data/visual_genome/objects.json")
88 |
89 | def _load_json(self, path):
90 | print(f"Loading {path} ...")
91 | with open(path, "r") as file:
92 | data = json.load(file)
93 | print("Completed.")
94 | return data
95 |
96 | def load_scene_graphs(self, num_images=20):
97 | return vg.get_scene_graphs(
98 | start_index=0,
99 | end_index=num_images,
100 | min_rels=1,
101 | data_dir="data/visual_genome/",
102 | image_data_dir="data/visual_genome/by-id/",
103 | )
104 |
105 | def load_scene_graph_by_id(self, image_id):
106 | return vg.get_scene_graph(
107 | image_id=image_id,
108 | images="data/visual_genome/",
109 | image_data_dir="data/visual_genome/by-id/",
110 | synset_file="data/visual_genome/synsets.json",
111 | )
112 |
113 | def scene_graph_to_language(self, scene_graph, text, logic_generator, num_objects=3):
114 | """Generate FOL language from a scene graph."""
115 | objects = list(set(obj.replace(" ", "") for obj in scene_graph.objects))
116 | datatype = DataType("type")
117 | constants = [Const(obj, datatype) for obj in objects]
118 |
119 | const_response = f"Constants:\ntype:{','.join(objects)}"
120 | predicates, pred_response = logic_generator.generate_predicates(text, const_response)
121 | lang = Language(consts=list(set(constants)), preds=list(set(predicates)), funcs=[])
122 | return lang, const_response, pred_response
123 |
124 | def data_index_to_atoms(self, data_index, lang):
125 | """Generate atoms from data index."""
126 | relationships = self.all_relationships[data_index]["relationships"]
127 | atoms = [self._parse_relationship(rel, lang) for rel in relationships if self._parse_relationship(rel, lang)]
128 | return to_object_type_atoms(atoms, lang)
129 |
130 | def _parse_relationship(self, rel, lang):
131 | pred_name = rel["predicate"].replace(" ", "_").lower()
132 | dtype = DataType("object")
133 | pred = Predicate(pred_name, 2, [dtype, dtype])
134 | # Either of name of names is used as a key
135 | # Add key "name" if it is names
136 | if "names" in rel["object"].keys():
137 | rel["object"]["name"] = rel["object"]["names"][0]
138 | if "names" in rel["subject"].keys():
139 | rel["subject"]["name"] = rel["subject"]["names"][0]
140 | consts = [
141 | Const(rel["subject"]["name"].replace(" ", "_").lower() + f"_{rel['subject']['object_id']}", dtype=dtype),
142 | Const(rel["object"]["name"].replace(" ", "").lower() + f"_{rel['object']['object_id']}", dtype=dtype),
143 | ]
144 | return Atom(pred, consts)
145 |
146 |
147 | class PredictedSceneGraphUtils:
148 | """Utils for predicted scene graphs."""
149 |
150 | def __init__(self, model="veto", base_path=""):
151 | self.model = model
152 | self.base_path = base_path
153 | self.scene_graphs = self._load_predicted_scene_graphs(model)
154 | self.all_relationships = self._preprocess_relationships(self.scene_graphs)
155 |
156 | def _load_predicted_scene_graphs(self, model):
157 | with open(f"{self.base_path}data/predicted_scene_graphs/{model}.pkl", "rb") as file:
158 | return pickle.load(file)
159 |
160 | def _preprocess_relationships(self, scene_graphs):
161 | return {
162 | int(graph_[0]["img_info"].split("/")[-1].split(".")[0]): graph_[0]["rel_info"]["spo"]
163 | for graph_ in scene_graphs
164 | }
165 |
166 | def _parse_relationship(self, rel):
167 | dtype = DataType("object")
168 | pred = Predicate(rel["p_str"].replace(" ", "_").lower(), 2, [dtype, dtype])
169 | consts = [
170 | Const(rel["s_str"].replace(" ", "_").lower() + f"_{rel['s_unique_id']}", dtype=dtype),
171 | Const(rel["o_str"].replace(" ", "_").lower() + f"_{rel['o_unique_id']}", dtype=dtype),
172 | ]
173 | return Atom(pred, consts), rel["score"]
174 |
175 | def load_scene_graph_by_id(self, image_id):
176 | return self.all_relationships.get(image_id, [])
177 |
178 | def data_index_to_atoms(self, data_index, lang):
179 | return self._generate_atoms(lang, self.scene_graphs[data_index][0]["rel_info"]["spo"])
180 |
181 | def image_id_to_atoms(self, image_id, lang):
182 | return self._generate_atoms(lang, self.all_relationships.get(image_id, []))
183 |
184 | def _generate_atoms(self, lang, relationships):
185 | atoms = [
186 | self._parse_relationship(rel)[0]
187 | for rel in relationships
188 | if self._parse_relationship(rel)[1] > 0.98
189 | ]
190 | return to_object_type_atoms(atoms, lang)
--------------------------------------------------------------------------------
/src/deisam_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import cv2
5 | import torch
6 | from PIL import Image
7 | from segment_anything import SamPredictor, build_sam
8 | from torchvision.ops import masks_to_boxes
9 | from huggingface_hub import hf_hub_download
10 |
11 | # Visualization and utility imports
12 | from groundingdino.util.inference import annotate
13 | from visualization_utils import (
14 | answer_to_boxes,
15 | save_box_to_file,
16 | save_segmented_images,
17 | save_segmented_images_with_target_scores,
18 | show_mask,
19 | show_mask_with_alpha,
20 | )
21 |
22 | # Neumann library imports
23 | from neumann.facts_converter import FactsConverter
24 | from neumann.fol.data_utils import DataUtils
25 | from neumann.fol.language import DataType
26 | from neumann.fol.logic import Atom, Const, Predicate
27 | from neumann.logic_utils import generate_atoms
28 | from neumann.neumann_utils import get_neumann_model, get_trainable_neumann_model
29 |
30 |
31 | def load_neumann(lang, rules, atoms, device, infer_step=2):
32 | """Load a Neumann reasoner model with specified language, rules, and atoms."""
33 | atoms = add_target_cond_atoms(lang, atoms)
34 | fc = FactsConverter(lang, atoms, [], device)
35 | reasoner = get_neumann_model(
36 | rules, [], [], atoms, lang.consts, lang, 1, infer_step, device
37 | )
38 | return fc, reasoner
39 |
40 |
41 | def load_neumann_for_sgg_training(lang, rules_to_learn, rules_bk, atoms, device, infer_step=4):
42 | """Load a trainable Neumann model for scene graph generation (SGG) training."""
43 | atoms = add_target_cond_atoms_for_sgg_training(lang, atoms)
44 | fc = FactsConverter(lang, atoms, [], device)
45 | reasoner = get_trainable_neumann_model(
46 | rules_to_learn, rules_bk, [], atoms, lang.consts, lang, 1, infer_step, device
47 | )
48 | return fc, reasoner
49 |
50 |
51 | def load_sam_model(device):
52 | """Load SAM model from a checkpoint file."""
53 | sam_checkpoint = "sam_vit_h_4b8939.pth"
54 | sam = build_sam(checkpoint=sam_checkpoint)
55 | sam.to(device=device)
56 | return SamPredictor(sam)
57 |
58 |
59 | def crop_objects(img, masks):
60 | """Crop objects from the image using provided masks."""
61 | cropped_objects = []
62 |
63 | for mask in masks:
64 | x, y, width, height = mask["bbox"]
65 |
66 | if width * height > 2000:
67 | cropped_image = img[int(y):int(y + height), int(x):int(x + width), :]
68 | cropped_objects.append(cropped_image)
69 |
70 | return cropped_objects
71 |
72 |
73 | def load_image(path):
74 | """Load an image from the specified file path and convert it to RGB."""
75 | return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
76 |
77 |
78 | def extract_target_ids(v_T, threshold):
79 | """Extract target IDs from a tensor of valuations using a threshold."""
80 | target_ids = []
81 | obj_id = 0
82 |
83 | for atom in self.reasoner.atoms:
84 | if atom.pred.name == "target":
85 | index = self.reasoner.atoms.index(atom)
86 | if v_T[index] > threshold:
87 | target_ids.append(obj_id)
88 | obj_id += 1
89 |
90 | return target_ids
91 |
92 |
93 | def load_model_from_hf(repo_id, filename, config_filename, device="cpu"):
94 | """Load a model from Hugging Face Hub."""
95 | config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
96 | from groundingdino.models import build_model
97 | from groundingdino.util.slconfig import SLConfig
98 | from groundingdino.util.utils import clean_state_dict
99 |
100 | args = SLConfig.fromfile(config_path)
101 | model = build_model(args)
102 | args.device = device
103 |
104 | model_file = hf_hub_download(repo_id=repo_id, filename=filename)
105 | checkpoint = torch.load(model_file, map_location="cpu")
106 | model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
107 | print(f"Model loaded from {model_file}")
108 |
109 | return model.eval()
110 |
111 |
112 | def segment_by_grounded_sam(self, image, image_source, text_prompt, box_threshold=0.35, text_threshold=0.25):
113 | """Segment an image using a grounded SAM model."""
114 | boxes, logits, phrases = predict(
115 | model=self.model,
116 | image=image,
117 | caption=text_prompt,
118 | box_threshold=box_threshold,
119 | text_threshold=text_threshold,
120 | )
121 | annotated_frame = annotate(
122 | image_source=image_source, boxes=boxes, logits=logits, phrases=phrases
123 | )[..., ::-1] # Convert BGR to RGB
124 | return boxes, logits, phrases, annotated_frame
125 |
126 |
127 | def save_box_results(args, masks, answer, file_id, counter):
128 | """Save bounding boxes of predicted segmentations with ground truth."""
129 | pr_boxes = masks_to_boxes(masks.squeeze(1).to(torch.int32))
130 | gt_boxes = answer_to_boxes(answer)
131 | save_box_to_file(pr_boxes, gt_boxes, file_id, counter, args)
132 | return pr_boxes, gt_boxes
133 |
134 |
135 | def save_segmentation_result(args, masks, answer, image_source, counter, vg_image_id, data_index, deictic_representation, is_failed):
136 | """Plot and save segmentation masks on original visual inputs."""
137 | annotated_frame = image_source
138 |
139 | for mask in masks:
140 | annotated_frame = show_mask(mask[0], annotated_frame)
141 |
142 | base_path = f"plot/{args.dataset}_comp{args.complexity}/DeiSAM{args.sgg_model}/"
143 | os.makedirs(base_path, exist_ok=True)
144 |
145 | save_segmented_images(
146 | counter=counter,
147 | vg_image_id=vg_image_id,
148 | image_with_mask=annotated_frame,
149 | description=deictic_representation,
150 | base_path=base_path,
151 | )
152 |
153 | if is_failed:
154 | failed_path = f"plot/{args.dataset}_comp{args.complexity}/DeiSAM{args.sgg_model}_failed/"
155 | os.makedirs(failed_path, exist_ok=True)
156 | save_segmented_images(
157 | counter=counter,
158 | vg_image_id=vg_image_id,
159 | annotated_frame_with_mask=annotated_frame,
160 | description=deictic_representation,
161 | base_path=failed_path,
162 | )
163 |
164 |
165 | def save_segmentation_result_with_alphas(args, masks, mask_probs, answer, image_source, counter, vg_image_id, data_index, deictic_representation, iter):
166 | """Plot and save segmentation masks with alpha values on visual inputs."""
167 | annotated_frame = image_source
168 |
169 | for i, mask in enumerate(masks):
170 | annotated_frame = show_mask_with_alpha(mask=mask[0], image=annotated_frame, alpha=mask_probs[i])
171 |
172 | base_path = f"learning_plot/{args.dataset}_comp{args.complexity}/seed_{args.seed}/iter_{iter}/DeiSAM{args.sgg_model}/"
173 | os.makedirs(base_path, exist_ok=True)
174 |
175 | save_segmented_images_with_target_scores(
176 | counter=counter,
177 | vg_image_id=vg_image_id,
178 | image_with_mask=annotated_frame,
179 | deictic_representation=deictic_representation,
180 | mask_probs=mask_probs,
181 | base_path=base_path,
182 | )
183 |
184 |
185 | def save_llm_response(args, pred_response, rule_response, counter, image_id, deictic_representation):
186 | """Save LLM responses (pred & rule) to files."""
187 | base_path = f"llm_output/{args.dataset}_comp{args.complexity}/DeiSAM{args.sgg_model}/"
188 | os.makedirs(f"{base_path}pred_response/", exist_ok=True)
189 | os.makedirs(f"{base_path}rule_response/", exist_ok=True)
190 |
191 | pred_file = f"{counter}_vg{image_id}_{deictic_representation.replace('.', '').replace('/', ' ')}.txt"
192 | rule_file = f"{counter}_vg{image_id}_{deictic_representation.replace('.', '').replace('/', ' ')}.txt"
193 |
194 | with open(f"{base_path}pred_response/{pred_file}", "w") as f:
195 | f.write(pred_response)
196 |
197 | with open(f"{base_path}rule_response/{rule_file}", "w") as f:
198 | f.write(rule_response)
199 |
200 |
201 | def get_random_masks(model):
202 | """Get random masks from the Neumann model's target atoms."""
203 | targets = [atom for atom in model.neumann.atoms if "target(obj_" in str(atom)]
204 | return [random.choice(targets)]
205 |
206 |
207 | def add_target_cond_atoms(lang, atoms):
208 | """Add target and condition atoms to the atoms list."""
209 | spec_predicate = Predicate(".", 1, [DataType("spec")])
210 | true_atom = Atom(spec_predicate, [Const("__T__", dtype=DataType("spec"))])
211 |
212 | target_atoms = generate_target_atoms(lang)
213 | cond_atoms = generate_cond_atoms(lang)
214 | return [true_atom] + sorted(set(atoms)) + target_atoms + cond_atoms
215 |
216 |
217 | def generate_target_atoms(lang):
218 | """Generate target atoms for each object constant in the language."""
219 | target_predicate = Predicate("target", 1, [DataType("object")])
220 | return [
221 | Atom(target_predicate, [const])
222 | for const in lang.consts if const.dtype.name == "object" and "obj_" in const.name
223 | ]
224 |
225 |
226 |
227 |
228 | def generate_cond_atoms(lang):
229 | """Generate condition atoms for each object constant in the language."""
230 | cond_predicates = [Predicate(f"cond{i}", 1, [DataType("object")]) for i in range(1, 4)]
231 | cond_atoms = []
232 | for pred in cond_predicates:
233 | cond_atoms.extend(
234 | Atom(pred, [const])
235 | for const in lang.consts if const.dtype.name == "object" and "obj_" in const.name
236 | )
237 | return sorted(set(cond_atoms))
238 |
239 |
240 | def add_target_cond_atoms_for_sgg_training(lang, atoms, num_sgg_models=2):
241 | """Add target and condition atoms for SGG training to the list of atoms."""
242 | spec_predicate = Predicate(".", 1, [DataType("spec")])
243 | true_atom = Atom(spec_predicate, [Const("__T__", dtype=DataType("spec"))])
244 |
245 | target_atoms = generate_target_atoms_for_sgg_training(lang, num_sgg_models)
246 | cond_atoms = generate_cond_atoms_for_sgg_training(lang, num_sgg_models)
247 | return [true_atom] + sorted(set(atoms)) + target_atoms + cond_atoms
248 |
249 |
250 | def generate_target_atoms_for_sgg_training(lang, num_sgg_models):
251 | """Generate target atoms for main and SGG models."""
252 | main_target_pred = Predicate("target", 1, [DataType("object")])
253 | sgg_target_preds = [
254 | Predicate(f"target_sgg{i}", 1, [DataType("object")]) for i in range(num_sgg_models)
255 | ]
256 |
257 | all_target_preds = [main_target_pred] + sgg_target_preds
258 | target_atoms = [
259 | Atom(pred, [const])
260 | for pred in all_target_preds
261 | for const in lang.consts if const.dtype.name == "object" and "obj_" in const.name
262 | ]
263 |
264 | return sorted(set(target_atoms))
265 |
266 |
267 | def generate_cond_atoms_for_sgg_training(lang, num_sgg_models):
268 | """Generate conditional atoms for SGG training."""
269 | cond_atoms = []
270 |
271 | for i in range(num_sgg_models):
272 | cond_preds = [Predicate(f"cond{j}_sgg{i}", 1, [DataType("object")]) for j in range(1, 4)]
273 |
274 | for cond_pred in cond_preds:
275 | cond_atoms.extend(
276 | Atom(cond_pred, [const])
277 | for const in lang.consts if const.dtype.name == "object" and "obj_" in const.name
278 | )
279 |
280 | return sorted(set(cond_atoms))
--------------------------------------------------------------------------------
/src/lark/exp.lark:
--------------------------------------------------------------------------------
1 | clause : atom ":-" body
2 |
3 | body : atom "," body
4 | | atom "."
5 | | "."
6 |
7 | atom : predicate "(" args ")"
8 |
9 | args : term "," args
10 | | term
11 |
12 | term : functor "(" args ")"
13 | | const
14 | | variable
15 |
16 | const : /[a-z0-9\*\_']+/
17 |
18 | variable : /[A-Z]+[A-Za-z0-9]*/
19 | | /\_/
20 |
21 | functor : /[a-z0-9]+/
22 |
23 | predicate : /[a-z0-9\_]+/
24 |
25 | var_name : /[A-Z]/
26 | small_chars : /[a-z0-9]+/
27 | chars : /[^\+\|\s\(\)']+/[/\n+/]
28 | allchars : /[^']+/[/\n+/]
29 |
--------------------------------------------------------------------------------
/src/learning_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | from torchvision.ops import masks_to_boxes
7 | import openai
8 |
9 | from data_vg import DeicticVisualGenomeSGGTraining, VisualGenomeUtils, PredictedSceneGraphUtils
10 | from deisam import TrainableDeiSAM
11 | from deisam_utils import get_random_masks, save_segmentation_result_with_alphas
12 | from learning_utils import to_bce_examples
13 | from rtpt import RTPT
14 | from visualization_utils import answer_to_boxes, save_box_to_file
15 |
16 | torch.set_num_threads(10)
17 |
18 |
19 | class DeiSAMTrainer:
20 | def __init__(self, args, device):
21 | self.args = args
22 | self.device = device
23 | self._set_random_seeds()
24 |
25 | self.data_loader = DeicticVisualGenomeSGGTraining(args, mode="train")
26 | self.val_data_loader = DeicticVisualGenomeSGGTraining(args, mode="val")
27 | self.test_data_loader = DeicticVisualGenomeSGGTraining(args, mode="test")
28 |
29 | self.vg_1 = VisualGenomeUtils()
30 | self.vg_2 = PredictedSceneGraphUtils(args.sgg_model)
31 |
32 | self.deisam = TrainableDeiSAM(api_key=args.api_key, device=device,
33 | vg_utils_list=[self.vg_1, self.vg_2],
34 | sem_uni=args.sem_uni)
35 |
36 | def _set_random_seeds(self):
37 | random.seed(self.args.seed)
38 | torch.manual_seed(self.args.seed)
39 | np.random.seed(self.args.seed)
40 |
41 | def train(self):
42 | rtpt = self._initialize_rtpt()
43 | optimizer = torch.optim.RMSprop(self.deisam.parameters(), lr=self.args.lr)
44 | bce_loss = torch.nn.BCELoss()
45 |
46 | os.makedirs(f"learning_logs/comp{self.args.complexity}", exist_ok=True)
47 |
48 | for epoch in range(self.args.epochs):
49 | for counter, (id, data_index, image_id, image_source, image,
50 | deictic_representation, answer) in enumerate(self.data_loader):
51 | if self.args.trained:
52 | self._load_trained_model()
53 | return
54 |
55 | if counter % 25 == 0:
56 | self._save_intermediate_model(counter)
57 |
58 | loss = self._process_training_step(counter, data_index, image_id,
59 | image_source, deictic_representation,
60 | answer, optimizer, bce_loss)
61 | rtpt.step(subtitle=f"Iter:{counter}")
62 |
63 | self._evaluate_on_test_data(counter)
64 |
65 | def _initialize_rtpt(self):
66 | rtpt = RTPT(name_initials="",
67 | experiment_name=f"LearnDeiSAM{self.args.complexity}",
68 | max_iterations=100)
69 | rtpt.start()
70 | return rtpt
71 |
72 | def _load_trained_model(self):
73 | save_path = f"models/comp{self.args.complexity}_iter100_seed{self.args.seed}.pth"
74 | saved_state = torch.load(save_path)
75 | trained_weights = saved_state["rule_weights"].to(self.device)
76 | self.deisam.rule_weights = torch.nn.Parameter(trained_weights).to(self.device)
77 |
78 | def _save_intermediate_model(self, counter):
79 | save_path = f"models/comp{self.args.complexity}_iter{counter}_seed{self.args.seed}.pth"
80 | torch.save(self.deisam.state_dict(), save_path)
81 | print(f"Intermediate model saved to {save_path}")
82 |
83 | def _process_training_step(self, counter, data_index, image_id, image_source,
84 | deictic_representation, answer, optimizer, bce_loss):
85 | print(f"=========== ID {counter}, IMAGE ID {image_id} ===========")
86 | print("Deictic representation:\n", deictic_representation)
87 |
88 | try:
89 | graphs = [self.vg_1.load_scene_graph_by_id(image_id),
90 | self.vg_2.load_scene_graph_by_id(image_id)]
91 | masks, target_scores, _ = self.deisam.forward(
92 | data_index, image_id, graphs, deictic_representation, image_source)
93 |
94 | if masks is None:
95 | print("No targets segmented.. skipping..")
96 | return
97 |
98 | predicted_boxes = masks_to_boxes(masks.squeeze(1).float().to(torch.int32))
99 | answer_boxes = torch.tensor(answer_to_boxes(answer), device=self.device).to(torch.int32)
100 |
101 | box_probs, box_labels = to_bce_examples(predicted_boxes, target_scores, answer_boxes, self.device)
102 | loss = bce_loss(box_probs, box_labels)
103 | loss.backward()
104 | optimizer.step()
105 | return loss.item()
106 |
107 | except (KeyError, openai.error.APIError, openai.InvalidRequestError, openai.error.ServiceUnavailableError):
108 | print(f"Skipped or error occurred for ID {counter}, IMAGE ID {image_id}")
109 |
110 | def _evaluate_on_test_data(self, counter):
111 | segment_testdata(self.args, self.deisam,
112 | self.vg_1, self.vg_2,
113 | self.test_data_loader,
114 | counter, self.device)
115 |
116 |
117 | def segment_testdata(args, deisam, vg_1, vg_2, data_loader, iter, device, n_data=400):
118 | for counter, (id, data_index, image_id, image_source, image, deictic_representation, answer) in enumerate(data_loader):
119 | if counter < args.start or counter > args.end or counter > n_data:
120 | continue
121 |
122 | print(f"===== TEST ID:{counter}, IMAGE ID:{image_id}")
123 | print("Deictic representation:\n", deictic_representation)
124 |
125 | try:
126 | graphs = [vg_1.load_scene_graph_by_id(image_id),
127 | vg_2.load_scene_graph_by_id(image_id)]
128 | masks, target_scores, _ = deisam.forward(
129 | data_index, image_id, graphs, deictic_representation, image_source)
130 |
131 | if masks is None:
132 | print(f"No targets found on image {counter}. Using random mask.")
133 | target_atoms = get_random_masks(deisam)
134 | masks = deisam.segment_objects_by_sam(image_source, target_atoms, image_id)
135 | target_scores = [torch.tensor(0.5).to(device)]
136 |
137 | predicted_boxes = masks_to_boxes(masks.squeeze(1).to(torch.int32))
138 | answer_boxes = torch.tensor(answer_to_boxes(answer), device=device).to(torch.int32)
139 |
140 | # TODO: specify the path not to mix up with the solve script results
141 | save_box_to_file(
142 | pr_boxes=predicted_boxes,
143 | pr_scores=target_scores,
144 | gt_boxes=answer_boxes,
145 | id=id,
146 | index=counter,
147 | iter=iter,
148 | args=args
149 | )
150 |
151 | target_scores_cpu = [x.detach().cpu().numpy() for x in target_scores]
152 | save_segmentation_result_with_alphas(
153 | args,
154 | masks,
155 | target_scores_cpu,
156 | answer,
157 | image_source,
158 | counter,
159 | image_id,
160 | data_index,
161 | deictic_representation,
162 | iter,
163 | )
164 | except openai.error.APIError:
165 | print("OpenAI API error.. skipping..")
166 |
167 |
168 | def main():
169 | parser = argparse.ArgumentParser()
170 |
171 | parser.add_argument("-s", "--start", type=int, default=0, help="Start point (data index) for the inference.")
172 | parser.add_argument("-e", "--end", type=int, default=400, help="End point (data index) for the inference.")
173 | parser.add_argument("-ep", "--epochs", type=int, default=1, help="Training epochs.")
174 | parser.add_argument("-sd", "--seed", type=int, default=0, help="Random seed.")
175 | parser.add_argument("--lr", type=float, default=1e-2, help="The learning rate.")
176 | parser.add_argument("-c", "--complexity", type=int, default=2, choices=[1, 2, 3], help="Complexity level.")
177 | parser.add_argument("-d", "--dataset", default="deictic_visual_genome", choices=["deictic_visual_genome", "deictic_visual_genome_short"], help="Dataset.")
178 | parser.add_argument("-m", "--model", default="DeiSAM", choices=["DeiSAM", "GroundedSAM"], help="Model to use.")
179 | parser.add_argument("-sgg", "--sgg-model", default="", choices=["", "VETO"], help="Scene Graph Generation model.")
180 | parser.add_argument("-su", "--sem-uni", action="store_true", help="Use semantic unifier.")
181 | parser.add_argument("-tr", "--trained", action="store_true", help="Use trained model.")
182 | parser.add_argument("-k", "--api-key", type=str, required=True, help="An OpenAI API key.")
183 |
184 | args = parser.parse_args()
185 |
186 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
187 | trainer = DeiSAMTrainer(args, device)
188 |
189 | print("Starting training...")
190 | trainer.train()
191 | print("Training completed.")
192 |
193 | if __name__ == "__main__":
194 | main()
--------------------------------------------------------------------------------
/src/llm_logic_generator.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | import lark
4 | import openai
5 | from lark import Lark
6 | from neumann.fol.exp_parser import ExpTree
7 | from neumann.fol.language import DataType, Language
8 | from neumann.fol.logic import Atom, Clause, Const, NeuralPredicate, Predicate
9 |
10 |
11 | flatten = lambda x: [z for y in x for z in (flatten(y) if hasattr(y, '__iter__') and not isinstance(y, str) else (y,))]
12 |
13 |
14 | class LLMLogicGenerator:
15 | """
16 | A class to generate logic languages and rules from task descriptions in natural language
17 | using GPT and prepared textual prompts.
18 | """
19 |
20 | def __init__(self, api_key):
21 | self.constants_prompt_path = "prompt/gen_constants.txt"
22 | self.predicates_prompt_path = "prompt/gen_predicates.txt"
23 | self.rules_prompt_path = "prompt/gen_rules.txt"
24 |
25 | self.constants_prompt = self._load_prompt(self.constants_prompt_path)
26 | self.predicates_prompt = self._load_prompt(self.predicates_prompt_path)
27 | self.rules_prompt = self._load_prompt(self.rules_prompt_path)
28 |
29 | # Setup parser
30 | lark_path = "src/lark/exp.lark"
31 | with open(lark_path, encoding="utf-8") as grammar:
32 | grammar_content = grammar.read()
33 | self.lp_atom = Lark(grammar_content, start="atom")
34 | self.lp_clause = Lark(grammar_content, start="clause")
35 |
36 | # Setup OpenAI API
37 | openai.api_key = api_key
38 | openai.organization = None
39 |
40 | def _load_prompt(self, file_path):
41 | with open(file_path, "r") as file:
42 | return file.read()
43 |
44 | def _parse_response(self, response, parse_function):
45 | return flatten([parse_function(line) for line in response.split("\n") if line.strip()])
46 |
47 | def _parse_constants(self, line):
48 | try:
49 | dtype_name, const_names_str = line.replace(" ", "").split(":")
50 | dtype = DataType(dtype_name)
51 | const_names = const_names_str.split(",")
52 | return [Const(name, dtype) for name in const_names]
53 | except ValueError:
54 | return []
55 |
56 | def _parse_predicates(self, line):
57 | pred_names = line.replace(" ", "").split(",")
58 | return [
59 | NeuralPredicate(name, 2, [DataType("object"), DataType("object")])
60 | for name in pred_names
61 | ]
62 |
63 | def _parse_rules(self, line, language):
64 | if ":-" in line:
65 | tree = self.lp_clause.parse(line.replace(" ", ""))
66 | return ExpTree(language).transform(tree)
67 | return None
68 |
69 | def query_gpt(self, text):
70 | """Query GPT-3.5 with a textual prompt."""
71 | try:
72 | response = openai.ChatCompletion.create(
73 | model="gpt-3.5-turbo",
74 | messages=[{"role": "user", "content": text}],
75 | )
76 | return response.choices[0]["message"]["content"].strip()
77 | except (
78 | openai.error.APIError,
79 | openai.error.APIConnectionError,
80 | openai.error.ServiceUnavailableError,
81 | openai.error.Timeout,
82 | json.JSONDecodeError,
83 | ):
84 | time.sleep(3)
85 | print("OpenAI API error encountered. Retrying...")
86 | return self.query_gpt(text)
87 |
88 | def generate_constants(self, text):
89 | prompt = f"{self.constants_prompt}\n{text}\n Constants:"
90 | response = self.query_gpt(prompt)
91 | constants = self._parse_response(response, self._parse_constants)
92 | return constants, "Constants:\n" + response
93 |
94 | def generate_predicates(self, text):
95 | prompt = f"{self.predicates_prompt}\n{text}\n"
96 | response = self.query_gpt(prompt)
97 | predicates = self._parse_response(response, self._parse_predicates)
98 |
99 | required_predicates = [
100 | ("type", 2, [DataType("object"), DataType("type")]),
101 | ("target", 1, [DataType("object")]),
102 | ("cond1", 1, [DataType("object")]),
103 | ("cond2", 1, [DataType("object")]),
104 | ("cond3", 1, [DataType("object")]),
105 | ]
106 |
107 | for name, arity, datatypes in required_predicates:
108 | if not any(p.name == name for p in predicates):
109 | if arity == 2:
110 | predicates.append(NeuralPredicate(name, arity, datatypes))
111 | else:
112 | predicates.append(Predicate(name, arity, datatypes))
113 |
114 | return predicates, response
115 |
116 | def get_preds_string(self, preds):
117 | return ",".join(p.name for p in preds if p.name not in ["target", "type", "cond1", "cond2", "cond3"])
118 |
119 | def generate_rules(self, text, language):
120 | pred_response = self.get_preds_string(language.preds)
121 | self.pred_response = pred_response
122 | prompt = f"\n\n{self.rules_prompt}\n{text}\navailable predicates: {pred_response}\n"
123 | response = self.query_gpt(prompt)
124 | self.rule_response = response
125 |
126 | rules = self._parse_response(response, lambda line: self._parse_rules(line, language))
127 | return [rule for rule in rules if rule]
128 |
129 | def generate_logic(self, text):
130 | """Generate constants, predicates, and rules using LLMs."""
131 | constants, const_response = self.generate_constants(text)
132 | if not constants:
133 | raise ValueError("Error: No constants found in the prompt.")
134 | predicates, _ = self.generate_predicates(text)
135 | if not predicates:
136 | raise ValueError("Error: No predicates found in the prompt.")
137 |
138 | language = Language(consts=constants, preds=predicates, funcs=[])
139 | rules = self.generate_rules(text, language)
140 | return language, rules
--------------------------------------------------------------------------------
/src/sam_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from groundingdino.util import box_ops
3 | from visual_genome_utils import objdata_to_box
4 |
5 |
6 | def to_xyxy(boxes):
7 | """Convert boxes from xywh format to xyxy format.
8 |
9 | Args:
10 | boxes (torch.Tensor): A tensor of boxes in xywh format.
11 |
12 | Returns:
13 | torch.Tensor: A tensor of boxes in xyxy format.
14 | """
15 | xyxy_boxes = [
16 | torch.tensor([x, y, x + w, y + h])
17 | for x, y, w, h in boxes
18 | ]
19 | return torch.stack(xyxy_boxes)
20 |
21 |
22 | def transform_boxes(boxes, image_source):
23 | """Transform normalized xywh boxes to unnormalized xyxy boxes.
24 |
25 | Args:
26 | boxes (torch.Tensor): A tensor of boxes in xywh format.
27 | image_source (numpy.ndarray): The source image to obtain dimensions.
28 |
29 | Returns:
30 | torch.Tensor or NoneAlgorithm: Transformed boxes in xyxy format, or NoneAlgorithm on failure.
31 | """
32 | try:
33 | return to_xyxy(boxes)
34 | except RuntimeError:
35 | return NoneAlgorithm
36 |
37 |
38 | def to_object_ids(target_atoms):
39 | """Extract object IDs from FOL target atoms.
40 |
41 | Args:
42 | target_atoms (list): List of target atoms.
43 |
44 | Returns:
45 | list: A list of object IDs.
46 | """
47 | return [int(atom.terms[0].name.split("_")[-1]) for atom in target_atoms]
48 |
49 |
50 | def to_boxes(target_atoms, data_index, visual_genome_utils):
51 | """Extract bounding boxes for target atoms using Visual Genome data.
52 |
53 | Args:
54 | target_atoms (list): List of target atoms.
55 | data_index (int): Index in the Visual Genome dataset.
56 | visual_genome_utils (object): Utilities for accessing Visual Genome data.
57 |
58 | Returns:
59 | list: A list of bounding boxes.
60 | """
61 | object_ids = to_object_ids(target_atoms)
62 | relations = visual_genome_utils.all_relationships[data_index]["relationships"]
63 |
64 | boxes = []
65 | for obj_id in object_ids:
66 | for rel in relations:
67 | if rel["object"]["object_id"] == obj_id:
68 | boxes.append(objdata_to_box(rel["object"]))
69 | break
70 | elif rel["subject"]["object_id"] == obj_id:
71 | boxes.append(objdata_to_box(rel["subject"]))
72 | break
73 | return boxes
74 |
75 |
76 | def to_boxes_with_sgg(target_atoms, image_id, visual_genome_utils):
77 | """Extract bounding boxes for target atoms using Scene Graph Generation data.
78 |
79 | Args:
80 | target_atoms (list): List of target atoms.
81 | image_id (int): Image ID for accessing relationships.
82 | visual_genome_utils (object): Utilities for accessing Visual Genome data.
83 |
84 | Returns:
85 | list: A list of bounding boxes.
86 | """
87 | object_ids = to_object_ids(target_atoms)
88 | relations = visual_genome_utils.all_relationships[image_id]
89 |
90 | boxes = []
91 | for obj_id in object_ids:
92 | for rel in relations:
93 | if rel["o_unique_id"] == obj_id:
94 | boxes.append(rel["obox"])
95 | break
96 | if rel["s_unique_id"] == obj_id:
97 | boxes.append(rel["sbox"])
98 | break
99 | return boxes
100 |
101 |
102 | def to_transformed_boxes(boxes, image_source, sam_predictor, device):
103 | """Apply transformations to bounding boxes using a SAM predictor.
104 |
105 | Args:
106 | boxes (list): List of bounding boxes.
107 | image_source (numpy.ndarray): The source image.
108 | sam_predictor (object): SAM predictor object with transform capabilities.
109 | device (torch.device): The device to convert the boxes to.
110 |
111 | Returns:
112 | torch.Tensor: A tensor of transformed boxes.
113 | """
114 | try:
115 | boxes_xyxy = to_xyxy(boxes)
116 | transformed_boxes = sam_predictor.transform.apply_boxes_torch(
117 | boxes_xyxy, image_source.shape[:2]
118 | ).to(device)
119 | return transformed_boxes
120 | except RuntimeError:
121 | return []
--------------------------------------------------------------------------------
/src/semantic_unifier.py:
--------------------------------------------------------------------------------
1 | import time
2 | from cgitb import text
3 | from hmac import new
4 |
5 | import openai
6 | import torch
7 |
8 | # from diffusers.scripts.convert_kakao_brain_unclip_to_diffusers import text_encoder
9 |
10 | from neumann.fol.language import Language
11 |
12 |
13 | class SemanticUnifier:
14 | def __init__(self, graph_lang, device):
15 | # self.lang = lang
16 | self.graph_lang = graph_lang
17 | self.device = device
18 | self.const_mapping = {}
19 | self.pred_mapping = {}
20 |
21 | # set up embeddings
22 | # self.consts_embeddings = self._get_consts_embeddings()
23 | # self.preds_embeddings = self._get_preds_embeddings()
24 | # self.graph_consts_embeddings = self._get_graph_consts_embeddings()
25 | # self.graph_preds_embeddings = self._get_graph_preds_embeddings()
26 |
27 | def _init_scene_graph_consts_embeddings(self):
28 | self.graph_consts_embeddings = self._get_graph_consts_embeddings()
29 |
30 | def _init_scene_graph_preds_embeddings(self):
31 | self.graph_preds_embeddings = self._get_graph_preds_embeddings()
32 |
33 | def _get_consts_embeddings(self):
34 | dic = {}
35 | for c in self.lang.consts:
36 | c_embedding = self.get_embedding(c.name.replace("_", " "))
37 | dic[c] = c_embedding
38 | return dic
39 |
40 | def _get_preds_embeddings(self):
41 | dic = {}
42 | for p in self.lang.preds:
43 | p_embedding = self.get_embedding(p.name.replace("_", " "))
44 | dic[p] = p_embedding
45 | return dic
46 |
47 | def _get_graph_consts_embeddings(self):
48 | dic = {}
49 | for c in self.graph_lang.consts:
50 | c_embedding = self.get_embedding(c.name)
51 | dic[c] = c_embedding
52 | return dic
53 |
54 | def _get_graph_preds_embeddings(self):
55 | dic = {}
56 | for p in self.graph_lang.preds:
57 | p_embedding = self.get_embedding(p.name)
58 | dic[p] = p_embedding
59 | return dic
60 |
61 | def to_language(self, graph_atoms):
62 | """Generate a FOL language given atoms that represent a scene graph.
63 |
64 | Args:
65 | graph_atoms (Atom): a set of atoms represent a scene graph.
66 |
67 | Returns:
68 | Language : a language computed from graph atoms.
69 | """
70 | preds = set()
71 | consts = set()
72 |
73 | for atom in graph_atoms:
74 | preds.add(atom.pred)
75 | for c in atom.terms:
76 | consts.add(c)
77 | lang = Language(preds=list(preds), funcs=[], consts=list(consts))
78 | return lang
79 |
80 | def get_most_similar_index(self, x, ys):
81 | pass
82 |
83 | def get_most_similar_predicate_in_graph(self, pred):
84 | # num_graph_pred = len(self.graph_lang.preds)
85 | X = self.get_embedding(pred.name).unsqueeze(0) # .expand((num_graph_pred, -1))
86 | X_graph = torch.stack(list(self.graph_preds_embeddings.values()))
87 | # score = torch.dot(X.T, X_graph)
88 | score = torch.sum(X * X_graph, axis=-1)
89 | index = torch.argmax(score).item()
90 | return self.graph_lang.preds[index]
91 |
92 | def get_most_similar_constant_in_graph(self, const):
93 | X = self.get_embedding(const.name).unsqueeze(0)
94 | X_graph = torch.stack(list(self.graph_consts_embeddings.values()))
95 | score = torch.sum(X * X_graph, axis=-1)
96 | index = torch.argmax(score).item()
97 | # print(self.graph_lang.consts, len(self.graph_lang.consts))
98 | # print(score, score.shape)
99 | # print(index)
100 | return self.graph_lang.consts[index]
101 |
102 | def build_const_mapping(self, lang, graph_lang):
103 | dic = {}
104 | for c in lang.consts:
105 | if not c in graph_lang.consts:
106 | # find the most similar graph const
107 | return 0
108 | pass
109 |
110 | def build_pred_mapping(self, lang, graph_lang):
111 | pass
112 |
113 | def rewrite_lang(self, lang, graph_lang):
114 | """Rewrite the language using only existing vocabluary in the graph.
115 |
116 | Args:
117 | lang (_type_): _description_
118 |
119 | Returns:
120 | _type_: _description_
121 | """
122 | new_lang = Language(
123 | preds=lang.preds.copy(), funcs=[], consts=lang.consts.copy()
124 | )
125 | const_mapping = {}
126 | for const in new_lang.consts:
127 | if const not in graph_lang.consts:
128 | new_const = self.get_most_similar_constant_in_graph(const)
129 | new_lang.consts.remove(const)
130 | new_lang.consts.append(new_const)
131 | const_mapping[const] = new_const
132 | predicate_mapping = {}
133 | for pred in new_lang.preds:
134 | if pred not in graph_lang.preds:
135 | new_pred = self.get_most_similar_predicate_in_graph(pred)
136 | new_lang.preds.remove(pred)
137 | new_lang.preds.append(new_pred)
138 | predicate_mapping[pred] = new_pred
139 | return new_lang, const_mapping, predicate_mapping
140 |
141 | # def rewrite_rules(self, rules, const_mapping, predicate_mapping):
142 | # new_rules = []
143 | # for rule in rules:
144 | # new_rule = rule
145 | # new_atoms = []
146 | # for atom in [rule.head] + rule.body:
147 | # # check / rewrite predicate
148 | # if atom.pred in predicate_mapping.keys():
149 | # atom.pred = predicate_mapping[atom.pred]
150 | # # check / rewrite const
151 | # for i, const in enumerate(atom.terms):
152 | # if const in const_mapping.keys():
153 | # atom.terms[i] = const_mapping[const]
154 | # new_atoms.append(atom)
155 | # new_rule.head = new_atoms[0]
156 | # new_rule.body = new_atoms[1:]
157 | # new_rules.append(new_rule)
158 | # return new_rules
159 |
160 | def rewrite_rules(self, rules, lang, graph_lang, rewrite_pred=True):
161 | reserved_preds = ["target", "type", "cond1", "cond2", "cond3"]
162 | self._init_scene_graph_consts_embeddings()
163 | self._init_scene_graph_preds_embeddings()
164 | new_rules = []
165 | # new_lang = Language(preds=lang.preds.copy(), funcs=[], consts=lang.consts.copy())
166 | for rule in rules:
167 | new_rule = rule
168 | new_atoms = []
169 | for atom in [rule.head] + rule.body:
170 | # check / rewrite predicate
171 | pred = atom.pred
172 | if (
173 | rewrite_pred
174 | and pred.name not in reserved_preds
175 | and pred not in graph_lang.preds
176 | ):
177 | # replace the non-existing predicate by the most similar one
178 | new_pred = self.get_most_similar_predicate_in_graph(pred)
179 | atom.pred = new_pred
180 | self.pred_mapping[pred.name] = new_pred.name
181 | print(pred.name, " -> ", new_pred.name)
182 | # new_lang.preds.remove(pred)
183 | # new_lang.preds.append(new_pred)
184 | # check / rewrite const
185 | for i, const in enumerate(atom.terms):
186 | if (
187 | const.__class__.__name__ == "Const"
188 | and const not in graph_lang.consts
189 | ):
190 | # replace the non-existing constant by the most similar one
191 | new_const = self.get_most_similar_constant_in_graph(const)
192 | atom.terms[i] = new_const
193 | self.const_mapping[const.name] = new_const.name
194 | print(const.name, " -> ", new_const.name)
195 | # new_lang.consts.remove(const)
196 | # new_lang.consts.append(new_const)
197 | new_atoms.append(atom)
198 | new_rule.head = new_atoms[0]
199 | new_rule.body = new_atoms[1:]
200 | new_rules.append(new_rule)
201 | return new_rules # , #new_lang
202 |
203 | # def unify(self, lang, graph_lang, rules):
204 | # rewrite lang
205 | # internally overwrite self.lang
206 | # new_lang, const_mapping, pred_mapping = self.rewrite_lang(lang, graph_lang)
207 | # generate new rules using the refined language
208 | # new_rules new_lang = self.rewrite_rules(rules, lang, graph_lang)
209 | # eturn new_lang, new_rules
210 |
211 | def get_embedding(self, text_to_embed):
212 | response = openai.Embedding.create(
213 | model="text-embedding-ada-002", input=[text_to_embed.replace("_", " ")]
214 | )
215 | # Extract the AI output embedding as a list of floats
216 | embedding = torch.tensor(response["data"][0]["embedding"]).to(self.device)
217 | return embedding
218 |
219 | # try:
220 | # # Embed a line of text
221 | # response = openai.Embedding.create(
222 | # model="text-embedding-ada-002", input=[text_to_embed.replace("_", " ")]
223 | # )
224 | # # Extract the AI output embedding as a list of floats
225 | # embedding = torch.tensor(response["data"][0]["embedding"]).to(self.device)
226 | # return embedding
227 | # # except (openai.ServiceUnavailableError, openai.InvalidRequestError):
228 | # except (openai.InvalidRequestError, openai.error.ServiceUnavailableError):
229 | # print(
230 | # "Got openai.InvalidRequestError or openai.error.ServiceUnavailableError for embeddings in Semantic Unification, waiting for 3s and try again..."
231 | # )
232 | # time.sleep(3)
233 | # return self.get_embedding(text_to_embed=text_to_embed)
234 |
--------------------------------------------------------------------------------
/src/visual_genome_utils.py:
--------------------------------------------------------------------------------
1 | import neumann
2 | import neumann.fol.logic as logic
3 | from neumann.fol.language import DataType, Language
4 | from neumann.fol.logic import Const
5 |
6 |
7 | def scene_graph_to_language(scene_graph, text, logic_generator, num_objects=2):
8 | """Extract a FOL language from a scene graph to parse rules later."""
9 |
10 | # Extract and sanitize object names from the scene graph
11 | objects = {str(obj).replace(" ", "") for obj in scene_graph.objects}
12 | datatype = DataType("type")
13 |
14 | # Define constants using the extracted object names
15 | constants = [Const(obj, datatype) for obj in objects]
16 |
17 | # Prepare constant response for predicates
18 | const_response = "Constants:\ntype:" + ",".join(objects)
19 |
20 | # Generate predicates using the logic generator based on the input text
21 | predicates, pred_response = logic_generator.generate_predicates(text)
22 | print(f"Predicate generator response:\n {pred_response}")
23 |
24 | # Formulate the language using constants and predicates
25 | lang = Language(consts=list(constants), preds=list(predicates), funcs=[])
26 | return lang
27 |
28 |
29 | def get_init_language_with_sgg(scene_graph, text, logic_generator):
30 | """Extract an initial FOL language from a predicted scene graph to parse rules later."""
31 |
32 | # Extract unique object and subject names from the scene graph
33 | objects = {rel["o_str"] for rel in scene_graph}
34 | subjects = {rel["s_str"] for rel in scene_graph}
35 | datatype = DataType("type")
36 |
37 | # Define constants using the extracted names
38 | constants = [Const(obj, datatype) for obj in objects | subjects]
39 |
40 | # Prepare constant response for predicates
41 | const_response = "Constants:\ntype:" + ",".join(objects)
42 |
43 | # Generate predicates using the logic generator based on the input text
44 | predicates, pred_response = logic_generator.generate_predicates(text, const_response)
45 | print(f"Predicate generator response:\n {pred_response}")
46 |
47 | # Formulate the language using constants and predicates
48 | lang = Language(consts=list(constants), preds=list(predicates), funcs=[])
49 | return lang
50 |
51 |
52 | def scene_graph_to_language_with_sgg(scene_graph):
53 | """Extract a complete FOL language from a scene graph for use with a semantic unifier."""
54 |
55 | # Extract unique objects, subjects, and relationships from the scene graph
56 | objects = {rel["o_str"] for rel in scene_graph}
57 | subjects = {rel["s_str"] for rel in scene_graph}
58 | relationships = {rel["p_str"] for rel in scene_graph}
59 | datatype = DataType("type")
60 | obj_datatype = DataType("object")
61 |
62 | # Define constants using the extracted names
63 | constants = [Const(obj, datatype) for obj in objects | subjects]
64 |
65 | # Define predicates for each relationship
66 | predicates = [
67 | logic.Predicate(rel.replace(" ", "_").lower(), 2, [obj_datatype, obj_datatype])
68 | for rel in relationships
69 | ]
70 |
71 | # Formulate the language using constants and predicates
72 | lang = Language(consts=list(constants), preds=list(predicates), funcs=[])
73 | return lang
74 |
75 |
76 | def objdata_to_box(data):
77 | """Convert object data to a bounding box format."""
78 | x, y, w, h = data["x"], data["y"], data["w"], data["h"]
79 | return x, y, w, h
80 |
--------------------------------------------------------------------------------
/src/visualization_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 | import torch
5 | from sam_utils import to_object_ids
6 |
7 |
8 | def apply_random_color():
9 | return np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
10 |
11 |
12 | def apply_default_color(alpha=0.75):
13 | return np.array([255 / 255, 10 / 255, 10 / 255, alpha])
14 |
15 |
16 | def get_colored_mask(mask, color):
17 | h, w = mask.shape[-2:]
18 | mask = mask.cpu() if mask.device.type != "cpu" else mask
19 | return mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
20 |
21 |
22 | def overlay_mask_on_image(mask, image, color):
23 | annotated_frame_pil = Image.fromarray(image).convert("RGBA")
24 | mask_image_pil = Image.fromarray(
25 | (mask.cpu().numpy() * 255).astype(np.uint8)
26 | ).convert("RGBA")
27 | return Image.alpha_composite(annotated_frame_pil, mask_image_pil)
28 |
29 |
30 | def show_mask(mask, image, random_color=False):
31 | color = apply_random_color() if random_color else apply_default_color()
32 | mask_image = get_colored_mask(mask, color)
33 | return np.array(overlay_mask_on_image(mask_image, image, color))
34 |
35 |
36 | def show_mask_with_alpha(mask, image, alpha, random_color=False):
37 | color = apply_random_color() if random_color else apply_default_color(alpha * 0.75)
38 | mask_image = get_colored_mask(mask, color)
39 | return np.array(overlay_mask_on_image(mask_image, image, color))
40 |
41 |
42 | def get_bbox_by_id(object_id, data_index, vg):
43 | objects = vg.all_objects[data_index]["objects"]
44 | target_object = next((o for o in objects if int(o["object_id"]) == object_id), None)
45 |
46 | if not target_object:
47 | raise ValueError(f"Object with ID {object_id} not found.")
48 |
49 | return target_object["x"], target_object["y"], target_object["w"], target_object["h"]
50 |
51 |
52 | def to_crops(image_source, boxes):
53 | return [image_source[x:x + w, y:y + h] for x, y, w, h in boxes]
54 |
55 |
56 | def _to_boxes(target_atoms, data_index, vg):
57 | return vg.target_atoms_to_regions(target_atoms, data_index)
58 |
59 |
60 | def objdata_to_box(data):
61 | return data["x"], data["y"], data["w"], data["h"]
62 |
63 |
64 | def to_boxes(target_atoms, data_index, vg):
65 | object_ids = to_object_ids(target_atoms)
66 | relations = vg.all_relationships[data_index]["relationships"]
67 |
68 | return [
69 | objdata_to_box(rel["object"] if rel["object"]["object_id"] == id else rel["subject"])
70 | for id in object_ids
71 | for rel in relations
72 | if rel["object"]["object_id"] == id or rel["subject"]["object_id"] == id
73 | ]
74 |
75 | # def to_boxes(target_atoms, data_index, vg):
76 | # # get box from relations!! not objects
77 | # object_ids = to_object_ids(target_atoms)
78 | # relations = vg.all_relationships[data_index]["relationships"]
79 | # boxes = []
80 | # for id in object_ids:
81 | # for rel in relations:
82 | # if rel["object"]["object_id"] == id:
83 | # boxes.append(objdata_to_box(rel["object"]))
84 | # break
85 | # elif rel["subject"]["object_id"] == id:
86 | # boxes.append(objdata_to_box(rel["subject"]))
87 | # break
88 | # return boxes
89 |
90 | def to_xyxy(boxes):
91 | xyxy_boxes = [torch.tensor([x, y, x + w, y + h]) for x, y, w, h in boxes]
92 | return torch.stack(xyxy_boxes)
93 |
94 |
95 | def save_boxes_to_file(boxes, path, is_prediction=True):
96 | text = "\n".join(f"target 1.0 {box[0]} {box[1]} {box[2]} {box[3]}" for box in boxes)
97 | with open(path, "w") as f:
98 | f.write(text)
99 | print(f"File saved to {path}")
100 |
101 |
102 | def save_box_to_file(pr_boxes, gt_boxes, id, counter, args):
103 | dirs = [
104 | f"result/{args.dataset}_comp{args.complexity}/{args.model}/prediction",
105 | f"result/{args.dataset}_comp{args.complexity}/{args.model}/ground_truth",
106 | ]
107 |
108 | for dir_path in dirs:
109 | os.makedirs(dir_path, exist_ok=True)
110 |
111 | pr_path = f"{dirs[0]}/{counter}_vg{id}.txt"
112 | gt_path = f"{dirs[1]}/{counter}_vg{id}.txt"
113 |
114 | save_boxes_to_file(pr_boxes, pr_path)
115 | save_boxes_to_file(gt_boxes, gt_path, is_prediction=False)
116 |
117 |
118 | def answer_to_boxes(answers):
119 | if not isinstance(answers, list):
120 | answers = [answers]
121 |
122 | return [[answer["x"], answer["y"], answer["x"] + answer["w"], answer["y"] + answer["h"]] for answer in answers]
123 |
124 |
125 | def save_segmented_images(counter, vg_image_id, image_with_mask, description, base_path="imgs/"):
126 | image_with_mask = Image.fromarray(image_with_mask).convert("RGB")
127 | description = description.replace("/", " ").replace(".", "")
128 | save_path = f"{base_path}deicticVG_ID:{counter}_VGImID:{vg_image_id}_{description}.png"
129 | image_with_mask.save(save_path)
130 | print(f"Image saved to {save_path}")
131 |
132 |
133 | def save_segmented_images_with_target_scores(counter, vg_image_id, image_with_mask, description, target_scores, base_path="imgs/"):
134 | if len(target_scores) > 3:
135 | target_scores = target_scores[:3]
136 | scores_str = str(np.round(target_scores, 2).tolist())
137 | save_segmented_images(counter, vg_image_id, image_with_mask, f"{description}_scores_{scores_str}", base_path)
--------------------------------------------------------------------------------