├── .devcontainer └── Dockerfile ├── .gitignore ├── LICENSE ├── README.md ├── imgs ├── deisam_logo.png ├── deisam_logo_eye.png └── deisam_task.png ├── neumann ├── LICENSE ├── README.md ├── neumann │ ├── __init__.py │ ├── clause_generator.py │ ├── data_behind_the_scenes.py │ ├── data_clevr.py │ ├── data_kandinsky.py │ ├── data_logic.py │ ├── data_vilp.py │ ├── explain_clevr.py │ ├── explanation_utils.py │ ├── facts_converter.py │ ├── fol │ │ ├── README.md │ │ ├── __init__.py │ │ ├── data_utils.py │ │ ├── exp.lark │ │ ├── exp_parser.py │ │ ├── language.py │ │ ├── logic.py │ │ └── logic_ops.py │ ├── img2facts.py │ ├── lark │ │ └── exp.lark │ ├── logic_utils.py │ ├── message_passing.py │ ├── mode_declaration.py │ ├── neumann.py │ ├── neumann_utils.py │ ├── neural_utils.py │ ├── percept.py │ ├── predict.py │ ├── reasoning_graph.py │ ├── refinement.py │ ├── scatter.py │ ├── soft_logic.py │ ├── solve_behind_the_scenes.py │ ├── solve_kandinsky.py │ ├── torch_utils.py │ ├── train_neumann.py │ ├── valuation.py │ ├── valuation_func.py │ └── visualize.py ├── requirements.txt └── setup.py ├── prompt ├── gen_constants.txt ├── gen_predicates.txt └── gen_rules.txt └── src ├── __init__.py ├── data_vg.py ├── deisam.py ├── deisam_utils.py ├── lark └── exp.lark ├── learning_demo.py ├── learning_utils.py ├── llm_logic_generator.py ├── sam_utils.py ├── semantic_unifier.py ├── solve_deivg.py ├── test_sgg.py ├── visual_genome_utils.py └── visualization_utils.py /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | # Select the base image 2 | 3 | ### To run with GPUs, use the following: 4 | FROM nvcr.io/nvidia/pytorch:21.10-py3 5 | ### 6 | 7 | 8 | ### To run without GPUs, use the following 9 | # FROM ubuntu:18.04 10 | # RUN apt-get update 11 | # RUN apt-get install -y software-properties-common 12 | # RUN add-apt-repository ppa:deadsnakes/ppa 13 | # RUN apt-get install -y python3.8-dev python3-pip 14 | # RUN rm /usr/bin/python3 && ln -s /usr/bin/python3.8 /usr/bin/python3 15 | # RUN python3 --version 16 | # RUN pip3 --version 17 | # RUN update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1 18 | ### 19 | 20 | 21 | ### other setup 22 | ENV TZ=Europe/Berlin 23 | ARG DEBIAN_FRONTEND=noninteractive 24 | 25 | # Select the working directory 26 | WORKDIR /Workspace 27 | 28 | # Install system libraries required by OpenCV. 29 | RUN apt-get update \ 30 | && apt-get install -y libgl1-mesa-glx libgtk2.0-0 libsm6 libxext6 \ 31 | && rm -rf /var/lib/apt/lists/* 32 | 33 | RUN pip install --upgrade pip 34 | # Install Python requirements 35 | # RUN pip install opencv-python==4.5.5.64 36 | RUN pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html 37 | RUN pip install networkx==3.0 lark-parser joblib scikit-learn torchsummary setuptools tensorboard numpy>=1.18.5 tqdm>=4.41.0 matplotlib>=3.2.2 opencv-python>=4.1.2 Pillow PyYAML>=5.3.1 scipy>=1.4.1 seaborn pandas rtpt 38 | 39 | RUN pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-1.12.0+cu116.html 40 | RUN pip install torch-geometric 41 | RUN pip install wandb anytree 42 | 43 | # RUN python -m pip install -e Grounded-Segment-Anything/GroundingDINO 44 | #RUN cd GroundingDINO 45 | #RUN pip install -e . 46 | #RUN cd .. 47 | RUN pip install "opencv-python-headless<4.3" 48 | RUN pip install --upgrade diffusers[torch] 49 | RUN pip install openai==0.28.1 50 | RUN pip install pydantic==1.9.0 51 | RUN pip install visual_genome 52 | 53 | # fix opencv 54 | #RUN pip uninstall opencv-python 55 | #RUN pip uninstall opencv-contrib-python 56 | #RUN pip uninstall opencv-contrib-python-headless 57 | #RUN pip3 install opencv-contrib-python==4.5.5.62 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Hikaru Shindo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### *Refactoring is undegoing. 2 | 5 | 6 | # DeiSAM: Segment Anything with Deictic Prompting (NeurIPS 2024) 7 | [Hikaru Shindo](https://www.hikarushindo.com/), Manuel Brack, Gopika Sudhakaran, Devendra Singh Dhami, Patrick Schramowski, Kristian Kersting 8 | 9 | [AI/ML Lab @ TU Darmstadt](https://ml-research.github.io/index.html) 10 | 11 |

12 | 13 |

14 | We propose DeiSAM, which integrates large pre-trained neural networks with differentiable logic reasoners. Given a complex, textual segmentation description, DeiSAM leverages Large Language Models (LLMs) to generate first-order logic rules and performs differentiable forward reasoning on generated scene graphs. 15 | 16 | 17 | # Install 18 | [Dockerfile](.devcontainer/Dockerfile) is avaialbe in the [.devcontainer](.devcontainer) folder. 19 | 20 | To install further dependencies, clone [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) and then: 21 | 22 | 23 | ``` 24 | cd neumann/ 25 | pip install -e . 26 | cd ../Grounded-Segment-Anything/ 27 | cd segment_anything 28 | pip install -e . 29 | cd ../GroundingDINO 30 | pip install -e . 31 | ``` 32 | 33 | If an error appears regarding OpenCV (circular import), try: 34 | ``` 35 | pip uninstall opencv-python 36 | pip uninstall opencv-contrib-python 37 | pip uninstall opencv-contrib-python-headless 38 | pip3 install opencv-contrib-python==4.5.5.62 39 | ``` 40 | 41 | Download vit model 42 | ``` 43 | wget https://huggingface.co/spaces/abhishek/StableSAM/resolve/main/sam_vit_h_4b8939.pth 44 | ``` 45 | 46 | # Dataset 47 | **DeiVG datasets can be downloaded here 48 | [link](https://hessenbox.tu-darmstadt.de/getlink/fiJwsDNjdY9HDrUMf3btjoHG/).** Please locate downloaded files to `data/` as follows (make sure you are in the home folder of this project): 49 | ``` 50 | mkdir data/ 51 | cd data 52 | wget https://hessenbox.tu-darmstadt.de/dl/fiJwsDNjdY9HDrUMf3btjoHG/.dir -O deivg.zip 53 | unzip deivg.zip 54 | cd visual_genome 55 | unzip by-id.zip 56 | ``` 57 | 58 | 59 | Please download Visual Genome images [link](https://homes.cs.washington.edu/~ranjay/visualgenome/api.html), and locate downloaded files to `data/visual_genome/` as follows: 60 | ``` 61 | cd data/visual_genome 62 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip 63 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip 64 | unzip images.zip 65 | unzip iamges2.zip 66 | mv VG_100K_2/* VG_100K/ 67 | ``` 68 | 69 | 70 | # Experiments 71 | To solve DeiVG using DeiSAM: 72 | ``` 73 | python src/solve_deivg.py --api-key YOUR_OPENAI_API_KEY -c 1 74 | python src/solve_deivg.py --api-key YOUR_OPENAI_API_KEY -c 2 75 | python src/solve_deivg.py --api-key YOUR_OPENAI_API_KEY -c 3 76 | ``` 77 | 78 | 79 | The demonstration of learning can be performed by: 80 | ``` 81 | python src/learning_demo.py --api-key YOUR_OPENAI_API_KEY -c 1 -sgg VETO -su 82 | python src/learning_demo.py --api-key YOUR_OPENAI_API_KEY -c 2 -sgg VETO -su 83 | ``` 84 | *Note that DeiSAM is esseitially a training-free model.* Learning here is a demonstration of the learning capability by gradients. The best performance will be always achieved by using the model with ground-truth scene graphs, which corresponds to `solve_deivg.py`. 85 | In other words, DeiSAM doesn't need to be trained when the scene graphs are availale. A future plan is to mitigate the case where scene graphs are not available. 86 | 87 | 88 | # Bibtex 89 | ``` 90 | @inproceedings{shindo24deisam, 91 | author = {Hikaru Shindo and 92 | Manuel Brack and 93 | Gopika Sudhakaran and 94 | Devendra Singh Dhami and 95 | Patrick Schramowski and 96 | Kristian Kersting}, 97 | title = {DeiSAM: Segment Anything with Deictic Prompting}, 98 | booktitle = {Proceedings of the Conference on Advances in Neural Information Processing Systems (NeurIPS)}, 99 | year = {2024}, 100 | } 101 | 102 | ``` 103 | 104 | 105 | 106 | # LICENSE 107 | See [LICENSE](./LICENSE). 108 | 109 | -------------------------------------------------------------------------------- /imgs/deisam_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/imgs/deisam_logo.png -------------------------------------------------------------------------------- /imgs/deisam_logo_eye.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/imgs/deisam_logo_eye.png -------------------------------------------------------------------------------- /imgs/deisam_task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/imgs/deisam_task.png -------------------------------------------------------------------------------- /neumann/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 anonymous 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /neumann/README.md: -------------------------------------------------------------------------------- 1 | # Learning Differentiable Logic Programs for Abstract Visual Reasoning 2 | 3 | 4 | 5 |

6 | 7 |

8 | 9 | # Abstract 10 | Visual reasoning is essential for building intelligent agents that understand the world and perform problem-solving beyond perception. Differentiable forward reasoning has been developed to integrate reasoning with gradient-based machine learning paradigms. 11 | However, due to the memory intensity, most existing approaches do not bring the best of the expressivity of first-order logic, excluding a crucial ability to solve *abstract visual reasoning*, where agents need to perform reasoning by using analogies on abstract concepts in different scenarios. 12 | To overcome this problem, we propose *NEUro-symbolic Message-pAssiNg reasoNer (NEUMANN)*, which is a graph-based differentiable forward reasoner, passing messages in a memory-efficient manner and handling structured programs with functors. 13 | Moreover, we propose a computationally-efficient structure learning algorithm to perform explanatory program induction on complex visual scenes. 14 | To evaluate, in addition to conventional visual reasoning tasks, we propose a new task, *visual reasoning behind-the-scenes*, where agents need to learn abstract programs and then answer queries by imagining scenes that are not observed. 15 | We empirically demonstrate that NEUMANN solves visual reasoning tasks efficiently, outperforming neural, symbolic, and neuro-symbolic baselines. 16 | 17 | 18 | ![neumann](./imgs/behind-the-scenes.png) 19 | 20 | **NEUMANN solves Behind-the-Scenes task.** 21 | Reasoning behind the scenes: The goal of this task is to compute the answer of a query, e.g., *``What is the color of the second left-most object after deleting a gray object?''* given a visual scene. To answer this query, the agent needs to reason behind the scenes and understand abstract operations on objects. In the first task, the agent needs to induce an explicit program given visual examples, where each example consists of several visual scenes that describe the input and the output of the operation to be learned. The abstract operations can be described and computed by first-order logic with functors. 22 | In the second task, the agent needs to apply the learned programs to new situations to solve queries reasoning about non-observational scenes. 23 | 24 | ## How does it work? 25 | NEUMANN compiles *first-order logic* programs into a *graph neural network*. Logical entailment is compted using probabilistic atoms and weighted rules using fuzzy logic operations. 26 | ![neumann](./imgs/reasoning_graph.png) 27 | 28 | # Relevant Repositories 29 | [Visual ILP: A repository of the dataset generation of CLEVR images for abstract operations.](https://github.com/ml-research/visual-ilp) 30 | 31 | [Behind-the-Scenes: A repository for the generation of visual scenes and queries for the behind-the-scenes task.](https://github.com/ml-research/behind-the-scenes) 32 | 33 | # Experiments 34 | 35 | ## Prerequisites 36 | Docker container is available in folder [.devcontainer](./.devcontainer/Dockerfile), 37 | which is compatible with [packages](./pip_requirements.txt) (produced by pip freeze). 38 | The main dependent packages are: 39 | ``` 40 | pytorch 41 | torch-geometric 42 | networkx 43 | ``` 44 | We used Python 3.8 for the experiments. 45 | See [Dockerfile](.devcontainer/Dockerfile) for more details. 46 | 47 | ## Build a Docker container 48 | Simply use VSCode to open the container, or build the container manually: 49 | To run on machines without GPUs 50 | ``` 51 | cp .devcontainer/Dockerfile_nogpu ./Dockerfile 52 | docker build -t neumann . 53 | docker run -it -v :/neumann --name neumann neumann 54 | ``` 55 | For example, the local path could be: `/Users/username/Workspace/github/neumann`. The path is where this repository has been cloned. 56 | 57 | For the GPU-equipped machines, use: 58 | ``` 59 | cp .devcontainer/Dockerfile ./Dockerfile 60 | docker build -t neumann . 61 | docker run -it -v :/neumann --name neumann neumann 62 | ``` 63 | To open the container on machines without GPUs using VSCode, run 64 | ``` 65 | cp .devcontainer/Dockerfile_nogpu .devcontainer/Dockerfile 66 | ``` 67 | and use the VSCode remotehost extension (recommended). 68 | 69 | 70 | 71 | ## Perform learning 72 | For example, in the container, learning Kandinsky patterns on red triangle using the demo dataset can be performed: 73 | ``` 74 | cd /neumann 75 | python3 src/train_neumann.py --dataset-type kandinsky --dataset red-triangle --num-objects 6 --batch-size 12 --no-cuda --epochs 30 --infer-step 4 --trial 5 --n-sample 10 --program-size 1 --max-var 6 --min-body-len 6 --pos-ratio 1.0 --neg-ratio 1.0 76 | ``` 77 | An exenplary log can be found [redtrianlge_log.txt](./logs/redtriangle_log.txt). 78 | 79 | More scripts are available: 80 | 81 | [Learning kandinsky/clevr-hans patterns](./scripts/solve_kandinsky_clevr.sh) 82 | 83 | [Solving Behind-the-Scenes](./scripts/solve_behind-the-scenes.sh) 84 | 85 | # LICENSE 86 | See [LICENSE](./LICENSE). The [src/yolov5](./src/yolov5) folder is following [GPL3](./src/yolov5/LICENSE) license. -------------------------------------------------------------------------------- /neumann/neumann/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/neumann/neumann/__init__.py -------------------------------------------------------------------------------- /neumann/neumann/clause_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from anytree import Node, PreOrderIter, RenderTree 5 | from anytree.search import find_by_attr, findall 6 | 7 | from .fol.logic import Clause 8 | from .logic_utils import add_true_atom, add_true_atoms, remove_true_atoms, true 9 | 10 | 11 | class ClauseGenerator(object): 12 | """Refinement-based clause generator that holds a tree representation of the generation steps. 13 | """ 14 | def __init__(self, refinement_generator, root_clauses, th_depth, n_sample): 15 | self.refinement_generator = refinement_generator 16 | self.th_depth = th_depth 17 | self.root_clauses = add_true_atoms(root_clauses) 18 | self.tree = Node(name="root", clause=Clause(true,[])) 19 | for c in root_clauses: 20 | Node(name=str(c), clause=c, parent=self.tree) 21 | #Node(name=str(nc), clause=nc, =target_node) 22 | self.n_sample = n_sample 23 | self.is_root = True 24 | self.refinement_history = set() 25 | self.refinement_score_history = set() 26 | 27 | 28 | def generate(self, clauses, clause_scores): 29 | clauses_to_refine = self.sample_clauses_by_scores(clauses, clause_scores) 30 | print("=== CLAUSES TO BE REFINED ===") 31 | for i, cr in enumerate(clauses_to_refine): 32 | print(i, ': ', cr) 33 | 34 | self.refinement_history = self.refinement_history.union(set(clauses_to_refine)) 35 | # self.refinement_history = list(set(self.clause_history)) 36 | new_clauses = add_true_atoms(self.apply_refinement(remove_true_atoms(clauses_to_refine))) 37 | # prune already appeared clauses 38 | new_clauses = [c for c in new_clauses if not c in self.refinement_history] 39 | return list(set(new_clauses)) 40 | 41 | 42 | def sample_clauses_by_scores(self, clauses, clause_scores): 43 | clauses_to_refine = [] 44 | print("Logits for the sampling: ") 45 | print(np.round(clause_scores.cpu().numpy(), 2)) 46 | 47 | n_sampled = 0 48 | while n_sampled < self.n_sample: 49 | if len(clauses) == 0: 50 | # no more clauses to be sampled 51 | break 52 | i_sampled_onehot = F.gumbel_softmax(clause_scores, tau=1.0, hard=True) 53 | i_sampled = int(torch.argmax(i_sampled_onehot, dim=0).item()) 54 | # selected_clause_indices = [i for i, j in enumerate(selected_clause_indices)] 55 | # clauses_to_refine_i = [c for i, c in enumerate(clauses) if selected_clause_indices[i] > 0] 56 | sampled_clause = clauses[i_sampled] 57 | score = np.round(clause_scores[i_sampled].cpu().numpy(), 2) 58 | 59 | if score in self.refinement_score_history: 60 | # if a clause with the same score is already sampled, just skip this 61 | # renormalize clause scores 62 | if i_sampled != len(clauses)-1: 63 | clause_scores = torch.cat([clause_scores[:i_sampled], clause_scores[i_sampled + 1:]]) 64 | clauses.remove(sampled_clause) 65 | else: 66 | clause_scores = clause_scores[:i_sampled] 67 | clauses.remove(sampled_clause) 68 | else: 69 | # append to the result 70 | clauses_to_refine.append(sampled_clause) 71 | # update history 72 | self.refinement_score_history.add(score) 73 | self.refinement_history.add(sampled_clause) 74 | # renormalize socres 75 | if i_sampled != len(clauses)-1: 76 | clause_scores = torch.cat([clause_scores[:i_sampled], clause_scores[i_sampled + 1:]]) 77 | clauses.remove(sampled_clause) 78 | else: 79 | clause_scores = clause_scores[:i_sampled] 80 | clauses.remove(sampled_clause) 81 | 82 | n_sampled += 1 83 | 84 | clauses_to_refine = list(set(clauses_to_refine)) 85 | return clauses_to_refine 86 | 87 | def split_by_head_preds(self, clauses, clause_scores): 88 | head_pred_clauses_dic = {} 89 | head_pred_scores_dic = {} 90 | for i, c in enumerate(clauses): 91 | if c.head.pred in head_pred_clauses_dic: 92 | head_pred_clauses_dic[c.head.pred].append(c) 93 | head_pred_scores_dic[c.head.pred].append(clause_scores[i]) 94 | else: 95 | head_pred_clauses_dic[c.head.pred] = [c] 96 | head_pred_scores_dic[c.head.pred] = [clause_scores[i]] 97 | 98 | for p in head_pred_scores_dic.keys(): 99 | head_pred_scores_dic[p] = torch.tensor(head_pred_scores_dic[p]) 100 | return head_pred_clauses_dic, head_pred_scores_dic 101 | 102 | 103 | def apply_refinement(self, clauses): 104 | all_new_clauses = [] 105 | for clause in clauses: 106 | new_clauses = self.generate_clauses_by_refinement(clause) 107 | all_new_clauses.extend(new_clauses) 108 | # add to the refinement tree 109 | #self.print_tree() 110 | # the true atom for the clause to be refined has been removed 111 | target_node = find_by_attr(self.tree, name='clause', value=add_true_atom(clause)) 112 | #print(target_node) 113 | for nc in new_clauses: 114 | all_nodes =list(PreOrderIter(self.tree)) 115 | if not nc in [n.clause for n in all_nodes]: 116 | Node(name=str(nc), clause=nc, parent=target_node) 117 | """ 118 | clauses_exist = [n.clause for n in target_node.children] 119 | if not nc in clauses_exist: 120 | print(target_node) 121 | print('clauses_exist: ', clauses_exist) 122 | Node(name=str(nc), clause=nc, parent=target_node) 123 | """ 124 | # target_node.children.append(child_node) 125 | return all_new_clauses 126 | 127 | 128 | def generate_clauses_by_refinement(self, clause): 129 | return list(set(add_true_atoms(self.refinement_generator.refine_clause(clause)))) 130 | 131 | def print_tree(self): 132 | print("-- rule generation tree --") 133 | print(RenderTree(self.tree).by_attr('clause')) 134 | 135 | def get_clauses_by_th_depth(self, th_depth): 136 | """Get all clauses that are located deeper nodes than given threashold.""" 137 | nodes = findall(self.tree, filter_=lambda node: node.depth >= th_depth) 138 | return [node.clause for node in nodes] 139 | """ 140 | def __sample_clauses_by_scores(self, clauses, clause_scores): 141 | clauses_to_refine = [] 142 | print("Logits for the sampling: ") 143 | print(np.round(clause_scores.cpu().numpy(), 2)) 144 | for i in range(clause_scores.size(0)): 145 | clauses_dic, scores_dic = self.split_by_head_preds(clauses, clause_scores[i]) 146 | for p, clauses_p in clauses_dic.items(): 147 | selected_clause_indices = torch.stack([F.gumbel_softmax(scores_dic[p] * 100, tau=1.0, hard=True) for j in range(int(self.n_sample / len(clauses_dic.keys())))]) 148 | selected_clause_indices, _ = torch.max(selected_clause_indices, dim=0) 149 | # selected_clause_indices = [i for i, j in enumerate(selected_clause_indices)] 150 | clauses_to_refine_i = [c for i, c in enumerate(clauses_p) if selected_clause_indices[i] > 0] 151 | clauses_to_refine.extend(clauses_to_refine_i) 152 | clauses_to_refine = list(set(clauses_to_refine)) 153 | return clauses_to_refine 154 | 155 | 156 | def sample_clauses_by_scores(self, clauses, clause_scores): 157 | clauses_to_refine = [] 158 | print("Logits for the sampling: ") 159 | print(np.round(clause_scores.cpu().numpy(), 2)) 160 | for i in range(clause_scores.size(0)): 161 | selected_clause_indices = torch.stack([F.gumbel_softmax(clause_scores[i], tau=1.0, hard=True) for j in range(int(self.n_sample))]) 162 | selected_clause_indices, _ = torch.max(selected_clause_indices, dim=0) 163 | # selected_clause_indices = [i for i, j in enumerate(selected_clause_indices)] 164 | clauses_to_refine_i = [c for i, c in enumerate(clauses) if selected_clause_indices[i] > 0] 165 | clauses_to_refine.extend(clauses_to_refine_i) 166 | clauses_to_refine = list(set(clauses_to_refine)) 167 | return clauses_to_refine 168 | """ -------------------------------------------------------------------------------- /neumann/neumann/data_behind_the_scenes.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | 11 | 12 | def load_image_clevr(path): 13 | """Load an image using given path. 14 | """ 15 | img = cv2.imread(path) # BGR 16 | assert img is not None, 'Image Not Found ' + path 17 | img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW 18 | img = np.ascontiguousarray(img) 19 | return img 20 | 21 | def load_question_json(question_json_path): 22 | with open(question_json_path) as f: 23 | question = json.load(f) 24 | #questions = question["questions"] 25 | #return questions 26 | return question 27 | 28 | class BehindTheScenes(torch.utils.data.Dataset): 29 | def __init__(self, question_json_path, lang, n_data, device, img_size=128, base='data/behind-the-scenes/'): 30 | super().__init__() 31 | #self.colors = ["cyan", "blue", "yellow",\ 32 | # "purple", "red", "green", "gray", "brown"] 33 | self.colors = ["cyan", "gray", "red", "yellow"] 34 | self.query_types = ["delete", "append", "reverse", "sort"] 35 | self.positions = ["1st", "2nd", "3rd"] 36 | self.base = base 37 | self.lang = lang 38 | self.device = device 39 | self.questions = random.sample(load_question_json(question_json_path), int(n_data)) 40 | self.img_size = img_size 41 | self.transform = transforms.Compose( 42 | [transforms.Resize((img_size, img_size))] 43 | ) 44 | #self.image_paths, self.answer_paths = load_images_and_labels( 45 | # dataset=dataset, split=split, base=base) 46 | # {"program": "query2(sort,2nd)", "split": "train", "image_index": 45, "answer": "red", \ 47 | # "image": "BehindTheScenes_train_000045", "question_index": 10290, "question": ["sort", "2nd"], \ 48 | # "image_filename": "BehindTheScenes_train_000045.png"}, 49 | 50 | def __getitem__(self, item): 51 | question = self.questions[item] 52 | # print('question: ', question) 53 | image_path = self.base + 'images/' + question["image_filename"] 54 | image = Image.open(image_path).convert("RGB") 55 | image = self.image_preprocess(image) 56 | answer = self.to_onehot(self.colors.index(question["answer"]), len(self.colors)) 57 | query_tuple = question["question"] 58 | query = self.to_query_vector(query_tuple) 59 | # TODO: concate and return?? 60 | # answer = load_answer(self.answer_paths[item]) 61 | return image, query, answer 62 | 63 | def __len__(self): 64 | return len(self.questions) 65 | 66 | def to_onehot(self, index, size): 67 | onehot = torch.zeros(size, ).to(self.device) 68 | onehot[index] = 1.0 69 | return onehot 70 | 71 | def to_query_vector(self, query_tuple): 72 | if len(query_tuple) == 3: 73 | query_type, color, position = query_tuple 74 | # ("delete", "red", "1st") 75 | q_1 = self.to_onehot(self.query_types.index(query_type), len(self.query_types)) 76 | q_2 = self.to_onehot(self.colors.index(color), len(self.colors)) 77 | q_3 = self.to_onehot(self.positions.index(position), len(self.positions)) 78 | return torch.cat([q_1, q_2, q_3]) 79 | elif len(query_tuple) == 2: 80 | query_type, position = query_tuple 81 | # ("sort", "1st") 82 | q_1 = self.to_onehot(self.query_types.index(query_type), len(self.query_types)) 83 | q_2 = torch.zeros(len(self.colors, )).to(self.device) 84 | q_3 = self.to_onehot(self.positions.index(position), len(self.positions)) 85 | return torch.cat([q_1, q_2, q_3]) 86 | 87 | 88 | def image_preprocess(self, image): 89 | image = transforms.ToTensor()(image)[:3, :, :] 90 | image = self.transform(image) 91 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1]. 92 | return image.unsqueeze(0) 93 | 94 | -------------------------------------------------------------------------------- /neumann/neumann/data_clevr.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | 9 | 10 | def load_images_and_labels(dataset='clevr-hans3', split='train', pos_ratio=1.0, neg_ratio=1.0, base=None): 11 | """Load image paths and labels for clevr-hans dataset. 12 | """ 13 | image_paths = [] 14 | labels = [] 15 | folder = 'data/clevr-hans/' + dataset + '/' + split + '/' 16 | true_folder = folder + 'true/' 17 | false_folder = folder + 'false/' 18 | 19 | filenames = sorted(os.listdir(true_folder)) 20 | n = int(pos_ratio * len(filenames)) 21 | for filename in filenames[:n]: 22 | if filename != '.DS_Store': 23 | image_paths.append(os.path.join(true_folder, filename)) 24 | labels.append(1) 25 | 26 | filenames = sorted(os.listdir(false_folder)) 27 | n = int(neg_ratio * len(filenames)) 28 | for filename in filenames[:n]: 29 | if filename != '.DS_Store': 30 | image_paths.append(os.path.join(false_folder, filename)) 31 | labels.append(0) 32 | return image_paths, labels 33 | 34 | 35 | def load_image_clevr(path): 36 | """Load an image using given path. 37 | """ 38 | img = cv2.imread(path) # BGR 39 | assert img is not None, 'Image Not Found ' + path 40 | img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW 41 | img = np.ascontiguousarray(img) 42 | return img 43 | 44 | 45 | 46 | class CLEVRHans(torch.utils.data.Dataset): 47 | """CLEVRHans dataset. 48 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py. 49 | """ 50 | 51 | def __init__(self, dataset, split, pos_ratio=1.0, neg_ratio=1.0, img_size=128, base=None): 52 | super().__init__() 53 | self.img_size = img_size 54 | self.dataset = dataset 55 | assert split in { 56 | "train", 57 | "val", 58 | "test", 59 | } # note: test isn't very useful since it doesn't have ground-truth scene information 60 | self.split = split 61 | self.transform = transforms.Compose( 62 | [transforms.Resize((img_size, img_size))] 63 | ) 64 | self.image_paths, self.labels = load_images_and_labels( 65 | dataset=dataset, split=split, pos_ratio=pos_ratio, neg_ratio=neg_ratio) 66 | 67 | def __getitem__(self, item): 68 | path = self.image_paths[item] 69 | image = Image.open(path).convert("RGB") 70 | image = transforms.ToTensor()(image)[:3, :, :] 71 | image = self.transform(image) 72 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1]. 73 | label = torch.tensor(self.labels[item], dtype=torch.float32) 74 | return image.unsqueeze(0), label 75 | 76 | 77 | def __old__getitem__(self, item): 78 | path = self.image_paths[item] 79 | image = Image.open(path).convert("RGB") 80 | image = transforms.ToTensor()(image)[:3, :, :] 81 | image = self.transform(image) 82 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1]. 83 | if self.dataset == 'clevr-hans3': 84 | labels = torch.zeros((3, ), dtype=torch.float32) 85 | elif self.dataset == 'clevr-hans7': 86 | labels = torch.zeros((7, ), dtype=torch.float32) 87 | labels[self.labels[item]] = 1.0 88 | return image, labels 89 | 90 | def __len__(self): 91 | return len(self.labels) 92 | 93 | class CLEVRConcept(torch.utils.data.Dataset): 94 | """The Concept-learning dataset for CLEVR-Hans. 95 | """ 96 | 97 | def __init__(self, dataset, split): 98 | self.dataset = dataset 99 | self.split = split 100 | self.data, self.labels = self.load_csv() 101 | print('concept data: ', self.data.shape, 'labels: ', len(self.labels)) 102 | 103 | def load_csv(self): 104 | data = [] 105 | labels = [] 106 | pos_csv_data = pd.read_csv( 107 | 'data/clevr/concept_data/' + self.split + '/' + self.dataset + '_pos' + '.csv', delimiter=' ') 108 | pos_data = pos_csv_data.values 109 | #pos_labels = np.ones((len(pos_data, ))) 110 | pos_labels = np.zeros((len(pos_data, ))) 111 | neg_csv_data = pd.read_csv( 112 | 'data/clevr/concept_data/' + self.split + '/' + self.dataset + '_neg' + '.csv', delimiter=' ') 113 | neg_data = neg_csv_data.values 114 | #neg_labels = np.zeros((len(neg_data, ))) 115 | neg_labels = np.ones((len(neg_data, ))) 116 | data = torch.tensor(np.concatenate( 117 | [pos_data, neg_data], axis=0), dtype=torch.float32) 118 | labels = torch.tensor(np.concatenate( 119 | [pos_labels, neg_labels], axis=0), dtype=torch.float32) 120 | return data, labels 121 | 122 | def __getitem__(self, item): 123 | return self.data[item], self.labels[item] 124 | 125 | def __len__(self): 126 | return len(self.data) 127 | -------------------------------------------------------------------------------- /neumann/neumann/data_kandinsky.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class KANDINSKY(torch.utils.data.Dataset): 11 | """Kandinsky Patterns dataset. 12 | """ 13 | 14 | def __init__(self, dataset, split, pos_ratio=1.0, neg_ratio=1.0, img_size=128): 15 | self.img_size = img_size 16 | assert split in { 17 | "train", 18 | "val", 19 | "test", 20 | } 21 | self.image_paths, self.labels = load_images_and_labels( 22 | dataset=dataset, split=split, pos_ratio=pos_ratio, neg_ratio=neg_ratio) 23 | print("{} {} images loaded!!".format(len(self.image_paths), split)) 24 | 25 | def __getitem__(self, item): 26 | image = load_image_yolo( 27 | self.image_paths[item], img_size=self.img_size) 28 | image = torch.from_numpy(image).type(torch.float32) / 255. 29 | 30 | label = torch.tensor(self.labels[item], dtype=torch.float32) 31 | 32 | # return image as one image, not two or more 33 | return image.unsqueeze(0), label 34 | 35 | def __len__(self): 36 | return len(self.labels) 37 | 38 | 39 | def load_images_and_labels(dataset='twopairs', split='train', pos_ratio=1.0, neg_ratio=1.0, img_size=128): 40 | """Load image paths and labels for kandinsky dataset. 41 | """ 42 | image_paths = [] 43 | labels = [] 44 | folder = 'data/kandinsky/' + dataset + '/' + split + '/' 45 | true_folder = folder + 'true/' 46 | false_folder = folder + 'false/' 47 | 48 | filenames = sorted(os.listdir(true_folder)) 49 | #if split == 'train': 50 | # n = int(len(filenames)/10) 51 | #else: 52 | # n = len(filenames) 53 | n = int(pos_ratio * len(filenames)) 54 | for filename in filenames[:n]: 55 | if filename != '.DS_Store': 56 | image_paths.append(os.path.join(true_folder, filename)) 57 | labels.append(1) 58 | 59 | filenames = sorted(os.listdir(false_folder)) 60 | n = int(neg_ratio * len(filenames)) 61 | for filename in filenames[:n]: 62 | if filename != '.DS_Store': 63 | image_paths.append(os.path.join(false_folder, filename)) 64 | labels.append(0) 65 | return image_paths, labels 66 | 67 | 68 | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): 69 | """A utilitiy function for yolov5 model to make predictions. The implementation is from the yolov5 repository. 70 | """ 71 | import cv2 72 | 73 | # Resize and pad image while meeting stride-multiple constraints 74 | shape = img.shape[:2] # current shape [height, width] 75 | if isinstance(new_shape, int): 76 | new_shape = (new_shape, new_shape) 77 | 78 | # Scale ratio (new / old) 79 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 80 | if not scaleup: # only scale down, do not scale up (for better test mAP) 81 | r = min(r, 1.0) 82 | 83 | # Compute padding 84 | ratio = r, r # width, height ratios 85 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 86 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - \ 87 | new_unpad[1] # wh padding 88 | if auto: # minimum rectangle 89 | dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding 90 | elif scaleFill: # stretch 91 | dw, dh = 0.0, 0.0 92 | new_unpad = (new_shape[1], new_shape[0]) 93 | ratio = new_shape[1] / shape[1], new_shape[0] / \ 94 | shape[0] # width, height ratios 95 | 96 | dw /= 2 # divide padding into 2 sides 97 | dh /= 2 98 | 99 | if shape[::-1] != new_unpad: # resize 100 | img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) 101 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 102 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 103 | img = cv2.copyMakeBorder( 104 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 105 | return img, ratio, (dw, dh) 106 | 107 | 108 | def load_image_yolo(path, img_size, stride=32): 109 | """Load an image using given path. 110 | """ 111 | img0 = cv2.imread(path) # BGR 112 | assert img0 is not None, 'Image Not Found ' + path 113 | img = cv2.resize(img0, (img_size, img_size)) 114 | 115 | # Convert 116 | img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW 117 | img = np.ascontiguousarray(img) 118 | return img 119 | -------------------------------------------------------------------------------- /neumann/neumann/data_logic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | class RandomValuation(torch.utils.data.Dataset): 10 | """CLEVRHans dataset. 11 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py. 12 | """ 13 | 14 | def __init__(self, dataset, split, atoms, n_data=2000): 15 | super().__init__() 16 | self.atoms = atoms 17 | self.n_data = n_data 18 | self.dataset = dataset 19 | assert split in { 20 | "train", 21 | "val", 22 | "test", 23 | } # note: test isn't very useful since it doesn't have ground-truth scene information 24 | self.split = split 25 | 26 | 27 | def __getitem__(self, item): 28 | return torch.rand((len(self.atoms), )) 29 | 30 | def __len__(self): 31 | return self.n_data 32 | 33 | class ZeroValuation(torch.utils.data.Dataset): 34 | """CLEVRHans dataset. 35 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py. 36 | """ 37 | 38 | def __init__(self, dataset, split, atoms, n_data=2000): 39 | super().__init__() 40 | self.atoms = atoms 41 | self.n_data = n_data 42 | self.dataset = dataset 43 | assert split in { 44 | "train", 45 | "val", 46 | "test", 47 | } # note: test isn't very useful since it doesn't have ground-truth scene information 48 | self.split = split 49 | 50 | 51 | def __getitem__(self, item): 52 | return torch.zeros((len(self.atoms), )) 53 | 54 | def __len__(self): 55 | return self.n_data 56 | -------------------------------------------------------------------------------- /neumann/neumann/data_vilp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | 9 | 10 | def load_images_and_labels(dataset='member', split='train', base=None, pos_ratio=1.0, neg_ratio=1.0): 11 | """Load image paths and labels for clevr-hans dataset. 12 | """ 13 | image_paths_list = [] 14 | labels = [] 15 | base_folder = 'data/vilp/' + dataset + '/' + split + '/true/' 16 | folder_names = sorted(os.listdir(base_folder)) 17 | if '.DS_Store' in folder_names: 18 | folder_names.remove('.DS_Store') 19 | 20 | if split == 'train': 21 | n = int(len(folder_names) * pos_ratio) 22 | else: 23 | n = len(folder_names) 24 | for folder_name in folder_names[:n]: 25 | folder = base_folder + folder_name + '/' 26 | filenames = sorted(os.listdir(folder)) 27 | image_paths = [] 28 | for filename in filenames: 29 | if filename != '.DS_Store': 30 | image_paths.append(os.path.join(folder, filename)) 31 | image_paths_list.append(image_paths) 32 | labels.append(1.0) 33 | base_folder = 'data/vilp/' + dataset + '/' + split + '/false/' 34 | if split == 'train': 35 | n = int(len(folder_names) * neg_ratio) 36 | else: 37 | n = len(folder_names) 38 | folder_names = sorted(os.listdir(base_folder)) 39 | if '.DS_Store' in folder_names: 40 | folder_names.remove('.DS_Store') 41 | for folder_name in folder_names[:n]: 42 | folder = base_folder + folder_name + '/' 43 | filenames = sorted(os.listdir(folder)) 44 | image_paths = [] 45 | for filename in filenames: 46 | if filename != '.DS_Store': 47 | image_paths.append(os.path.join(folder, filename)) 48 | image_paths_list.append(image_paths) 49 | labels.append(0.0) 50 | return image_paths_list, labels 51 | 52 | 53 | def load_images_and_labels_positive(dataset='member', split='train', base=None): 54 | """Load image paths and labels for clevr-hans dataset. 55 | """ 56 | image_paths_list = [] 57 | labels = [] 58 | base_folder = 'data/vilp/' + dataset + '/' + split + '/true/' 59 | folder_names = sorted(os.listdir(base_folder)) 60 | if '.DS_Store' in folder_names: 61 | folder_names.remove('.DS_Store') 62 | 63 | for folder_name in folder_names: 64 | folder = base_folder + folder_name + '/' 65 | filenames = sorted(os.listdir(folder)) 66 | image_paths = [] 67 | for filename in filenames: 68 | if filename != '.DS_Store': 69 | image_paths.append(os.path.join(folder, filename)) 70 | image_paths_list.append(image_paths) 71 | labels.append(1.0) 72 | return image_paths_list, labels 73 | 74 | def load_image_clevr(path): 75 | """Load an image using given path. 76 | """ 77 | img = cv2.imread(path) # BGR 78 | assert img is not None, 'Image Not Found ' + path 79 | img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW 80 | img = np.ascontiguousarray(img) 81 | return img 82 | 83 | 84 | class VisualILP(torch.utils.data.Dataset): 85 | """CLEVRHans dataset. 86 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py. 87 | """ 88 | 89 | def __init__(self, dataset, split, img_size=128, base=None, pos_ratio=1.0, neg_ratio=1.0): 90 | super().__init__() 91 | self.img_size = img_size 92 | self.dataset = dataset 93 | assert split in { 94 | "train", 95 | "val", 96 | "test", 97 | } # note: test isn't very useful since it doesn't have ground-truth scene information 98 | self.split = split 99 | self.transform = transforms.Compose( 100 | [transforms.Resize((img_size, img_size))] 101 | ) 102 | self.pos_ratio = pos_ratio 103 | self.neg_ratio = neg_ratio 104 | self.image_paths, self.labels = load_images_and_labels( 105 | dataset=dataset, split=split, base=base, pos_ratio=pos_ratio, neg_ratio=neg_ratio) 106 | 107 | def __getitem__(self, item): 108 | paths = self.image_paths[item] 109 | images = [] 110 | for path in paths: 111 | image = Image.open(path).convert("RGB") 112 | image = transforms.ToTensor()(image)[:3, :, :] 113 | image = self.transform(image) 114 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1]. 115 | images.append(image) 116 | # TODO: concate and return?? 117 | image = torch.stack(images, dim=0) 118 | return image, self.labels[item] 119 | 120 | def __len__(self): 121 | return len(self.labels) 122 | 123 | class VisualILP_POSITIVE(torch.utils.data.Dataset): 124 | """CLEVRHans dataset. 125 | The implementations is mainly from https://github.com/ml-research/NeSyConceptLearner/blob/main/src/pretrain-slot-attention/data.py. 126 | """ 127 | 128 | def __init__(self, dataset, split, img_size=128, base=None): 129 | super().__init__() 130 | self.img_size = img_size 131 | self.dataset = dataset 132 | assert split in { 133 | "train", 134 | "val", 135 | "test", 136 | } # note: test isn't very useful since it doesn't have ground-truth scene information 137 | self.split = split 138 | self.transform = transforms.Compose( 139 | [transforms.Resize((img_size, img_size))] 140 | ) 141 | self.image_paths, self.labels = load_images_and_labels_positive( 142 | dataset=dataset, split=split, base=base) 143 | 144 | def __getitem__(self, item): 145 | paths = self.image_paths[item] 146 | images = [] 147 | for path in paths: 148 | image = Image.open(path).convert("RGB") 149 | image = transforms.ToTensor()(image)[:3, :, :] 150 | image = self.transform(image) 151 | image = (image - 0.5) * 2.0 # Rescale to [-1, 1]. 152 | images.append(image) 153 | # TODO: concate and return?? 154 | image = torch.stack(images, dim=0) 155 | return image, self.labels[item] 156 | 157 | def __len__(self): 158 | return len(self.labels) 159 | -------------------------------------------------------------------------------- /neumann/neumann/explain_clevr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import time 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | from rtpt import RTPT 10 | from torch.utils.tensorboard import SummaryWriter 11 | from tqdm import tqdm 12 | 13 | from explanation_utils import * 14 | from logic_utils import get_lang 15 | from neumann_utils import get_data_loader, get_model, get_prob 16 | 17 | torch.set_num_threads(10) 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--batch-size", type=int, default=10, 22 | help="Batch size to infer with") 23 | parser.add_argument("--batch-size-bs", type=int, 24 | default=1, help="Batch size in beam search") 25 | parser.add_argument("--num-objects", type=int, default=6, 26 | help="The maximum number of objects in one image") 27 | parser.add_argument("--dataset", default="delete") # , choices=["member"]) 28 | parser.add_argument("--dataset-type", default="behind-the-scenes") 29 | parser.add_argument('--device', default='cpu', 30 | help='cuda device, i.e. 0 or cpu') 31 | parser.add_argument("--no-cuda", action="store_true", 32 | help="Run on CPU instead of GPU (not recommended)") 33 | parser.add_argument("--no-train", action="store_true", 34 | help="Perform prediction without training model") 35 | parser.add_argument("--small-data", action="store_true", 36 | help="Use small training data.") 37 | parser.add_argument("--num-workers", type=int, default=0, 38 | help="Number of threads for data loader") 39 | parser.add_argument('--gamma', default=0.01, type=float, 40 | help='Smooth parameter in the softor function') 41 | parser.add_argument("--plot", action="store_true", 42 | help="Plot images with captions.") 43 | parser.add_argument("--t-beam", type=int, default=4, 44 | help="Number of rule expantion of clause generation.") 45 | parser.add_argument("--n-beam", type=int, default=5, 46 | help="The size of the beam.") 47 | parser.add_argument("--n-max", type=int, default=50, 48 | help="The maximum number of clauses.") 49 | parser.add_argument("--program-size", type=int, default=1, 50 | help="The size of the logic program.") 51 | #parser.add_argument("--n-obj", type=int, default=2, help="The number of objects to be focused.") 52 | parser.add_argument("--epochs", type=int, default=20, 53 | help="The number of epochs.") 54 | parser.add_argument("--lr", type=float, default=1e-2, 55 | help="The learning rate.") 56 | parser.add_argument("--n-ratio", type=float, default=1.0, 57 | help="The ratio of data to be used.") 58 | parser.add_argument("--pre-searched", action="store_true", 59 | help="Using pre searched clauses.") 60 | parser.add_argument("--infer-step", type=int, default=6, 61 | help="The number of steps of forward reasoning.") 62 | parser.add_argument("--term-depth", type=int, default=3, 63 | help="The number of steps of forward reasoning.") 64 | parser.add_argument("--question-json-path", default="data/behind-the-scenes/BehindTheScenes_questions.json") 65 | args = parser.parse_args() 66 | return args 67 | 68 | # def get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False): 69 | 70 | 71 | def discretise_NEUMANN(NEUMANN, args, device): 72 | lark_path = 'src/lark/exp.lark' 73 | lang_base_path = 'data/lang/' 74 | lang, clauses_, bk, terms, atoms = get_lang( 75 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth) 76 | # Discretise NEUMANN rules 77 | clauses = NEUMANN.get_clauses() 78 | return get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False) 79 | 80 | def predict(NEUMANN, I2F, loader, args, device, th=None, split='train'): 81 | predicted_list = [] 82 | target_list = [] 83 | count = 0 84 | 85 | start = time.time() 86 | for epoch in tqdm(range(args.epochs)): 87 | for i, sample in enumerate(tqdm(loader), start=0): 88 | imgs, target_set = map(lambda x: x.to(device), sample) 89 | # to cuda 90 | target_set = target_set.float() 91 | 92 | #imgs: torch.Size([1, 1, 3, 128, 128]) 93 | V_0 = I2F(imgs) 94 | V_T = NEUMANN(V_0) 95 | #a NEUMANN.print_valuation_batch(V_T) 96 | predicted = get_prob(V_T, NEUMANN, args) 97 | 98 | # compute explanation for each input image 99 | for pred in predicted: 100 | if pred > 0.9: 101 | pred.backward(retain_graph=True) 102 | atom_grads = NEUMANN.mpm.dummy_zeros.grad.squeeze(-1).unsqueeze(0) 103 | attention_maps = I2F.pm.model.slot_attention.attention_maps.squeeze(0) 104 | target_attention_maps = get_target_maps(NEUMANN.atoms, atom_grads, attention_maps) 105 | #print(atom_grads, torch.max(atom_grads), atom_grads.shape) 106 | NEUMANN.print_valuation_batch(atom_grads) 107 | 108 | 109 | 110 | imgs_to_plot = to_plot_images_clevr(imgs.squeeze(0).detach().cpu()) 111 | captions = generate_captions(atom_grads, NEUMANN.atoms, args.num_objects, th=0.33) 112 | # + args.dataset + '/' + split + '/', \ 113 | save_images_with_captions_and_attention_maps(imgs_to_plot, target_attention_maps, captions, folder='explanation/clevr/', \ 114 | img_id=count, dataset=args.dataset) 115 | NEUMANN.mpm.dummy_zeros.grad.detach_() 116 | NEUMANN.mpm.dummy_zeros.grad.zero_() 117 | count += 1 118 | reasoning_time = time.time() - start 119 | print('Reasoning Time: ', reasoning_time) 120 | return 0, 0, 0, reasoning_time 121 | 122 | 123 | def to_one_label(ys, labels, th=0.7): 124 | ys_new = [] 125 | for i in range(len(ys)): 126 | y = ys[i] 127 | label = labels[i] 128 | # check in case answers are computed 129 | num_class = 0 130 | for p_j in y: 131 | if p_j > th: 132 | num_class += 1 133 | if num_class >= 2: 134 | # drop the value using label (the label is one-hot) 135 | drop_index = torch.argmin(label - y) 136 | y[drop_index] = y.min() 137 | ys_new.append(y) 138 | return torch.stack(ys_new) 139 | 140 | 141 | def main(n): 142 | seed_everything(n) 143 | args = get_args() 144 | assert args.batch_size == 1, "Set batch_size=1." 145 | #name = 'VILP' 146 | print('args ', args) 147 | if args.no_cuda: 148 | device = torch.device('cpu') 149 | elif len(args.device.split(',')) > 1: 150 | # multi gpu 151 | device = torch.device('cuda') 152 | else: 153 | device = torch.device('cuda:' + args.device) 154 | 155 | print('device: ', device) 156 | name = 'neumann/behind-the-scenes/' + str(n) 157 | writer = SummaryWriter(f"runs/{name}", purge_step=0) 158 | 159 | # Create RTPT object 160 | rtpt = RTPT(name_initials='HS', experiment_name=name, 161 | max_iterations=args.epochs) 162 | # Start the RTPT tracking 163 | rtpt.start() 164 | 165 | 166 | ## train_pos_loader, val_pos_loader, test_pos_loader = get_vilp_pos_loader(args) 167 | #####train_pos_loader, val_pos_loader, test_pos_loader = get_data_loader(args) 168 | 169 | # load logical representations 170 | lark_path = 'src/lark/exp.lark' 171 | lang_base_path = 'data/lang/' 172 | lang, clauses, bk, bk_clauses, terms, atoms = get_lang( 173 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth, use_learned_clauses=True) 174 | 175 | print("{} Atoms:".format(len(atoms))) 176 | 177 | # get torch data loader 178 | #question_json_path = 'data/behind-the-scenes/BehindTheScenes_questions_{}.json'.format(args.dataset) 179 | # test_loader = get_behind_the_scenes_loader(question_json_path, args.batch_size, lang, args.n_data, device) 180 | train_loader, val_loader, test_loader = get_data_loader(args, device) 181 | 182 | NEUMANN, I2F = get_model(lang=lang, clauses=clauses, atoms=atoms, terms=terms, bk=bk, bk_clauses=bk_clauses, 183 | program_size=args.program_size, device=device, dataset=args.dataset, dataset_type=args.dataset_type, 184 | num_objects=args.num_objects, infer_step=args.infer_step, train=False, explain=True)#train=not(args.no_train)) 185 | 186 | writer.add_scalar("graph/num_atom_nodes", len(NEUMANN.rgm.atom_node_idxs)) 187 | writer.add_scalar("graph/num_conj_nodes", len(NEUMANN.rgm.conj_node_idxs)) 188 | num_nodes = len(NEUMANN.rgm.atom_node_idxs) + len(NEUMANN.rgm.conj_node_idxs) 189 | writer.add_scalar("graph/num_nodes", num_nodes) 190 | 191 | num_edges = NEUMANN.rgm.edge_index.size(1) 192 | writer.add_scalar("graph/num_edges", num_edges) 193 | 194 | writer.add_scalar("graph/memory_total", num_nodes + num_edges) 195 | 196 | print("=====================") 197 | print("NUM NODES: ", num_nodes) 198 | print("NUM EDGES: ", num_edges) 199 | print("MEMORY TOTAL: ", num_nodes + num_edges) 200 | print("=====================") 201 | 202 | params = list(NEUMANN.parameters()) 203 | print('parameters: ', list(params)) 204 | 205 | print("Predicting on train data set...") 206 | times = [] 207 | # train split 208 | for j in range(n): 209 | acc_test, rec_test, th_test, time = predict( 210 | NEUMANN, I2F, test_loader, args, device, th=0.5, split='test') 211 | times.append(time) 212 | 213 | with open('out/inference_time/time_{}_ratio_{}.txt'.format(args.dataset, args.n_ratio), 'w') as f: 214 | f.write("\n".join(str(item) for item in times)) 215 | 216 | print("train acc: ", acc_test, "threashold: ", th_test, "recall: ", rec_test) 217 | 218 | 219 | def seed_everything(seed: int): 220 | import os 221 | import random 222 | 223 | import numpy as np 224 | import torch 225 | 226 | random.seed(seed) 227 | os.environ['PYTHONHASHSEED'] = str(seed) 228 | np.random.seed(seed) 229 | torch.manual_seed(seed) 230 | torch.cuda.manual_seed(seed) 231 | torch.backends.cudnn.deterministic = True 232 | torch.backends.cudnn.benchmark = True 233 | 234 | if __name__ == "__main__": 235 | main(n=1) 236 | 237 | -------------------------------------------------------------------------------- /neumann/neumann/explanation_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | 8 | attrs = ['color', 'shape', 'material', 'size'] 9 | 10 | def get_target_maps(atoms, atom_grads, attention_maps): 11 | if atom_grads.size(0) == 1: 12 | atom_grads = atom_grads.squeeze(0) 13 | 14 | obj_ids = [] 15 | scores = [] 16 | for i, g in enumerate(atom_grads): 17 | atom = atoms[i] 18 | if g > 0.8 and 'obj' in str(atom): 19 | obj_id = str(atom).split(',')[0][-1] 20 | obj_id = int(obj_id) 21 | if not obj_id in obj_ids: 22 | obj_ids.append(obj_id) 23 | scores.append(g) 24 | return [attention_maps[id] * scores[i] for i, id in enumerate(obj_ids)] 25 | 26 | 27 | def valuation_to_attr_string(v, atoms, e, th=0.5): 28 | """Generate string explanations of the scene. 29 | """ 30 | 31 | st = '' 32 | for i in range(e): 33 | st_i = '' 34 | for j, atom in enumerate(atoms): 35 | #print(atom, [str(term) for term in atom.terms]) 36 | if 'obj' + str(i) in [str(term) for term in atom.terms] and atom.pred.name in attrs: 37 | if v[j] > th: 38 | prob = np.round(v[j].detach().cpu().numpy(), 2) 39 | st_i += str(prob) + ':' + str(atom) + ',' 40 | if st_i != '': 41 | st_i = st_i[:-1] 42 | st += st_i + '\n' 43 | return st 44 | 45 | 46 | def valuation_to_rel_string(v, atoms, th=0.5): 47 | l = 15 48 | st = '' 49 | n = 0 50 | for j, atom in enumerate(atoms): 51 | if v[j] > th and not (atom.pred.name in attrs+['in', '.']): 52 | prob = np.round(v[j].detach().cpu().numpy(), 2) 53 | st += str(prob) + ':' + str(atom) + ',' 54 | n += len(str(prob) + ':' + str(atom) + ',') 55 | if n > l: 56 | st += '\n' 57 | n = 0 58 | return st[:-1] + '\n' 59 | 60 | 61 | def valuation_to_string(v, atoms, e, th=0.5): 62 | return valuation_to_attr_string(v, atoms, e, th) + valuation_to_rel_string(v, atoms, th) 63 | 64 | 65 | def valuations_to_string(V, atoms, e, th=0.5): 66 | """Generate string explanation of the scenes. 67 | """ 68 | st = '' 69 | for i in range(V.size(0)): 70 | st += 'image ' + str(i) + '\n' 71 | # for each data in the batch 72 | st += valuation_to_string(V[i], atoms, e, th) 73 | return st 74 | 75 | 76 | def generate_captions(V, atoms, e, th): 77 | captions = [] 78 | for v in V: 79 | # for each data in the batch 80 | captions.append(valuation_to_string(v, atoms, e, th)) 81 | return captions 82 | 83 | 84 | def save_images_with_captions_and_attention_maps(imgs, attention_maps, captions, folder, img_id, dataset): 85 | if not os.path.exists(folder): 86 | os.makedirs(folder) 87 | 88 | figsize = (12, 6) 89 | # imgs should be denormalized. 90 | 91 | img_size = imgs[0].shape[0] 92 | attention_maps = np.array([m.cpu().detach().numpy() for m in attention_maps]) 93 | attention_map = np.zeros_like(attention_maps[0]) 94 | for am in attention_maps: 95 | attention_map += am 96 | attention_map = attention_map.reshape(32, 32) 97 | attention_map = cv2.resize(attention_map, (img_size, img_size)) 98 | #attention_map = torch.tensor(attention_map).extend() 99 | 100 | # apply attention maps to filter 101 | for i, img in enumerate(imgs): 102 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(4, 2)) #, sharex=True, sharey=True) 103 | ax1.axis('off') 104 | ax2.axis('off') 105 | 106 | #e=figsize, dpi=80) 107 | ax1.imshow(img) 108 | #ax1.xlabel(captions[i]) 109 | #ax1.tight_layout() 110 | #ax1.savefig(folder+str(img_id)+'_original.png') 111 | 112 | # ax2.figure(figsize=figsize, dpi=80) 113 | ax2.imshow(attention_map, cmap='cividis') 114 | #ax2.set_xlabel(captions[i], fontsize=12) 115 | #plt.axis('off') 116 | fig.tight_layout() 117 | fig.savefig(folder + dataset + '_img' + str(img_id) + '_explanation.svg') 118 | 119 | plt.close() 120 | 121 | def save_images_with_captions_and_attention_maps_indivisual(imgs, attention_maps, captions, folder, img_id, dataset): 122 | if not os.path.exists(folder): 123 | os.makedirs(folder) 124 | 125 | figsize = (12, 6) 126 | # imgs should be denormalized. 127 | 128 | img_size = imgs[0].shape[0] 129 | attention_maps = np.array([m.cpu().detach().numpy() for m in attention_maps]) 130 | attention_map = np.zeros_like(attention_maps[0]) 131 | for am in attention_maps: 132 | attention_map += am 133 | attention_map = attention_map.reshape(32,32) 134 | attention_map = cv2.resize(attention_map, (img_size, img_size)) 135 | #attention_map = torch.tensor(attention_map).extend() 136 | 137 | # apply attention maps to filter 138 | for i, img in enumerate(imgs): 139 | plt.figure(figsize=figsize, dpi=80) 140 | plt.imshow(img) 141 | plt.xlabel(captions[i]) 142 | plt.tight_layout() 143 | plt.savefig(folder+str(img_id)+'_original.png') 144 | 145 | plt.figure(figsize=figsize, dpi=80) 146 | plt.imshow(attention_map, cmap='cividis') 147 | plt.xlabel(captions[i]) 148 | plt.tight_layout() 149 | plt.savefig(folder+str(img_id)+'_attention_map.png') 150 | 151 | plt.close() 152 | 153 | def denormalize_clevr(imgs): 154 | """denormalize clevr images 155 | """ 156 | # normalizing: image = (image - 0.5) * 2.0 # Rescale to [-1, 1]. 157 | return (0.5 * imgs) + 0.5 158 | 159 | 160 | def denormalize_kandinsky(imgs): 161 | """denormalize kandinsky images 162 | """ 163 | return imgs 164 | 165 | 166 | def to_plot_images_clevr(imgs): 167 | return [img.permute(1, 2, 0).detach().numpy() for img in denormalize_clevr(imgs)] 168 | 169 | 170 | def to_plot_images_kandinsky(imgs): 171 | return [img.permute(1, 2, 0).detach().numpy() for img in denormalize_kandinsky(imgs)] 172 | -------------------------------------------------------------------------------- /neumann/neumann/facts_converter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from tqdm import tqdm 4 | 5 | from .fol.logic import NeuralPredicate 6 | 7 | 8 | class FactsConverter(nn.Module): 9 | """FactsConverter converts the output fromt the perception module to the valuation vector. 10 | """ 11 | 12 | def __init__(self, lang, atoms, bk, device=None): 13 | super(FactsConverter, self).__init__() 14 | # self.e = perception_module.e 15 | # self.d = perception_module.d 16 | self.lang = lang 17 | # self.vm = valuation_module # valuation functions 18 | self.device = device 19 | self.atoms = atoms 20 | self.bk = bk 21 | # init indices 22 | self.np_indices = self._get_np_atom_indices() 23 | self.bk_indices = self._get_bk_atom_indices() 24 | 25 | def __str__(self): 26 | return "FactsConverter(entities={}, dimension={})".format(self.e, self.d) 27 | 28 | def __repr__(self): 29 | return "FactsConverter(entities={}, dimension={})".format(self.e, self.d) 30 | 31 | def _get_np_atom_indices(self): 32 | """Pre compute the indices of atoms with neural predicats.""" 33 | indices = [] 34 | for i, atom in enumerate(self.atoms): 35 | if type(atom.pred) == NeuralPredicate: 36 | indices.append(i) 37 | return indices 38 | 39 | def _get_bk_atom_indices(self): 40 | """Pre compute the indices of atoms in background knowledge.""" 41 | indices = [] 42 | for i, atom in enumerate(self.atoms): 43 | if atom in self.bk: 44 | indices.append(i) 45 | return indices 46 | 47 | def forward(self, Z): 48 | return self.convert(Z) 49 | 50 | def get_params(self): 51 | return self.vm.get_params() 52 | 53 | def init_valuation(self, n, batch_size): 54 | v = torch.zeros((batch_size, n)).to(self.device) 55 | v[:, 1] = 1.0 56 | return v 57 | 58 | def filter_by_datatype(): 59 | pass 60 | 61 | def to_vec(self, term, zs): 62 | pass 63 | 64 | 65 | def convert(self, bk_atoms=None): 66 | V = torch.zeros(1, len(self.atoms)).to(self.device) 67 | 68 | # add background knowledge 69 | for i, atom in enumerate(self.atoms): 70 | if atom in bk_atoms: 71 | V[0, i] += 1.0 72 | 73 | V[0, 0] += 1.0 74 | return V 75 | 76 | 77 | class FactsConverterWithQuery(nn.Module): 78 | """FactsConverter converts the output fromt the perception module to the valuation vector. 79 | """ 80 | 81 | # def __init__(self, lang, perception_module, valuation_module, device=None): 82 | def __init__(self, lang, atoms, bk, perception_module, valuation_module, device=None): 83 | super(FactsConverterWithQuery, self).__init__() 84 | self.e = perception_module.e 85 | self.d = perception_module.d 86 | self.lang = lang 87 | self.vm = valuation_module # valuation functions 88 | self.device = device 89 | self.atoms = atoms 90 | self.bk = bk 91 | # init indices 92 | self.np_indices = self._get_np_atom_indices() 93 | self.bk_indices = self._get_bk_atom_indices() 94 | 95 | def __str__(self): 96 | return "FactsConverter(entities={}, dimension={})".format(self.e, self.d) 97 | 98 | def __repr__(self): 99 | return "FactsConverter(entities={}, dimension={})".format(self.e, self.d) 100 | 101 | def _get_np_atom_indices(self): 102 | """Pre compute the indices of atoms with neural predicats.""" 103 | indices = [] 104 | for i, atom in enumerate(self.atoms): 105 | if type(atom.pred) == NeuralPredicate: 106 | indices.append(i) 107 | return indices 108 | 109 | def _get_bk_atom_indices(self): 110 | """Pre compute the indices of atoms in background knowledge.""" 111 | indices = [] 112 | for i, atom in enumerate(self.atoms): 113 | if atom in self.bk: 114 | indices.append(i) 115 | return indices 116 | 117 | def forward(self, Z, Q): 118 | return self.convert(Z, Q) 119 | 120 | def get_params(self): 121 | return self.vm.get_params() 122 | 123 | def init_valuation(self, n, batch_size): 124 | v = torch.zeros((batch_size, n)).to(self.device) 125 | v[:, 1] = 1.0 126 | return v 127 | 128 | def filter_by_datatype(): 129 | pass 130 | 131 | def to_vec(self, term, zs): 132 | pass 133 | 134 | def convert(self, Z, Q): 135 | batch_size = Z.size(0) 136 | 137 | # V = self.init_valuation(len(G), Z.size(0)) 138 | #V = torch.zeros((batch_size, len(G))).to( 139 | # torch.float32).to(self.device) 140 | V = torch.zeros((batch_size, len(self.atoms))).to( 141 | torch.float32).to(self.device) 142 | 143 | # T to be 1.0 144 | V[:, 0] = 1.0 145 | 146 | for i in self.np_indices: 147 | V[:, i] = self.vm(Z, Q, self.atoms[i]) 148 | 149 | for i in self.bk_indices: 150 | V[:, i] += torch.ones((batch_size, )).to( 151 | torch.float32).to(self.device) 152 | """ 153 | for i, atom in enumerate(G): 154 | if type(atom.pred) == NeuralPredicate: 155 | V[:, i] = self.vm(Z, Q, atom) 156 | elif atom in B: 157 | # V[:, i] += 1.0 158 | V[:, i] += torch.ones((batch_size, )).to( 159 | torch.float32).to(self.device) 160 | """ 161 | return V -------------------------------------------------------------------------------- /neumann/neumann/fol/README.md: -------------------------------------------------------------------------------- 1 | # First-Order Logic 2 | The implementation of the first-order logic based on the [Differentiable Inductive Logic Programming](https://arxiv.org/abs/2103.01719) project. -------------------------------------------------------------------------------- /neumann/neumann/fol/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/neumann/neumann/fol/__init__.py -------------------------------------------------------------------------------- /neumann/neumann/fol/data_utils.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from lark import Lark 4 | 5 | from .exp_parser import ExpTree 6 | from .language import DataType, Language 7 | from .logic import Const, FuncSymbol, NeuralPredicate, Predicate 8 | 9 | 10 | class DataUtils(object): 11 | """Utilities about I/O of logic. 12 | """ 13 | 14 | def __init__(self, lark_path, lang_base_path, dataset_type='kandinsky', dataset=None): 15 | #if dataset == 'behind-the-scenes': 16 | # for behind the scenes 17 | # self.base_path = lang_base_path + dataset_type + '/' 18 | #if: 19 | self.base_path = lang_base_path + dataset_type + '/' + dataset + '/' 20 | with open(lark_path, encoding="utf-8") as grammar: 21 | self.lp_atom = Lark(grammar.read(), start="atom") 22 | with open(lark_path, encoding="utf-8") as grammar: 23 | self.lp_clause = Lark(grammar.read(), start="clause") 24 | 25 | def load_clauses(self, path, lang): 26 | """Read lines and parse to Atom objects. 27 | """ 28 | clauses = [] 29 | if os.path.isfile(path): 30 | with open(path) as f: 31 | for line in f: 32 | if line[-1] == '\n': 33 | line = line[:-1] 34 | if len(line) == 0 or line[0] == '#': 35 | continue 36 | tree = self.lp_clause.parse(line) 37 | clause = ExpTree(lang).transform(tree) 38 | clauses.append(clause) 39 | return clauses 40 | 41 | def load_atoms(self, path, lang): 42 | """Read lines and parse to Atom objects. 43 | """ 44 | atoms = [] 45 | 46 | if os.path.isfile(path): 47 | with open(path) as f: 48 | for line in f: 49 | if line[-1] == '\n': 50 | line = line[:-2] 51 | else: 52 | line = line[:-1] 53 | tree = self.lp_atom.parse(line) 54 | atom = ExpTree(lang).transform(tree) 55 | atoms.append(atom) 56 | return atoms 57 | 58 | def load_preds(self, path): 59 | f = open(path) 60 | lines = f.readlines() 61 | preds = [self.parse_pred(line) for line in lines] 62 | return preds 63 | 64 | def load_neural_preds(self, path): 65 | f = open(path) 66 | lines = f.readlines() 67 | preds = [self.parse_neural_pred(line) for line in lines] 68 | return preds 69 | 70 | def load_consts(self, path): 71 | f = open(path) 72 | lines = f.readlines() 73 | consts = [] 74 | for line in lines: 75 | consts.extend(self.parse_const(line)) 76 | return consts 77 | 78 | def load_funcs(self, path): 79 | funcs = [] 80 | if os.path.isfile(path): 81 | with open(path) as f: 82 | lines = f.readlines() 83 | for line in lines: 84 | funcs.append(self.parse_func(line)) 85 | return funcs 86 | 87 | def load_terms(self, path, lang): 88 | f = open(path) 89 | lines = f.readlines() 90 | terms = [] 91 | for line in lines: 92 | terms.extend(self.parse_term(line, lang)) 93 | return 94 | def parse_pred(self, line): 95 | """Parse string to predicates. 96 | """ 97 | line = line.replace('\n', '') 98 | pred, arity, dtype_names_str = line.split(':') 99 | dtype_names = dtype_names_str.split(',') 100 | dtypes = [DataType(dt) for dt in dtype_names] 101 | assert int(arity) == len( 102 | dtypes), 'Invalid arity and dtypes in ' + pred + '.' 103 | return Predicate(pred, int(arity), dtypes) 104 | 105 | def parse_neural_pred(self, line): 106 | """Parse string to predicates. 107 | """ 108 | line = line.replace('\n', '') 109 | pred, arity, dtype_names_str = line.split(':') 110 | dtype_names = dtype_names_str.split(',') 111 | dtypes = [DataType(dt) for dt in dtype_names] 112 | assert int(arity) == len( 113 | dtypes), 'Invalid arity and dtypes in ' + pred + '.' 114 | return NeuralPredicate(pred, int(arity), dtypes) 115 | 116 | def parse_func(self, line): 117 | """Parse string to function symbols. 118 | (Format) name:arity:input_type:output_type 119 | """ 120 | name, arity, in_dtypes, out_dtype = line.replace("\n", "").split(':') 121 | in_dtypes = in_dtypes.split(',') 122 | in_dtypes = [DataType(in_dtype) for in_dtype in in_dtypes] 123 | out_dtype = DataType(out_dtype) 124 | return FuncSymbol(name, int(arity), in_dtypes, out_dtype) 125 | 126 | 127 | 128 | def parse_const(self, line): 129 | """Parse string to constants. 130 | """ 131 | line = line.replace('\n', '') 132 | dtype_name, const_names_str = line.split(':') 133 | dtype = DataType(dtype_name) 134 | const_names = const_names_str.split(',') 135 | return [Const(const_name, dtype) for const_name in const_names] 136 | 137 | def parse_term(self, line, lang): 138 | """Parse string to func_terms. 139 | """ 140 | line = line.replace('\n', '') 141 | dtype_name, term_names_str = line.split(':') 142 | dtype = DataType(dtype_name) 143 | term_strs = term_names_str.split(',') 144 | terms = [] 145 | for term_str in term_strs: 146 | if not term_str == '': 147 | print("term_str: ", term_str) 148 | tree = self.lp_term.parse(term_str) 149 | terms.append(ExpTree(lang).transform(tree)) 150 | return terms 151 | #return [Const(const_name, dtype) for const_name in const_names] 152 | 153 | def parse_clause(self, clause_str, lang): 154 | tree = self.lp_clause.parse(clause_str) 155 | return ExpTree(lang).transform(tree) 156 | 157 | def get_clauses(self, lang): 158 | return self.load_clauses(self.base_path + 'clauses.txt', lang) 159 | 160 | def get_bk(self, lang): 161 | return self.load_atoms(self.base_path + 'bk.txt', lang) 162 | 163 | def get_facts(self, lang): 164 | return self.load_atoms(self.base_path + 'facts.txt', lang) 165 | 166 | def load_language(self): 167 | """Load language, background knowledge, and clauses from files. 168 | """ 169 | preds = self.load_preds(self.base_path + 'preds.txt') + \ 170 | self.load_neural_preds(self.base_path + 'neural_preds.txt') 171 | consts = self.load_consts(self.base_path + 'consts.txt') 172 | funcs = self.load_funcs(self.base_path + 'funcs.txt') 173 | #terms = self.load_terms(self.base_path + 'terms.txt') 174 | lang = Language(preds, funcs, consts) 175 | return lang 176 | -------------------------------------------------------------------------------- /neumann/neumann/fol/exp.lark: -------------------------------------------------------------------------------- 1 | clause : atom ":-" body 2 | 3 | body : atom "," body 4 | | atom "." 5 | | "." 6 | 7 | atom : predicate "(" args ")" 8 | 9 | args : term "," args 10 | | term 11 | 12 | term : functor "(" args ")" 13 | | const 14 | | variable 15 | 16 | const : /[a-z0-9\*]+/ 17 | 18 | variable : /[A-Z0-9]+/ 19 | 20 | functor : /[a-z0-9]+/ 21 | 22 | predicate : /[a-z0-9\_]+/ 23 | 24 | var_name : /[A-Z]/ 25 | small_chars : /[a-z0-9]+/ 26 | chars : /[^\+\|\s\(\)']+/[/\n+/] 27 | allchars : /[^']+/[/\n+/] 28 | -------------------------------------------------------------------------------- /neumann/neumann/fol/exp_parser.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from lark import Lark 3 | from lark import Transformer 4 | from .logic import * 5 | 6 | 7 | def flatten(x): return [z for y in x for z in ( 8 | flatten(y) if hasattr(y, '__iter__') and not isinstance(y, str) else (y,))] 9 | 10 | 11 | class ExpTree(Transformer): 12 | '''Functions to parse strings into logical objects using Lark 13 | 14 | Attrs: 15 | lang (language): the language of first-order logic. 16 | ''' 17 | 18 | def __init__(self, lang): 19 | self.lang = lang 20 | 21 | def clause(self, trees): 22 | head = trees[0] 23 | body = flatten([trees[1]]) 24 | return Clause(head, body) 25 | 26 | def body(self, trees): 27 | if len(trees) == 0: 28 | return [] 29 | elif len(trees) == 1: 30 | return trees[0] 31 | else: 32 | return [trees[0]] + trees[1:] 33 | 34 | def atom(self, trees): 35 | pred = trees[0] 36 | args = flatten([trees[1]]) 37 | return Atom(pred, args) 38 | 39 | def args(self, content): 40 | if len(content) == 1: 41 | return content[0] 42 | else: 43 | return [content[0]] + content[1:] 44 | 45 | def const(self, name): 46 | dtype = self.lang.get_by_dtype_name(name[0]) 47 | return Const(name[0], dtype) 48 | 49 | def variable(self, name): 50 | return Var(name[0]) 51 | 52 | def functor(self, name): 53 | func = [f for f in self.lang.funcs if f.name == name[0]][0] 54 | return func 55 | 56 | def predicate(self, alphas): 57 | pred = [p for p in self.lang.preds if p.name == alphas[0]][0] 58 | return pred 59 | 60 | def term(self, content): 61 | if type(content[0]) == FuncSymbol: 62 | func = content[0] 63 | args = flatten([content[1]]) 64 | return FuncTerm(func, args) 65 | else: 66 | return content[0] 67 | 68 | def small_chars(self, content): 69 | return content[0] 70 | -------------------------------------------------------------------------------- /neumann/neumann/fol/language.py: -------------------------------------------------------------------------------- 1 | from .logic import Var 2 | import itertools 3 | 4 | 5 | class Language(object): 6 | """Language of first-order logic. 7 | 8 | A class of languages in first-order logic. 9 | 10 | Args: 11 | preds (List[Predicate]): A set of predicate symbols. 12 | funcs (List[FunctionSymbol]): A set of function symbols. 13 | consts (List[Const]): A set of constants. 14 | 15 | Attrs: 16 | preds (List[Predicate]): A set of predicate symbols. 17 | funcs (List[FunctionSymbol]): A set of function symbols. 18 | consts (List[Const]): A set of constants. 19 | """ 20 | 21 | def __init__(self, preds, funcs, consts): 22 | self.preds = preds 23 | self.funcs = funcs 24 | self.consts = consts 25 | self.var_gen = VariableGenerator() 26 | 27 | 28 | def __str__(self): 29 | s = "===Predicates===\n" 30 | for pred in self.preds: 31 | s += pred.__str__() + '\n' 32 | s += "===Function Symbols===\n" 33 | for func in self.funcs: 34 | s += func.__str__() + '\n' 35 | s += "===Constants===\n" 36 | for const in self.consts: 37 | s += const.__str__() + '\n' 38 | return s 39 | 40 | def __repr__(self): 41 | return self.__str__() 42 | 43 | def get_var_and_dtype(self, atom): 44 | """Get all variables in an input atom with its dtypes by enumerating variables in the input atom. 45 | 46 | Note: 47 | with the assumption with function free atoms. 48 | 49 | Args: 50 | atom (Atom): The atom. 51 | 52 | Returns: 53 | List of tuples (var, dtype) 54 | """ 55 | var_dtype_list = [] 56 | for i, arg in enumerate(atom.terms): 57 | if arg.is_var(): 58 | dtype = atom.pred.dtypes[i] 59 | var_dtype_list.append((arg, dtype)) 60 | return var_dtype_list 61 | 62 | def get_by_dtype(self, dtype): 63 | """Get constants that match given dtypes. 64 | 65 | Args: 66 | dtype (DataType): The data type. 67 | 68 | Returns: 69 | List of constants whose data type is the given data type. 70 | """ 71 | return [c for c in self.consts if c.dtype == dtype] 72 | 73 | def get_by_dtype_name(self, dtype_name): 74 | """Get constants that match given dtype name. 75 | 76 | Args: 77 | dtype_name (str): The name of the data type to be used. 78 | 79 | Returns: 80 | List of constants whose datatype has the given name. 81 | """ 82 | return [c for c in self.consts if c.dtype.name == dtype_name] 83 | 84 | def term_index(self, term): 85 | """Get the index of a term in the language. 86 | 87 | Args: 88 | term (Term): The term to be used. 89 | 90 | Returns: 91 | int: The index of the term. 92 | """ 93 | terms = self.get_by_dtype(term.dtype) 94 | return terms.index(term) 95 | 96 | def get_const_by_name(self, const_name): 97 | """Get the constant by its name. 98 | 99 | Args: 100 | const_name (str): The name of the constant. 101 | 102 | Returns: 103 | Const: The matched constant with the given name. 104 | 105 | """ 106 | const = [c for c in self.consts if const_name == c.name] 107 | assert len(const) == 1, 'Too many match in ' + const_name 108 | return const[0] 109 | 110 | def get_pred_by_name(self, pred_name): 111 | """Get the predicate by its name. 112 | 113 | Args: 114 | pred_name (str): The name of the predicate. 115 | 116 | Returns: 117 | Predicate: The matched preicate with the given name. 118 | """ 119 | pred = [pred for pred in self.preds if pred.name == pred_name] 120 | assert len(pred) == 1, 'Too many or less match in ' + pred_name 121 | return pred[0] 122 | 123 | 124 | class DataType(object): 125 | """Data type in first-order logic. 126 | 127 | A class of data types in first-order logic. 128 | 129 | Args: 130 | name (str): The name of the data type. 131 | 132 | Attrs: 133 | name (str): The name of the data type. 134 | """ 135 | 136 | def __init__(self, name): 137 | self.name = name 138 | 139 | def __eq__(self, other): 140 | if type(other) == str: 141 | return self.name == other 142 | else: 143 | return self.name == other.name 144 | 145 | def __str__(self): 146 | return self.name 147 | 148 | def __repr__(self): 149 | return self.__str__() 150 | 151 | def __hash__(self): 152 | return hash(self.__str__()) 153 | 154 | 155 | 156 | 157 | class VariableGenerator(): 158 | """ 159 | generator of variables 160 | Parameters 161 | __________ 162 | base_name : str 163 | base name of variables 164 | """ 165 | 166 | def __init__(self, base_name='x'): 167 | self.counter = 0 168 | self.base_name = base_name 169 | 170 | def generate(self): 171 | """ 172 | generate variable with new name 173 | Returns 174 | ------- 175 | generated_var : .logic.Var 176 | generated variable 177 | """ 178 | generated_var = Var(self.base_name + str(self.counter)) 179 | self.counter += 1 180 | return generated_var 181 | -------------------------------------------------------------------------------- /neumann/neumann/fol/logic_ops.py: -------------------------------------------------------------------------------- 1 | from .logic import Clause, Atom, FuncTerm, Const, Var 2 | 3 | 4 | def subs(exp, target_var, const): 5 | """ 6 | Substitute var = const 7 | 8 | Inputs 9 | ------ 10 | exp : .logic.CLause .logic.Atom .logic.FuncTerm .logic.Const .logic.Var 11 | logical expression 12 | atom, clause, or term 13 | target_var : .logic.Var 14 | target variable of the substitution 15 | const : .logic.Const 16 | constant to be substituted 17 | 18 | Returns 19 | ------- 20 | exp : .logic.CLause .logic.Atom .logic.FuncTerm .logic.Const .logic.Var 21 | result of the substitution 22 | logical expression 23 | atom, clause, or term 24 | """ 25 | if type(exp) == Clause: 26 | head = subs(exp.head, target_var, const) 27 | body = [subs(bi, target_var, const) for bi in exp.body] 28 | return Clause(head, body) 29 | elif type(exp) == Atom: 30 | terms = [subs(term, target_var, const) for term in exp.terms] 31 | return Atom(exp.pred, terms) 32 | elif type(exp) == FuncTerm: 33 | args = [subs(arg, target_var, const) for arg in exp.args] 34 | return FuncTerm(exp.func_symbol, args) 35 | elif type(exp) == Var: 36 | if exp.name == target_var.name: 37 | return const 38 | else: 39 | return exp 40 | elif type(exp) == Const: 41 | return exp 42 | else: 43 | assert 1 == 0, 'Unknown type in substitution: ' + str(exp) 44 | 45 | 46 | def subs_list(exp, theta_list): 47 | if type(exp) == Clause: 48 | head = exp.head 49 | body = exp.body 50 | for target_var, const in theta_list: 51 | head = subs(head, target_var, const) 52 | body = [subs(bi, target_var, const) for bi in body] 53 | return Clause(head, body) 54 | elif type(exp) == Atom: 55 | terms = exp.terms 56 | for target_var, const in theta_list: 57 | terms = [subs(term, target_var, const) for term in terms] 58 | return Atom(exp.pred, terms) 59 | #elif type(exp) == FuncTerm: 60 | # for target_var, const in theta_list: 61 | # args = [subs(arg, target_var, const) for arg in exp.args] 62 | # return FuncTerm(exp.func_symbol, args) 63 | #elif type(exp) == Var: 64 | # if exp.name == target_var.name: 65 | # return const 66 | # else: 67 | # return exp 68 | #elif type(exp) == Const: 69 | # return exp 70 | else: 71 | assert 1 == 0, 'Unknown type in substitution: ' + str(exp) 72 | 73 | 74 | 75 | def __subs_list(clause, theta_list): 76 | """ 77 | perform list of substitutions 78 | 79 | Inputs 80 | ------ 81 | clause : .logic.Clause 82 | target clause 83 | theta_list : List[(.logic.Var, .logic.Const)] 84 | list of substitute operations to be performed 85 | """ 86 | result = clause 87 | for theta in theta_list: 88 | result = subs(result, theta[0], theta[1]) 89 | return result 90 | 91 | 92 | def unify(atoms): 93 | """ 94 | Unification of first-order logic expressions 95 | details in [Foundations of Inductive Logic Programming. Nienhuys-Cheng, S.-H. et.al. 1997.] 96 | 97 | Inputs 98 | ------ 99 | atoms : List[.logic.Atom] 100 | Returns 101 | ------- 102 | flag : bool 103 | unifiable or not 104 | unifier : List[(.logic.Var, .logic.Const)] 105 | unifiable - unifier (list of substitutions) 106 | not unifiable - empty list 107 | """ 108 | # empty set 109 | if len(atoms) == 0: 110 | return (1, []) 111 | # check predicates 112 | for i in range(len(atoms)-1): 113 | if atoms[i].pred != atoms[i+1].pred: 114 | return (0, []) 115 | 116 | # check all the same 117 | all_same_flag = True 118 | for i in range(len(atoms)-1): 119 | all_same_flag = all_same_flag and (atoms[i] == atoms[i+1]) 120 | if all_same_flag: 121 | return (1, []) 122 | 123 | k = 0 124 | theta_list = [] 125 | 126 | atoms_ = atoms 127 | while(True): 128 | # check terms from left 129 | for i in range(atoms_[0].pred.arity): 130 | # atom_1(term_1, ..., term_i, ...), ..., atom_j(term_1, ..., term_i, ...), ... 131 | terms_i = [atoms_[j].terms[i] for j in range(len(atoms_))] 132 | disagree_flag, disagree_set = get_disagreements(terms_i) 133 | if not disagree_flag: 134 | continue 135 | var_list = [x for x in disagree_set if type(x) == Var] 136 | if len(var_list) == 0: 137 | return (0, []) 138 | else: 139 | # substitute 140 | subs_var = var_list[0] 141 | # find term where the var does not occur 142 | subs_flag, subs_term = find_subs_term( 143 | subs_var, disagree_set) 144 | if subs_flag: 145 | k += 1 146 | theta_list.append((subs_var, subs_term)) 147 | subs_flag = True 148 | # UNIFICATION SUCCESS 149 | atoms_ = [subs(atom, subs_var, subs_term) 150 | for atom in atoms_] 151 | if is_singleton(atoms_): 152 | return (1, theta_list) 153 | else: 154 | # UNIFICATION FAILED 155 | return (0, []) 156 | 157 | 158 | def get_disagreements(terms): 159 | """ 160 | get desagreements in the unification algorithm 161 | details in [Foundations of Inductive Logic Programming. Nienhuys-Cheng, S.-H. et.al. 1997.] 162 | 163 | Inputs 164 | ------ 165 | temrs : List[Term] 166 | Term : .logic.FuncTerm .logic.Const .logic.Var 167 | list of terms 168 | 169 | Returns 170 | ------- 171 | disagree_flag : bool 172 | flag of disagreement 173 | disagree_terms : List[Term] 174 | Term : .logic.FuncTerm .logic.Const .logic.Var 175 | terms of disagreement 176 | """ 177 | disagree_flag, disagree_index = get_disagree_index(terms) 178 | if disagree_flag: 179 | disagree_terms = [term.get_ith_term( 180 | disagree_index) for term in terms] 181 | return disagree_flag, disagree_terms 182 | else: 183 | return disagree_flag, [] 184 | 185 | 186 | def get_disagree_index(terms): 187 | """ 188 | get the desagreement index in the unification algorithm 189 | details in [Foundations of Inductive Logic Programming. Nienhuys-Cheng, S.-H. et.al. 1997.] 190 | 191 | Inputs 192 | ------ 193 | terms : List[Term] 194 | Term : .logic.FuncTerm .logic.Const .logic.Var 195 | list of terms 196 | 197 | Returns 198 | ------- 199 | disagree_flag : bool 200 | flag of disagreement 201 | disagree_index : int 202 | index of the disagreement term in the args of predicates 203 | """ 204 | symbols_list = [term.to_list() for term in terms] 205 | n = min([len(symbols) for symbols in symbols_list]) 206 | for i in range(n): 207 | ith_symbols = [symbols[i] for symbols in symbols_list] 208 | for j in range(len(ith_symbols)-1): 209 | if ith_symbols[j] != ith_symbols[j+1]: 210 | return (True, i) 211 | # all the same terms 212 | return (False, 0) 213 | 214 | 215 | def occur_check(variable, term): 216 | """ 217 | occur check function 218 | details in [Foundations of Inductive Logic Programming. Nienhuys-Cheng, S.-H. et.al. 1997.] 219 | 220 | Inputs 221 | ------ 222 | variable : .logic.Var 223 | term : Term 224 | Term : .logic.FuncTerm .logic.Const .logic.Var 225 | 226 | Returns 227 | ------- 228 | occur_flag : bool 229 | flag ofthe occurance of the variable 230 | """ 231 | if type(term) == Const: 232 | return False 233 | elif type(term) == Var: 234 | return variable.name == term.name 235 | else: 236 | # func term case 237 | for arg in term.args: 238 | if occur_check(variable, arg): 239 | return True 240 | return False 241 | 242 | 243 | def find_subs_term(subs_var, disagree_set): 244 | """ 245 | Find term where the var does not occur 246 | 247 | Inputs 248 | ------ 249 | subs_var : .logic.Var 250 | disagree_set : List[.logic.Term] 251 | 252 | Returns 253 | ------- 254 | flag : bool 255 | term : .logic.Term 256 | """ 257 | for term in disagree_set: 258 | if not occur_check(subs_var, term): 259 | return True, term 260 | return False, Term() 261 | 262 | 263 | def is_singleton(atoms): 264 | """ 265 | returns whether all the input atoms are the same or not 266 | 267 | Inputs 268 | ------ 269 | atoms: List[.logic.Atom] 270 | [a_1, a_2, ..., a_n] 271 | 272 | Returns 273 | ------- 274 | flag : bool 275 | a_1 == a_2 == ... == a_n 276 | """ 277 | result = True 278 | for i in range(len(atoms)-1): 279 | result = result and (atoms[i] == atoms[i+1]) 280 | return result 281 | 282 | 283 | def is_entailed(e, clause, facts, n): 284 | """ 285 | decision function of ground atom is entailed by a clause and facts by n-step inference 286 | 287 | Inputs 288 | ------ 289 | e : .logic.Atom 290 | ground atom 291 | clause : .logic.Clause 292 | clause 293 | facts : List[.logic.Atom] 294 | set of facts 295 | n : int 296 | infer step 297 | 298 | Returns 299 | ------- 300 | flag : bool 301 | ${clause} \cup facts \models e$ 302 | """ 303 | if len(clause.body) == 0: 304 | flag, thetas = unify([e, clause.head]) 305 | return flag 306 | if len(clause.body) == 1: 307 | return e in t_p_n(clause, facts, n) 308 | 309 | 310 | def t_p_n(clause, facts, n): 311 | """ 312 | applying the T_p operator n-times taking union of results 313 | 314 | Inputs 315 | ------ 316 | clause : .logic.Clause 317 | clause 318 | facts : List[.logic.Atom] 319 | set of facts 320 | n : int 321 | infer step 322 | 323 | Returns 324 | ------- 325 | G : Set[.logic.Atom] 326 | set of ground atoms entailed by ${clause} \cup facts$ 327 | """ 328 | G = set(facts) 329 | for i in range(n): 330 | G = G.union(t_p(clause, G)) 331 | return G 332 | 333 | 334 | def t_p(clause, facts): 335 | """ 336 | T_p operator 337 | limited to clauses with one body atom 338 | 339 | Inputs 340 | ------ 341 | clause : .logic.Clause 342 | clause 343 | facts : List[.logic.Atom] 344 | set of facts 345 | 346 | Returns 347 | ------- 348 | S : List[.logic.Atom] 349 | set of ground atoms entailed by one step forward-chaining inference 350 | """ 351 | # |body| == 1 352 | S = [] 353 | unify_dic = {} 354 | for fact in facts: 355 | flag, thetas = unify([clause.body[0], fact]) 356 | if flag: 357 | head_fact = subs_list(clause.head, thetas) 358 | S = S + [head_fact] 359 | return list(set(S)) 360 | -------------------------------------------------------------------------------- /neumann/neumann/img2facts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Img2Facts(nn.Module): 6 | """Img2Facts module converts raw images into a form of probabilistic facts. Each image is fed to the perception module, and the result is concatenated and fed to facts converter. 7 | """ 8 | 9 | def __init__(self, perception_module, facts_converter, atoms, bk, device): 10 | super().__init__() 11 | self.pm = perception_module 12 | self.fc = facts_converter 13 | self.atoms = atoms 14 | self.bk = bk 15 | self.device = device 16 | 17 | def forward(self, x): 18 | # x: batch_size * num_iamges * C * W * H 19 | num_images = x.size(1) 20 | # feed each input image 21 | zs_list = [] 22 | # TODO: concat image ids 23 | for i in range(num_images): 24 | # B * E * D 25 | zs = self.pm(x[:, i, :, :]) 26 | # image_ids = torch.tensor([i for i in range(x.size(0))]).unsqueeze(0).unsqueeze(0).to(self.device) 27 | image_ids = torch.tensor(i).expand( 28 | (zs.size(0), zs.size(1), 1)).to(self.device) 29 | zs = torch.cat((zs, image_ids), dim=2) 30 | zs_list.append(zs) 31 | zs = torch.cat(zs_list, dim=1) 32 | # zs: batch_size * num_images * num_objects * num_attributes 33 | return self.fc(zs) 34 | 35 | class Img2FactsWithQuery(nn.Module): 36 | """Img2Facts module converts raw images into a form of probabilistic facts. Each image is fed to the perception module, and the result is concatenated and fed to facts converter. 37 | """ 38 | 39 | def __init__(self, perception_module, facts_converter, atoms, bk, device): 40 | super().__init__() 41 | self.pm = perception_module 42 | self.fc = facts_converter 43 | self.atoms = atoms 44 | self.bk = bk 45 | self.device = device 46 | 47 | def forward(self, x, query): 48 | # x: batch_size * num_iamges * C * W * H 49 | num_images = x.size(1) 50 | # feed each input image 51 | zs_list = [] 52 | # TODO: concat image ids 53 | for i in range(num_images): 54 | # B * E * D 55 | zs = self.pm(x[:, i, :, :]) 56 | # image_ids = torch.tensor([i for i in range(x.size(0))]).unsqueeze(0).unsqueeze(0).to(self.device) 57 | image_ids = torch.tensor(i).expand( 58 | (zs.size(0), zs.size(1), 1)).to(self.device) 59 | zs = torch.cat((zs, image_ids), dim=2) 60 | zs_list.append(zs) 61 | zs = torch.cat(zs_list, dim=1) 62 | # zs: batch_size * num_images * num_objects * num_attributes 63 | return self.fc(zs, query) -------------------------------------------------------------------------------- /neumann/neumann/lark/exp.lark: -------------------------------------------------------------------------------- 1 | clause : atom ":-" body 2 | 3 | body : atom "," body 4 | | atom "." 5 | | "." 6 | 7 | atom : predicate "(" args ")" 8 | 9 | args : term "," args 10 | | term 11 | 12 | term : functor "(" args ")" 13 | | const 14 | | variable 15 | 16 | const : /[a-z0-9\*\_]+/ 17 | 18 | variable : /[A-Z]+[A-Za-z0-9]*/ 19 | 20 | functor : /[a-z0-9]+/ 21 | 22 | predicate : /[a-z0-9\_]+/ 23 | 24 | var_name : /[A-Z]/ 25 | small_chars : /[a-z0-9]+/ 26 | chars : /[^\+\|\s\(\)']+/[/\n+/] 27 | allchars : /[^']+/[/\n+/] 28 | -------------------------------------------------------------------------------- /neumann/neumann/message_passing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric.nn 4 | from torch import Tensor 5 | from torch.nn import Linear, ReLU, Sequential 6 | from torch_geometric.nn import MessagePassing 7 | from torch_scatter import gather_csr, scatter, segment_csr 8 | 9 | # from neural_utils import MLP 10 | from .scatter import * 11 | 12 | 13 | class Atom2ConjConv(MessagePassing): 14 | """Message Passing class for messages from atom nodes to conjunction nodes. 15 | 16 | Args: 17 | soft_logic (softlogic): An implementation of the soft-logic operations. 18 | device (device): A device. 19 | """ 20 | 21 | def __init__(self, soft_logic, device): 22 | super().__init__(aggr='add') 23 | self.soft_logic = soft_logic 24 | self.device = device 25 | 26 | def forward(self, x, edge_index, conj_node_idxs, batch_size): 27 | """Perform atom2conj message-passing. 28 | Args: 29 | x (tensor): A data (node features). 30 | edge_index (tensor): An edge index. 31 | conj_node_idxs (tensor): A list of indicies of conjunction nodes extended for the batch size. 32 | batch_size (int): A batch size. 33 | """ 34 | return self.propagate(edge_index, x=x, conj_node_idxs=conj_node_idxs, batch_size=batch_size).view((batch_size, -1)) 35 | 36 | def message(self, x_j): 37 | """Compute the message. 38 | """ 39 | return x_j 40 | 41 | def update(self, message, x, conj_node_idxs): 42 | """Update the node features. 43 | Args: 44 | message (tensor, [node_num, node_dim]): Messages aggregated by `aggregate`. 45 | x (tensor, [node_num, node_dim]): Node features on the previous step. 46 | """ 47 | return self.soft_logic._or(torch.stack([x[conj_node_idxs], message[conj_node_idxs]])) 48 | 49 | def aggregate(self, inputs, index): 50 | """Aggregate the messages. 51 | Args: 52 | inputs (tensor, [num_edges, num_features]): The values come from each edge. 53 | index (tensor, [2, num_of_edges]): The indices of the terminal nodes (conjunction nodes). 54 | """ 55 | return scatter_mul(inputs, index, dim=0) 56 | 57 | 58 | class Conj2AtomConv(MessagePassing): 59 | """Message Passing class for messages from atom nodes to conjunction nodes. 60 | 61 | Args: 62 | soft_logic (softlogic): An implementation of the soft-logic operations. 63 | device (device): A device. 64 | """ 65 | 66 | def __init__(self, soft_logic, device): 67 | super().__init__(aggr='add') 68 | self.soft_logic = soft_logic 69 | self.device = device 70 | self.eps = 1e-4 71 | # self.linear = nn.Linear() 72 | 73 | def forward(self, x, edge_index, edge_weight, edge_clause_index, atom_node_idxs, n_nodes, batch_size): 74 | """Perform conj2atom message-passing. 75 | Args: 76 | edge_index (tensor): An edge index. 77 | x (tensor): A data (node features). 78 | edge_weight (tensor): The edge weights. 79 | edge_clause_index (tensor): A list of indices of clauses representing which clause produced each edge in the reasoning graph. 80 | atom_node_idxs (tensor): A list of indicies of atom nodes extended for the batch size. 81 | n_nodes (int): The number of nodes in the reasoning graph. 82 | batch_size (int): A batch size. 83 | """ 84 | return self.propagate(edge_index, x=x, edge_weight=edge_weight, edge_clause_index=edge_clause_index, 85 | atom_node_idxs=atom_node_idxs, n_nodes=n_nodes, batch_size=batch_size).view(batch_size, -1) 86 | 87 | def message(self, x_j, edge_weight, edge_clause_index): 88 | """Compute the message. 89 | """ 90 | return edge_weight.view(-1, 1) * x_j 91 | 92 | def update(self, message, x, atom_node_idxs): 93 | """Update the node features. 94 | Args: 95 | message (tensor, [node_num, node_dim]): Messages aggregated by `aggregate`. 96 | x (tensor, [node_num, node_dim]): Node features on the previous step. 97 | """ 98 | return self.soft_logic._or(torch.stack([x[atom_node_idxs], message[atom_node_idxs]])) 99 | 100 | def aggregate(self, inputs, index, n_nodes): 101 | """Aggregate the messages. 102 | Args: 103 | inputs (tensor, [num_edges, num_features]): The values come from each edge. 104 | index (tensor, [2, num_of_edges]): The indices of the terminal nodes (conjunction nodes). 105 | n_nodes (int): The number of nodes in the reasoning graph. 106 | """ 107 | # softor 108 | # gamma = 0.05 109 | gamma = 0.015 110 | log_sum_exp = gamma * \ 111 | self._logsumexp((inputs) * (1/gamma), index, n_nodes) 112 | if log_sum_exp.max() > 1.0: 113 | return log_sum_exp / log_sum_exp.max() 114 | else: 115 | return log_sum_exp 116 | 117 | def _logsumexp(self, inputs, index, n_nodes): 118 | return torch.log(scatter(src=torch.exp(inputs), index=index, dim=0, dim_size=n_nodes, reduce='sum') + self.eps) 119 | 120 | 121 | class MessagePassingModule(torch.nn.Module): 122 | """The bi-directional message-passing module. 123 | 124 | Args: 125 | soft_logic (softlogic): An implementation of the soft-logic operations. 126 | device (device): A device. 127 | T (int): The number of steps for reasoning. 128 | """ 129 | 130 | def __init__(self, soft_logic, device, T): 131 | super().__init__() 132 | self.soft_logic = soft_logic 133 | self.device = device 134 | self.T = T 135 | self.atom2conj = Atom2ConjConv(soft_logic, device) 136 | self.conj2atom = Conj2AtomConv(soft_logic, device) 137 | 138 | def forward(self, data, clause_weights, edge_clause_index, edge_type, atom_node_idxs, conj_node_idxs, batch_size, explain=False): 139 | """ 140 | Args: 141 | data (torch_geometric.Data): A logic progam and probabilistic facts as a graph data. 142 | clause_weights (torch.Tensor): Weights for clauses. 143 | edge_clause_index (torch.Tensor): A clause indices for each edge, representing which clause produces the edge. 144 | edge_type (torch.Tensor): Edge types (atom2conj:0 or conj2atom:1) 145 | atom_node_idxs (torch.Tensor): The indices of atom nodes. 146 | conj_node_idxs (torch.Tensor): The indices of conjunction nodes. 147 | batch_size (int): The batch size of input. 148 | """ 149 | x_atom = data.x[atom_node_idxs] 150 | x_conj = data.x[conj_node_idxs] 151 | 152 | 153 | # filter the edge index using the edge type obtaining the set of atom2conj edges and the set of conj2atom edges 154 | atom2conj_edge_index = self._filter_edge_index( 155 | data.edge_index, edge_type, 'atom2conj', batch_size).to(self.device) 156 | conj2atom_edge_index = self._filter_edge_index( 157 | data.edge_index, edge_type, 'conj2atom', batch_size).to(self.device) 158 | 159 | # filter the edge-clause index using the edge type obtaining the set of atom2conj edge-clause indices and the set of conj2atom edge-clause indices 160 | conj2atom_edge_clause_index = self._filter_edge_clause_index( 161 | edge_clause_index, edge_type, batch_size).to(self.device) 162 | edge_weight = torch.gather( 163 | input=clause_weights, dim=0, index=conj2atom_edge_clause_index) 164 | 165 | n_nodes = data.x.size(0) 166 | 167 | self.x_atom_list = [data.x[atom_node_idxs].view((batch_size, -1)).detach().cpu().numpy()[:,1:]] 168 | x = data.x 169 | 170 | 171 | # dummy variable to compute inpute gradients 172 | if explain: 173 | #print(x[atom_node_idxs], x[atom_node_idxs].shape) 174 | self.dummy_zeros = torch.zeros_like(x[atom_node_idxs], requires_grad=True).to(torch.float32).to(self.device) 175 | self.dummy_zeros.requires_grad_() 176 | self.dummy_zeros.retain_grad() 177 | #print(self.dummy_zeros) 178 | # add dummy zeros to get input gradients 179 | x[atom_node_idxs] = x[atom_node_idxs] + self.dummy_zeros 180 | 181 | # iterate message passing T times 182 | for t in range(self.T): 183 | 184 | # step 1: Atom -> Conj 185 | x_conj_new = self.atom2conj( 186 | x, atom2conj_edge_index, conj_node_idxs, batch_size) 187 | 188 | # create new tensor (node features) by updating conjunction embeddings 189 | x = self._cat_x_atom_x_conj( 190 | x[atom_node_idxs].view((batch_size, -1)), x_conj_new) 191 | 192 | # step 2: Conj -> Atom 193 | x_atom_new = self.conj2atom(x=x, edge_weight=edge_weight, edge_index=conj2atom_edge_index, edge_clause_index=edge_clause_index, atom_node_idxs=atom_node_idxs, 194 | n_nodes=n_nodes, batch_size=batch_size) 195 | self.x_atom_list.append(x_atom_new.detach().cpu().numpy()[:,1:]) 196 | 197 | x = self._cat_x_atom_x_conj(x_atom_new, x_conj_new) 198 | self.x_atom_final = x_atom_new 199 | return x_atom_new 200 | 201 | def _cat_x_atom_x_conj(self, x_atom, x_conj): 202 | """Concatenate the features of atom ndoes and those of conj nodes. 203 | Args: 204 | x_atom : batch_size * n_atom 205 | x_conj : batch_size * n_conj 206 | 207 | Returns: 208 | [x_atom_1, x_conj_1, x_atom_2, x_conj_2, ...] 209 | """ 210 | xs = [] 211 | for i in range(x_atom.size(0)): 212 | x_i = torch.cat([x_atom[i], x_conj[i]]) 213 | xs.append(x_i) 214 | return torch.cat(xs).unsqueeze(-1) 215 | 216 | def _filter_edge_clause_index(self, edge_clause_index, edge_type, batch_size): 217 | """Filter the edge index by the edge type. 218 | """ 219 | edge_clause_index = torch.stack( 220 | [edge_clause_index for i in range(batch_size)]).view((-1)) 221 | edge_type = torch.stack([edge_type for i in range(batch_size)]) 222 | mask = (edge_type == 1).view((-1)) 223 | return edge_clause_index[mask] 224 | 225 | def _filter_edge_index(self, edge_index, edge_type, mode, batch_size): 226 | """Filter the edge index by the edge type. 227 | """ 228 | edge_type = torch.stack([edge_type for i in range(batch_size)]) 229 | if mode == 'atom2conj': 230 | mask = (edge_type == 0).view((-1)) 231 | return edge_index[:, mask] 232 | elif mode == 'conj2atom': 233 | mask = (edge_type == 1).view((-1)) 234 | return edge_index[:, mask] 235 | else: 236 | assert 0, "Invalid mode in _filter_edge_index" 237 | -------------------------------------------------------------------------------- /neumann/neumann/mode_declaration.py: -------------------------------------------------------------------------------- 1 | from fol.language import DataType 2 | 3 | 4 | class ModeDeclaration(object): 5 | """from https://www.cs.ox.ac.uk/activities/programinduction/Aleph/aleph.html 6 | p(ModeType, ModeType,...) 7 | Here are some examples of how they appear in a file: 8 | :- mode(1,mem(+number,+list)). 9 | :- mode(1,dec(+integer,-integer)). 10 | :- mode(1,mult(+integer,+integer,-integer)). 11 | :- mode(1,plus(+integer,+integer,-integer)). 12 | :- mode(1,(+integer)=(#integer)). 13 | :- mode(*,has_car(+train,-car)). 14 | Each ModeType is either (a) simple; or (b) structured. 15 | A simple ModeType is one of: 16 | (a) +T specifying that when a literal with predicate symbol p appears in a 17 | hypothesised clause, the corresponding argument should be an "input" variable of type T; 18 | (b) -T specifying that the argument is an "output" variable of type T; or 19 | (c) #T specifying that it should be a constant of type T. 20 | All the examples above have simple modetypes. 21 | A structured ModeType is of the form f(..) where f is a function symbol, 22 | each argument of which is either a simple or structured ModeType. 23 | Here is an example containing a structured ModeType: 24 | To make this more clear, here is an example for the mode declarations for 25 | the grandfather task from 26 | above::- modeh(1, grandfather(+human, +human)).:- 27 | modeb(*, parent(-human, +human)).:- 28 | modeb(*, male(+human)). 29 | The first mode states that the head of the rule 30 | (and therefore the targetpredicate) will be the atomgrandfather. 31 | Its parameters have to be of the typehuman. 32 | The + annotation says that the rule head needs two variables. 33 | Thesecond mode declaration states theparentatom and declares again 34 | that theparameters have to be of type human. 35 | Here, the + at the second parametertells, that the system is only allowed to 36 | introduce the atomparentin the clauseif it already contains a variable of type human. 37 | The first attribute introduces a new variable into the clause. 38 | The modes consist of a recall n that states how many versions of the 39 | literal are allowed in a rule and an atom with place-markers that state the literal to-gether 40 | with annotations on input- and output-variables as well as constants (see[Mug95]). 41 | 42 | Args: 43 | recall (int): The recall number i.e. how many times the declaration can be instanciated 44 | pred (Predicate): The predicate. 45 | mode_terms (ModeTerm): Terms for mode declarations. 46 | """ 47 | 48 | def __init__(self, mode_type, recall, pred, mode_terms, ordered=True): 49 | self.mode_type = mode_type # head or body 50 | self.recall = recall 51 | self.pred = pred 52 | self.mode_terms = mode_terms 53 | self.ordered = ordered 54 | 55 | def __str__(self): 56 | s = 'mode_' + self.mode_type + '(' 57 | for mt in self.mode_terms: 58 | s += str(mt) 59 | s += ',' 60 | s = s[0:-1] 61 | s += ')' 62 | return s 63 | 64 | def __repr__(self): 65 | return self.__str__() 66 | 67 | def __hash__(self): 68 | return hash(self.__str__()) 69 | 70 | 71 | class ModeTerm(object): 72 | """Terms for mode declarations. It has mode (+, -, #) and data types. 73 | """ 74 | 75 | def __init__(self, mode, dtype): 76 | self.mode = mode 77 | assert mode in ['+', '-', '#'], "Invalid mode declaration." 78 | self.dtype = dtype 79 | 80 | def __str__(self): 81 | return self.mode + self.dtype.name 82 | 83 | def __repr__(self): 84 | return self.__str__() 85 | 86 | 87 | def get_mode_declarations_clevr(lang, obj_num): 88 | p_image = ModeTerm('+', DataType('image')) 89 | m_object = ModeTerm('-', DataType('object')) 90 | p_object = ModeTerm('+', DataType('object')) 91 | s_color = ModeTerm('#', DataType('color')) 92 | s_shape = ModeTerm('#', DataType('shape')) 93 | s_material = ModeTerm('#', DataType('material')) 94 | s_size = ModeTerm('#', DataType('size')) 95 | 96 | # modeh_1 = ModeDeclaration('head', 'kp', p_image) 97 | 98 | """ 99 | kp1(X):-in(O1,X),in(O2,X),size(O1,large),shape(O1,cube),size(O2,large),shape(O2,cylinder). 100 | kp2(X):-in(O1,X),in(O2,X),size(O1,small),material(O1,metal),shape(O1,cube),size(O2,small),shape(O2,sphere). 101 | kp3(X):-in(O1,X),in(O2,X),size(O1,large),color(O1,blue),shape(O1,sphere),size(O2,small),color(O2,yellow),shape(O2,sphere).""" 102 | 103 | modeb_list = [ 104 | #ModeDeclaration('body', obj_num, lang.get_pred_by_name( 105 | # 'in'), [m_object, p_image]), 106 | ModeDeclaration('body', 2, lang.get_pred_by_name( 107 | 'color'), [p_object, s_color]), 108 | ModeDeclaration('body', 2, lang.get_pred_by_name( 109 | 'shape'), [p_object, s_shape]), 110 | #ModeDeclaration('body', 1, lang.get_pred_by_name( 111 | # 'material'), [p_object, s_material]), 112 | ModeDeclaration('body', 2, lang.get_pred_by_name( 113 | 'size'), [p_object, s_size]), 114 | ] 115 | return modeb_list 116 | 117 | 118 | def get_mode_declarations_kandinsky(lang, obj_num): 119 | p_image = ModeTerm('+', DataType('image')) 120 | m_object = ModeTerm('-', DataType('object')) 121 | p_object = ModeTerm('+', DataType('object')) 122 | s_color = ModeTerm('#', DataType('color')) 123 | s_shape = ModeTerm('#', DataType('shape')) 124 | 125 | # modeh_1 = ModeDeclaration('head', 'kp', p_image) 126 | 127 | modeb_list = [ 128 | #ModeDeclaration('body', obj_num, lang.get_pred_by_name( 129 | # 'in'), [m_object, p_image]), 130 | ModeDeclaration('body', 1, lang.get_pred_by_name( 131 | 'color'), [p_object, s_color]), 132 | ModeDeclaration('body', 1, lang.get_pred_by_name( 133 | 'shape'), [p_object, s_shape]), 134 | ModeDeclaration('body', 1, lang.get_pred_by_name( 135 | 'same_color_pair'), [p_object, p_object], ordered=False), 136 | ModeDeclaration('body', 2, lang.get_pred_by_name( 137 | 'same_shape_pair'), [p_object, p_object], ordered=False), 138 | ModeDeclaration('b1ody', 1, lang.get_pred_by_name( 139 | 'diff_color_pair'), [p_object, p_object], ordered=False), 140 | ModeDeclaration('body', 1, lang.get_pred_by_name( 141 | 'diff_shape_pair'), [p_object, p_object], ordered=False), 142 | ModeDeclaration('body', 1, lang.get_pred_by_name( 143 | 'closeby'), [p_object, p_object], ordered=False), 144 | ModeDeclaration('body', 1, lang.get_pred_by_name('online'), [ 145 | p_object, p_object, p_object, p_object, p_object], ordered=False), 146 | # ModeDeclaration('body', 2, lang.get_pred_by_name('diff_shape_pair'), [p_object, p_object]), 147 | ] 148 | return modeb_list 149 | 150 | def get_mode_declarations_vilp(lang, dataset): 151 | p_colors = ModeTerm('+', DataType('colors')) 152 | p_color = ModeTerm('+', DataType('color')) 153 | # modeh_1 = ModeDeclaration('head', 'kp', p_image) 154 | if dataset=='member': 155 | modeb_list = [ 156 | ModeDeclaration('body', 1, lang.get_pred_by_name( 157 | 'member'), [p_color, p_colors])] 158 | elif dataset == 'delete': 159 | modeb_list = [ 160 | ModeDeclaration('body', 1, lang.get_pred_by_name( 161 | 'delete'), [p_color, p_colors, p_colors])] 162 | elif dataset == 'append': 163 | modeb_list = [ 164 | ModeDeclaration('body', 1, lang.get_pred_by_name( 165 | 'append'), [p_colors, p_colors, p_colors])] 166 | elif dataset == 'reverse': 167 | modeb_list = [ 168 | ModeDeclaration('body', 1, lang.get_pred_by_name( 169 | 'reverse'), [p_colors, p_colors, p_colors]) 170 | #ModeDeclaration('body', 1, lang.get_pred_by_name( 171 | # 'append'), [p_colors, p_colors, p_colors]), 172 | ] 173 | elif dataset == 'sort': 174 | modeb_list = [ 175 | ModeDeclaration('body', 1, lang.get_pred_by_name('perm'), [p_colors, p_colors]), 176 | ModeDeclaration('body', 1, lang.get_pred_by_name('is_sorted'), [p_colors]), 177 | ModeDeclaration('body', 1, lang.get_pred_by_name('smaller'), [p_color, p_color]), 178 | ] 179 | #ModeDeclaration('body', 1, lang.get_pred_by_name( 180 | # 'append'), [p_colors, p_colors, p_colors]), 181 | #ModeDeclaration('body', 1, lang.get_pred_by_name( 182 | # 'reverse'), [p_colors, p_colors]), 183 | #ModeDeclaration('body', 1, lang.get_pred_by_name( 184 | # 'sort'), [p_colors, p_colors]) 185 | return modeb_list 186 | 187 | def get_mode_declarations(args, lang): 188 | if args.dataset_type == 'kandinsky': 189 | return get_mode_declarations_kandinsky(lang, args.num_objects) 190 | elif args.dataset_type == 'clevr-hans': 191 | return get_mode_declarations_clevr(lang, 10) 192 | elif args.dataset_type == 'vilp': 193 | return get_mode_declarations_vilp(lang, args.dataset) 194 | else: 195 | assert False, "Invalid data type." -------------------------------------------------------------------------------- /neumann/neumann/neural_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class LogisticRegression(torch.nn.Module): 7 | def __init__(self, input_dim, output_dim=1): 8 | super(LogisticRegression, self).__init__() 9 | self.linear = torch.nn.Linear(input_dim, output_dim) 10 | 11 | def forward(self, x): 12 | y_pred = torch.sigmoid(self.linear(x)) 13 | return y_pred 14 | 15 | class MLP(nn.Module): 16 | def __init__(self, in_channels, out_channels, hidden_dim=256): 17 | super(MLP, self).__init__() 18 | # Number of input features is input_dim. 19 | self.layer_1 = nn.Linear(in_channels, hidden_dim) 20 | self.layer_2 = nn.Linear(hidden_dim, hidden_dim) 21 | self.layer_out = nn.Linear(hidden_dim, out_channels) 22 | self.relu = nn.ReLU() 23 | 24 | def forward(self, inputs): 25 | x = self.relu(self.layer_1(inputs)) 26 | x = self.relu(self.layer_2(x)) 27 | x = self.layer_out(x) 28 | #x = torch.sigmoid(x) 29 | return x -------------------------------------------------------------------------------- /neumann/neumann/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import random 4 | import time 5 | 6 | import matplotlib.pyplot as plt 7 | import networkx as nx 8 | import numpy as np 9 | import torch 10 | from rtpt import RTPT 11 | from sklearn.metrics import accuracy_score, recall_score, roc_curve 12 | from torch.utils.tensorboard import SummaryWriter 13 | from tqdm import tqdm 14 | 15 | from .logic_utils import get_lang 16 | from .neumann_utils import ( 17 | generate_captions, 18 | get_data_loader, 19 | get_model, 20 | get_prob, 21 | save_images_with_captions, 22 | to_plot_images_clevr, 23 | to_plot_images_kandinsky, 24 | ) 25 | from .tensor_encoder import TensorEncoder 26 | from .tensor_utils import build_infer_module 27 | from .visualize import plot_proof_history 28 | 29 | random.seed(0) 30 | 31 | 32 | def get_args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument( 35 | "--batch-size", type=int, default=1, help="Batch size to infer with" 36 | ) 37 | parser.add_argument( 38 | "--num-objects", 39 | type=int, 40 | default=3, 41 | help="The maximum number of objects in one image", 42 | ) 43 | parser.add_argument("--dataset") # , choices=["member"]) 44 | parser.add_argument("--dataset_type") # , choices=["member"]) 45 | parser.add_argument("--rtpt-name", default="") # , choices=["member"]) 46 | parser.add_argument( 47 | "--dataset-type", 48 | choices=["vilp", "clevr-hans", "kandinsky"], 49 | help="vilp or kandinsky or clevr", 50 | ) 51 | parser.add_argument("--device", default="cpu", help="cuda device, i.e. 0 or cpu") 52 | parser.add_argument( 53 | "--no-cuda", 54 | action="store_true", 55 | help="Run on CPU instead of GPU (not recommended)", 56 | ) 57 | parser.add_argument( 58 | "--no-train", 59 | action="store_true", 60 | help="Perform prediction without training model", 61 | ) 62 | parser.add_argument( 63 | "--small-data", action="store_true", help="Use small training data." 64 | ) 65 | parser.add_argument( 66 | "--num-workers", type=int, default=4, help="Number of threads for data loader" 67 | ) 68 | parser.add_argument( 69 | "--gamma", 70 | default=0.01, 71 | type=float, 72 | help="Smooth parameter in the softor function", 73 | ) 74 | parser.add_argument( 75 | "--plot", action="store_true", help="Plot images with captions." 76 | ) 77 | parser.add_argument( 78 | "--t-beam", 79 | type=int, 80 | default=4, 81 | help="Number of rule expantion of clause generation.", 82 | ) 83 | parser.add_argument("--n-beam", type=int, default=5, help="The size of the beam.") 84 | parser.add_argument( 85 | "--n-max", type=int, default=50, help="The maximum number of clauses." 86 | ) 87 | parser.add_argument( 88 | "--program-size", 89 | "-m", 90 | type=int, 91 | default=1, 92 | help="The size of the logic program.", 93 | ) 94 | # parser.add_argument("--n-obj", type=int, default=2, help="The number of objects to be focused.") 95 | parser.add_argument("--epochs", type=int, default=20, help="The number of epochs.") 96 | parser.add_argument("--lr", type=float, default=1e-2, help="The learning rate.") 97 | parser.add_argument( 98 | "--n-data", type=float, default=200, help="The number of data to be used." 99 | ) 100 | parser.add_argument( 101 | "--pre-searched", action="store_true", help="Using pre searched clauses." 102 | ) 103 | parser.add_argument( 104 | "-T", 105 | "--infer-step", 106 | type=int, 107 | default=10, 108 | help="The number of steps of forward reasoning.", 109 | ) 110 | parser.add_argument( 111 | "--term-depth", 112 | type=int, 113 | default=3, 114 | help="The number of steps of forward reasoning.", 115 | ) 116 | args = parser.parse_args() 117 | return args 118 | 119 | 120 | def main(): 121 | args = get_args() 122 | print("args ", args) 123 | if args.no_cuda: 124 | device = torch.device("cpu") 125 | elif len(args.device.split(",")) > 1: 126 | # multi gpu 127 | device = torch.device("cuda") 128 | else: 129 | device = torch.device("cuda:" + args.device) 130 | 131 | print("device: ", device) 132 | 133 | # Create RTPT object 134 | rtpt = RTPT( 135 | name_initials="", 136 | experiment_name="NEUMANN_{}".format(args.dataset), 137 | max_iterations=args.epochs, 138 | ) 139 | # Start the RTPT tracking 140 | rtpt.start() 141 | 142 | # Get torch data loader 143 | # train_loader, val_loader, test_loader = get_data_loader(args, device) 144 | 145 | # Load logical representations 146 | lark_path = "src/lark/exp.lark" 147 | lang_base_path = "data/lang/" 148 | lang, clauses, bk, bk_clauses, terms, atoms = get_lang( 149 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth 150 | ) 151 | print(terms) 152 | print("{} Atoms:".format(len(atoms))) 153 | print(atoms) 154 | 155 | # Load the NEUMANN model 156 | NEUMANN = get_model( 157 | lang=lang, 158 | clauses=clauses, 159 | atoms=atoms, 160 | terms=terms, 161 | bk=bk, 162 | bk_clauses=bk_clauses, 163 | program_size=args.program_size, 164 | device=device, 165 | dataset=args.dataset, 166 | dataset_type=args.dataset_type, 167 | num_objects=args.num_objects, 168 | term_depth=args.term_depth, 169 | infer_step=args.infer_step, 170 | train=not (args.no_train), 171 | ) 172 | 173 | x = torch.zeros((1, len(atoms))).to(device) 174 | x[:, 0] = 1.0 175 | ## x[:,1] = 0.8 176 | print("x: ", x) 177 | print(np.round(NEUMANN(x).detach().cpu().numpy(), 2)) 178 | print(NEUMANN.rgm.edge_index) 179 | print("graph: ") 180 | print(NEUMANN.rgm.networkx_graph) 181 | NEUMANN.plot_reasoning_graph(name=args.dataset) 182 | plot_proof_history( 183 | NEUMANN.mpm.x_atom_list, atoms[1:], args.infer_step, args.dataset, mode="graph" 184 | ) 185 | # nx.draw(NEUMANN.rgm.networkx_graph) 186 | # plt.savefig('imgs/{}.png'.format(args.dataset)) 187 | 188 | print("==== tensor based reasoner") 189 | from logic_utils import false 190 | 191 | atoms = [false] + atoms 192 | rgm = NEUMANN.rgm 193 | rgm.facts = atoms 194 | x = torch.zeros((1, len(atoms))).to(device) 195 | x[:, 1] = 1.0 196 | for i, atom in enumerate(atoms): 197 | if atom in bk: 198 | x[:, i] += 1.0 199 | ## x[:,2] = 0.8 200 | IM = build_infer_module( 201 | clauses, 202 | atoms, 203 | lang, 204 | rgm, 205 | device, 206 | m=args.program_size, 207 | infer_step=args.infer_step, 208 | train=True, 209 | ) 210 | IM(x) 211 | IM.W = NEUMANN.clause_weights 212 | print(IM.get_weights()) 213 | plot_proof_history( 214 | IM.V_list, atoms[2:], args.infer_step, args.dataset, mode="tensor" 215 | ) 216 | 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /neumann/neumann/scatter.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | 6 | 7 | 8 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 9 | if dim < 0: 10 | dim = other.dim() + dim 11 | if src.dim() == 1: 12 | for _ in range(0, dim): 13 | src = src.unsqueeze(0) 14 | for _ in range(src.dim(), other.dim()): 15 | src = src.unsqueeze(-1) 16 | src = src.expand(other.size()) 17 | return src 18 | 19 | def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 20 | out: Optional[torch.Tensor] = None, 21 | dim_size: Optional[int] = None) -> torch.Tensor: 22 | index = broadcast(index, src, dim) 23 | if out is None: 24 | size = list(src.size()) 25 | if dim_size is not None: 26 | size[dim] = dim_size 27 | elif index.numel() == 0: 28 | size[dim] = 0 29 | else: 30 | size[dim] = int(index.max()) + 1 31 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 32 | return out.scatter_add_(dim, index, src) 33 | else: 34 | return out.scatter_add_(dim, index, src) 35 | 36 | 37 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 38 | out: Optional[torch.Tensor] = None, 39 | dim_size: Optional[int] = None) -> torch.Tensor: 40 | return scatter_sum(src, index, dim, out, dim_size) 41 | 42 | 43 | def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 44 | out: Optional[torch.Tensor] = None, 45 | dim_size: Optional[int] = None) -> torch.Tensor: 46 | return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size) 47 | 48 | 49 | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 50 | out: Optional[torch.Tensor] = None, 51 | dim_size: Optional[int] = None) -> torch.Tensor: 52 | out = scatter_sum(src, index, dim, out, dim_size) 53 | dim_size = out.size(dim) 54 | 55 | index_dim = dim 56 | if index_dim < 0: 57 | index_dim = index_dim + src.dim() 58 | if index.dim() <= index_dim: 59 | index_dim = index.dim() - 1 60 | 61 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) 62 | count = scatter_sum(ones, index, index_dim, None, dim_size) 63 | count[count < 1] = 1 64 | count = broadcast(count, out, dim) 65 | if out.is_floating_point(): 66 | out.true_divide_(count) 67 | else: 68 | out.div_(count, rounding_mode='floor') 69 | return out 70 | 71 | 72 | def scatter_min( 73 | src: torch.Tensor, index: torch.Tensor, dim: int = -1, 74 | out: Optional[torch.Tensor] = None, 75 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 76 | return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) 77 | 78 | 79 | def scatter_max( 80 | src: torch.Tensor, index: torch.Tensor, dim: int = -1, 81 | out: Optional[torch.Tensor] = None, 82 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 83 | return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) 84 | 85 | 86 | def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 87 | out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, 88 | reduce: str = "sum") -> torch.Tensor: 89 | r""" 90 | | 91 | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ 92 | master/docs/source/_figures/add.svg?sanitize=true 93 | :align: center 94 | :width: 400px 95 | | 96 | Reduces all values from the :attr:`src` tensor into :attr:`out` at the 97 | indices specified in the :attr:`index` tensor along a given axis 98 | :attr:`dim`. 99 | For each value in :attr:`src`, its output index is specified by its index 100 | in :attr:`src` for dimensions outside of :attr:`dim` and by the 101 | corresponding value in :attr:`index` for dimension :attr:`dim`. 102 | The applied reduction is defined via the :attr:`reduce` argument. 103 | Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional 104 | tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` 105 | and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional 106 | tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. 107 | Moreover, the values of :attr:`index` must be between :math:`0` and 108 | :math:`y - 1`, although no specific ordering of indices is required. 109 | The :attr:`index` tensor supports broadcasting in case its dimensions do 110 | not match with :attr:`src`. 111 | For one-dimensional tensors with :obj:`reduce="sum"`, the operation 112 | computes 113 | .. math:: 114 | \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j 115 | where :math:`\sum_j` is over :math:`j` such that 116 | :math:`\mathrm{index}_j = i`. 117 | .. note:: 118 | This operation is implemented via atomic operations on the GPU and is 119 | therefore **non-deterministic** since the order of parallel operations 120 | to the same value is undetermined. 121 | For floating-point variables, this results in a source of variance in 122 | the result. 123 | :param src: The source tensor. 124 | :param index: The indices of elements to scatter. 125 | :param dim: The axis along which to index. (default: :obj:`-1`) 126 | :param out: The destination tensor. 127 | :param dim_size: If :attr:`out` is not given, automatically create output 128 | with size :attr:`dim_size` at dimension :attr:`dim`. 129 | If :attr:`dim_size` is not given, a minimal sized output tensor 130 | according to :obj:`index.max() + 1` is returned. 131 | :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`, 132 | :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) 133 | :rtype: :class:`Tensor` 134 | .. code-block:: python 135 | from torch_scatter import scatter 136 | src = torch.randn(10, 6, 64) 137 | index = torch.tensor([0, 1, 0, 1, 2, 1]) 138 | # Broadcasting in the first and last dim. 139 | out = scatter(src, index, dim=1, reduce="sum") 140 | print(out.size()) 141 | .. code-block:: 142 | torch.Size([10, 3, 64]) 143 | """ 144 | if reduce == 'sum' or reduce == 'add': 145 | return scatter_sum(src, index, dim, out, dim_size) 146 | if reduce == 'mul': 147 | return scatter_mul(src, index, dim, out, dim_size) 148 | elif reduce == 'mean': 149 | return scatter_mean(src, index, dim, out, dim_size) 150 | elif reduce == 'min': 151 | return scatter_min(src, index, dim, out, dim_size)[0] 152 | elif reduce == 'max': 153 | return scatter_max(src, index, dim, out, dim_size)[0] 154 | else: 155 | raise ValueError -------------------------------------------------------------------------------- /neumann/neumann/soft_logic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .torch_utils import softor 4 | 5 | 6 | class SoftLogic(object): 7 | """An class of the soft-implementation of logic operations, i.e., logical-or and logical-and. 8 | """ 9 | 10 | def __init__(self): 11 | pass 12 | 13 | def _or(self, x): 14 | return softor(x, dim=0, gamma=0.01) 15 | 16 | def _and(self, x): 17 | return torch.prod(x, dim=0) 18 | 19 | 20 | class LukasiewiczSoftLogic(SoftLogic): 21 | pass 22 | 23 | 24 | class LNNSoftLogic(SoftLogic): 25 | pass 26 | 27 | 28 | class aILPSoftLogic(SoftLogic): 29 | pass 30 | -------------------------------------------------------------------------------- /neumann/neumann/solve_behind_the_scenes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from rtpt import RTPT 8 | from sklearn.metrics import accuracy_score, recall_score, roc_curve 9 | from torch.utils.tensorboard import SummaryWriter 10 | from tqdm import tqdm 11 | 12 | from logic_utils import get_lang 13 | from mode_declaration import get_mode_declarations 14 | from neumann_utils import (get_behind_the_scenes_loader, get_clause_evaluator, 15 | get_model, get_prob) 16 | from tensor_encoder import TensorEncoder 17 | 18 | # from nsfr_utils import save_images_with_captions, to_plot_images_clevr, generate_captions 19 | 20 | 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--batch-size", type=int, default=2, 25 | help="Batch size to infer with") 26 | parser.add_argument("--batch-size-bs", type=int, 27 | default=1, help="Batch size in beam search") 28 | parser.add_argument("--num-objects", type=int, default=3, 29 | help="The maximum number of objects in one image") 30 | parser.add_argument("--dataset", default="delete") # , choices=["member"]) 31 | parser.add_argument("--dataset-type", default="behind-the-scenes") 32 | parser.add_argument('--device', default='cpu', 33 | help='cuda device, i.e. 0 or cpu') 34 | parser.add_argument("--no-cuda", action="store_true", 35 | help="Run on CPU instead of GPU (not recommended)") 36 | parser.add_argument("--no-train", action="store_true", 37 | help="Perform prediction without training model") 38 | parser.add_argument("--small-data", action="store_true", 39 | help="Use small training data.") 40 | parser.add_argument("--num-workers", type=int, default=0, 41 | help="Number of threads for data loader") 42 | parser.add_argument('--gamma', default=0.01, type=float, 43 | help='Smooth parameter in the softor function') 44 | parser.add_argument("--plot", action="store_true", 45 | help="Plot images with captions.") 46 | parser.add_argument("--t-beam", type=int, default=4, 47 | help="Number of rule expantion of clause generation.") 48 | parser.add_argument("--n-beam", type=int, default=5, 49 | help="The size of the beam.") 50 | parser.add_argument("--n-max", type=int, default=50, 51 | help="The maximum number of clauses.") 52 | parser.add_argument("--program-size", type=int, default=1, 53 | help="The size of the logic program.") 54 | #parser.add_argument("--n-obj", type=int, default=2, help="The number of objects to be focused.") 55 | parser.add_argument("--epochs", type=int, default=20, 56 | help="The number of epochs.") 57 | parser.add_argument("--lr", type=float, default=1e-2, 58 | help="The learning rate.") 59 | parser.add_argument("--n-data", type=float, default=200, 60 | help="The number of data to be used.") 61 | parser.add_argument("--pre-searched", action="store_true", 62 | help="Using pre searched clauses.") 63 | parser.add_argument("--infer-step", type=int, default=6, 64 | help="The number of steps of forward reasoning.") 65 | parser.add_argument("--term-depth", type=int, default=3, 66 | help="The number of steps of forward reasoning.") 67 | parser.add_argument("--question-json-path", default="data/behind-the-scenes/BehindTheScenes_questions.json") 68 | args = parser.parse_args() 69 | return args 70 | 71 | # def get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False): 72 | 73 | 74 | def discretise_NEUMANN(NEUMANN, args, device): 75 | lark_path = 'src/lark/exp.lark' 76 | lang_base_path = 'data/lang/' 77 | lang, clauses_, bk, terms, atoms = get_lang( 78 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth) 79 | # Discretise NEUMANN rules 80 | clauses = NEUMANN.get_clauses() 81 | return get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False) 82 | 83 | 84 | def predict(NEUMANN, I2F, loader, args, device, th=None, split='train'): 85 | predicted_list = [] 86 | target_list = [] 87 | count = 0 88 | 89 | for i, sample in tqdm(enumerate(loader, start=0)): 90 | imgs, query, target_set = map(lambda x: x.to(device), sample) 91 | 92 | # to cuda 93 | target_set = target_set.float() 94 | 95 | V_0 = I2F(imgs, query) 96 | V_T = NEUMANN(V_0) 97 | predicted = get_prob(V_T, NEUMANN, args) 98 | predicted = to_one_label(predicted, target_set) 99 | predicted = torch.softmax(predicted * 10, dim=1) 100 | predicted_list.append(predicted.detach()) 101 | target_list.append(target_set.detach()) 102 | """ 103 | if args.plot: 104 | imgs = to_plot_images_clevr(imgs.squeeze(1)) 105 | captions = generate_captions( 106 | V_T, NEUMANN.atoms, I2F.pm.e, th=0.3) 107 | save_images_with_captions( 108 | imgs, captions, folder='result/kandinsky/' + args.dataset + '/' + split + '/', img_id_start=count, dataset=args.dataset) 109 | """ 110 | count += V_T.size(0) # batch size 111 | 112 | predicted = torch.cat(predicted_list, dim=0).detach().cpu().numpy() 113 | target_set = torch.cat(target_list, dim=0).to( 114 | torch.int64).detach().cpu().numpy() 115 | 116 | if th == None: 117 | fpr, tpr, thresholds = roc_curve(target_set, predicted, pos_label=1) 118 | accuracy_scores = [] 119 | print('ths', thresholds) 120 | for thresh in thresholds: 121 | accuracy_scores.append(accuracy_score( 122 | target_set, [m > thresh for m in predicted])) 123 | 124 | accuracies = np.array(accuracy_scores) 125 | max_accuracy = accuracies.max() 126 | max_accuracy_threshold = thresholds[accuracies.argmax()] 127 | rec_score = recall_score( 128 | target_set, [m > thresh for m in predicted], average=None) 129 | 130 | print('target_set: ', target_set, target_set.shape) 131 | print('predicted: ', predicted, predicted.shape) 132 | print('accuracy: ', max_accuracy) 133 | print('threshold: ', max_accuracy_threshold) 134 | print('recall: ', rec_score) 135 | 136 | return max_accuracy, rec_score, max_accuracy_threshold 137 | else: 138 | accuracy = accuracy_score(target_set, [m > th for m in predicted]) 139 | rec_score = recall_score( 140 | target_set, [m > th for m in predicted], average=None) 141 | return accuracy, rec_score, th 142 | 143 | 144 | def to_one_label(ys, labels, th=0.7): 145 | ys_new = [] 146 | for i in range(len(ys)): 147 | y = ys[i] 148 | label = labels[i] 149 | # check in case answers are computed 150 | num_class = 0 151 | for p_j in y: 152 | if p_j > th: 153 | num_class += 1 154 | if num_class >= 2: 155 | # drop the value using label (the label is one-hot) 156 | drop_index = torch.argmin(label - y) 157 | y[drop_index] = y.min() 158 | ys_new.append(y) 159 | return torch.stack(ys_new) 160 | 161 | 162 | def main(n): 163 | args = get_args() 164 | #name = 'VILP' 165 | print('args ', args) 166 | if args.no_cuda: 167 | device = torch.device('cpu') 168 | elif len(args.device.split(',')) > 1: 169 | # multi gpu 170 | device = torch.device('cuda') 171 | else: 172 | device = torch.device('cuda:' + args.device) 173 | 174 | print('device: ', device) 175 | name = 'neumann/behind-the-scenes/' + str(n) 176 | writer = SummaryWriter(f"runs/{name}", purge_step=0) 177 | 178 | # Create RTPT object 179 | rtpt = RTPT(name_initials='HS', experiment_name=name, 180 | max_iterations=args.epochs) 181 | # Start the RTPT tracking 182 | rtpt.start() 183 | 184 | 185 | ## train_pos_loader, val_pos_loader, test_pos_loader = get_vilp_pos_loader(args) 186 | #####train_pos_loader, val_pos_loader, test_pos_loader = get_data_loader(args) 187 | 188 | # load logical representations 189 | lark_path = 'src/lark/exp.lark' 190 | lang_base_path = 'data/lang/' 191 | lang, clauses, bk, bk_clauses, terms, atoms = get_lang( 192 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth) 193 | 194 | print("{} Atoms:".format(len(atoms))) 195 | 196 | NEUMANN, I2F = get_model(lang=lang, clauses=clauses, atoms=atoms, terms=terms, bk=bk, bk_clauses=bk_clauses, 197 | program_size=args.program_size, device=device, dataset=args.dataset, dataset_type=args.dataset_type, 198 | num_objects=args.num_objects, infer_step=args.infer_step, train=False)#train=not(args.no_train)) 199 | print(NEUMANN.rgm) 200 | # get torch data loader 201 | question_json_path = 'data/behind-the-scenes/BehindTheScenes_questions_{}.json'.format(args.dataset) 202 | test_loader = get_behind_the_scenes_loader(question_json_path, args.batch_size, lang, device) 203 | 204 | 205 | writer.add_scalar("graph/num_atom_nodes", len(NEUMANN.rgm.atom_node_idxs)) 206 | writer.add_scalar("graph/num_conj_nodes", len(NEUMANN.rgm.conj_node_idxs)) 207 | num_nodes = len(NEUMANN.rgm.atom_node_idxs) + len(NEUMANN.rgm.conj_node_idxs) 208 | writer.add_scalar("graph/num_nodes", num_nodes) 209 | 210 | num_edges = NEUMANN.rgm.edge_index.size(1) 211 | writer.add_scalar("graph/num_edges", num_edges) 212 | 213 | writer.add_scalar("graph/memory_total", num_nodes + num_edges) 214 | 215 | print("=====================") 216 | print("NUM NODES: ", num_nodes) 217 | print("NUM EDGES: ", num_edges) 218 | print("MEMORY TOTAL: ", num_nodes + num_edges) 219 | print("=====================") 220 | 221 | params = list(NEUMANN.parameters()) 222 | print('parameters: ', list(params)) 223 | 224 | print("Predicting on test data set...") 225 | # test split 226 | acc_test, rec_test, th_test = predict( 227 | NEUMANN, I2F, test_loader, args, device, th=0.5, split='test') 228 | 229 | print("test acc: ", acc_test, "threashold: ", th_test, "recall: ", rec_test) 230 | 231 | 232 | if __name__ == "__main__": 233 | for i in range(1): 234 | main(n=i) 235 | -------------------------------------------------------------------------------- /neumann/neumann/solve_kandinsky.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from rtpt import RTPT 8 | from sklearn.metrics import accuracy_score, recall_score, roc_curve 9 | from torch.utils.tensorboard import SummaryWriter 10 | from tqdm import tqdm 11 | 12 | from logic_utils import get_lang 13 | from mode_declaration import get_mode_declarations 14 | from neumann_utils import (get_clause_evaluator, get_data_loader, get_model, 15 | get_prob) 16 | from tensor_encoder import TensorEncoder 17 | 18 | # from nsfr_utils import save_images_with_captions, to_plot_images_clevr, generate_captions 19 | 20 | torch.set_num_threads(10) 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--batch-size", type=int, default=10, 25 | help="Batch size to infer with") 26 | parser.add_argument("--batch-size-bs", type=int, 27 | default=1, help="Batch size in beam search") 28 | parser.add_argument("--num-objects", type=int, default=6, 29 | help="The maximum number of objects in one image") 30 | parser.add_argument("--dataset", default="delete") # , choices=["member"]) 31 | parser.add_argument("--dataset-type", default="behind-the-scenes") 32 | parser.add_argument('--device', default='cpu', 33 | help='cuda device, i.e. 0 or cpu') 34 | parser.add_argument("--no-cuda", action="store_true", 35 | help="Run on CPU instead of GPU (not recommended)") 36 | parser.add_argument("--no-train", action="store_true", 37 | help="Perform prediction without training model") 38 | parser.add_argument("--small-data", action="store_true", 39 | help="Use small training data.") 40 | parser.add_argument("--num-workers", type=int, default=0, 41 | help="Number of threads for data loader") 42 | parser.add_argument('--gamma', default=0.01, type=float, 43 | help='Smooth parameter in the softor function') 44 | parser.add_argument("--plot", action="store_true", 45 | help="Plot images with captions.") 46 | parser.add_argument("--t-beam", type=int, default=4, 47 | help="Number of rule expantion of clause generation.") 48 | parser.add_argument("--n-beam", type=int, default=5, 49 | help="The size of the beam.") 50 | parser.add_argument("--n-max", type=int, default=50, 51 | help="The maximum number of clauses.") 52 | parser.add_argument("--program-size", type=int, default=1, 53 | help="The size of the logic program.") 54 | #parser.add_argument("--n-obj", type=int, default=2, help="The number of objects to be focused.") 55 | parser.add_argument("--epochs", type=int, default=20, 56 | help="The number of epochs.") 57 | parser.add_argument("--lr", type=float, default=1e-2, 58 | help="The learning rate.") 59 | parser.add_argument("--n-ratio", type=float, default=1.0, 60 | help="The ratio of data to be used.") 61 | parser.add_argument("--pre-searched", action="store_true", 62 | help="Using pre searched clauses.") 63 | parser.add_argument("--infer-step", type=int, default=6, 64 | help="The number of steps of forward reasoning.") 65 | parser.add_argument("--term-depth", type=int, default=3, 66 | help="The number of steps of forward reasoning.") 67 | parser.add_argument("--question-json-path", default="data/behind-the-scenes/BehindTheScenes_questions.json") 68 | args = parser.parse_args() 69 | return args 70 | 71 | # def get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False): 72 | 73 | 74 | def discretise_NEUMANN(NEUMANN, args, device): 75 | lark_path = 'src/lark/exp.lark' 76 | lang_base_path = 'data/lang/' 77 | lang, clauses_, bk, terms, atoms = get_lang( 78 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth) 79 | # Discretise NEUMANN rules 80 | clauses = NEUMANN.get_clauses() 81 | return get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False) 82 | 83 | 84 | def predict(NEUMANN, I2F, loader, args, device, th=None, split='train'): 85 | predicted_list = [] 86 | target_list = [] 87 | count = 0 88 | 89 | start = time.time() 90 | for epoch in tqdm(range(args.epochs)): 91 | for i, sample in enumerate(tqdm(loader), start=0): 92 | imgs, target_set = map(lambda x: x.to(device), sample) 93 | 94 | # to cuda 95 | target_set = target_set.float() 96 | 97 | V_0 = I2F(imgs) 98 | V_T = NEUMANN(V_0) 99 | #a NEUMANN.print_valuation_batch(V_T) 100 | #predicted = get_prob(V_T, NEUMANN, args) 101 | # predicted = to_one_label(predicted, target_set) 102 | # predicted = torch.softmax(predicted * 10, dim=1) 103 | #predicted_list.append(predicted.detach()) 104 | #target_list.append(target_set.detach()) 105 | """ 106 | if args.plot: 107 | imgs = to_plot_images_clevr(imgs.squeeze(1)) 108 | captions = generate_captions( 109 | V_T, NEUMANN.atoms, I2F.pm.e, th=0.3) 110 | save_images_with_captions( 111 | imgs, captions, folder='result/kandinsky/' + args.dataset + '/' + split + '/', img_id_start=count, dataset=args.dataset) 112 | """ 113 | count += V_T.size(0) # batch size 114 | reasoning_time = time.time() - start 115 | print('Reasoning Time: ', reasoning_time) 116 | return 0, 0, 0, reasoning_time 117 | 118 | predicted = torch.cat(predicted_list, dim=0).detach().cpu().numpy() 119 | target_set = torch.cat(target_list, dim=0).to( 120 | torch.int64).detach().cpu().numpy() 121 | 122 | if th == None: 123 | fpr, tpr, thresholds = roc_curve(target_set, predicted, pos_label=1) 124 | accuracy_scores = [] 125 | print('ths', thresholds) 126 | for thresh in thresholds: 127 | accuracy_scores.append(accuracy_score( 128 | target_set, [m > thresh for m in predicted])) 129 | 130 | accuracies = np.array(accuracy_scores) 131 | max_accuracy = accuracies.max() 132 | max_accuracy_threshold = thresholds[accuracies.argmax()] 133 | rec_score = recall_score( 134 | target_set, [m > thresh for m in predicted], average=None) 135 | 136 | print('target_set: ', target_set, target_set.shape) 137 | print('predicted: ', predicted, predicted.shape) 138 | print('accuracy: ', max_accuracy) 139 | print('threshold: ', max_accuracy_threshold) 140 | print('recall: ', rec_score) 141 | 142 | return max_accuracy, rec_score, max_accuracy_threshold, reasoning_time 143 | else: 144 | accuracy = accuracy_score(target_set, [m > th for m in predicted]) 145 | rec_score = recall_score( 146 | target_set, [m > th for m in predicted], average=None) 147 | return accuracy, rec_score, th, reasoning_time 148 | 149 | 150 | def to_one_label(ys, labels, th=0.7): 151 | ys_new = [] 152 | for i in range(len(ys)): 153 | y = ys[i] 154 | label = labels[i] 155 | # check in case answers are computed 156 | num_class = 0 157 | for p_j in y: 158 | if p_j > th: 159 | num_class += 1 160 | if num_class >= 2: 161 | # drop the value using label (the label is one-hot) 162 | drop_index = torch.argmin(label - y) 163 | y[drop_index] = y.min() 164 | ys_new.append(y) 165 | return torch.stack(ys_new) 166 | 167 | 168 | def main(n): 169 | args = get_args() 170 | #name = 'VILP' 171 | print('args ', args) 172 | if args.no_cuda: 173 | device = torch.device('cpu') 174 | elif len(args.device.split(',')) > 1: 175 | # multi gpu 176 | device = torch.device('cuda') 177 | else: 178 | device = torch.device('cuda:' + args.device) 179 | 180 | print('device: ', device) 181 | name = 'neumann/behind-the-scenes/' + str(n) 182 | writer = SummaryWriter(f"runs/{name}", purge_step=0) 183 | 184 | # Create RTPT object 185 | rtpt = RTPT(name_initials='HS', experiment_name=name, 186 | max_iterations=args.epochs) 187 | # Start the RTPT tracking 188 | rtpt.start() 189 | 190 | 191 | ## train_pos_loader, val_pos_loader, test_pos_loader = get_vilp_pos_loader(args) 192 | #####train_pos_loader, val_pos_loader, test_pos_loader = get_data_loader(args) 193 | 194 | # load logical representations 195 | lark_path = 'src/lark/exp.lark' 196 | lang_base_path = 'data/lang/' 197 | lang, clauses, bk, bk_clauses, terms, atoms = get_lang( 198 | lark_path, lang_base_path, args.dataset_type, args.dataset, args.term_depth, use_learned_clauses=True) 199 | 200 | print("{} Atoms:".format(len(atoms))) 201 | 202 | # get torch data loader 203 | #question_json_path = 'data/behind-the-scenes/BehindTheScenes_questions_{}.json'.format(args.dataset) 204 | # test_loader = get_behind_the_scenes_loader(question_json_path, args.batch_size, lang, args.n_data, device) 205 | train_loader, val_loader, test_loader = get_data_loader(args, device) 206 | 207 | NEUMANN, I2F = get_model(lang=lang, clauses=clauses, atoms=atoms, terms=terms, bk=bk, bk_clauses=bk_clauses, 208 | program_size=args.program_size, device=device, dataset=args.dataset, dataset_type=args.dataset_type, 209 | num_objects=args.num_objects, infer_step=args.infer_step, train=False)#train=not(args.no_train)) 210 | 211 | writer.add_scalar("graph/num_atom_nodes", len(NEUMANN.rgm.atom_node_idxs)) 212 | writer.add_scalar("graph/num_conj_nodes", len(NEUMANN.rgm.conj_node_idxs)) 213 | num_nodes = len(NEUMANN.rgm.atom_node_idxs) + len(NEUMANN.rgm.conj_node_idxs) 214 | writer.add_scalar("graph/num_nodes", num_nodes) 215 | 216 | num_edges = NEUMANN.rgm.edge_index.size(1) 217 | writer.add_scalar("graph/num_edges", num_edges) 218 | 219 | writer.add_scalar("graph/memory_total", num_nodes + num_edges) 220 | 221 | print("=====================") 222 | print("NUM NODES: ", num_nodes) 223 | print("NUM EDGES: ", num_edges) 224 | print("MEMORY TOTAL: ", num_nodes + num_edges) 225 | print("=====================") 226 | 227 | params = list(NEUMANN.parameters()) 228 | print('parameters: ', list(params)) 229 | 230 | print("Predicting on train data set...") 231 | times = [] 232 | # train split 233 | for j in range(n): 234 | acc_test, rec_test, th_test, time = predict( 235 | NEUMANN, I2F, train_loader, args, device, th=0.5, split='test') 236 | times.append(time) 237 | 238 | with open('out/inference_time/time_{}_ratio_{}.txt'.format(args.dataset, args.n_ratio), 'w') as f: 239 | f.write("\n".join(str(item) for item in times)) 240 | 241 | print("train acc: ", acc_test, "threashold: ", th_test, "recall: ", rec_test) 242 | 243 | if __name__ == "__main__": 244 | main(n=5) 245 | 246 | -------------------------------------------------------------------------------- /neumann/neumann/torch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def logsumexp(inputs, dim=None, keepdim=False): 6 | """Numerically stable logsumexp. 7 | from https://github.com/pytorch/pytorch/issues/2591#issuecomment-364474328 8 | Args: 9 | inputs: A Variable with any shape. 10 | dim: An integer. 11 | keepdim: A boolean. 12 | 13 | Returns: 14 | Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). 15 | """ 16 | # For a 1-D array x (any array along a single dimension), 17 | # log sum exp(x) = s + log sum exp(x - s) 18 | # with s = max(x) being a common choice. 19 | if dim is None: 20 | inputs = inputs.view(-1) 21 | dim = 0 22 | s, _ = torch.max(inputs, dim=dim, keepdim=True) 23 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() 24 | if not keepdim: 25 | outputs = outputs.squeeze(dim) 26 | return outputs 27 | 28 | 29 | def weight_sum(W_l, H): 30 | # W: C 31 | # H: C * B * G 32 | W_ex = W_l.unsqueeze(dim=-1).unsqueeze(dim=-1).expand_as(H) 33 | # C * B * G 34 | WH = W_ex * H 35 | # B * G 36 | WH_sum = torch.sum(WH, dim=0) 37 | return WH_sum 38 | 39 | 40 | def softor(xs, dim=0, gamma=0.015): 41 | """The softor function. 42 | 43 | Args: 44 | xs (tensor or list(tensor)): The input tensor. 45 | dim (int): The dimension to be removed. 46 | gamma (float: The smooth parameter for logsumexp. 47 | Returns: 48 | log_sum_exp (tensor): The result of taking or along dim. 49 | """ 50 | # xs is List[Tensor] or Tensor 51 | if not torch.is_tensor(xs): 52 | xs = torch.stack(xs, dim) 53 | log_sum_exp = gamma*logsumexp(xs * (1/gamma), dim=dim) 54 | # log_sum_exp = gamma * torch.log(torch.sum(torch.exp(xs/gamma),dim=dim)) 55 | if log_sum_exp.max() > 1.0: 56 | return log_sum_exp / log_sum_exp.max() 57 | else: 58 | return log_sum_exp 59 | 60 | 61 | def print_valuation(valuation, atoms, n=40): 62 | """Print the valuation tensor. 63 | 64 | Print the valuation tensor using given atoms. 65 | Args: 66 | valuation (tensor;(B*G)): A valuation tensor. 67 | atoms (list(atom)): The ground atoms. 68 | """ 69 | for b in range(valuation.size(0)): 70 | print('===== BATCH: ', b, '=====') 71 | v = valuation[b].detach().cpu().numpy() 72 | idxs = np.argsort(-v) 73 | for i in idxs: 74 | if v[i] > 0.1: 75 | print(i, atoms[i], ': ', round(v[i], 3)) 76 | -------------------------------------------------------------------------------- /neumann/neumann/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import numpy as np 4 | from matplotlib.backends.backend_pdf import PdfPages 5 | from sklearn.manifold import TSNE 6 | 7 | 8 | def plot_atoms(x_atom, atoms, path): 9 | x_atom = x_atom.detach().cpu().numpy() 10 | labels = [str(atom) for atom in atoms] 11 | X_reduced = TSNE(n_components=2, random_state=0).fit_transform(x_atom) 12 | fig, ax = plt.subplots(figsize=(30,30)) 13 | ax.scatter(X_reduced[:, 0], X_reduced[:, 1]) 14 | for i, label in enumerate(labels): 15 | ax.annotate(label, (X_reduced[i,0], X_reduced[i,1])) 16 | plt.savefig(path) 17 | 18 | def plot_infer_embeddings(x_atom_list, atoms): 19 | for i, x_atom in enumerate(x_atom_list): 20 | plot_atoms(x_atom, atoms, 'imgs/x_atom_' + str(i) + '.png') 21 | 22 | 23 | def plot_proof_history(V_list, atoms, infer_step, dataset, mode='graph'): 24 | if dataset == 'graph': 25 | fig, ax = plt.subplots(figsize=(12, 12)) 26 | else: 27 | fig, ax = plt.subplots(figsize=(6, 6)) 28 | # extract first batch 29 | vs_img = np.round(np.array([vs[0] for vs in V_list]),2) 30 | print(vs_img) 31 | vmax = vs_img.max() 32 | im = ax.imshow(vs_img, cmap="Blues") 33 | ##m = ax.imshow(vs_img, cmap="plasma") 34 | ax.set_xticks(np.arange(len(atoms))) 35 | ax.set_yticks(np.arange(infer_step+1)) 36 | plt.yticks(fontname = "monospace", fontsize=12) 37 | plt.xticks(fontname = "monospace", fontsize=11) 38 | ax.set_xticklabels([str(x) for x in atoms]) 39 | ax.set_yticklabels(["v_{}".format(i) for i in range(infer_step+1)]) 40 | plt.setp(ax.get_xticklabels(), rotation=30, ha="right", rotation_mode="anchor") 41 | # Loop over data dimensions and create text annotations. 42 | plt.rcParams.update({'font.size': 10}) 43 | for i in range(infer_step+1): 44 | for j in range(len(atoms)): 45 | if vs_img[i, j] > 0.1: 46 | #print(vs_img[i,j], vs_img[i, j] / vmax ) 47 | if vs_img[i, j] / vmax < 0.4: 48 | text = ax.text(j, i, str(vs_img[i, j]).replace('0.', '.'), ha="center", va="center", color="gray") 49 | else: 50 | text = ax.text(j, i, str(vs_img[i, j]).replace('0.', '.'), ha="center", va="center", color="w") 51 | if mode == 'graph': 52 | ax.set_title("Proof history on {} dataset (NEUMANN)".format(dataset), fontsize=18) 53 | elif mode == 'tensor': 54 | ax.set_title("Proof history on {} dataset (Tensor-based Reasoner)".format(dataset), fontsize=18) 55 | fig.tight_layout() 56 | plt.show() 57 | folder_path = "plot" 58 | # plt.savefig(f"{folder_path}/{dataset}_infer_history.svg") 59 | plt.savefig(f"{folder_path}/{dataset}_{infer_step}_{mode}_history.svg") 60 | plt.savefig(f"{folder_path}/{dataset}_{infer_step}_{mode}_history.png") 61 | 62 | def plot_reasoning_graph(path, reasoning_graph_module): 63 | #pp = PdfPages(path) 64 | 65 | G = reasoning_graph_module.networkx_graph 66 | fig = plt.figure(1, figsize=(30, 30)) 67 | 68 | first_partition_nodes = list(range(len(reasoning_graph_module.facts))) 69 | edges = G.edges() 70 | colors_rg = [G[u][v]['color'] for u, v in edges] 71 | colors = [] 72 | for c in colors_rg: 73 | if c == 'r': 74 | colors.append('indianred') 75 | elif c == 'b': 76 | colors.append('royalblue') 77 | #weights = [G[u][v]['weight'] for u,v in edges] 78 | 79 | nx.draw_networkx( 80 | G, 81 | alpha=0.5, 82 | labels=reasoning_graph_module.node_labels, 83 | node_size=2, node_color='lightgray', edge_color=colors, font_size=10, 84 | pos=nx.drawing.layout.bipartite_layout(G, first_partition_nodes)) # Or whatever other display options you like 85 | plt.tight_layout() 86 | 87 | plt.show() 88 | plt.savefig(path) 89 | plt.close() 90 | #pp.savefig(fig) 91 | #pp.close() -------------------------------------------------------------------------------- /neumann/requirements.txt: -------------------------------------------------------------------------------- 1 | ### NSFR 2 | torch==1.10.0 3 | lark-parser 4 | sklearn 5 | torchsummary 6 | setuptools 7 | tensorboard 8 | numpy>=1.18.5 9 | tqdm>=4.41.0 10 | matplotlib>=3.2.2 11 | opencv-python>=4.1.2 12 | Pillow 13 | PyYAML>=5.3.1 14 | scipy>=1.4.1 15 | seaborn 16 | pandas 17 | rtpt 18 | -------------------------------------------------------------------------------- /neumann/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | # import ipdb; ipdb.set_trace() 6 | 7 | 8 | def package_files(directory): 9 | paths = [] 10 | for path, directories, filenames in os.walk(directory): 11 | for filename in filenames: 12 | paths.append(os.path.join("..", path, filename)) 13 | return paths 14 | 15 | 16 | # extra_files = package_files('nsfr/data') + package_files('nsfr/lark') 17 | extra_files = package_files("neumann/src/lark") 18 | # import ipdb; ipdb.set_trace() 19 | 20 | setup( 21 | name="neumann", 22 | version="0.1.0", 23 | author="anonymous", 24 | author_email="anon.gouv", 25 | packages=find_packages(), 26 | package_data={"": extra_files}, 27 | include_package_data=True, 28 | # package_dir={'':'src'}, 29 | url="tba", 30 | description="GNN-based Differentiable Forward Reasoner", 31 | long_description=open("README.md").read(), 32 | install_requires=[ 33 | "matplotlib", 34 | "numpy", 35 | "seaborn", 36 | "setuptools", 37 | # "torch", 38 | "tqdm", 39 | "lark", 40 | ], 41 | ) 42 | -------------------------------------------------------------------------------- /prompt/gen_constants.txt: -------------------------------------------------------------------------------- 1 | Given a deictic representation, generate predicates in the format "const1,const2,...,constn". 2 | Use only small characters, 3 | 4 | An object that is next to the keyboard. 5 | keyboard 6 | 7 | An object that is on the desk. 8 | desk 9 | 10 | An object that has the papers. 11 | papers 12 | 13 | An object which is on the desk and next to the phone. 14 | desk,phone -------------------------------------------------------------------------------- /prompt/gen_predicates.txt: -------------------------------------------------------------------------------- 1 | Given a deictic representation, generate predicates in the format "pred1,pred2,...,predn". 2 | Use comma "," to separate predicates. 3 | Answer in one line. 4 | 5 | The available predicates are: on, has, wearing, of, in, near, behind, with, holding, above, sitting_on, wears, under, riding, in_front_of, standing_on, at, carrying, attached_to, walking_on, over, for, looking_at, watching, hanging_from, laying_on, eating, and, belonging_to, parked_on, playing_on, using, covering, between, along, covered_in, part_of, lying_on, on_back_of, to, walking_in, mounted_on, across, against, from, growing_on, painted_on, playing, made_of, says, flying_in 6 | 7 | 8 | Examples: 9 | 10 | an object that is next to keyboard. 11 | next_to 12 | 13 | an object that is on a desk. 14 | on 15 | 16 | an object that has a papers. 17 | has 18 | 19 | an object which is on a desk and next to a phone. 20 | on,next_to 21 | 22 | an object that is behind a couch and on floor. 23 | behind,on 24 | 25 | an object that is near a desk and against wall. 26 | near,against 27 | 28 | an object that is parked on a street and has back. 29 | parked_on,has 30 | 31 | an object that has sides, that is on a pole, and that is above a stop sign. 32 | has,on,above -------------------------------------------------------------------------------- /prompt/gen_rules.txt: -------------------------------------------------------------------------------- 1 | Given a deictic representation and available predicates, generate rules in the format. 2 | The rule's format is 3 | target(X):-cond1(X),...condn(X). 4 | cond1(X):-pred1(X,Y),type(Y,const1). 5 | ... 6 | condn(X):-predn(X,Y),type(Y,const2). 7 | Use predicates and constants that appear in the given sentence. 8 | Capitalize variables: X, Y, Z, W, etc. 9 | 10 | Examples: 11 | 12 | an object that is next to a keyboard. 13 | available predicates: next_to 14 | cond1(X):-next_to(X,Y),type(Y,keyboard). 15 | target(X):-cond1(X). 16 | 17 | an object that is on a desk. 18 | available predicates: on 19 | cond1(X):-on(X,Y),type(Y,desk). 20 | target(X):-cond1(X). 21 | 22 | 23 | an object that has papers. 24 | available predicates: has 25 | cond1(X):-has(X,Y),type(Y,papers). 26 | target(X):-cond1(X). 27 | 28 | an object that is on a white pillow. 29 | available predicates: on 30 | cond1(X):-on(X,Y),type(Y,white_pillow). 31 | target(X):-cond1(X). 32 | 33 | an object that has a fur. 34 | available predicate: has 35 | cond1(X):-has(X,Y),type(Y,fur). 36 | target(X):-cond1(X). 37 | 38 | an object that is on a ground, and that is behind a white line. 39 | available predicates: on,behind 40 | cond1(X):-on(X,Y),type(Y,ground). 41 | cond2(X):-behind(X,Y),type(Y,whiteline). 42 | target(X):-cond1(X),cond2(X) 43 | 44 | an object that is near a desk and against wall. 45 | available predicates: near,against 46 | cond1(X):-near(X,Y),type(Y,desk). 47 | cond2(X):-against(X,Y),type(Y,wall). 48 | target(X):-cond1(X),cond2(X). 49 | 50 | an object that has sides, that is on a pole, and that is above a stop sign. 51 | available predicates: has,on,above 52 | cond1(X):-has(X,Y),type(Y,sides). 53 | cond2(X):-on(X,Y),type(Y,pole). 54 | cond3(X):-above(X,Y),type(Y,stopsign). 55 | target(X):-cond1(X),cond2(X),cond3(X). 56 | 57 | an object that is wearing a shirt, that has a hair, and that is wearing shoes. 58 | available predicates: wearing,has,wearing 59 | cond1(X):-wearing(X,Y),type(Y,shirt). 60 | cond2(X):-has(X,Y),type(Y,hair). 61 | cond3(X):-wearing(X,Y),type(Y,shoes). 62 | target(X):-cond1(X),cond2(X),cond3(X). -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/deictic-segment-anything/e7a014546350bf5c9e41342fd368f24488ae8acb/src/__init__.py -------------------------------------------------------------------------------- /src/data_vg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | import random 5 | 6 | import visual_genome.local as vg 7 | from groundingdino.util.inference import load_image 8 | from neumann.fol.language import DataType, Language 9 | from neumann.fol.logic import Atom, Const, NeuralPredicate, Predicate 10 | 11 | 12 | def to_object_type_atoms(atoms, lang): 13 | """Convert VG atoms to an object-type format.""" 14 | unique_constants = set(term for atom in atoms for term in atom.terms) 15 | new_lang = Language(preds=lang.preds.copy(), funcs=[], consts=[]) 16 | new_const_dic = { 17 | vg_const: Const(f"obj_{vg_const.name.split('_')[-1]}", dtype=DataType("object")) 18 | for vg_const in unique_constants 19 | } 20 | 21 | new_lang.consts.extend(new_const_dic.values()) 22 | p_type = NeuralPredicate("type", 2, [DataType("object"), DataType("type")]) 23 | 24 | new_atoms = [ 25 | Atom(atom.pred, [new_const_dic[term] for term in atom.terms]) 26 | for atom in atoms 27 | ] 28 | 29 | for type_const, obj_const in new_const_dic.items(): 30 | base_type_const = Const(type_const.name.split("_")[0], dtype=DataType("type")) 31 | if base_type_const not in new_lang.consts: 32 | new_lang.consts.append(base_type_const) 33 | new_atoms.append(Atom(p_type, [obj_const, base_type_const])) 34 | 35 | new_lang = Language( 36 | preds=list(set(new_lang.preds)), funcs=[], consts=list(set(new_lang.consts)) 37 | ) 38 | return new_atoms, new_lang 39 | 40 | 41 | class BaseVisualGenomeDataset(torch.utils.data.Dataset): 42 | """Base class for Visual Genome datasets.""" 43 | 44 | def __init__(self, json_path): 45 | self.json_data = self._load_json(json_path) 46 | 47 | def _load_json(self, path): 48 | with open(path, "r") as file: 49 | return json.load(file) 50 | 51 | def _parse_data(self, item): 52 | data = self.json_data["queries"][item] 53 | return data[0] + ".", data[1], data[2], data[3], data[4] 54 | 55 | def __len__(self): 56 | return len(self.json_data["queries"]) 57 | 58 | 59 | class DeicticVisualGenome(BaseVisualGenomeDataset): 60 | """Deictic Visual Genome dataset.""" 61 | 62 | def __getitem__(self, item): 63 | deictic_representation, answer, image_id, vg_data_index, id = self._parse_data(item) 64 | image_source, image = load_image(f"data/visual_genome/VG_100K/{image_id}.jpg") 65 | return id, vg_data_index, image_id, image_source, image, deictic_representation, answer 66 | 67 | 68 | class DeicticVisualGenomeSGGTraining(BaseVisualGenomeDataset): 69 | """Deictic Visual Genome dataset for scene graph generator training.""" 70 | 71 | def __init__(self, args, mode="train"): 72 | filename = f"deictic_vg_comp{args.complexity}_sgg_{mode}.json" 73 | json_path = f"data/deivg_learning/{filename}" 74 | super().__init__(json_path) 75 | 76 | def __getitem__(self, item): 77 | deictic_representation, answer, image_id, vg_data_index, id = self._parse_data(item) 78 | image_source, image = load_image(f"data/visual_genome/VG_100K/{image_id}.jpg") 79 | return id, vg_data_index, image_id, image_source, image, deictic_representation, answer 80 | 81 | 82 | class VisualGenomeUtils: 83 | """A utility class for the Visual Genome dataset.""" 84 | 85 | def __init__(self): 86 | self.all_relationships = self._load_json("data/visual_genome/relationships_deictic.json") 87 | self.all_objects = self._load_json("data/visual_genome/objects.json") 88 | 89 | def _load_json(self, path): 90 | print(f"Loading {path} ...") 91 | with open(path, "r") as file: 92 | data = json.load(file) 93 | print("Completed.") 94 | return data 95 | 96 | def load_scene_graphs(self, num_images=20): 97 | return vg.get_scene_graphs( 98 | start_index=0, 99 | end_index=num_images, 100 | min_rels=1, 101 | data_dir="data/visual_genome/", 102 | image_data_dir="data/visual_genome/by-id/", 103 | ) 104 | 105 | def load_scene_graph_by_id(self, image_id): 106 | return vg.get_scene_graph( 107 | image_id=image_id, 108 | images="data/visual_genome/", 109 | image_data_dir="data/visual_genome/by-id/", 110 | synset_file="data/visual_genome/synsets.json", 111 | ) 112 | 113 | def scene_graph_to_language(self, scene_graph, text, logic_generator, num_objects=3): 114 | """Generate FOL language from a scene graph.""" 115 | objects = list(set(obj.replace(" ", "") for obj in scene_graph.objects)) 116 | datatype = DataType("type") 117 | constants = [Const(obj, datatype) for obj in objects] 118 | 119 | const_response = f"Constants:\ntype:{','.join(objects)}" 120 | predicates, pred_response = logic_generator.generate_predicates(text, const_response) 121 | lang = Language(consts=list(set(constants)), preds=list(set(predicates)), funcs=[]) 122 | return lang, const_response, pred_response 123 | 124 | def data_index_to_atoms(self, data_index, lang): 125 | """Generate atoms from data index.""" 126 | relationships = self.all_relationships[data_index]["relationships"] 127 | atoms = [self._parse_relationship(rel, lang) for rel in relationships if self._parse_relationship(rel, lang)] 128 | return to_object_type_atoms(atoms, lang) 129 | 130 | def _parse_relationship(self, rel, lang): 131 | pred_name = rel["predicate"].replace(" ", "_").lower() 132 | dtype = DataType("object") 133 | pred = Predicate(pred_name, 2, [dtype, dtype]) 134 | # Either of name of names is used as a key 135 | # Add key "name" if it is names 136 | if "names" in rel["object"].keys(): 137 | rel["object"]["name"] = rel["object"]["names"][0] 138 | if "names" in rel["subject"].keys(): 139 | rel["subject"]["name"] = rel["subject"]["names"][0] 140 | consts = [ 141 | Const(rel["subject"]["name"].replace(" ", "_").lower() + f"_{rel['subject']['object_id']}", dtype=dtype), 142 | Const(rel["object"]["name"].replace(" ", "").lower() + f"_{rel['object']['object_id']}", dtype=dtype), 143 | ] 144 | return Atom(pred, consts) 145 | 146 | 147 | class PredictedSceneGraphUtils: 148 | """Utils for predicted scene graphs.""" 149 | 150 | def __init__(self, model="veto", base_path=""): 151 | self.model = model 152 | self.base_path = base_path 153 | self.scene_graphs = self._load_predicted_scene_graphs(model) 154 | self.all_relationships = self._preprocess_relationships(self.scene_graphs) 155 | 156 | def _load_predicted_scene_graphs(self, model): 157 | with open(f"{self.base_path}data/predicted_scene_graphs/{model}.pkl", "rb") as file: 158 | return pickle.load(file) 159 | 160 | def _preprocess_relationships(self, scene_graphs): 161 | return { 162 | int(graph_[0]["img_info"].split("/")[-1].split(".")[0]): graph_[0]["rel_info"]["spo"] 163 | for graph_ in scene_graphs 164 | } 165 | 166 | def _parse_relationship(self, rel): 167 | dtype = DataType("object") 168 | pred = Predicate(rel["p_str"].replace(" ", "_").lower(), 2, [dtype, dtype]) 169 | consts = [ 170 | Const(rel["s_str"].replace(" ", "_").lower() + f"_{rel['s_unique_id']}", dtype=dtype), 171 | Const(rel["o_str"].replace(" ", "_").lower() + f"_{rel['o_unique_id']}", dtype=dtype), 172 | ] 173 | return Atom(pred, consts), rel["score"] 174 | 175 | def load_scene_graph_by_id(self, image_id): 176 | return self.all_relationships.get(image_id, []) 177 | 178 | def data_index_to_atoms(self, data_index, lang): 179 | return self._generate_atoms(lang, self.scene_graphs[data_index][0]["rel_info"]["spo"]) 180 | 181 | def image_id_to_atoms(self, image_id, lang): 182 | return self._generate_atoms(lang, self.all_relationships.get(image_id, [])) 183 | 184 | def _generate_atoms(self, lang, relationships): 185 | atoms = [ 186 | self._parse_relationship(rel)[0] 187 | for rel in relationships 188 | if self._parse_relationship(rel)[1] > 0.98 189 | ] 190 | return to_object_type_atoms(atoms, lang) -------------------------------------------------------------------------------- /src/deisam_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import cv2 5 | import torch 6 | from PIL import Image 7 | from segment_anything import SamPredictor, build_sam 8 | from torchvision.ops import masks_to_boxes 9 | from huggingface_hub import hf_hub_download 10 | 11 | # Visualization and utility imports 12 | from groundingdino.util.inference import annotate 13 | from visualization_utils import ( 14 | answer_to_boxes, 15 | save_box_to_file, 16 | save_segmented_images, 17 | save_segmented_images_with_target_scores, 18 | show_mask, 19 | show_mask_with_alpha, 20 | ) 21 | 22 | # Neumann library imports 23 | from neumann.facts_converter import FactsConverter 24 | from neumann.fol.data_utils import DataUtils 25 | from neumann.fol.language import DataType 26 | from neumann.fol.logic import Atom, Const, Predicate 27 | from neumann.logic_utils import generate_atoms 28 | from neumann.neumann_utils import get_neumann_model, get_trainable_neumann_model 29 | 30 | 31 | def load_neumann(lang, rules, atoms, device, infer_step=2): 32 | """Load a Neumann reasoner model with specified language, rules, and atoms.""" 33 | atoms = add_target_cond_atoms(lang, atoms) 34 | fc = FactsConverter(lang, atoms, [], device) 35 | reasoner = get_neumann_model( 36 | rules, [], [], atoms, lang.consts, lang, 1, infer_step, device 37 | ) 38 | return fc, reasoner 39 | 40 | 41 | def load_neumann_for_sgg_training(lang, rules_to_learn, rules_bk, atoms, device, infer_step=4): 42 | """Load a trainable Neumann model for scene graph generation (SGG) training.""" 43 | atoms = add_target_cond_atoms_for_sgg_training(lang, atoms) 44 | fc = FactsConverter(lang, atoms, [], device) 45 | reasoner = get_trainable_neumann_model( 46 | rules_to_learn, rules_bk, [], atoms, lang.consts, lang, 1, infer_step, device 47 | ) 48 | return fc, reasoner 49 | 50 | 51 | def load_sam_model(device): 52 | """Load SAM model from a checkpoint file.""" 53 | sam_checkpoint = "sam_vit_h_4b8939.pth" 54 | sam = build_sam(checkpoint=sam_checkpoint) 55 | sam.to(device=device) 56 | return SamPredictor(sam) 57 | 58 | 59 | def crop_objects(img, masks): 60 | """Crop objects from the image using provided masks.""" 61 | cropped_objects = [] 62 | 63 | for mask in masks: 64 | x, y, width, height = mask["bbox"] 65 | 66 | if width * height > 2000: 67 | cropped_image = img[int(y):int(y + height), int(x):int(x + width), :] 68 | cropped_objects.append(cropped_image) 69 | 70 | return cropped_objects 71 | 72 | 73 | def load_image(path): 74 | """Load an image from the specified file path and convert it to RGB.""" 75 | return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 76 | 77 | 78 | def extract_target_ids(v_T, threshold): 79 | """Extract target IDs from a tensor of valuations using a threshold.""" 80 | target_ids = [] 81 | obj_id = 0 82 | 83 | for atom in self.reasoner.atoms: 84 | if atom.pred.name == "target": 85 | index = self.reasoner.atoms.index(atom) 86 | if v_T[index] > threshold: 87 | target_ids.append(obj_id) 88 | obj_id += 1 89 | 90 | return target_ids 91 | 92 | 93 | def load_model_from_hf(repo_id, filename, config_filename, device="cpu"): 94 | """Load a model from Hugging Face Hub.""" 95 | config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) 96 | from groundingdino.models import build_model 97 | from groundingdino.util.slconfig import SLConfig 98 | from groundingdino.util.utils import clean_state_dict 99 | 100 | args = SLConfig.fromfile(config_path) 101 | model = build_model(args) 102 | args.device = device 103 | 104 | model_file = hf_hub_download(repo_id=repo_id, filename=filename) 105 | checkpoint = torch.load(model_file, map_location="cpu") 106 | model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 107 | print(f"Model loaded from {model_file}") 108 | 109 | return model.eval() 110 | 111 | 112 | def segment_by_grounded_sam(self, image, image_source, text_prompt, box_threshold=0.35, text_threshold=0.25): 113 | """Segment an image using a grounded SAM model.""" 114 | boxes, logits, phrases = predict( 115 | model=self.model, 116 | image=image, 117 | caption=text_prompt, 118 | box_threshold=box_threshold, 119 | text_threshold=text_threshold, 120 | ) 121 | annotated_frame = annotate( 122 | image_source=image_source, boxes=boxes, logits=logits, phrases=phrases 123 | )[..., ::-1] # Convert BGR to RGB 124 | return boxes, logits, phrases, annotated_frame 125 | 126 | 127 | def save_box_results(args, masks, answer, file_id, counter): 128 | """Save bounding boxes of predicted segmentations with ground truth.""" 129 | pr_boxes = masks_to_boxes(masks.squeeze(1).to(torch.int32)) 130 | gt_boxes = answer_to_boxes(answer) 131 | save_box_to_file(pr_boxes, gt_boxes, file_id, counter, args) 132 | return pr_boxes, gt_boxes 133 | 134 | 135 | def save_segmentation_result(args, masks, answer, image_source, counter, vg_image_id, data_index, deictic_representation, is_failed): 136 | """Plot and save segmentation masks on original visual inputs.""" 137 | annotated_frame = image_source 138 | 139 | for mask in masks: 140 | annotated_frame = show_mask(mask[0], annotated_frame) 141 | 142 | base_path = f"plot/{args.dataset}_comp{args.complexity}/DeiSAM{args.sgg_model}/" 143 | os.makedirs(base_path, exist_ok=True) 144 | 145 | save_segmented_images( 146 | counter=counter, 147 | vg_image_id=vg_image_id, 148 | image_with_mask=annotated_frame, 149 | description=deictic_representation, 150 | base_path=base_path, 151 | ) 152 | 153 | if is_failed: 154 | failed_path = f"plot/{args.dataset}_comp{args.complexity}/DeiSAM{args.sgg_model}_failed/" 155 | os.makedirs(failed_path, exist_ok=True) 156 | save_segmented_images( 157 | counter=counter, 158 | vg_image_id=vg_image_id, 159 | annotated_frame_with_mask=annotated_frame, 160 | description=deictic_representation, 161 | base_path=failed_path, 162 | ) 163 | 164 | 165 | def save_segmentation_result_with_alphas(args, masks, mask_probs, answer, image_source, counter, vg_image_id, data_index, deictic_representation, iter): 166 | """Plot and save segmentation masks with alpha values on visual inputs.""" 167 | annotated_frame = image_source 168 | 169 | for i, mask in enumerate(masks): 170 | annotated_frame = show_mask_with_alpha(mask=mask[0], image=annotated_frame, alpha=mask_probs[i]) 171 | 172 | base_path = f"learning_plot/{args.dataset}_comp{args.complexity}/seed_{args.seed}/iter_{iter}/DeiSAM{args.sgg_model}/" 173 | os.makedirs(base_path, exist_ok=True) 174 | 175 | save_segmented_images_with_target_scores( 176 | counter=counter, 177 | vg_image_id=vg_image_id, 178 | image_with_mask=annotated_frame, 179 | deictic_representation=deictic_representation, 180 | mask_probs=mask_probs, 181 | base_path=base_path, 182 | ) 183 | 184 | 185 | def save_llm_response(args, pred_response, rule_response, counter, image_id, deictic_representation): 186 | """Save LLM responses (pred & rule) to files.""" 187 | base_path = f"llm_output/{args.dataset}_comp{args.complexity}/DeiSAM{args.sgg_model}/" 188 | os.makedirs(f"{base_path}pred_response/", exist_ok=True) 189 | os.makedirs(f"{base_path}rule_response/", exist_ok=True) 190 | 191 | pred_file = f"{counter}_vg{image_id}_{deictic_representation.replace('.', '').replace('/', ' ')}.txt" 192 | rule_file = f"{counter}_vg{image_id}_{deictic_representation.replace('.', '').replace('/', ' ')}.txt" 193 | 194 | with open(f"{base_path}pred_response/{pred_file}", "w") as f: 195 | f.write(pred_response) 196 | 197 | with open(f"{base_path}rule_response/{rule_file}", "w") as f: 198 | f.write(rule_response) 199 | 200 | 201 | def get_random_masks(model): 202 | """Get random masks from the Neumann model's target atoms.""" 203 | targets = [atom for atom in model.neumann.atoms if "target(obj_" in str(atom)] 204 | return [random.choice(targets)] 205 | 206 | 207 | def add_target_cond_atoms(lang, atoms): 208 | """Add target and condition atoms to the atoms list.""" 209 | spec_predicate = Predicate(".", 1, [DataType("spec")]) 210 | true_atom = Atom(spec_predicate, [Const("__T__", dtype=DataType("spec"))]) 211 | 212 | target_atoms = generate_target_atoms(lang) 213 | cond_atoms = generate_cond_atoms(lang) 214 | return [true_atom] + sorted(set(atoms)) + target_atoms + cond_atoms 215 | 216 | 217 | def generate_target_atoms(lang): 218 | """Generate target atoms for each object constant in the language.""" 219 | target_predicate = Predicate("target", 1, [DataType("object")]) 220 | return [ 221 | Atom(target_predicate, [const]) 222 | for const in lang.consts if const.dtype.name == "object" and "obj_" in const.name 223 | ] 224 | 225 | 226 | 227 | 228 | def generate_cond_atoms(lang): 229 | """Generate condition atoms for each object constant in the language.""" 230 | cond_predicates = [Predicate(f"cond{i}", 1, [DataType("object")]) for i in range(1, 4)] 231 | cond_atoms = [] 232 | for pred in cond_predicates: 233 | cond_atoms.extend( 234 | Atom(pred, [const]) 235 | for const in lang.consts if const.dtype.name == "object" and "obj_" in const.name 236 | ) 237 | return sorted(set(cond_atoms)) 238 | 239 | 240 | def add_target_cond_atoms_for_sgg_training(lang, atoms, num_sgg_models=2): 241 | """Add target and condition atoms for SGG training to the list of atoms.""" 242 | spec_predicate = Predicate(".", 1, [DataType("spec")]) 243 | true_atom = Atom(spec_predicate, [Const("__T__", dtype=DataType("spec"))]) 244 | 245 | target_atoms = generate_target_atoms_for_sgg_training(lang, num_sgg_models) 246 | cond_atoms = generate_cond_atoms_for_sgg_training(lang, num_sgg_models) 247 | return [true_atom] + sorted(set(atoms)) + target_atoms + cond_atoms 248 | 249 | 250 | def generate_target_atoms_for_sgg_training(lang, num_sgg_models): 251 | """Generate target atoms for main and SGG models.""" 252 | main_target_pred = Predicate("target", 1, [DataType("object")]) 253 | sgg_target_preds = [ 254 | Predicate(f"target_sgg{i}", 1, [DataType("object")]) for i in range(num_sgg_models) 255 | ] 256 | 257 | all_target_preds = [main_target_pred] + sgg_target_preds 258 | target_atoms = [ 259 | Atom(pred, [const]) 260 | for pred in all_target_preds 261 | for const in lang.consts if const.dtype.name == "object" and "obj_" in const.name 262 | ] 263 | 264 | return sorted(set(target_atoms)) 265 | 266 | 267 | def generate_cond_atoms_for_sgg_training(lang, num_sgg_models): 268 | """Generate conditional atoms for SGG training.""" 269 | cond_atoms = [] 270 | 271 | for i in range(num_sgg_models): 272 | cond_preds = [Predicate(f"cond{j}_sgg{i}", 1, [DataType("object")]) for j in range(1, 4)] 273 | 274 | for cond_pred in cond_preds: 275 | cond_atoms.extend( 276 | Atom(cond_pred, [const]) 277 | for const in lang.consts if const.dtype.name == "object" and "obj_" in const.name 278 | ) 279 | 280 | return sorted(set(cond_atoms)) -------------------------------------------------------------------------------- /src/lark/exp.lark: -------------------------------------------------------------------------------- 1 | clause : atom ":-" body 2 | 3 | body : atom "," body 4 | | atom "." 5 | | "." 6 | 7 | atom : predicate "(" args ")" 8 | 9 | args : term "," args 10 | | term 11 | 12 | term : functor "(" args ")" 13 | | const 14 | | variable 15 | 16 | const : /[a-z0-9\*\_']+/ 17 | 18 | variable : /[A-Z]+[A-Za-z0-9]*/ 19 | | /\_/ 20 | 21 | functor : /[a-z0-9]+/ 22 | 23 | predicate : /[a-z0-9\_]+/ 24 | 25 | var_name : /[A-Z]/ 26 | small_chars : /[a-z0-9]+/ 27 | chars : /[^\+\|\s\(\)']+/[/\n+/] 28 | allchars : /[^']+/[/\n+/] 29 | -------------------------------------------------------------------------------- /src/learning_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | from torchvision.ops import masks_to_boxes 7 | import openai 8 | 9 | from data_vg import DeicticVisualGenomeSGGTraining, VisualGenomeUtils, PredictedSceneGraphUtils 10 | from deisam import TrainableDeiSAM 11 | from deisam_utils import get_random_masks, save_segmentation_result_with_alphas 12 | from learning_utils import to_bce_examples 13 | from rtpt import RTPT 14 | from visualization_utils import answer_to_boxes, save_box_to_file 15 | 16 | torch.set_num_threads(10) 17 | 18 | 19 | class DeiSAMTrainer: 20 | def __init__(self, args, device): 21 | self.args = args 22 | self.device = device 23 | self._set_random_seeds() 24 | 25 | self.data_loader = DeicticVisualGenomeSGGTraining(args, mode="train") 26 | self.val_data_loader = DeicticVisualGenomeSGGTraining(args, mode="val") 27 | self.test_data_loader = DeicticVisualGenomeSGGTraining(args, mode="test") 28 | 29 | self.vg_1 = VisualGenomeUtils() 30 | self.vg_2 = PredictedSceneGraphUtils(args.sgg_model) 31 | 32 | self.deisam = TrainableDeiSAM(api_key=args.api_key, device=device, 33 | vg_utils_list=[self.vg_1, self.vg_2], 34 | sem_uni=args.sem_uni) 35 | 36 | def _set_random_seeds(self): 37 | random.seed(self.args.seed) 38 | torch.manual_seed(self.args.seed) 39 | np.random.seed(self.args.seed) 40 | 41 | def train(self): 42 | rtpt = self._initialize_rtpt() 43 | optimizer = torch.optim.RMSprop(self.deisam.parameters(), lr=self.args.lr) 44 | bce_loss = torch.nn.BCELoss() 45 | 46 | os.makedirs(f"learning_logs/comp{self.args.complexity}", exist_ok=True) 47 | 48 | for epoch in range(self.args.epochs): 49 | for counter, (id, data_index, image_id, image_source, image, 50 | deictic_representation, answer) in enumerate(self.data_loader): 51 | if self.args.trained: 52 | self._load_trained_model() 53 | return 54 | 55 | if counter % 25 == 0: 56 | self._save_intermediate_model(counter) 57 | 58 | loss = self._process_training_step(counter, data_index, image_id, 59 | image_source, deictic_representation, 60 | answer, optimizer, bce_loss) 61 | rtpt.step(subtitle=f"Iter:{counter}") 62 | 63 | self._evaluate_on_test_data(counter) 64 | 65 | def _initialize_rtpt(self): 66 | rtpt = RTPT(name_initials="", 67 | experiment_name=f"LearnDeiSAM{self.args.complexity}", 68 | max_iterations=100) 69 | rtpt.start() 70 | return rtpt 71 | 72 | def _load_trained_model(self): 73 | save_path = f"models/comp{self.args.complexity}_iter100_seed{self.args.seed}.pth" 74 | saved_state = torch.load(save_path) 75 | trained_weights = saved_state["rule_weights"].to(self.device) 76 | self.deisam.rule_weights = torch.nn.Parameter(trained_weights).to(self.device) 77 | 78 | def _save_intermediate_model(self, counter): 79 | save_path = f"models/comp{self.args.complexity}_iter{counter}_seed{self.args.seed}.pth" 80 | torch.save(self.deisam.state_dict(), save_path) 81 | print(f"Intermediate model saved to {save_path}") 82 | 83 | def _process_training_step(self, counter, data_index, image_id, image_source, 84 | deictic_representation, answer, optimizer, bce_loss): 85 | print(f"=========== ID {counter}, IMAGE ID {image_id} ===========") 86 | print("Deictic representation:\n", deictic_representation) 87 | 88 | try: 89 | graphs = [self.vg_1.load_scene_graph_by_id(image_id), 90 | self.vg_2.load_scene_graph_by_id(image_id)] 91 | masks, target_scores, _ = self.deisam.forward( 92 | data_index, image_id, graphs, deictic_representation, image_source) 93 | 94 | if masks is None: 95 | print("No targets segmented.. skipping..") 96 | return 97 | 98 | predicted_boxes = masks_to_boxes(masks.squeeze(1).float().to(torch.int32)) 99 | answer_boxes = torch.tensor(answer_to_boxes(answer), device=self.device).to(torch.int32) 100 | 101 | box_probs, box_labels = to_bce_examples(predicted_boxes, target_scores, answer_boxes, self.device) 102 | loss = bce_loss(box_probs, box_labels) 103 | loss.backward() 104 | optimizer.step() 105 | return loss.item() 106 | 107 | except (KeyError, openai.error.APIError, openai.InvalidRequestError, openai.error.ServiceUnavailableError): 108 | print(f"Skipped or error occurred for ID {counter}, IMAGE ID {image_id}") 109 | 110 | def _evaluate_on_test_data(self, counter): 111 | segment_testdata(self.args, self.deisam, 112 | self.vg_1, self.vg_2, 113 | self.test_data_loader, 114 | counter, self.device) 115 | 116 | 117 | def segment_testdata(args, deisam, vg_1, vg_2, data_loader, iter, device, n_data=400): 118 | for counter, (id, data_index, image_id, image_source, image, deictic_representation, answer) in enumerate(data_loader): 119 | if counter < args.start or counter > args.end or counter > n_data: 120 | continue 121 | 122 | print(f"===== TEST ID:{counter}, IMAGE ID:{image_id}") 123 | print("Deictic representation:\n", deictic_representation) 124 | 125 | try: 126 | graphs = [vg_1.load_scene_graph_by_id(image_id), 127 | vg_2.load_scene_graph_by_id(image_id)] 128 | masks, target_scores, _ = deisam.forward( 129 | data_index, image_id, graphs, deictic_representation, image_source) 130 | 131 | if masks is None: 132 | print(f"No targets found on image {counter}. Using random mask.") 133 | target_atoms = get_random_masks(deisam) 134 | masks = deisam.segment_objects_by_sam(image_source, target_atoms, image_id) 135 | target_scores = [torch.tensor(0.5).to(device)] 136 | 137 | predicted_boxes = masks_to_boxes(masks.squeeze(1).to(torch.int32)) 138 | answer_boxes = torch.tensor(answer_to_boxes(answer), device=device).to(torch.int32) 139 | 140 | # TODO: specify the path not to mix up with the solve script results 141 | save_box_to_file( 142 | pr_boxes=predicted_boxes, 143 | pr_scores=target_scores, 144 | gt_boxes=answer_boxes, 145 | id=id, 146 | index=counter, 147 | iter=iter, 148 | args=args 149 | ) 150 | 151 | target_scores_cpu = [x.detach().cpu().numpy() for x in target_scores] 152 | save_segmentation_result_with_alphas( 153 | args, 154 | masks, 155 | target_scores_cpu, 156 | answer, 157 | image_source, 158 | counter, 159 | image_id, 160 | data_index, 161 | deictic_representation, 162 | iter, 163 | ) 164 | except openai.error.APIError: 165 | print("OpenAI API error.. skipping..") 166 | 167 | 168 | def main(): 169 | parser = argparse.ArgumentParser() 170 | 171 | parser.add_argument("-s", "--start", type=int, default=0, help="Start point (data index) for the inference.") 172 | parser.add_argument("-e", "--end", type=int, default=400, help="End point (data index) for the inference.") 173 | parser.add_argument("-ep", "--epochs", type=int, default=1, help="Training epochs.") 174 | parser.add_argument("-sd", "--seed", type=int, default=0, help="Random seed.") 175 | parser.add_argument("--lr", type=float, default=1e-2, help="The learning rate.") 176 | parser.add_argument("-c", "--complexity", type=int, default=2, choices=[1, 2, 3], help="Complexity level.") 177 | parser.add_argument("-d", "--dataset", default="deictic_visual_genome", choices=["deictic_visual_genome", "deictic_visual_genome_short"], help="Dataset.") 178 | parser.add_argument("-m", "--model", default="DeiSAM", choices=["DeiSAM", "GroundedSAM"], help="Model to use.") 179 | parser.add_argument("-sgg", "--sgg-model", default="", choices=["", "VETO"], help="Scene Graph Generation model.") 180 | parser.add_argument("-su", "--sem-uni", action="store_true", help="Use semantic unifier.") 181 | parser.add_argument("-tr", "--trained", action="store_true", help="Use trained model.") 182 | parser.add_argument("-k", "--api-key", type=str, required=True, help="An OpenAI API key.") 183 | 184 | args = parser.parse_args() 185 | 186 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 187 | trainer = DeiSAMTrainer(args, device) 188 | 189 | print("Starting training...") 190 | trainer.train() 191 | print("Training completed.") 192 | 193 | if __name__ == "__main__": 194 | main() -------------------------------------------------------------------------------- /src/llm_logic_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import lark 4 | import openai 5 | from lark import Lark 6 | from neumann.fol.exp_parser import ExpTree 7 | from neumann.fol.language import DataType, Language 8 | from neumann.fol.logic import Atom, Clause, Const, NeuralPredicate, Predicate 9 | 10 | 11 | flatten = lambda x: [z for y in x for z in (flatten(y) if hasattr(y, '__iter__') and not isinstance(y, str) else (y,))] 12 | 13 | 14 | class LLMLogicGenerator: 15 | """ 16 | A class to generate logic languages and rules from task descriptions in natural language 17 | using GPT and prepared textual prompts. 18 | """ 19 | 20 | def __init__(self, api_key): 21 | self.constants_prompt_path = "prompt/gen_constants.txt" 22 | self.predicates_prompt_path = "prompt/gen_predicates.txt" 23 | self.rules_prompt_path = "prompt/gen_rules.txt" 24 | 25 | self.constants_prompt = self._load_prompt(self.constants_prompt_path) 26 | self.predicates_prompt = self._load_prompt(self.predicates_prompt_path) 27 | self.rules_prompt = self._load_prompt(self.rules_prompt_path) 28 | 29 | # Setup parser 30 | lark_path = "src/lark/exp.lark" 31 | with open(lark_path, encoding="utf-8") as grammar: 32 | grammar_content = grammar.read() 33 | self.lp_atom = Lark(grammar_content, start="atom") 34 | self.lp_clause = Lark(grammar_content, start="clause") 35 | 36 | # Setup OpenAI API 37 | openai.api_key = api_key 38 | openai.organization = None 39 | 40 | def _load_prompt(self, file_path): 41 | with open(file_path, "r") as file: 42 | return file.read() 43 | 44 | def _parse_response(self, response, parse_function): 45 | return flatten([parse_function(line) for line in response.split("\n") if line.strip()]) 46 | 47 | def _parse_constants(self, line): 48 | try: 49 | dtype_name, const_names_str = line.replace(" ", "").split(":") 50 | dtype = DataType(dtype_name) 51 | const_names = const_names_str.split(",") 52 | return [Const(name, dtype) for name in const_names] 53 | except ValueError: 54 | return [] 55 | 56 | def _parse_predicates(self, line): 57 | pred_names = line.replace(" ", "").split(",") 58 | return [ 59 | NeuralPredicate(name, 2, [DataType("object"), DataType("object")]) 60 | for name in pred_names 61 | ] 62 | 63 | def _parse_rules(self, line, language): 64 | if ":-" in line: 65 | tree = self.lp_clause.parse(line.replace(" ", "")) 66 | return ExpTree(language).transform(tree) 67 | return None 68 | 69 | def query_gpt(self, text): 70 | """Query GPT-3.5 with a textual prompt.""" 71 | try: 72 | response = openai.ChatCompletion.create( 73 | model="gpt-3.5-turbo", 74 | messages=[{"role": "user", "content": text}], 75 | ) 76 | return response.choices[0]["message"]["content"].strip() 77 | except ( 78 | openai.error.APIError, 79 | openai.error.APIConnectionError, 80 | openai.error.ServiceUnavailableError, 81 | openai.error.Timeout, 82 | json.JSONDecodeError, 83 | ): 84 | time.sleep(3) 85 | print("OpenAI API error encountered. Retrying...") 86 | return self.query_gpt(text) 87 | 88 | def generate_constants(self, text): 89 | prompt = f"{self.constants_prompt}\n{text}\n Constants:" 90 | response = self.query_gpt(prompt) 91 | constants = self._parse_response(response, self._parse_constants) 92 | return constants, "Constants:\n" + response 93 | 94 | def generate_predicates(self, text): 95 | prompt = f"{self.predicates_prompt}\n{text}\n" 96 | response = self.query_gpt(prompt) 97 | predicates = self._parse_response(response, self._parse_predicates) 98 | 99 | required_predicates = [ 100 | ("type", 2, [DataType("object"), DataType("type")]), 101 | ("target", 1, [DataType("object")]), 102 | ("cond1", 1, [DataType("object")]), 103 | ("cond2", 1, [DataType("object")]), 104 | ("cond3", 1, [DataType("object")]), 105 | ] 106 | 107 | for name, arity, datatypes in required_predicates: 108 | if not any(p.name == name for p in predicates): 109 | if arity == 2: 110 | predicates.append(NeuralPredicate(name, arity, datatypes)) 111 | else: 112 | predicates.append(Predicate(name, arity, datatypes)) 113 | 114 | return predicates, response 115 | 116 | def get_preds_string(self, preds): 117 | return ",".join(p.name for p in preds if p.name not in ["target", "type", "cond1", "cond2", "cond3"]) 118 | 119 | def generate_rules(self, text, language): 120 | pred_response = self.get_preds_string(language.preds) 121 | self.pred_response = pred_response 122 | prompt = f"\n\n{self.rules_prompt}\n{text}\navailable predicates: {pred_response}\n" 123 | response = self.query_gpt(prompt) 124 | self.rule_response = response 125 | 126 | rules = self._parse_response(response, lambda line: self._parse_rules(line, language)) 127 | return [rule for rule in rules if rule] 128 | 129 | def generate_logic(self, text): 130 | """Generate constants, predicates, and rules using LLMs.""" 131 | constants, const_response = self.generate_constants(text) 132 | if not constants: 133 | raise ValueError("Error: No constants found in the prompt.") 134 | predicates, _ = self.generate_predicates(text) 135 | if not predicates: 136 | raise ValueError("Error: No predicates found in the prompt.") 137 | 138 | language = Language(consts=constants, preds=predicates, funcs=[]) 139 | rules = self.generate_rules(text, language) 140 | return language, rules -------------------------------------------------------------------------------- /src/sam_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from groundingdino.util import box_ops 3 | from visual_genome_utils import objdata_to_box 4 | 5 | 6 | def to_xyxy(boxes): 7 | """Convert boxes from xywh format to xyxy format. 8 | 9 | Args: 10 | boxes (torch.Tensor): A tensor of boxes in xywh format. 11 | 12 | Returns: 13 | torch.Tensor: A tensor of boxes in xyxy format. 14 | """ 15 | xyxy_boxes = [ 16 | torch.tensor([x, y, x + w, y + h]) 17 | for x, y, w, h in boxes 18 | ] 19 | return torch.stack(xyxy_boxes) 20 | 21 | 22 | def transform_boxes(boxes, image_source): 23 | """Transform normalized xywh boxes to unnormalized xyxy boxes. 24 | 25 | Args: 26 | boxes (torch.Tensor): A tensor of boxes in xywh format. 27 | image_source (numpy.ndarray): The source image to obtain dimensions. 28 | 29 | Returns: 30 | torch.Tensor or NoneAlgorithm: Transformed boxes in xyxy format, or NoneAlgorithm on failure. 31 | """ 32 | try: 33 | return to_xyxy(boxes) 34 | except RuntimeError: 35 | return NoneAlgorithm 36 | 37 | 38 | def to_object_ids(target_atoms): 39 | """Extract object IDs from FOL target atoms. 40 | 41 | Args: 42 | target_atoms (list): List of target atoms. 43 | 44 | Returns: 45 | list: A list of object IDs. 46 | """ 47 | return [int(atom.terms[0].name.split("_")[-1]) for atom in target_atoms] 48 | 49 | 50 | def to_boxes(target_atoms, data_index, visual_genome_utils): 51 | """Extract bounding boxes for target atoms using Visual Genome data. 52 | 53 | Args: 54 | target_atoms (list): List of target atoms. 55 | data_index (int): Index in the Visual Genome dataset. 56 | visual_genome_utils (object): Utilities for accessing Visual Genome data. 57 | 58 | Returns: 59 | list: A list of bounding boxes. 60 | """ 61 | object_ids = to_object_ids(target_atoms) 62 | relations = visual_genome_utils.all_relationships[data_index]["relationships"] 63 | 64 | boxes = [] 65 | for obj_id in object_ids: 66 | for rel in relations: 67 | if rel["object"]["object_id"] == obj_id: 68 | boxes.append(objdata_to_box(rel["object"])) 69 | break 70 | elif rel["subject"]["object_id"] == obj_id: 71 | boxes.append(objdata_to_box(rel["subject"])) 72 | break 73 | return boxes 74 | 75 | 76 | def to_boxes_with_sgg(target_atoms, image_id, visual_genome_utils): 77 | """Extract bounding boxes for target atoms using Scene Graph Generation data. 78 | 79 | Args: 80 | target_atoms (list): List of target atoms. 81 | image_id (int): Image ID for accessing relationships. 82 | visual_genome_utils (object): Utilities for accessing Visual Genome data. 83 | 84 | Returns: 85 | list: A list of bounding boxes. 86 | """ 87 | object_ids = to_object_ids(target_atoms) 88 | relations = visual_genome_utils.all_relationships[image_id] 89 | 90 | boxes = [] 91 | for obj_id in object_ids: 92 | for rel in relations: 93 | if rel["o_unique_id"] == obj_id: 94 | boxes.append(rel["obox"]) 95 | break 96 | if rel["s_unique_id"] == obj_id: 97 | boxes.append(rel["sbox"]) 98 | break 99 | return boxes 100 | 101 | 102 | def to_transformed_boxes(boxes, image_source, sam_predictor, device): 103 | """Apply transformations to bounding boxes using a SAM predictor. 104 | 105 | Args: 106 | boxes (list): List of bounding boxes. 107 | image_source (numpy.ndarray): The source image. 108 | sam_predictor (object): SAM predictor object with transform capabilities. 109 | device (torch.device): The device to convert the boxes to. 110 | 111 | Returns: 112 | torch.Tensor: A tensor of transformed boxes. 113 | """ 114 | try: 115 | boxes_xyxy = to_xyxy(boxes) 116 | transformed_boxes = sam_predictor.transform.apply_boxes_torch( 117 | boxes_xyxy, image_source.shape[:2] 118 | ).to(device) 119 | return transformed_boxes 120 | except RuntimeError: 121 | return [] -------------------------------------------------------------------------------- /src/semantic_unifier.py: -------------------------------------------------------------------------------- 1 | import time 2 | from cgitb import text 3 | from hmac import new 4 | 5 | import openai 6 | import torch 7 | 8 | # from diffusers.scripts.convert_kakao_brain_unclip_to_diffusers import text_encoder 9 | 10 | from neumann.fol.language import Language 11 | 12 | 13 | class SemanticUnifier: 14 | def __init__(self, graph_lang, device): 15 | # self.lang = lang 16 | self.graph_lang = graph_lang 17 | self.device = device 18 | self.const_mapping = {} 19 | self.pred_mapping = {} 20 | 21 | # set up embeddings 22 | # self.consts_embeddings = self._get_consts_embeddings() 23 | # self.preds_embeddings = self._get_preds_embeddings() 24 | # self.graph_consts_embeddings = self._get_graph_consts_embeddings() 25 | # self.graph_preds_embeddings = self._get_graph_preds_embeddings() 26 | 27 | def _init_scene_graph_consts_embeddings(self): 28 | self.graph_consts_embeddings = self._get_graph_consts_embeddings() 29 | 30 | def _init_scene_graph_preds_embeddings(self): 31 | self.graph_preds_embeddings = self._get_graph_preds_embeddings() 32 | 33 | def _get_consts_embeddings(self): 34 | dic = {} 35 | for c in self.lang.consts: 36 | c_embedding = self.get_embedding(c.name.replace("_", " ")) 37 | dic[c] = c_embedding 38 | return dic 39 | 40 | def _get_preds_embeddings(self): 41 | dic = {} 42 | for p in self.lang.preds: 43 | p_embedding = self.get_embedding(p.name.replace("_", " ")) 44 | dic[p] = p_embedding 45 | return dic 46 | 47 | def _get_graph_consts_embeddings(self): 48 | dic = {} 49 | for c in self.graph_lang.consts: 50 | c_embedding = self.get_embedding(c.name) 51 | dic[c] = c_embedding 52 | return dic 53 | 54 | def _get_graph_preds_embeddings(self): 55 | dic = {} 56 | for p in self.graph_lang.preds: 57 | p_embedding = self.get_embedding(p.name) 58 | dic[p] = p_embedding 59 | return dic 60 | 61 | def to_language(self, graph_atoms): 62 | """Generate a FOL language given atoms that represent a scene graph. 63 | 64 | Args: 65 | graph_atoms (Atom): a set of atoms represent a scene graph. 66 | 67 | Returns: 68 | Language : a language computed from graph atoms. 69 | """ 70 | preds = set() 71 | consts = set() 72 | 73 | for atom in graph_atoms: 74 | preds.add(atom.pred) 75 | for c in atom.terms: 76 | consts.add(c) 77 | lang = Language(preds=list(preds), funcs=[], consts=list(consts)) 78 | return lang 79 | 80 | def get_most_similar_index(self, x, ys): 81 | pass 82 | 83 | def get_most_similar_predicate_in_graph(self, pred): 84 | # num_graph_pred = len(self.graph_lang.preds) 85 | X = self.get_embedding(pred.name).unsqueeze(0) # .expand((num_graph_pred, -1)) 86 | X_graph = torch.stack(list(self.graph_preds_embeddings.values())) 87 | # score = torch.dot(X.T, X_graph) 88 | score = torch.sum(X * X_graph, axis=-1) 89 | index = torch.argmax(score).item() 90 | return self.graph_lang.preds[index] 91 | 92 | def get_most_similar_constant_in_graph(self, const): 93 | X = self.get_embedding(const.name).unsqueeze(0) 94 | X_graph = torch.stack(list(self.graph_consts_embeddings.values())) 95 | score = torch.sum(X * X_graph, axis=-1) 96 | index = torch.argmax(score).item() 97 | # print(self.graph_lang.consts, len(self.graph_lang.consts)) 98 | # print(score, score.shape) 99 | # print(index) 100 | return self.graph_lang.consts[index] 101 | 102 | def build_const_mapping(self, lang, graph_lang): 103 | dic = {} 104 | for c in lang.consts: 105 | if not c in graph_lang.consts: 106 | # find the most similar graph const 107 | return 0 108 | pass 109 | 110 | def build_pred_mapping(self, lang, graph_lang): 111 | pass 112 | 113 | def rewrite_lang(self, lang, graph_lang): 114 | """Rewrite the language using only existing vocabluary in the graph. 115 | 116 | Args: 117 | lang (_type_): _description_ 118 | 119 | Returns: 120 | _type_: _description_ 121 | """ 122 | new_lang = Language( 123 | preds=lang.preds.copy(), funcs=[], consts=lang.consts.copy() 124 | ) 125 | const_mapping = {} 126 | for const in new_lang.consts: 127 | if const not in graph_lang.consts: 128 | new_const = self.get_most_similar_constant_in_graph(const) 129 | new_lang.consts.remove(const) 130 | new_lang.consts.append(new_const) 131 | const_mapping[const] = new_const 132 | predicate_mapping = {} 133 | for pred in new_lang.preds: 134 | if pred not in graph_lang.preds: 135 | new_pred = self.get_most_similar_predicate_in_graph(pred) 136 | new_lang.preds.remove(pred) 137 | new_lang.preds.append(new_pred) 138 | predicate_mapping[pred] = new_pred 139 | return new_lang, const_mapping, predicate_mapping 140 | 141 | # def rewrite_rules(self, rules, const_mapping, predicate_mapping): 142 | # new_rules = [] 143 | # for rule in rules: 144 | # new_rule = rule 145 | # new_atoms = [] 146 | # for atom in [rule.head] + rule.body: 147 | # # check / rewrite predicate 148 | # if atom.pred in predicate_mapping.keys(): 149 | # atom.pred = predicate_mapping[atom.pred] 150 | # # check / rewrite const 151 | # for i, const in enumerate(atom.terms): 152 | # if const in const_mapping.keys(): 153 | # atom.terms[i] = const_mapping[const] 154 | # new_atoms.append(atom) 155 | # new_rule.head = new_atoms[0] 156 | # new_rule.body = new_atoms[1:] 157 | # new_rules.append(new_rule) 158 | # return new_rules 159 | 160 | def rewrite_rules(self, rules, lang, graph_lang, rewrite_pred=True): 161 | reserved_preds = ["target", "type", "cond1", "cond2", "cond3"] 162 | self._init_scene_graph_consts_embeddings() 163 | self._init_scene_graph_preds_embeddings() 164 | new_rules = [] 165 | # new_lang = Language(preds=lang.preds.copy(), funcs=[], consts=lang.consts.copy()) 166 | for rule in rules: 167 | new_rule = rule 168 | new_atoms = [] 169 | for atom in [rule.head] + rule.body: 170 | # check / rewrite predicate 171 | pred = atom.pred 172 | if ( 173 | rewrite_pred 174 | and pred.name not in reserved_preds 175 | and pred not in graph_lang.preds 176 | ): 177 | # replace the non-existing predicate by the most similar one 178 | new_pred = self.get_most_similar_predicate_in_graph(pred) 179 | atom.pred = new_pred 180 | self.pred_mapping[pred.name] = new_pred.name 181 | print(pred.name, " -> ", new_pred.name) 182 | # new_lang.preds.remove(pred) 183 | # new_lang.preds.append(new_pred) 184 | # check / rewrite const 185 | for i, const in enumerate(atom.terms): 186 | if ( 187 | const.__class__.__name__ == "Const" 188 | and const not in graph_lang.consts 189 | ): 190 | # replace the non-existing constant by the most similar one 191 | new_const = self.get_most_similar_constant_in_graph(const) 192 | atom.terms[i] = new_const 193 | self.const_mapping[const.name] = new_const.name 194 | print(const.name, " -> ", new_const.name) 195 | # new_lang.consts.remove(const) 196 | # new_lang.consts.append(new_const) 197 | new_atoms.append(atom) 198 | new_rule.head = new_atoms[0] 199 | new_rule.body = new_atoms[1:] 200 | new_rules.append(new_rule) 201 | return new_rules # , #new_lang 202 | 203 | # def unify(self, lang, graph_lang, rules): 204 | # rewrite lang 205 | # internally overwrite self.lang 206 | # new_lang, const_mapping, pred_mapping = self.rewrite_lang(lang, graph_lang) 207 | # generate new rules using the refined language 208 | # new_rules new_lang = self.rewrite_rules(rules, lang, graph_lang) 209 | # eturn new_lang, new_rules 210 | 211 | def get_embedding(self, text_to_embed): 212 | response = openai.Embedding.create( 213 | model="text-embedding-ada-002", input=[text_to_embed.replace("_", " ")] 214 | ) 215 | # Extract the AI output embedding as a list of floats 216 | embedding = torch.tensor(response["data"][0]["embedding"]).to(self.device) 217 | return embedding 218 | 219 | # try: 220 | # # Embed a line of text 221 | # response = openai.Embedding.create( 222 | # model="text-embedding-ada-002", input=[text_to_embed.replace("_", " ")] 223 | # ) 224 | # # Extract the AI output embedding as a list of floats 225 | # embedding = torch.tensor(response["data"][0]["embedding"]).to(self.device) 226 | # return embedding 227 | # # except (openai.ServiceUnavailableError, openai.InvalidRequestError): 228 | # except (openai.InvalidRequestError, openai.error.ServiceUnavailableError): 229 | # print( 230 | # "Got openai.InvalidRequestError or openai.error.ServiceUnavailableError for embeddings in Semantic Unification, waiting for 3s and try again..." 231 | # ) 232 | # time.sleep(3) 233 | # return self.get_embedding(text_to_embed=text_to_embed) 234 | -------------------------------------------------------------------------------- /src/visual_genome_utils.py: -------------------------------------------------------------------------------- 1 | import neumann 2 | import neumann.fol.logic as logic 3 | from neumann.fol.language import DataType, Language 4 | from neumann.fol.logic import Const 5 | 6 | 7 | def scene_graph_to_language(scene_graph, text, logic_generator, num_objects=2): 8 | """Extract a FOL language from a scene graph to parse rules later.""" 9 | 10 | # Extract and sanitize object names from the scene graph 11 | objects = {str(obj).replace(" ", "") for obj in scene_graph.objects} 12 | datatype = DataType("type") 13 | 14 | # Define constants using the extracted object names 15 | constants = [Const(obj, datatype) for obj in objects] 16 | 17 | # Prepare constant response for predicates 18 | const_response = "Constants:\ntype:" + ",".join(objects) 19 | 20 | # Generate predicates using the logic generator based on the input text 21 | predicates, pred_response = logic_generator.generate_predicates(text) 22 | print(f"Predicate generator response:\n {pred_response}") 23 | 24 | # Formulate the language using constants and predicates 25 | lang = Language(consts=list(constants), preds=list(predicates), funcs=[]) 26 | return lang 27 | 28 | 29 | def get_init_language_with_sgg(scene_graph, text, logic_generator): 30 | """Extract an initial FOL language from a predicted scene graph to parse rules later.""" 31 | 32 | # Extract unique object and subject names from the scene graph 33 | objects = {rel["o_str"] for rel in scene_graph} 34 | subjects = {rel["s_str"] for rel in scene_graph} 35 | datatype = DataType("type") 36 | 37 | # Define constants using the extracted names 38 | constants = [Const(obj, datatype) for obj in objects | subjects] 39 | 40 | # Prepare constant response for predicates 41 | const_response = "Constants:\ntype:" + ",".join(objects) 42 | 43 | # Generate predicates using the logic generator based on the input text 44 | predicates, pred_response = logic_generator.generate_predicates(text, const_response) 45 | print(f"Predicate generator response:\n {pred_response}") 46 | 47 | # Formulate the language using constants and predicates 48 | lang = Language(consts=list(constants), preds=list(predicates), funcs=[]) 49 | return lang 50 | 51 | 52 | def scene_graph_to_language_with_sgg(scene_graph): 53 | """Extract a complete FOL language from a scene graph for use with a semantic unifier.""" 54 | 55 | # Extract unique objects, subjects, and relationships from the scene graph 56 | objects = {rel["o_str"] for rel in scene_graph} 57 | subjects = {rel["s_str"] for rel in scene_graph} 58 | relationships = {rel["p_str"] for rel in scene_graph} 59 | datatype = DataType("type") 60 | obj_datatype = DataType("object") 61 | 62 | # Define constants using the extracted names 63 | constants = [Const(obj, datatype) for obj in objects | subjects] 64 | 65 | # Define predicates for each relationship 66 | predicates = [ 67 | logic.Predicate(rel.replace(" ", "_").lower(), 2, [obj_datatype, obj_datatype]) 68 | for rel in relationships 69 | ] 70 | 71 | # Formulate the language using constants and predicates 72 | lang = Language(consts=list(constants), preds=list(predicates), funcs=[]) 73 | return lang 74 | 75 | 76 | def objdata_to_box(data): 77 | """Convert object data to a bounding box format.""" 78 | x, y, w, h = data["x"], data["y"], data["w"], data["h"] 79 | return x, y, w, h 80 | -------------------------------------------------------------------------------- /src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | from sam_utils import to_object_ids 6 | 7 | 8 | def apply_random_color(): 9 | return np.concatenate([np.random.random(3), np.array([0.8])], axis=0) 10 | 11 | 12 | def apply_default_color(alpha=0.75): 13 | return np.array([255 / 255, 10 / 255, 10 / 255, alpha]) 14 | 15 | 16 | def get_colored_mask(mask, color): 17 | h, w = mask.shape[-2:] 18 | mask = mask.cpu() if mask.device.type != "cpu" else mask 19 | return mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 20 | 21 | 22 | def overlay_mask_on_image(mask, image, color): 23 | annotated_frame_pil = Image.fromarray(image).convert("RGBA") 24 | mask_image_pil = Image.fromarray( 25 | (mask.cpu().numpy() * 255).astype(np.uint8) 26 | ).convert("RGBA") 27 | return Image.alpha_composite(annotated_frame_pil, mask_image_pil) 28 | 29 | 30 | def show_mask(mask, image, random_color=False): 31 | color = apply_random_color() if random_color else apply_default_color() 32 | mask_image = get_colored_mask(mask, color) 33 | return np.array(overlay_mask_on_image(mask_image, image, color)) 34 | 35 | 36 | def show_mask_with_alpha(mask, image, alpha, random_color=False): 37 | color = apply_random_color() if random_color else apply_default_color(alpha * 0.75) 38 | mask_image = get_colored_mask(mask, color) 39 | return np.array(overlay_mask_on_image(mask_image, image, color)) 40 | 41 | 42 | def get_bbox_by_id(object_id, data_index, vg): 43 | objects = vg.all_objects[data_index]["objects"] 44 | target_object = next((o for o in objects if int(o["object_id"]) == object_id), None) 45 | 46 | if not target_object: 47 | raise ValueError(f"Object with ID {object_id} not found.") 48 | 49 | return target_object["x"], target_object["y"], target_object["w"], target_object["h"] 50 | 51 | 52 | def to_crops(image_source, boxes): 53 | return [image_source[x:x + w, y:y + h] for x, y, w, h in boxes] 54 | 55 | 56 | def _to_boxes(target_atoms, data_index, vg): 57 | return vg.target_atoms_to_regions(target_atoms, data_index) 58 | 59 | 60 | def objdata_to_box(data): 61 | return data["x"], data["y"], data["w"], data["h"] 62 | 63 | 64 | def to_boxes(target_atoms, data_index, vg): 65 | object_ids = to_object_ids(target_atoms) 66 | relations = vg.all_relationships[data_index]["relationships"] 67 | 68 | return [ 69 | objdata_to_box(rel["object"] if rel["object"]["object_id"] == id else rel["subject"]) 70 | for id in object_ids 71 | for rel in relations 72 | if rel["object"]["object_id"] == id or rel["subject"]["object_id"] == id 73 | ] 74 | 75 | # def to_boxes(target_atoms, data_index, vg): 76 | # # get box from relations!! not objects 77 | # object_ids = to_object_ids(target_atoms) 78 | # relations = vg.all_relationships[data_index]["relationships"] 79 | # boxes = [] 80 | # for id in object_ids: 81 | # for rel in relations: 82 | # if rel["object"]["object_id"] == id: 83 | # boxes.append(objdata_to_box(rel["object"])) 84 | # break 85 | # elif rel["subject"]["object_id"] == id: 86 | # boxes.append(objdata_to_box(rel["subject"])) 87 | # break 88 | # return boxes 89 | 90 | def to_xyxy(boxes): 91 | xyxy_boxes = [torch.tensor([x, y, x + w, y + h]) for x, y, w, h in boxes] 92 | return torch.stack(xyxy_boxes) 93 | 94 | 95 | def save_boxes_to_file(boxes, path, is_prediction=True): 96 | text = "\n".join(f"target 1.0 {box[0]} {box[1]} {box[2]} {box[3]}" for box in boxes) 97 | with open(path, "w") as f: 98 | f.write(text) 99 | print(f"File saved to {path}") 100 | 101 | 102 | def save_box_to_file(pr_boxes, gt_boxes, id, counter, args): 103 | dirs = [ 104 | f"result/{args.dataset}_comp{args.complexity}/{args.model}/prediction", 105 | f"result/{args.dataset}_comp{args.complexity}/{args.model}/ground_truth", 106 | ] 107 | 108 | for dir_path in dirs: 109 | os.makedirs(dir_path, exist_ok=True) 110 | 111 | pr_path = f"{dirs[0]}/{counter}_vg{id}.txt" 112 | gt_path = f"{dirs[1]}/{counter}_vg{id}.txt" 113 | 114 | save_boxes_to_file(pr_boxes, pr_path) 115 | save_boxes_to_file(gt_boxes, gt_path, is_prediction=False) 116 | 117 | 118 | def answer_to_boxes(answers): 119 | if not isinstance(answers, list): 120 | answers = [answers] 121 | 122 | return [[answer["x"], answer["y"], answer["x"] + answer["w"], answer["y"] + answer["h"]] for answer in answers] 123 | 124 | 125 | def save_segmented_images(counter, vg_image_id, image_with_mask, description, base_path="imgs/"): 126 | image_with_mask = Image.fromarray(image_with_mask).convert("RGB") 127 | description = description.replace("/", " ").replace(".", "") 128 | save_path = f"{base_path}deicticVG_ID:{counter}_VGImID:{vg_image_id}_{description}.png" 129 | image_with_mask.save(save_path) 130 | print(f"Image saved to {save_path}") 131 | 132 | 133 | def save_segmented_images_with_target_scores(counter, vg_image_id, image_with_mask, description, target_scores, base_path="imgs/"): 134 | if len(target_scores) > 3: 135 | target_scores = target_scores[:3] 136 | scores_str = str(np.round(target_scores, 2).tolist()) 137 | save_segmented_images(counter, vg_image_id, image_with_mask, f"{description}_scores_{scores_str}", base_path) --------------------------------------------------------------------------------