├── arlc
├── __init__.py
├── datasets
│ ├── __init__.py
│ └── iraven.py
├── utils
│ ├── const.py
│ ├── checkpath.py
│ ├── averagemeter.py
│ ├── general.py
│ ├── vsa.py
│ ├── raven
│ │ ├── scene.py
│ │ ├── env.py
│ │ ├── raven_one_hot.py
│ │ └── extraction.py
│ └── parsing.py
├── losses.py
├── selection.py
├── execution.py
└── rule_templates.py
├── requirements.txt
├── figs
└── arlc_preview.png
├── .gitignore
├── .pre-commit-config.yaml
├── experiments
├── arlc_learn.sh
├── arlc_progr_to_learn.sh
├── ablations
│ ├── context.sh
│ └── learnvrf_nopn_2x2.sh
├── uncertainty
│ ├── train_noisy.sh
│ ├── arlc_learn_noisy.sh
│ ├── train_confounders.sh
│ ├── arlc_eval_noisy.sh
│ ├── exp_dist_inference.sh
│ ├── exp_confounders_inference.sh
│ └── eval_noisy.sh
├── arlc_progr.sh
├── arlc_ood.sh
└── iravenx
│ ├── arlc_program_eval.sh
│ └── arlc_learn_iravenx_50.sh
├── setup.py
├── data
└── README.md
├── results.py
├── README.md
├── main.py
└── LICENSE
/arlc/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | matplotlib
3 | tensorboard
4 | progressbar
5 | scikit-learn
6 | pre-commit
7 |
--------------------------------------------------------------------------------
/arlc/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .iraven import GeneralIRAVENDataset
2 |
3 | __all__ = ["GeneralIRAVENDataset"]
4 |
--------------------------------------------------------------------------------
/figs/arlc_preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/abductive-rule-learner-with-context-awareness/main/figs/arlc_preview.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # MacOS related files
2 | *.DS_Store
3 | *.icloud
4 |
5 | # Python files
6 | **__pycache__**
7 | **.pyc
8 | *.out
9 | *.vscode
10 | **.egg-info
11 | *.py[cod]
12 | *$py.class
13 | .ruff_cache
14 | *.log
15 | .pytest_cache
16 | results
17 |
18 | # Jupyter notebooks
19 | *.ipynb_checkpoints
20 |
21 | # Credentials
22 | **/.config
23 |
24 | **.code-workspace
25 |
26 | models
27 | data
28 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 22.3.0
4 | hooks:
5 | - id: black
6 | language_version: python3.10
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v4.4.0
9 | hooks:
10 | - id: trailing-whitespace
11 | - id: end-of-file-fixer
12 | - id: check-yaml
13 | - id: check-json
14 | - id: check-merge-conflict
15 | - id: check-case-conflict
16 | - id: mixed-line-ending
17 | - id: fix-byte-order-marker
18 |
--------------------------------------------------------------------------------
/experiments/arlc_learn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG="distribute_four"
4 | RUN="arlc_learn"
5 | EXP_DIR="models"
6 | EPOCHS=25
7 | NTEST=10
8 | NRULES=5
9 |
10 | for SEED in $(seq 1 $NTEST);
11 | do
12 | echo "Running training with seed $SEED"
13 | python main.py --epochs $EPOCHS \
14 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --data_dir data \
15 | --batch_size 8 --num_workers 1 --num_rules $NRULES --seed $SEED --run_name $RUN --exp_dir $EXP_DIR
16 | done
17 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
18 |
--------------------------------------------------------------------------------
/experiments/arlc_progr_to_learn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG="distribute_four"
4 | RUN="arlc_progr"
5 | EXP_DIR="models"
6 | EPOCHS=25
7 | NTEST=10
8 | NRULES=5
9 |
10 |
11 | for SEED in $(seq 1 $NTEST);
12 | do
13 | echo "Running training with seed $SEED"
14 | python main.py --epochs $EPOCHS \
15 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --data_dir data \
16 | --batch_size 8 --num_workers 1 --num_rules $NRULES --seed $SEED --run_name $RUN --exp_dir $EXP_DIR --program
17 | done
18 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
19 |
--------------------------------------------------------------------------------
/experiments/ablations/context.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG="distribute_four"
4 | RUN="ablation_np_pn_superpos"
5 | EXP_DIR="models"
6 | EPOCHS=25
7 | NTEST=10
8 | NRULES=5
9 |
10 | for SEED in $(seq 1 $NTEST);
11 | do
12 | echo "Running training with seed $SEED"
13 | python main.py --epochs $EPOCHS \
14 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --data_dir /dccstor/saentis/data/I-RAVEN \
15 | --batch_size 8 --num_workers 1 --num_rules $NRULES --seed $SEED --run_name $RUN --num_term 6 --rule_type arlc \
16 | --exp_dir $EXP_DIR
17 | done
18 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
19 |
--------------------------------------------------------------------------------
/experiments/ablations/learnvrf_nopn_2x2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG="distribute_four"
4 | RUN="ablation_np_pn_superpos"
5 | EXP_DIR="models"
6 | EPOCHS=25
7 | NTEST=10
8 | NRULES=5
9 |
10 | for SEED in $(seq 1 $NTEST);
11 | do
12 | echo "Running training with seed $SEED"
13 | python main.py --epochs $EPOCHS \
14 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --data_dir /dccstor/saentis/data/I-RAVEN \
15 | --batch_size 8 --num_workers 1 --num_rules $NRULES --seed $SEED --run_name $RUN --num_term 6 --rule_type learnvrf \
16 | --exp_dir $EXP_DIR
17 | done
18 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
19 |
--------------------------------------------------------------------------------
/experiments/uncertainty/train_noisy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG="center_single"
4 | EXP_DIR="models"
5 | NTEST=5
6 | NRULES=5
7 | DATA="iravenx"
8 | EPOCHS=10
9 | SIGMA=0.7
10 | RUN="iravenx_noisy_$SIGMA"
11 |
12 | if [ "$DATA" = "iraven" ]; then
13 | DATA_DIR="data/I-RAVEN"
14 | else
15 | DATA_DIR="data/I-RAVEN-X"
16 | fi
17 |
18 | for SEED in $(seq 1 $NTEST);
19 | do
20 | echo $SEED
21 | python main.py --epochs $EPOCHS --dyn_range 100 --n 10 \
22 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR --entropy --sigma $SIGMA \
23 | --batch_size 8 --num_workers 1 --num_rules $NRULES --num_terms 26 --seed $SEED --run_name $RUN --exp_dir $EXP_DIR --partition _shuffle
24 | done
25 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
26 |
--------------------------------------------------------------------------------
/experiments/uncertainty/arlc_learn_noisy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG="center_single"
4 | SIGMA=0.7
5 | RUN="arlc_learn_noisy_$SIGMA"
6 | EXP_DIR="models"
7 | EPOCHS=15
8 | NTEST=5
9 | NRULES=5
10 | DATA="iravenx"
11 | if [ "$DATA" = "iraven" ]; then
12 | DATA_DIR="data/I-RAVEN"
13 | else
14 | DATA_DIR="data/I-RAVEN-X"
15 | fi
16 |
17 |
18 |
19 | for SEED in $(seq 1 $NTEST);
20 | do
21 | echo $SEED
22 | python main.py --epochs $EPOCHS --n 3 --dyn_range 10 --num_terms 12 \
23 | --vsa_conversion --vsa_selection --config $CONFIG --dataset $DATA --data_dir $DATA_DIR --partition _shuffle --entropy \
24 | --batch_size 8 --num_workers 1 --num_rules $NRULES --seed $SEED --run_name $RUN --exp_dir $EXP_DIR --orientation-confounder 0 --sigma $SIGMA
25 | done
26 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
27 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | from setuptools import setup, find_packages
7 |
8 | setup(
9 | name="arlc",
10 | version="1.0.0",
11 | description="Abductive Rule Learner with Context-awareness",
12 | url="https://research.ibm.com/people/giacomo-camposampiero--1",
13 | author="Giacomo Camposampiero",
14 | author_email="giacomo.camposampiero1@ibm.com",
15 | license="GPL-3.0",
16 | packages=find_packages(
17 | where="arlc",
18 | ),
19 | include_package_data=True,
20 | zip_safe=False,
21 | )
22 |
--------------------------------------------------------------------------------
/experiments/uncertainty/train_confounders.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG="center_single"
4 | CONF=5
5 | RUN="iravenx_confounders_entropy_$CONF"
6 | EXP_DIR="models"
7 | NTEST=5
8 | NRULES=5
9 | DEBUG=0
10 | DATA="iravenx"
11 | EPOCHS=10
12 |
13 | if [ "$DATA" = "iraven" ]; then
14 | DATA_DIR="data/I-RAVEN"
15 | else
16 | DATA_DIR="data/I-RAVEN-X"
17 | fi
18 |
19 |
20 | for SEED in $(seq 1 $NTEST);
21 | do
22 | echo $SEED
23 | python main.py --epochs $EPOCHS --dyn_range 100 --n 10 \
24 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR --entropy \
25 | --batch_size 8 --num_workers 1 --num_rules $NRULES --num_terms 26 --seed $SEED --run_name $RUN --exp_dir $EXP_DIR --partition _shuffle --orientation-confounder $CONF
26 | done
27 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
28 |
--------------------------------------------------------------------------------
/arlc/utils/const.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | DIM_EXIST = 2
7 | DIM_POSITION_2x2 = 15
8 | DIM_POSITION_3x3 = 511
9 | DIM_NUMBER_2x2 = 4
10 | DIM_NUMBER_3x3 = 9
11 | DIM_ONEHOT = 1001
12 | LOG_EPSILON = 1e-39
13 | NORM_SCALE = 1e15
14 |
15 | NUMPOS = {
16 | "center_single": 1,
17 | "distribute_four": 4,
18 | "distribute_nine": 9,
19 | "left_center_single_right_center_single": 2,
20 | "up_center_single_down_center_single": 2,
21 | "in_center_single_out_center_single": 2,
22 | "in_distribute_four_out_center_single": 5,
23 | }
24 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ## Get the Data
2 | Generate the I-RAVEN dataset with the instructions proveded [here](https://github.com/husheng12345/SRAN) and save it in this folder.
3 |
4 | ```bash
5 | git clone https://github.com/husheng12345/SRAN
6 | pip2 install --user -r SRAN/I-RAVEN/requirements.txt
7 | python2 SRAN/I-RAVEN/main.py --save-dir .
8 | ```
9 |
10 | ## Prepare the Data
11 |
12 | Run the rule preprocessing script:
13 | ```bash
14 | python arlc/utils/raven/extraction.py --data_path data
15 | ```
16 |
17 | In the latest version of the code we migrated from the original numpy-based dataset to a JSON-based following the approach of [Hu et al.](https://github.com/hxiaoyang/lm-raven).
18 | To convert the original dataset to the JSON files required by the new dataloader, use the script provided [here](https://github.com/IBM/raven-large-language-models/blob/main/src/datasets/generation/iraven_task.py).
19 |
--------------------------------------------------------------------------------
/experiments/uncertainty/arlc_eval_noisy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL=$1
4 |
5 | CONFIG="center_single"
6 | EXP_DIR="models"
7 | EPOCHS=10
8 | NTEST=5
9 | NRULES=5
10 | DEBUG=1
11 | DATA="iraven"
12 | SIGMA=0.3
13 |
14 | if [ "$DATA" = "iraven" ]; then
15 | DATA_DIR="data/I-RAVEN"
16 | else
17 | DATA_DIR="data/I-RAVEN-X"
18 | fi
19 |
20 | for SEED in $(seq 1 $NTEST);
21 | do
22 | python main.py --resume models/$MODEL/$SEED/ckpt --n 3 --dyn_range 10 --num_terms 12 \
23 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR \
24 | --batch_size 512 --num_workers 1 --num_rules $NRULES --num_terms 12 --seed $SEED --run_name $MODEL --orientation-confounder 0 \
25 | --exp_dir $EXP_DIR --partition _shuffle --mode test --evaluate-rule --entropy --sigma $SIGMA
26 | done
27 | python results.py --path models/$MODEL --seeds $NTEST
28 |
--------------------------------------------------------------------------------
/experiments/uncertainty/exp_dist_inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | RUN=$1
4 |
5 | CONFIG="center_single"
6 | EXP_DIR="models"
7 | EPOCHS=10
8 | NTEST=5
9 | NRULES=5
10 | DEBUG=1
11 | DATA="iravenx"
12 | SIGMAS=(-0.7 -0.51)
13 |
14 | if [ "$DATA" = "iraven" ]; then
15 | DATA_DIR="data/I-RAVEN"
16 | else
17 | DATA_DIR="data/I-RAVEN-X"
18 | fi
19 |
20 | for SIGMA in "${SIGMAS[@]}";
21 | do
22 | echo "****************** sigma $SIGMAS"
23 | for SEED in $(seq 1 $NTEST);
24 | do
25 | python main.py --dyn_range 100 --n 10 --resume models/$RUN/$SEED/ckpt \
26 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR \
27 | --batch_size 128 --num_workers 1 --num_rules $NRULES --num_terms 26 --seed $SEED --run_name $RUN \
28 | --exp_dir $EXP_DIR --partition _shuffle --mode test --evaluate-rule --orientation-confounder 0 --entropy --sigma $SIGMA
29 | done
30 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
31 | done
32 |
--------------------------------------------------------------------------------
/experiments/uncertainty/exp_confounders_inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | RUN=$1
4 |
5 | CONFIG="center_single"
6 | EXP_DIR="models"
7 | EPOCHS=10
8 | NTEST=3
9 | NRULES=5
10 | DEBUG=1
11 | DATA="iravenx"
12 | N_CONFOUNDERS=(0 1 3 5 10 30 300)
13 |
14 | if [ "$DATA" = "iraven" ]; then
15 | DATA_DIR="data/I-RAVEN"
16 | else
17 | DATA_DIR="data/I-RAVEN-X"
18 | fi
19 |
20 | for N_CONFOUNDERS in "${N_CONFOUNDERS[@]}";
21 | do
22 | echo "****************** N_CONF $N_CONFOUNDERS"
23 | for SEED in $(seq 1 $NTEST);
24 | do
25 | python main.py --dyn_range 1000 --n 10 --resume models/$RUN/$SEED/ckpt \
26 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR \
27 | --batch_size 64 --num_workers 1 --num_rules $NRULES --num_terms 26 --seed $SEED --run_name $RUN \
28 | --exp_dir $EXP_DIR --partition _shuffle --mode test --evaluate-rule --orientation-confounder $N_CONFOUNDERS --entropy
29 | done
30 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
31 | done
32 |
--------------------------------------------------------------------------------
/experiments/arlc_progr.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python main.py --vsa_conversion --vsa_selection --shared_rules --config center_single --data_dir data --batch_size 256 --num_workers 1 --mode test --program
4 | python main.py --vsa_conversion --vsa_selection --shared_rules --config distribute_four --data_dir data --batch_size 256 --num_workers 1 --mode test --program
5 | python main.py --vsa_conversion --vsa_selection --shared_rules --config distribute_nine --data_dir data --batch_size 256 --num_workers 1 --mode test --program
6 | python main.py --vsa_conversion --vsa_selection --shared_rules --config left_right --data_dir data --batch_size 256 --num_workers 1 --mode test --program
7 | python main.py --vsa_conversion --vsa_selection --shared_rules --config up_down --data_dir data --batch_size 256 --num_workers 1 --mode test --program
8 | python main.py --vsa_conversion --vsa_selection --shared_rules --config in_out_single --data_dir data --batch_size 256 --num_workers 1 --mode test --program
9 | python main.py --vsa_conversion --vsa_selection --shared_rules --config in_out_four --data_dir data --batch_size 256 --num_workers 1 --mode test --program
10 |
--------------------------------------------------------------------------------
/arlc/utils/checkpath.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import os, sys, time, torch, shutil
7 |
8 |
9 | def check_paths(args):
10 | try:
11 | if not os.path.exists(args.save_dir):
12 | os.makedirs(args.save_dir)
13 | if not os.path.exists(args.log_dir):
14 | os.makedirs(args.log_dir)
15 | new_log_dir = os.path.join(args.log_dir, time.ctime().replace(" ", "-"))
16 | args.log_dir = new_log_dir
17 | if not os.path.exists(args.log_dir):
18 | os.makedirs(args.log_dir)
19 | if not os.path.exists(args.checkpoint_dir):
20 | os.makedirs(args.checkpoint_dir)
21 | except OSError as e:
22 | print(e)
23 | sys.exit(1)
24 |
25 |
26 | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar", savedir=""):
27 | save_name = os.path.join(savedir, filename)
28 | torch.save(state, save_name)
29 | if is_best:
30 | save_name = os.path.join(savedir, "model_best.pth.tar")
31 | shutil.copyfile(os.path.join(savedir, filename), save_name)
32 |
--------------------------------------------------------------------------------
/results.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import numpy as np
7 | from collections import defaultdict
8 | import os
9 | from arlc.utils.parsing import eval_parse_args
10 | import json
11 |
12 |
13 | def main():
14 | print("\n")
15 | args = eval_parse_args()
16 | res = defaultdict(list)
17 | for i in range(1, args.seeds + 1):
18 | with open(os.path.join(args.path, f"{i}/ckpt/eval.json")) as f:
19 | dat = json.load(f)
20 | for k, v in dat.items():
21 | res[k].append(v)
22 | for k, v in res.items():
23 | print(f"{k}\t\t{np.mean(v)} ({np.std(v)})")
24 | print(f"{k}\t\t{np.min(v)} {np.max(v)}")
25 | print(f"{k}\t\t{v}")
26 |
27 | mean = np.mean(sum(res.values(), []))
28 | std = np.mean([np.std(x) for x in res.values()])
29 | print("\nLaTex table entry:")
30 | print(
31 | " & ".join([f"${np.mean(v):.1f}^{{\pm{np.std(v):.1f}}}$" for v in res.values()])
32 | + f" & ${mean:.1f}^{{\pm{std:.1f})}}$"
33 | )
34 | print("\n")
35 |
36 |
37 | if __name__ == "__main__":
38 | main()
39 |
--------------------------------------------------------------------------------
/experiments/uncertainty/eval_noisy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | RUN=$1
4 | SIGMA=$2
5 |
6 | CONFIG="center_single"
7 | EXP_DIR="models"
8 | EPOCHS=10
9 | NTEST=5
10 | NRULES=5
11 | DATA="iravenx"
12 | N_CONFOUNDERS=0
13 |
14 | if [ "$DATA" = "iraven" ]; then
15 | DATA_DIR="data/I-RAVEN"
16 | else
17 | DATA_DIR="data/I-RAVEN-X"
18 | fi
19 |
20 | echo "****************** EVAL 100"
21 | for SEED in $(seq 1 $NTEST);
22 | do
23 | python main.py --dyn_range 100 --n 10 --resume models/$RUN/$SEED/ckpt \
24 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR \
25 | --batch_size 128 --num_workers 1 --num_rules $NRULES --num_terms 26 --seed $SEED --run_name $RUN \
26 | --exp_dir $EXP_DIR --partition _shuffle --mode test --evaluate-rule --entropy --sigma $SIGMA
27 | done
28 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
29 |
30 |
31 | echo "****************** EVAL 1000"
32 | for SEED in $(seq 1 $NTEST);
33 | do
34 | python main.py --dyn_range 1000 --n 10 --resume models/$RUN/$SEED/ckpt \
35 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR \
36 | --batch_size 128 --num_workers 1 --num_rules $NRULES --num_terms 26 --seed $SEED --run_name $RUN \
37 | --exp_dir $EXP_DIR --partition _shuffle --mode test --evaluate-rule --entropy --sigma $SIGMA
38 | done
39 | python results.py --path $EXP_DIR/$RUN --seeds $NTEST
40 |
--------------------------------------------------------------------------------
/experiments/arlc_ood.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Type' --gen_rule 'Constant'
4 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Type' --gen_rule 'Progression'
5 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Type' --gen_rule 'Distribute_Three'
6 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Size' --gen_rule 'Constant'
7 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Size' --gen_rule 'Progression'
8 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Size' --gen_rule 'Distribute_Three'
9 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Size' --gen_rule 'Arithmetic'
10 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Color' --gen_rule 'Constant'
11 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Color' --gen_rule 'Progression'
12 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Color' --gen_rule 'Distribute_Three'
13 | python main.py --data_dir data --vsa_conversion --vsa_selection --shared_rules --gen_attribute 'Color' --gen_rule 'Arithmetic'
14 |
--------------------------------------------------------------------------------
/arlc/utils/averagemeter.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 |
7 | class ProgressMeter(object):
8 | def __init__(self, num_batches, meters, prefix=""):
9 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
10 | self.meters = meters
11 | self.prefix = prefix
12 |
13 | def display(self, batch):
14 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
15 | entries += [str(meter) for meter in self.meters]
16 | print("\t".join(entries), flush=True)
17 |
18 | def _get_batch_fmtstr(self, num_batches):
19 | num_digits = len(str(num_batches // 1))
20 | fmt = "{:" + str(num_digits) + "d}"
21 | return "[" + fmt + "/" + fmt.format(num_batches) + "]"
22 |
23 |
24 | class AverageMeter(object): # Computes and stores the average and current value
25 | def __init__(self, name, fmt=":f"):
26 | self.name = name
27 | self.fmt = fmt
28 | self.reset()
29 |
30 | def reset(self):
31 | self.val = 0
32 | self.avg = 0
33 | self.sum = 0
34 | self.count = 0
35 |
36 | def update(self, val, n=1):
37 | if (self.count + n) != 0:
38 | self.val = val
39 | self.sum += val * n
40 | self.count += n
41 | self.avg = self.sum / self.count
42 |
43 | def __str__(self):
44 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
45 | return fmtstr.format(**self.__dict__)
46 |
--------------------------------------------------------------------------------
/arlc/utils/general.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import torch
7 |
8 |
9 | iravenx_rule_map = {
10 | "Constant": 0,
11 | "Progression": 1,
12 | "Arithmetic": 2,
13 | "Distribute_Three": 3,
14 | }
15 | iravenx_index_map = {
16 | 0: "Constant",
17 | 1: "Progression",
18 | 2: "Arithmetic",
19 | 3: "Distribute_Three",
20 | }
21 |
22 |
23 | LOG_EPSILON = 1e-39
24 | NORM_SCALE = 1e15
25 |
26 |
27 | def normalize(unnorm_prob, dim=-1):
28 | unnorm_prob = unnorm_prob * NORM_SCALE
29 | sum_dim = torch.sum(unnorm_prob, dim=dim, keepdim=True)
30 | norm_prob = unnorm_prob / sum_dim
31 | return norm_prob, sum_dim
32 |
33 |
34 | def to_n_bit_string(n, number):
35 | format_string = "{" + "0:0{}b".format(n) + "}"
36 | return format_string.format(number)
37 |
38 |
39 | def left_rotate(number, steps, num_bits):
40 | offset = steps % num_bits
41 | index = ((number << offset) | (number >> (num_bits - offset))) & (2**num_bits - 1)
42 | return index
43 |
44 |
45 | def right_rotate(number, steps, num_bits):
46 | offset = steps % num_bits
47 | index = ((number >> offset) | (number << (num_bits - offset))) & (2**num_bits - 1)
48 | return index
49 |
50 |
51 | def count_1(n):
52 | return bin(n).count("1")
53 |
54 |
55 | def sample_action(prob, sample=True):
56 | if sample:
57 | temp = torch.ones_like(prob) * 10 ** (-7)
58 | prob = torch.where(prob < 0, temp, prob)
59 | action = torch.distributions.Categorical(prob).sample()
60 | else:
61 | action = torch.argmax(prob, dim=-1)
62 | logprob = torch.log(torch.gather(prob, -1, action.unsqueeze(-1))).squeeze(-1)
63 | return action, logprob
64 |
65 |
66 | def log(x):
67 | return torch.log(x + LOG_EPSILON)
68 |
--------------------------------------------------------------------------------
/experiments/iravenx/arlc_program_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # All rules
4 | python main.py --epochs 10 --dyn_range 50 --vsa_conversion --vsa_selection --shared_rules --config "center_single" --dataset "iravenx" --data_dir "/dccstor/saentis/data/I-RAVEN-X/shuffle" --batch_size 128 --num_workers 1 --num_rules 5 --num_terms 22 --mode test --program --seed 0 --run_name "iravenx_50_program" --exp_dir "debug" --annealing 8 --partition Arithmetic_shuffle
5 | python main.py --epochs 10 --dyn_range 100 --vsa_conversion --vsa_selection --shared_rules --config "center_single" --dataset "iravenx" --data_dir "/dccstor/saentis/data/I-RAVEN-X/shuffle" --batch_size 128 --num_workers 1 --num_rules 5 --num_terms 22 --mode test --program --seed 0 --run_name "iravenx_50_program" --exp_dir "debug" --annealing 8 --partition Arithmetic_shuffle
6 | python main.py --epochs 10 --dyn_range 1000 --vsa_conversion --vsa_selection --shared_rules --config "center_single" --dataset "iravenx" --data_dir "/dccstor/saentis/data/I-RAVEN-X/shuffle" --batch_size 128 --num_workers 1 --num_rules 5 --num_terms 22 --mode test --program --seed 0 --run_name "iravenx_50_program" --exp_dir "debug" --annealing 8 --partition Arithmetic_shuffle
7 |
8 | # Arithmetic rules
9 | python main.py --epochs 10 --dyn_range 50 --vsa_conversion --vsa_selection --shared_rules --config "center_single" --dataset "iravenx" --data_dir "/dccstor/saentis/data/I-RAVEN-X/shuffle" --batch_size 128 --num_workers 1 --num_rules 5 --num_terms 22 --mode test --program --seed 0 --run_name "iravenx_50_program" --exp_dir "debug" --annealing 8 --partition Arithmetic_shuffle
10 | python main.py --epochs 10 --dyn_range 100 --vsa_conversion --vsa_selection --shared_rules --config "center_single" --dataset "iravenx" --data_dir "/dccstor/saentis/data/I-RAVEN-X/shuffle" --batch_size 128 --num_workers 1 --num_rules 5 --num_terms 22 --mode test --program --seed 0 --run_name "iravenx_50_program" --exp_dir "debug" --annealing 8 --partition Arithmetic_shuffle
11 | python main.py --epochs 10 --dyn_range 1000 --vsa_conversion --vsa_selection --shared_rules --config "center_single" --dataset "iravenx" --data_dir "/dccstor/saentis/data/I-RAVEN-X/shuffle" --batch_size 128 --num_workers 1 --num_rules 5 --num_terms 22 --mode test --program --seed 0 --run_name "iravenx_50_program" --exp_dir "debug" --annealing 8 --partition Arithmetic_shuffle
12 |
--------------------------------------------------------------------------------
/arlc/losses.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class CosineLoss(nn.Module):
12 | def __init__(self):
13 | super(CosineLoss, self).__init__()
14 |
15 | def loss(self, output, target):
16 | loss = 1 - F.cosine_similarity(output, target, dim=-1)
17 | return loss
18 |
19 | def forward(self, output, target):
20 | loss = self.loss(output, target)
21 | loss = loss.mean(dim=-1)
22 | return loss
23 |
24 | def score(self, output, targets):
25 | losses = self.loss(output, targets)
26 | score = -losses
27 | return score
28 |
29 |
30 | class CrossEntropyLoss(nn.Module):
31 | def __init__(self):
32 | super(CrossEntropyLoss, self).__init__()
33 |
34 | def loss(self, output, target):
35 | loss = torch.sum(-target * torch.log(output), dim=-1)
36 | return loss
37 |
38 | def forward(self, output, target):
39 | loss = self.loss(output, target)
40 | loss = loss.mean(dim=-1)
41 | return loss
42 |
43 | def score(self, output, targets):
44 | losses = self.loss(output, targets)
45 | score = -losses
46 | return score
47 |
48 |
49 | class KLDivLoss(nn.Module):
50 | def __init__(self):
51 | super(KLDivLoss, self).__init__()
52 |
53 | def loss(self, output, target):
54 | output_normalized = output / output.sum(dim=-1, keepdim=True)
55 | target_normalized = target / target.sum(dim=-1, keepdim=True)
56 | epsilon = 1e-10
57 | loss = torch.sum(
58 | output_normalized
59 | * (
60 | torch.log(output_normalized + epsilon)
61 | - torch.log(target_normalized + epsilon)
62 | ),
63 | dim=-1,
64 | )
65 | return loss
66 |
67 | def forward(self, output, target):
68 | loss = self.loss(output, target)
69 | loss = loss.mean(dim=-1)
70 | return loss
71 |
72 | def score(self, output, targets):
73 | losses = self.loss(output, targets)
74 | scores = -losses
75 | return scores
76 |
--------------------------------------------------------------------------------
/experiments/iravenx/arlc_learn_iravenx_50.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG="center_single"
4 | RUN="iravenx_50"
5 | EXP_DIR="models"
6 | EPOCHS=10
7 | NTEST=5
8 | NRULES=5
9 | DATA="iravenx"
10 |
11 | if [ "$DATA" = "iraven" ]; then
12 | DATA_DIR="path_top_standard_raven"
13 | else
14 | DATA_DIR="path_to_iravenx"
15 | fi
16 |
17 | # Train the model
18 | python main.py --epochs $EPOCHS --dyn_range 50 --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR --annealing $EPOCHS --batch_size 256 --num_workers 1 --num_rules $NRULES --num_terms 22 --seed 2 --run_name $RUN --exp_dir $EXP_DIR --partition _shuffle
19 |
20 | # Eval on unseen dynamic ranges
21 | python main.py --epochs $EPOCHS --dyn_range 100 --mode test --resume models/iravenx_50/2/ckpt --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR --annealing $EPOCHS --batch_size 256 --num_workers 1 --num_rules $NRULES --num_terms 22 --seed 0 --run_name $RUN --exp_dir $EXP_DIR --partition _shuffle
22 |
23 | python main.py --epochs $EPOCHS --dyn_range 1000 --mode test --resume models/iravenx_50/2/ckpt \
24 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR --annealing $EPOCHS \
25 | --batch_size 256 --num_workers 1 --num_rules $NRULES --num_terms 22 --seed 0 --run_name $RUN --exp_dir $EXP_DIR --partition _shuffle
26 |
27 |
28 |
29 | # Test arithmetic accuracies
30 | python main.py --epochs $EPOCHS --dyn_range 50 --mode test --resume models/iravenx_50/2/ckpt \
31 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR --annealing $EPOCHS \
32 | --batch_size 256 --num_workers 1 --num_rules $NRULES --num_terms 22 --seed 0 --run_name $RUN --exp_dir $EXP_DIR --partition Arithmetic_shuffle
33 |
34 | python main.py --epochs $EPOCHS --dyn_range 100 --mode test --resume models/iravenx_50/2/ckpt \
35 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR --annealing $EPOCHS \
36 | --batch_size 256 --num_workers 1 --num_rules $NRULES --num_terms 22 --seed 0 --run_name $RUN --exp_dir $EXP_DIR --partition Arithmetic_shuffle
37 |
38 | python main.py --epochs $EPOCHS --dyn_range 1000 --mode test --resume models/iravenx_50/2/ckpt \
39 | --vsa_conversion --vsa_selection --shared_rules --config $CONFIG --dataset $DATA --data_dir $DATA_DIR --annealing $EPOCHS \
40 | --batch_size 256 --num_workers 1 --num_rules $NRULES --num_terms 22 --seed 0 --run_name $RUN --exp_dir $EXP_DIR --partition Arithmetic_shuffle
--------------------------------------------------------------------------------
/arlc/selection.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | from collections import namedtuple
7 | import torch as t
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import numpy as np
11 |
12 |
13 | class RuleSelector(nn.Module):
14 | def __init__(self, loss_fn, temperature, rule_selector="sample"):
15 | super(RuleSelector, self).__init__()
16 | self.loss_fn = loss_fn
17 | self.temperature = temperature
18 | self.train_mode = True
19 |
20 | def train(self):
21 | self.train_mode = True
22 |
23 | def eval(self):
24 | self.train_mode = False
25 |
26 | def attribute_forward(self, outputs, tests, candidates=None, targets=None):
27 | if self.train_mode:
28 | tests = (
29 | t.cat(
30 | (
31 | tests,
32 | candidates[t.arange(candidates.shape[0]), targets].unsqueeze(1),
33 | ),
34 | dim=1,
35 | )
36 | .unsqueeze(1)
37 | .expand(-1, outputs.shape[1], -1, -1)
38 | )
39 | scores = self.loss_fn.score(outputs, tests).mean(dim=-1)
40 | weights = F.softmax(scores / self.temperature, dim=-1)
41 | else:
42 | tests = tests.unsqueeze(1).expand(-1, outputs.shape[1], -1, -1)
43 | scores = self.loss_fn.score(outputs[:, :, :2], tests).mean(dim=-1)
44 | weights = F.softmax(scores / self.temperature, dim=-1)
45 | outputs = t.einsum("ijkh,ij->ikh", outputs, weights)
46 | return outputs, weights
47 |
48 | def _entropy(self, dist):
49 | dist = dist.detach().cpu().numpy()
50 | entropy = -(dist * np.log(dist))
51 | return entropy[~np.isnan(entropy)].sum() / dist.shape[0]
52 |
53 | def forward(self, outputs, tests, candidates=None, targets=None, use_position=True):
54 | res = {}
55 | rules = {}
56 | weights = {}
57 | for attr in outputs._fields:
58 | if attr in ["position", "number"] and (
59 | not use_position or outputs.position is None
60 | ):
61 | res[attr] = rules[attr] = None
62 | continue
63 | res[attr], weights[attr] = self.attribute_forward(
64 | getattr(outputs, attr),
65 | getattr(tests, attr),
66 | getattr(candidates, attr, None),
67 | targets,
68 | )
69 | res = type(outputs)(**res)
70 | entropy_attr = {k: self._entropy(v) for k, v in weights.items()}
71 | return res, entropy_attr
72 |
--------------------------------------------------------------------------------
/arlc/utils/vsa.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | from collections import namedtuple
7 | from nvsa.reasoning.vsa_block_utils import (
8 | pmf2vec,
9 | binding_circular,
10 | block_discrete_codebook,
11 | )
12 | from arlc.utils.const import (
13 | DIM_POSITION_2x2,
14 | DIM_POSITION_3x3,
15 | DIM_NUMBER_2x2,
16 | DIM_NUMBER_3x3,
17 | DIM_ONEHOT,
18 | )
19 | import torch.nn as nn
20 | from nvsa.reasoning.vsa_block_utils import (
21 | block_discrete_codebook,
22 | block_continuous_codebook,
23 | )
24 |
25 |
26 | def generate_nvsa_codebooks(args, rng):
27 | """
28 | Generate the codebooks for NVSA frontend and backend.
29 | The codebook can also be loaded if it is stored under args.resume/
30 | """
31 | backend_cb_cont, _ = block_continuous_codebook(
32 | device=args.device,
33 | scene_dim=1024,
34 | d=args.nvsa_backend_d,
35 | k=args.nvsa_backend_k,
36 | rng=rng,
37 | fully_orthogonal=False,
38 | )
39 | backend_cb_discrete, _ = block_discrete_codebook(
40 | device=args.device, d=args.nvsa_backend_d, k=args.nvsa_backend_k, rng=rng
41 | )
42 | return backend_cb_cont, backend_cb_discrete
43 |
44 |
45 | class VSAConverter(nn.Module):
46 | def __init__(
47 | self,
48 | device,
49 | constellation,
50 | dictionary,
51 | dictionary_type="Discrete",
52 | context_dim=8,
53 | attributes_superposition=False,
54 | ):
55 | super(VSAConverter, self).__init__()
56 | self.device = device
57 | self.constellation = constellation
58 | self.d = dictionary.shape[1] * dictionary.shape[2]
59 | self.k = dictionary.shape[1]
60 | self.dictionary = dictionary
61 | self.dictionary_type = dictionary_type
62 | self.compute_attribute_dicts()
63 | self.context_dim = context_dim
64 | self.attributes_superposition = attributes_superposition
65 | if self.attributes_superposition:
66 | attribute_keys, _ = block_discrete_codebook(
67 | device=device, d=self.d, k=self.k, scene_dim=5
68 | )
69 | self.attribute_keys = nn.Parameter(attribute_keys)
70 |
71 | def compute_attribute_dicts(self):
72 | if "distribute" in self.constellation or "in_out_four" == self.constellation:
73 | if "four" in self.constellation:
74 | DIM_POSITION = DIM_POSITION_2x2
75 | DIM_NUMBER = DIM_NUMBER_2x2
76 | else:
77 | DIM_POSITION = DIM_POSITION_3x3
78 | DIM_NUMBER = DIM_NUMBER_3x3
79 | self.position_dictionary = self.dictionary[:DIM_POSITION]
80 | self.number_dictionary = self.dictionary[:DIM_NUMBER]
81 |
82 | def compute_values(self, scene_prob):
83 | vsas = {}
84 | for attr in scene_prob._fields:
85 | if attr == "position" and (
86 | "distribute" in self.constellation
87 | or "in_out_four" == self.constellation
88 | ):
89 | vsas[attr] = pmf2vec(self.position_dictionary, scene_prob.position)
90 | elif attr == "number" and (
91 | "distribute" in self.constellation
92 | or "in_out_four" == self.constellation
93 | ):
94 | vsas[attr] = pmf2vec(self.number_dictionary, scene_prob.number)
95 | elif attr in ["position", "number"]:
96 | vsas[attr] = None
97 | else:
98 | vsas[attr] = pmf2vec(
99 | self.dictionary[: DIM_ONEHOT + 1], getattr(scene_prob, attr)
100 | )
101 | return type(scene_prob)(**vsas)
102 |
103 | def forward(self, scene_prob):
104 | return self.compute_values(scene_prob)
105 |
--------------------------------------------------------------------------------
/arlc/utils/raven/scene.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | from collections import OrderedDict, namedtuple
7 | from itertools import product
8 | import numpy as np
9 | import torch
10 | import arlc.utils.general as utils
11 |
12 |
13 | class SceneEngine:
14 | def __init__(self, number_slots, device):
15 | self.device = device
16 | self.num_slots = number_slots
17 | self.positions = list(product(range(2), repeat=self.num_slots))
18 | # assume nonempty
19 | start_index = 1
20 | position2number = np.sum(self.positions[start_index:], axis=1)
21 | # note the correspondence of positions: first digit from the left corresponds to part one
22 | self.positions = torch.tensor(self.positions[start_index:], dtype=torch.int).to(
23 | self.device
24 | )
25 | self.dim_position = self.positions.shape[0]
26 | self.num_pos_index_map = OrderedDict()
27 | for i in range(start_index, self.num_slots + 1):
28 | self.num_pos_index_map[i] = torch.tensor(
29 | list(
30 | filter(
31 | lambda idx: position2number[idx] == i,
32 | range(len(position2number)),
33 | )
34 | ),
35 | dtype=torch.long,
36 | ).to(self.device)
37 |
38 | def compute_scene_prob(self, **attribute_logprobs):
39 | position_prob, position_logprob = self.compute_position_prob(
40 | attribute_logprobs.pop("exist")
41 | )
42 | number_prob, number_logprob = self.compute_number_prob(position_prob)
43 | SceneProb = namedtuple(
44 | "SceneProb", ["position", "number"] + [k for k in attribute_logprobs.keys()]
45 | )
46 | SceneLogProb = namedtuple(
47 | "SceneLogProb",
48 | ["position", "number"] + [k for k in attribute_logprobs.keys()],
49 | )
50 | attr_probs = {
51 | k: self.compute_attribute_prob(v, position_logprob)
52 | for k, v in attribute_logprobs.items()
53 | }
54 | att_logprobs = {k: utils.log(v) for k, v in attr_probs.items()}
55 | return (
56 | SceneProb(position_prob, number_prob, **attr_probs),
57 | SceneLogProb(position_logprob, number_logprob, **att_logprobs),
58 | )
59 |
60 | def compute_position_prob(self, exist_logprob):
61 | batch_size = exist_logprob.shape[0]
62 | num_panels = exist_logprob.shape[1]
63 | exist_logprob = exist_logprob.unsqueeze(2).expand(
64 | -1, -1, self.dim_position, -1, -1
65 | )
66 | index = (
67 | self.positions.unsqueeze(0)
68 | .unsqueeze(0)
69 | .expand(batch_size, num_panels, -1, -1)
70 | .unsqueeze(-1)
71 | .type(torch.long)
72 | )
73 | position_logprob = torch.gather(
74 | exist_logprob, -1, index
75 | ) # (batch_size, num_panels, self.dim_position, slots, 1)
76 | position_logprob = torch.sum(
77 | position_logprob.squeeze(-1), dim=-1
78 | ) # (batch_size, num_panels, self.dim_position)
79 | position_prob = torch.exp(position_logprob)
80 | # assume nonempty: all zero state is filtered out
81 | position_prob = utils.normalize(position_prob)[0]
82 | position_logprob = utils.log(position_prob)
83 | return position_prob, position_logprob
84 |
85 | def compute_number_prob(self, position_prob):
86 | all_num_prob = []
87 | for _, indices in self.num_pos_index_map.items():
88 | num_prob = torch.sum(position_prob[:, :, indices], dim=-1, keepdim=True)
89 | all_num_prob.append(num_prob)
90 | number_prob = torch.cat(all_num_prob, dim=-1)
91 | return number_prob, utils.log(number_prob)
92 |
93 | def compute_attribute_prob(self, logprob, position_logprob):
94 | batch_size = logprob.shape[0]
95 | num_panels = logprob.shape[1]
96 | index = (
97 | self.positions.unsqueeze(0)
98 | .unsqueeze(0)
99 | .expand(batch_size, num_panels, -1, -1)
100 | .unsqueeze(-1)
101 | .type(torch.float)
102 | )
103 | logprob = logprob.unsqueeze(2).expand(-1, -1, self.dim_position, -1, -1)
104 | logprob = (
105 | index * logprob
106 | ) # (batch_size, num_panels, self.dim_position, slots, DIM_TYPE)
107 | logprob = torch.sum(logprob, dim=3) + position_logprob.unsqueeze(-1)
108 | prob = torch.exp(logprob)
109 | prob = torch.sum(prob, dim=2)
110 | inconsist_prob = 1.0 - torch.clamp(
111 | torch.sum(prob, dim=-1, keepdim=True), max=1.0
112 | ) # clamp for numerical stability
113 | prob = torch.cat([prob, inconsist_prob], dim=-1)
114 | return torch.nan_to_num(prob, nan=0.0)
115 |
--------------------------------------------------------------------------------
/arlc/datasets/iraven.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import os
7 | import numpy as np
8 | import torch as t
9 | from torch.utils.data import Dataset
10 | import random
11 | import json
12 |
13 | rule_map = {"Constant": 0, "Progression": 1, "Arithmetic": 2, "Distribute_Three": 3}
14 |
15 |
16 | class GeneralIRAVENDataset(Dataset):
17 | def __init__(
18 | self,
19 | dataset_type,
20 | data_dir,
21 | constellation_filter,
22 | rule_filter="",
23 | attribute_filter="",
24 | n_train=None,
25 | in_memory=False,
26 | partition="",
27 | n=10,
28 | n_show=3,
29 | maxval=1000,
30 | n_confounders=0,
31 | ):
32 |
33 | self.n = n
34 | self.n_show = n_show
35 | self.n_tot = n_show * n - 1 + 8
36 | self.n_confounders = n_confounders
37 | self.maxval = maxval
38 |
39 | if dataset_type == "train":
40 | self.filtered_indeces = np.arange(6000)
41 | elif dataset_type == "val":
42 | self.filtered_indeces = np.arange(6000, 8000)
43 | elif dataset_type == "test":
44 | self.filtered_indeces = np.arange(8000, 10000)
45 |
46 | if rule_filter != "" or attribute_filter != "":
47 | raise ValueError("Rule filtering not implemented")
48 |
49 | if n_train:
50 | self.filtered_indeces = self.filtered_indeces[:n_train]
51 |
52 | self.old_raven = not "I-RAVEN-" in data_dir
53 | data_file = (
54 | f"{constellation_filter}{partition}_n_{n}_maxval_{maxval}.json"
55 | )
56 | print(f"Number of confounders: {n_confounders}")
57 | self.constellation = constellation_filter
58 | # load entire dataset from
59 | with open(os.path.join(data_dir, data_file), "r") as f:
60 | self.dataset = json.load(f)
61 |
62 | def __len__(self):
63 | return len(self.filtered_indeces)
64 |
65 | def _get_panel_number(self, x, y):
66 | if not (0 <= x <= 1 and 0 <= y <= 1):
67 | raise ValueError("Point is outside the 1x1 box")
68 | if self.constellation == "distribute_nine":
69 | div = 1 / 3
70 | ppr = 3
71 | elif self.constellation == "distribute_four":
72 | div = 1 / 2
73 | ppr = 2
74 | col = int(x / div)
75 | row = int((1 - y) / div)
76 | panel_number = row * ppr + col
77 | return panel_number
78 |
79 | def __getitem__(self, index):
80 | valid_index = self.filtered_indeces[index % len(self.filtered_indeces)]
81 | data = self.dataset[str(valid_index)]
82 | # dimension panel, slots, attributes
83 | input_tensor = t.ones((self.n_tot, 9, 5 + self.n_confounders)).float() * (-1)
84 | for i in range(self.n_tot):
85 | panels = data["rpm"][i + (self.n - self.n_show) * self.n][0]
86 | if self.constellation == "center_single":
87 | input_tensor[:, 0, 0] = 0 # Fix position in center constellation
88 | panel = data["rpm"][i + (self.n - self.n_show) * self.n][0]
89 | input_tensor[i, 0, 2] = int(panel["Color"])
90 | input_tensor[i, 0, 3] = int(panel["Size"]) + self.old_raven * 1
91 | input_tensor[i, 0, 4] = int(panel["Type"]) + self.old_raven * 2
92 | input_tensor[i, 0, 1] = int(panel["Angle"])
93 | for n in range(self.n_confounders):
94 | input_tensor[i, 0, 5 + n] = int(panel[f"Confounder{n}"])
95 | else:
96 | for pidx, (pos, ent) in enumerate(
97 | zip(panels["positions"], panels["entities"])
98 | ):
99 | input_tensor[i, pidx, 2] = int(ent["Color"])
100 | input_tensor[i, pidx, 3] = int(ent["Size"])
101 | input_tensor[i, pidx, 4] = int(ent["Type"])
102 | input_tensor[i, pidx, 1] = int(ent["Angle"])
103 | input_tensor[i, pidx, 0] = int(self._get_panel_number(*pos[:2]))
104 | for n in range(self.n_confounders):
105 | input_tensor[i, 0, 5 + n] = random.randint(0, self.maxval)
106 |
107 | label_tensor = t.tensor(int(data["target"])).long()
108 | rules = data["rules"][0]
109 | if "Number/Position" in rules:
110 | num_pos = "Number/Position"
111 | elif "Number" in rules:
112 | num_pos = "Number"
113 | else:
114 | num_pos = "Position"
115 | pos_num_rule = t.tensor(np.array(rule_map[rules[num_pos]])).float()
116 | color_rule = t.tensor(np.array(rule_map[rules["Color"]])).float()
117 | size_rule = t.tensor(np.array(rule_map[rules["Size"]])).float()
118 | type_rule = t.tensor(np.array(rule_map[rules["Type"]])).float()
119 | rules_tensor = t.stack([pos_num_rule, color_rule, size_rule, type_rule])
120 | return input_tensor, label_tensor, rules_tensor
121 |
122 |
123 | if __name__ == "__main__":
124 | dataset = GeneralIRAVENDataset(
125 | "train",
126 | "/dccstor/saentis/data/I-RAVEN-X",
127 | "center_single",
128 | n=10,
129 | maxval=100,
130 | partition="_shuffle",
131 | n_confounders=10,
132 | )
133 | print(dataset.__getitem__(0))
134 |
--------------------------------------------------------------------------------
/arlc/utils/raven/env.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | from arlc.utils.raven.scene import SceneEngine
7 |
8 |
9 | def get_env(env_name, device, **kwargs):
10 | if env_name == "center_single":
11 | return CenterSingle(device, **kwargs)
12 | if env_name == "distribute_four":
13 | return DistributeFour(device, **kwargs)
14 | if env_name == "distribute_nine":
15 | return DistributeNine(device, **kwargs)
16 | if env_name == "in_center_single_out_center_single":
17 | return InCenterSingleOutCenterSingle(device, **kwargs)
18 | if env_name == "in_distribute_four_out_center_single":
19 | return InDistributeFourOutCenterSingle(device, **kwargs)
20 | if env_name == "left_center_single_right_center_single":
21 | return LeftCenterSingleRightCenterSingle(device, **kwargs)
22 | if env_name == "up_center_single_down_center_single":
23 | return UpCenterSingleDownCenterSingle(device, **kwargs)
24 | return None
25 |
26 |
27 | class GeneralEnv(object):
28 | def __init__(self, num_slots, device, **kwargs):
29 | self.num_slots = num_slots
30 | self.device = device
31 | self.scene_engine = SceneEngine(self.num_slots, device)
32 |
33 | def prepare(self, model_output):
34 | return self.scene_engine.compute_scene_prob(**model_output)
35 |
36 |
37 | class CenterSingle(GeneralEnv):
38 | def __init__(self, device, **kwargs):
39 | super(CenterSingle, self).__init__(1, device, **kwargs)
40 |
41 |
42 | class DistributeFour(GeneralEnv):
43 | def __init__(self, device, **kwargs):
44 | super(DistributeFour, self).__init__(4, device, **kwargs)
45 |
46 |
47 | class DistributeNine(GeneralEnv):
48 | def __init__(self, device, **kwargs):
49 | super(DistributeNine, self).__init__(9, device, **kwargs)
50 |
51 |
52 | class OutCenterSingle(GeneralEnv):
53 | def __init__(self, device, **kwargs):
54 | super(OutCenterSingle, self).__init__(1, device, **kwargs)
55 |
56 |
57 | class InCenterSingleOutCenterSingle(object):
58 | def __init__(self, device, **kwargs):
59 | self.in_center_single = CenterSingle(device, **kwargs)
60 | self.out_center_single = OutCenterSingle(device, **kwargs)
61 |
62 | def prepare(self, model_output):
63 | in_component = []
64 | out_component = []
65 | for element in model_output:
66 | in_component.append(element[:, :, 1:, :])
67 | out_component.append(element[:, :, :1, :])
68 | in_scene_prob, in_scene_logprob = self.in_center_single.prepare(in_component)
69 | out_scene_prob, out_scene_logprob = self.out_center_single.prepare(
70 | out_component
71 | )
72 | return (in_scene_prob, out_scene_prob), (in_scene_logprob, out_scene_logprob)
73 |
74 |
75 | class InDistributeFourOutCenterSingle(object):
76 | def __init__(self, device, **kwargs):
77 | self.in_distribute_four = DistributeFour(device, **kwargs)
78 | self.out_center_single = OutCenterSingle(device, **kwargs)
79 |
80 | def prepare(self, model_output):
81 | in_component = []
82 | out_component = []
83 | for element in model_output:
84 | in_component.append(element[:, :, 1:5, :])
85 | out_component.append(element[:, :, :1, :])
86 | in_scene_prob, in_scene_logprob = self.in_distribute_four.prepare(in_component)
87 | out_scene_prob, out_scene_logprob = self.out_center_single.prepare(
88 | out_component
89 | )
90 | return (in_scene_prob, out_scene_prob), (in_scene_logprob, out_scene_logprob)
91 |
92 |
93 | class LeftCenterSingleRightCenterSingle(object):
94 | def __init__(self, device, **kwargs):
95 | self.left_center_single = CenterSingle(device, **kwargs)
96 | self.right_center_single = CenterSingle(device, **kwargs)
97 |
98 | def prepare(self, model_output):
99 | left_component = []
100 | right_component = []
101 | for element in model_output:
102 | left_component.append(element[:, :, :1, :])
103 | right_component.append(element[:, :, 1:, :])
104 | left_scene_prob, left_scene_logprob = self.left_center_single.prepare(
105 | left_component
106 | )
107 | right_scene_prob, right_scene_logprob = self.right_center_single.prepare(
108 | right_component
109 | )
110 | return (left_scene_prob, right_scene_prob), (
111 | left_scene_logprob,
112 | right_scene_logprob,
113 | )
114 |
115 |
116 | class UpCenterSingleDownCenterSingle(object):
117 | def __init__(self, device, **kwargs):
118 | self.up_center_single = CenterSingle(device, **kwargs)
119 | self.down_center_single = CenterSingle(device, **kwargs)
120 |
121 | def prepare(self, model_output):
122 | up_component = []
123 | down_component = []
124 | for element in model_output:
125 | up_component.append(element[:, :, :1, :])
126 | down_component.append(element[:, :, 1:, :])
127 | up_scene_prob, up_scene_logprob = self.up_center_single.prepare(up_component)
128 | down_scene_prob, down_scene_logprob = self.down_center_single.prepare(
129 | down_component
130 | )
131 | return (up_scene_prob, down_scene_prob), (up_scene_logprob, down_scene_logprob)
132 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Abductive Rule Learner with Context Awareness (ARLC)
2 |
3 |
4 |

5 |
6 |
7 | This repo contains the code for the Abductive Rule Learner with Context Awareness (ARLC), a probabilistic abductive reasoner for solving Raven's progressive matrices (RPM).
8 | The repo features the code used to run experiments in three publications:
9 | - Giacomo Camposampiero, Michael Hersche, Aleksandar Terzić, Roger Wattenhofer, Abu Sebastian and Abbas Rahimi. *Towards Learning Abductive Reasoning using VSA Distributed Representations*. 18th International Conference on Neural-Symbolic Learning and Reasoning (NeSy) **[Spotlight]**, 2024. [[Paper]](http://arxiv.org/abs/2406.19121)
10 | - Michael Hersche, Giacomo Camposampiero, Roger Wattenhofer, Abu Sebastian and Abbas Rahimi. *Towards Learning to Reason: Comparing LLMs with Neuro-Symbolic on Arithmetic Relations in Abstract Reasoning*. Neural Reasoning and Mathematical Discovery (NEURMAD) @ AAAI, 2025. [[Paper]](https://arxiv.org/pdf/2412.05586)
11 | - Giacomo Camposampiero‡, Michael Hersche‡, Roger Wattenhofer, Abu Sebastian and Abbas Rahimi. *Can Large Reasoning Models do Analogical Reasoning under Perceptual Uncertainty?*. arXiv, 2025. [[Paper]]()
12 |
13 | ‡ these authors contributed equally.
14 |
15 | ## Build the Environment 🛠️
16 |
17 | #### Hardware
18 | You will need a machine with a CUDA-enabled GPU and the Nvidia SDK installed to compile the CUDA kernels. We tested our methods on an NVIDA Tesla V100 GPU with CUDA Version 11.3.1.
19 |
20 | #### Installing Dependencies
21 |
22 | The `mamba` software is required for running the code. You can create a new mamba environment using
23 |
24 | ```bash
25 | mamba create --name arlc python=3.7
26 | mamba activate arlc
27 | ```
28 |
29 | To install PyTorch 1.11 and CUDA, use
30 | ```bash
31 | mamba install pytorch==1.11.0 torchvision==0.12.0 cudatoolkit=11.3 -c pytorch -c conda-forge
32 | ```
33 |
34 | Clone and install the [neuro-vsa repo](https://github.com/IBM/neuro-vector-symbolic-architectures) (some of their utils are re-used in this project)
35 | ```bash
36 | git clone https://github.com/IBM/neuro-vector-symbolic-architectures.git
37 | cd neuro-vector-symbolic-architectures
38 | pip install -e . --no-dependencies
39 | ```
40 |
41 | Finally, clone and install this repo
42 | ```bash
43 | git clone https://github.com/IBM/abductive-rule-learner-with-context-awareness.git
44 | cd abductive-rule-learner-with-context-awareness
45 | pip install -r requirements.txt
46 | pip install -e .
47 | pre-commit install
48 | ```
49 |
50 | We suggest to format the code of the entire repository to improve its readability.
51 | To do so, please install and run `black`
52 | ```bash
53 | pip install black
54 | black abductive-rule-learner-with-context-awareness/
55 | ```
56 |
57 |
58 | #### I-RAVEN Dataset
59 | You can find the instructions to download and pre-process the data in the `data` folder.
60 |
61 |
62 | ## Run our Experiments 🔬
63 | You can replicate the main experiments shown in the paper with the following scripts
64 | ```bash
65 | # ARLC learned from data
66 | ./experiments/arlc_learn.sh
67 | # ARLC initialized with programming, then learned
68 | ./experiments/arlc_progr_to_learn.sh
69 | # ARLC programmed and evaluated
70 | ./experiments/arlc_progr.sh
71 | ```
72 |
73 | To replicate our ablations on the introduced contributions, run
74 | ```bash
75 | # line 1 ablation table
76 | # obtained with the code from https://github.com/IBM/learn-vector-symbolic-architectures-rule-formulations, modified to run with multiple random seeds
77 |
78 | # line 2 ablation table
79 | ./experiments/ablations/learnvrf_nopn_2x2.sh
80 |
81 | # line 3 ablation table
82 | ./experiments/ablations/context.sh
83 |
84 | # line 4 ablation table
85 | # same as ./experiments/arlc_learn.sh
86 | ```
87 |
88 | To replicate our OOD experiments, run
89 | ```bash
90 | ./experiments/arlc_ood.sh
91 | ```
92 |
93 | ## I-RAVEN-X dataset evaluation
94 | To replicate the results reported in _Towards Learning Abductive Reasoning using VSA Distributed Representations_ (Hersche et al., 2024) on the novel I-RAVEN-X dataset,
95 | run the experiments in the `experiments/iravenx` folder.
96 | - `arlc_learn_iravenx_50.sh` allows to train an ARLC model from scratch, and evaluate it on both the full I-RAVEN-X dataset and the subset of Arithmetic rules.
97 | - `arlc_program_eval.sh` allows to evaluate the programmed ARLC on both the full I-RAVEN-X dataset and the subset of Arithmetic rules.
98 |
99 | ## I-RAVEN-X with perceptual uncertainty dataset evaluation
100 | To replicate the results reported in *Can Large Reasoning Models do Analogical Reasoning under Perceptual Uncertainty?* (Camposampiero, Hersche et al., 2025) on the novel I-RAVEN-X dataset with perceptual uncertainty, run the experiments in the `experiments/uncertainty` folder.
101 |
102 | ## Citation 📚
103 | If you use the work released here for your research, please consider citing our paper:
104 | ```
105 | @inproceedings{camposampiero2024towards,
106 | title={Towards Learning Abductive Reasoning using VSA Distributed Representations},
107 | author={Camposampiero, Giacomo and Hersche, Michael and Terzi{\'c}, Aleksandar and Wattenhofer, Roger and Sebastian, Abu and Rahimi, Abbas},
108 | booktitle={18th International Conference on Neural-Symbolic Learning and Reasoning (NeSy)},
109 | year={2024},
110 | month={sep}
111 | }
112 | ```
113 |
114 |
115 | ## License 🔏
116 | Please refer to the LICENSE file for the licensing of our code. Our implementation relies on [PrAE](https://github.com/WellyZhang/PrAE) released under GPL v3.0 and [Learn-VRF](https://github.com/IBM/learn-vector-symbolic-architectures-rule-formulations) released under GPL v3.0, as well as on [In-Context Analgoical Reasoning with Pre-Trained Language Models](https://github.com/hxiaoyang/lm-raven) distributed under the MIT license.
117 |
--------------------------------------------------------------------------------
/arlc/utils/parsing.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import argparse
7 |
8 |
9 | def eval_parse_args():
10 | arg_parser = argparse.ArgumentParser()
11 | arg_parser.add_argument("--path", type=str)
12 | arg_parser.add_argument("--seeds", type=int)
13 | args = arg_parser.parse_args()
14 | return args
15 |
16 |
17 | def parse_args():
18 | arg_parser = argparse.ArgumentParser(
19 | description="NVSA lernable backend training and evaluation on RAVEN"
20 | )
21 | arg_parser.add_argument("--n", type=int)
22 | arg_parser.add_argument("--run_name", type=str)
23 | arg_parser.add_argument("--mode", type=str, default="train", help="Train/test")
24 | arg_parser.add_argument("--exp_dir", type=str, default="results/")
25 | arg_parser.add_argument("--dataset", type=str, default="iraven")
26 | arg_parser.add_argument("--data_dir", type=str, default="dataset/")
27 | arg_parser.add_argument("--dyn_range", type=int, default=-1)
28 | arg_parser.add_argument("--rule_type", type=str, default="arlc")
29 | arg_parser.add_argument("--num_terms", type=int, default=12)
30 | arg_parser.add_argument(
31 | "--resume", type=str, default="", help="Resume from a initialized model"
32 | )
33 | arg_parser.add_argument("--seed", type=int, default=1234, help="Random number seed")
34 | arg_parser.add_argument("--run", type=int, default=0, help="Run id")
35 |
36 | # Dataset
37 | arg_parser.add_argument(
38 | "--partition",
39 | type=str,
40 | default="",
41 | )
42 | arg_parser.add_argument(
43 | "--config",
44 | type=str,
45 | default="center_single",
46 | help="The configuration used for training",
47 | )
48 | arg_parser.add_argument(
49 | "--gen_attribute",
50 | type=str,
51 | default="",
52 | help="Generalization experiment [Type, Size, Color]",
53 | )
54 | arg_parser.add_argument(
55 | "--gen_rule",
56 | type=str,
57 | default="",
58 | help="Generalization experiment [Arithmetic, Constant, Progression, Distribute_Three]",
59 | )
60 | arg_parser.add_argument("--n-train", type=int, default=None)
61 |
62 | # Training hyperparameters
63 | arg_parser.add_argument(
64 | "--model",
65 | type=str,
66 | default="LearnableFormula",
67 | help="Model used in the reasoner (LearnableFormula, MLP)",
68 | )
69 | arg_parser.add_argument(
70 | "--epochs", type=int, default=50, help="The number of training epochs"
71 | )
72 | arg_parser.add_argument("--batch_size", type=int, default=4, help="Size of batch")
73 | arg_parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
74 | arg_parser.add_argument(
75 | "--weight-decay",
76 | type=float,
77 | default=0,
78 | help="Weight decay of optimizer, same as l2 reg",
79 | )
80 | arg_parser.add_argument(
81 | "--num_workers", type=int, default=8, help="Number of workers for data loader"
82 | )
83 | arg_parser.add_argument(
84 | "--clip",
85 | type=float,
86 | default=10,
87 | help="Max value/norm in gradient clipping (now l2 norm)",
88 | )
89 | arg_parser.add_argument(
90 | "--vsa_conversion",
91 | action="store_true",
92 | default=False,
93 | help="Use or not the VSA converter",
94 | )
95 | arg_parser.add_argument(
96 | "--vsa_selection",
97 | action="store_true",
98 | default=False,
99 | help="Use or not the VSA selector",
100 | )
101 | arg_parser.add_argument(
102 | "--context_superposition",
103 | action="store_true",
104 | default=False,
105 | help="Use or not the VSA selector",
106 | )
107 | arg_parser.add_argument(
108 | "--program",
109 | action="store_true",
110 | default=False,
111 | help="Program the model with golden weights",
112 | )
113 | arg_parser.add_argument("--evaluate-rule", action="store_true")
114 | arg_parser.add_argument(
115 | "--loss_fn", type=str, default="CosineLoss", help="Loss to use in the training"
116 | )
117 | arg_parser.add_argument(
118 | "--num_rules", type=int, default=5, help="Number of rules per each attribute"
119 | )
120 | arg_parser.add_argument("--annealing", type=int, default=-1)
121 | arg_parser.add_argument(
122 | "--rule_selector_temperature",
123 | type=float,
124 | default=0.01,
125 | help="Temperature used in the rule selector's softmax",
126 | )
127 | arg_parser.add_argument(
128 | "--rule_selector", type=str, default="weight", help="Can be sample or weight"
129 | )
130 | arg_parser.add_argument(
131 | "--shared_rules",
132 | action="store_true",
133 | default=False,
134 | help="Share the same rules across different attributes",
135 | )
136 | arg_parser.add_argument(
137 | "--hidden_layers",
138 | type=int,
139 | default=3,
140 | help="Number of hidden MLP layers to use in the neural model",
141 | )
142 |
143 | # NVSA backend settings
144 | arg_parser.add_argument(
145 | "--nvsa-backend-d", type=int, default=1024, help="VSA dimension in backend"
146 | )
147 | arg_parser.add_argument(
148 | "--nvsa-backend-k", type=int, default=4, help="Number of blocks in VSA vectors"
149 | )
150 | arg_parser.add_argument(
151 | "--orientation-confounder",
152 | type=int,
153 | default=0,
154 | )
155 | arg_parser.add_argument(
156 | "--entropy",
157 | action="store_true",
158 | default=False,
159 | )
160 | arg_parser.add_argument(
161 | "--sigma",
162 | type=float,
163 | default=0.1,
164 | )
165 | args = arg_parser.parse_args()
166 | return args
167 |
--------------------------------------------------------------------------------
/arlc/execution.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import torch as t
7 | import torch.nn as nn
8 | from arlc.rule_templates import (
9 | ExtendedGeneralLearnableFormula,
10 | IravenxGeneralLearnableFormula,
11 | IravenVGeneralLearnableFormula,
12 | )
13 | from arlc.utils.vsa import VSAConverter
14 |
15 |
16 | class RuleLevelReasoner(nn.Module):
17 | def __init__(
18 | self,
19 | device,
20 | constellation,
21 | model,
22 | hidden_layers,
23 | dictionary,
24 | vsa_conversion=False,
25 | vsa_selection=False,
26 | context_superposition=False,
27 | num_rules=5,
28 | shared_rules=False,
29 | program=False,
30 | rule_type="arlc",
31 | num_terms=12,
32 | n=3,
33 | ):
34 | super(RuleLevelReasoner, self).__init__()
35 | self.device = device
36 | self.constellation = constellation
37 | self.model = model
38 | self.program = program
39 | self.rule_type = rule_type
40 | self.num_terms = num_terms
41 | self.num_panels = n
42 | self.d = dictionary.shape[1] * dictionary.shape[2]
43 | self.k = dictionary.shape[1]
44 | self.vsa_conversion = vsa_conversion
45 | self.vsa_selection = vsa_selection
46 | self.context_superposition = context_superposition
47 | self.vsa_converter = VSAConverter(
48 | device, self.constellation, dictionary, dictionary_type="Continuous"
49 | )
50 | self.num_rules = num_rules
51 | self.rules_set = RulesSet(
52 | model=self.model,
53 | hidden_layers=hidden_layers,
54 | num_rules=self.num_rules,
55 | d_in=2 * self.d,
56 | d_out=-1,
57 | d_vsa=self.d,
58 | k=self.k,
59 | context_superpostion=self.context_superposition,
60 | context_keys=None,
61 | program=self.program,
62 | rule_type=self.rule_type,
63 | num_terms=self.num_terms,
64 | num_panels=self.num_panels,
65 | )
66 |
67 | def forward(self, scene_prob, targets=None, distribute=False):
68 | # convert logprob to VSAs
69 | scene_vsa = self.vsa_converter(scene_prob)
70 | # flatten scene
71 | scene = {}
72 | for attr in scene_vsa._fields:
73 | if attr in ["position", "number"] and not distribute:
74 | scene[attr] = None
75 | else:
76 | scene[attr] = t.flatten(
77 | getattr(scene_vsa, attr),
78 | start_dim=len(getattr(scene_vsa, attr).shape) - 2,
79 | )
80 | scene = type(scene_vsa)(**scene)
81 | # set indices for test panels
82 | if self.num_panels == 10:
83 | test_indeces = [9, 19]
84 | elif self.num_panels == 3:
85 | test_indeces = [2, 5]
86 | elif self.num_panels == 5:
87 | test_indeces = [4, 9]
88 | # compute output vectors
89 | output = dict()
90 | tests = dict()
91 | candidates = dict()
92 | for attr in scene._fields:
93 | if attr in ["position", "number"] and not distribute:
94 | tests[attr] = output[attr] = candidates[attr] = None
95 | else:
96 | tests[attr] = getattr(scene, attr)[:, test_indeces]
97 | output[attr] = self.rules_set(getattr(scene, attr))
98 | candidates[attr] = getattr(scene, attr)[:, -8:]
99 | # compile them in named tuples and return
100 | output = type(scene_vsa)(**output)
101 | tests = type(scene_vsa)(**tests)
102 | candidates = type(scene_vsa)(**candidates)
103 | return output, candidates, tests
104 |
105 | def anneal_softmax(self):
106 | for rule in self.rules_set.rules:
107 | rule.rule.anneal_softmax()
108 |
109 |
110 | class RulesSet(nn.Module):
111 | def __init__(
112 | self,
113 | model,
114 | hidden_layers,
115 | num_rules,
116 | d_in,
117 | d_out,
118 | d_vsa,
119 | k,
120 | context_superpostion=False,
121 | context_keys=None,
122 | program=None,
123 | rule_type="arlc",
124 | num_terms=12,
125 | num_panels=3,
126 | ):
127 | super(RulesSet, self).__init__()
128 | rule_class = GeneralRule
129 | if program:
130 | rules = ["add", "sub", "dist3", "progr"]
131 | else:
132 | rules = [None] * num_rules
133 | self.rules = nn.ModuleList(
134 | [
135 | rule_class(
136 | model,
137 | hidden_layers,
138 | d_in,
139 | d_out,
140 | d_vsa,
141 | k,
142 | context_superpostion,
143 | context_keys,
144 | program_rule,
145 | num_terms=num_terms,
146 | num_panels=num_panels,
147 | )
148 | for program_rule in rules
149 | ]
150 | )
151 |
152 | def forward(self, attribute):
153 | output_list = [
154 | rule(attribute).reshape((attribute.shape[0], 3, -1)) for rule in self.rules
155 | ]
156 | outputs = t.stack(output_list, dim=1)
157 | return outputs
158 |
159 |
160 | class GeneralRule(nn.Module):
161 | def __init__(
162 | self,
163 | model,
164 | hidden_layers,
165 | d_in,
166 | d_out,
167 | d_vsa,
168 | k,
169 | context_superposition=False,
170 | context_keys=None,
171 | program_rule=None,
172 | num_terms=12,
173 | num_panels=3,
174 | ):
175 | super(GeneralRule, self).__init__()
176 | self.d_in = d_in
177 | self.d_out = d_out
178 | self.d = d_vsa
179 | self.k = k
180 | self.context_superposition = context_superposition
181 | self.context_keys = context_keys
182 | if num_panels == 10:
183 | # I-RAVEN-X
184 | self.a3_indeces = range(0, 9)
185 | self.a6_indeces = range(10, 19)
186 | self.a9_indeces = range(20, 29)
187 | self.a3_context_indeces = [range(10, 20), range(20, 29)]
188 | self.a6_context_indeces = [range(0, 10), range(20, 29)]
189 | self.a9_context_indeces = [range(0, 10), range(10, 19)]
190 | self.rule = IravenxGeneralLearnableFormula(
191 | examples_len=9,
192 | context_len=19,
193 | k=self.k,
194 | num_terms=num_terms,
195 | program_rule=program_rule,
196 | )
197 | elif num_panels == 5:
198 | # I-RAVEN-V
199 | self.a3_indeces = range(0, 4)
200 | self.a6_indeces = range(5, 9)
201 | self.a9_indeces = range(10, 14)
202 | self.a3_context_indeces = [range(5, 10), range(10, 14)]
203 | self.a6_context_indeces = [range(0, 5), range(10, 14)]
204 | self.a9_context_indeces = [range(0, 5), range(5, 9)]
205 | self.rule = IravenVGeneralLearnableFormula(
206 | examples_len=4,
207 | context_len=9,
208 | k=self.k,
209 | num_terms=num_terms,
210 | program_rule=program_rule,
211 | )
212 | elif num_panels == 3 and num_terms == 12:
213 | # I-RAVEN
214 | self.a3_indeces = [0, 1]
215 | self.a6_indeces = [3, 4]
216 | self.a9_indeces = [6, 7]
217 | self.a3_context_indeces = [[3, 4, 5], [6, 7]]
218 | self.a6_context_indeces = [[0, 1, 2], [6, 7]]
219 | self.a9_context_indeces = [[0, 1, 2], [3, 4]]
220 | self.rule = ExtendedGeneralLearnableFormula(
221 | examples_len=2, context_len=5, k=self.k, program_rule=program_rule
222 | )
223 |
224 | def forward(self, x):
225 | a3 = self.rule(
226 | x=x[:, self.a3_indeces],
227 | ctx=t.cat([x[:, idx] for idx in self.a3_context_indeces], dim=1),
228 | )
229 | a6 = self.rule(
230 | x=x[:, self.a6_indeces],
231 | ctx=t.cat([x[:, idx] for idx in self.a6_context_indeces], dim=1),
232 | )
233 | a9 = self.rule(
234 | x=x[:, self.a9_indeces],
235 | ctx=t.cat([x[:, idx] for idx in self.a9_context_indeces], dim=1),
236 | )
237 | return t.cat((a3, a6, a9), dim=1)
238 |
--------------------------------------------------------------------------------
/arlc/utils/raven/raven_one_hot.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import torch
7 | from arlc.utils.const import DIM_ONEHOT
8 |
9 |
10 | def smooth_dist(q, sigma=0.1):
11 | if sigma < 0:
12 | return bin_dist(q, -sigma)
13 | elif sigma == 0:
14 | return q
15 | l = torch.argmax(q)
16 | temp = torch.arange(0, q.shape[0])
17 | temp = torch.div(torch.abs(temp - l), -2 * sigma)
18 | temp = torch.exp(temp)
19 | temp = temp / temp.sum()
20 | return temp
21 |
22 |
23 | def bin_dist(q, threshold):
24 | """Smooth a probability distribution by binning.
25 | In practice, the method takes a tensor representing a PMF and returns a tensor where
26 | the probability of the most likely value PT is sampled in [q, 1], and the probability of
27 | its two neighbours PN1 and PN2 are sampled in [0, 1-q] and 1-PT-PN1, respectively.
28 |
29 | Args:
30 | q (torch.Tensor): input PMF tensor
31 | """
32 | l = torch.argmax(q)
33 | q[l] = threshold + (1 - threshold) * torch.rand(1)
34 | right_index = (l + 1) % q.shape[0]
35 | left_index = (l - 1) % q.shape[0]
36 | q[right_index] = (1 - q[l]) * torch.rand(1)
37 | q[left_index] = 1 - q[l] - q[right_index]
38 | return q / q.sum()
39 |
40 |
41 | def create_one_hot(puzzle, panel_constellation, sigma=0.1):
42 | eps = 10 ** (-10)
43 | batch_size, num_panels, _, num_att = puzzle.shape
44 |
45 | if panel_constellation == "center_single":
46 | exist_prob = torch.ones((batch_size, num_panels, 1, 2)) * eps
47 | type_prob = torch.ones((batch_size, num_panels, 1, DIM_ONEHOT)) * eps
48 | size_prob = torch.ones((batch_size, num_panels, 1, DIM_ONEHOT)) * eps
49 | color_prob = torch.ones((batch_size, num_panels, 1, DIM_ONEHOT)) * eps
50 | angle_prob = torch.ones((batch_size, num_panels, 1, DIM_ONEHOT)) * eps
51 | confounders_prob = [
52 | torch.ones((batch_size, num_panels, 1, DIM_ONEHOT)) * eps
53 | for _ in range(num_att - 5)
54 | ]
55 | exist_prob[:, :, 0, 1] = 1
56 | for bs in range(batch_size):
57 | for i in range(num_panels):
58 | exist_prob[bs, i, 0] = smooth_dist(exist_prob[bs, i, 0], sigma)
59 | type_prob[bs, i, 0, int(puzzle[bs, i, 0, 4])] = 1
60 | type_prob[bs, i, 0] = smooth_dist(type_prob[bs, i, 0], sigma)
61 | size_prob[bs, i, 0, int(puzzle[bs, i, 0, 3])] = 1
62 | size_prob[bs, i, 0] = smooth_dist(size_prob[bs, i, 0], sigma)
63 | color_prob[bs, i, 0, int(puzzle[bs, i, 0, 2])] = 1
64 | color_prob[bs, i, 0] = smooth_dist(color_prob[bs, i, 0], sigma)
65 | angle_prob[bs, i, 0, int(puzzle[bs, i, 0, 1])] = 1
66 | angle_prob[bs, i, 0] = smooth_dist(angle_prob[bs, i, 0], sigma)
67 | for j in range(len(confounders_prob)):
68 | confounders_prob[j][bs, i, 0, int(puzzle[bs, i, 0, 5 + j])] = 1
69 | confounders_prob[j][bs, i, 0] = smooth_dist(
70 | confounders_prob[j][bs, i, 0], sigma
71 | )
72 | att_prob = {
73 | "exist": torch.log(exist_prob),
74 | "type": torch.log(type_prob),
75 | "size": torch.log(size_prob),
76 | "color": torch.log(color_prob),
77 | "angle": torch.log(angle_prob),
78 | }
79 | conf_prob = {
80 | f"confounder{i}": torch.log(confounders_prob[i])
81 | for i in range(len(confounders_prob))
82 | }
83 | return {**att_prob, **conf_prob}
84 |
85 | if panel_constellation == "distribute_four":
86 | exist_prob = torch.ones((batch_size, num_panels, 4, 2)) * eps
87 | type_prob = torch.ones((batch_size, num_panels, 4, DIM_ONEHOT)) * eps
88 | size_prob = torch.ones((batch_size, num_panels, 4, DIM_ONEHOT)) * eps
89 | color_prob = torch.ones((batch_size, num_panels, 4, DIM_ONEHOT)) * eps
90 | angle_prob = torch.ones((batch_size, num_panels, 4, DIM_ONEHOT)) * eps
91 |
92 | for bs in range(batch_size):
93 | for i in range(num_panels):
94 | temp = [0, 1, 2, 3]
95 | for j in range(4):
96 | if puzzle[bs, i, j, 0] == -1:
97 | k = temp[0]
98 | exist_prob[bs, i, k, 0] = 1
99 | exist_prob[bs, i, k] = smooth_dist(exist_prob[bs, i, k])
100 | temp.remove(k)
101 | else:
102 | k = int(puzzle[bs, i, j, 0])
103 | exist_prob[bs, i, k, 1] = 1
104 | exist_prob[bs, i, k] = smooth_dist(exist_prob[bs, i, k])
105 | type_prob[bs, i, k, int(puzzle[bs, i, j, 4])] = 1
106 | type_prob[bs, i, k] = smooth_dist(type_prob[bs, i, k])
107 | size_prob[bs, i, k, int(puzzle[bs, i, j, 3])] = 1
108 | size_prob[bs, i, k] = smooth_dist(size_prob[bs, i, k])
109 | color_prob[bs, i, k, int(puzzle[bs, i, j, 2])] = 1
110 | color_prob[bs, i, k] = smooth_dist(color_prob[bs, i, k])
111 | angle_prob[bs, i, k, int(puzzle[bs, i, j, 1])] = 1
112 | angle_prob[bs, i, k] = smooth_dist(angle_prob[bs, i, k])
113 | temp.remove(k)
114 | return {
115 | "exist": torch.log(exist_prob),
116 | "type": torch.log(type_prob),
117 | "size": torch.log(size_prob),
118 | "color": torch.log(color_prob),
119 | "angle": torch.log(angle_prob),
120 | }
121 |
122 | if panel_constellation == "distribute_nine":
123 | exist_prob = torch.ones((batch_size, num_panels, 9, 2)) * eps
124 | type_prob = torch.ones((batch_size, num_panels, 9, DIM_ONEHOT)) * eps
125 | size_prob = torch.ones((batch_size, num_panels, 9, DIM_ONEHOT)) * eps
126 | color_prob = torch.ones((batch_size, num_panels, 9, DIM_ONEHOT)) * eps
127 | angle_prob = torch.ones((batch_size, num_panels, 9, DIM_ONEHOT)) * eps
128 |
129 | for bs in range(batch_size):
130 | for i in range(num_panels):
131 | temp = [0, 1, 2, 3, 4, 5, 6, 7, 8]
132 | for j in range(9):
133 | if puzzle[bs, i, j, 0] == -1:
134 | k = temp[0]
135 | exist_prob[bs, i, k, 0] = 1
136 | exist_prob[bs, i, k] = smooth_dist(exist_prob[bs, i, k])
137 | temp.remove(k)
138 | else:
139 | k = int(puzzle[bs, i, j, 0])
140 | exist_prob[bs, i, k, 1] = 1
141 | exist_prob[bs, i, k] = smooth_dist(exist_prob[bs, i, k])
142 | type_prob[bs, i, k, int(puzzle[bs, i, j, 4])] = 1
143 | type_prob[bs, i, k] = smooth_dist(type_prob[bs, i, k])
144 | size_prob[bs, i, k, int(puzzle[bs, i, j, 3])] = 1
145 | size_prob[bs, i, k] = smooth_dist(size_prob[bs, i, k])
146 | color_prob[bs, i, k, int(puzzle[bs, i, j, 2])] = 1
147 | color_prob[bs, i, k] = smooth_dist(color_prob[bs, i, k])
148 | angle_prob[bs, i, k, int(puzzle[bs, i, j, 1])] = 1
149 | angle_prob[bs, i, k] = smooth_dist(angle_prob[bs, i, k])
150 | temp.remove(k)
151 | return {
152 | "exist": torch.log(exist_prob),
153 | "type": torch.log(type_prob),
154 | "size": torch.log(size_prob),
155 | "color": torch.log(color_prob),
156 | "angle": torch.log(angle_prob),
157 | }
158 |
159 | if (
160 | panel_constellation == "left_right"
161 | or panel_constellation == "up_down"
162 | or panel_constellation == "in_out_single"
163 | ):
164 | exist_prob = torch.ones((batch_size, num_panels, 2, 2)) * eps
165 | type_prob = torch.ones((batch_size, num_panels, 2, DIM_ONEHOT)) * eps
166 | size_prob = torch.ones((batch_size, num_panels, 2, DIM_ONEHOT)) * eps
167 | color_prob = torch.ones((batch_size, num_panels, 2, DIM_ONEHOT)) * eps
168 | angle_prob = torch.ones((batch_size, num_panels, 2, DIM_ONEHOT)) * eps
169 | exist_prob[:, :, 0, 1] = 1
170 | exist_prob[:, :, 1, 1] = 1
171 |
172 | for bs in range(batch_size):
173 | for i in range(num_panels):
174 | for j in range(2):
175 | k = int(puzzle[bs, i, j, 0])
176 | exist_prob[bs, i, k] = smooth_dist(exist_prob[bs, i, k])
177 | type_prob[bs, i, k, int(puzzle[bs, i, j, 4])] = 1
178 | size_prob[bs, i, k, int(puzzle[bs, i, j, 3])] = 1
179 | color_prob[bs, i, k, int(puzzle[bs, i, j, 2])] = 1
180 | angle_prob[bs, i, k, int(puzzle[bs, i, j, 1])] = 1
181 | return (
182 | torch.log(exist_prob),
183 | torch.log(type_prob),
184 | torch.log(size_prob),
185 | torch.log(color_prob),
186 | torch.log(angle_prob),
187 | )
188 |
189 | if panel_constellation == "in_out_four":
190 | exist_prob = torch.ones((batch_size, num_panels, 5, 2)) * eps
191 | type_prob = torch.ones((batch_size, num_panels, 5, DIM_ONEHOT)) * eps
192 | size_prob = torch.ones((batch_size, num_panels, 5, DIM_ONEHOT)) * eps
193 | color_prob = torch.ones((batch_size, num_panels, 5, DIM_ONEHOT)) * eps
194 | angle_prob = torch.ones((batch_size, num_panels, 5, DIM_ONEHOT)) * eps
195 |
196 | for bs in range(batch_size):
197 | for i in range(num_panels):
198 | temp = [0, 1, 2, 3, 4]
199 | for j in range(5):
200 | if puzzle[bs, i, j, 0] == -1:
201 | k = temp[0]
202 | exist_prob[bs, i, k, 0] = 1
203 | exist_prob[bs, i, k] = smooth_dist(exist_prob[bs, i, k])
204 | temp.remove(k)
205 | else:
206 | k = int(puzzle[bs, i, j, 0])
207 | exist_prob[bs, i, k, 1] = 1
208 | exist_prob[bs, i, k] = smooth_dist(exist_prob[bs, i, k])
209 | type_prob[bs, i, k, int(puzzle[bs, i, j, 4])] = 1
210 | type_prob[bs, i, k] = smooth_dist(type_prob[bs, i, k])
211 | size_prob[bs, i, k, int(puzzle[bs, i, j, 3])] = 1
212 | size_prob[bs, i, k] = smooth_dist(size_prob[bs, i, k])
213 | color_prob[bs, i, k, int(puzzle[bs, i, j, 2])] = 1
214 | color_prob[bs, i, k] = smooth_dist(color_prob[bs, i, k])
215 | angle_prob[bs, i, k, int(puzzle[bs, i, j, 1])] = 1
216 | angle_prob[bs, i, k] = smooth_dist(angle_prob[bs, i, k])
217 | temp.remove(k)
218 | return (
219 | torch.log(exist_prob),
220 | torch.log(type_prob),
221 | torch.log(size_prob),
222 | torch.log(color_prob),
223 | torch.log(angle_prob),
224 | )
225 |
--------------------------------------------------------------------------------
/arlc/rule_templates.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import torch as t
7 | import torch.nn as nn
8 | from nvsa.reasoning.vsa_block_utils import (
9 | block_binding2,
10 | block_unbinding2,
11 | )
12 |
13 |
14 | def oh(tensor, idx, val=10001.0):
15 | tensor[idx] = val
16 | return tensor
17 |
18 |
19 | class ExtendedGeneralLearnableFormula(nn.Module):
20 | def __init__(self, examples_len, context_len, k, program_rule=None):
21 | super(ExtendedGeneralLearnableFormula, self).__init__()
22 | self.k = k
23 | self.context_len = context_len
24 | self.examples_len = examples_len
25 | self.program_rule = program_rule
26 | if program_rule:
27 | self.program_weights(rule=program_rule)
28 | else:
29 | self.init_terms(12, examples_len + context_len + 1)
30 | self.softmax = nn.Softmax(dim=-1)
31 | self.T = 1
32 |
33 | def init_terms(self, num_terms, num_panels):
34 | terms = list()
35 | for _ in range(num_terms):
36 | terms.append(nn.Parameter(t.randn(num_panels)))
37 | self.terms = nn.ParameterList(terms)
38 |
39 | def program_weights(self, rule, device="cuda"):
40 | print(f"Programming {rule}")
41 | # init every term with the identity
42 | self.terms = []
43 | for i in range(12):
44 | self.terms.append(oh(t.zeros(8, device=device), -1))
45 | if rule == "add":
46 | self.terms[0] = oh(t.zeros(8, device=device), 0) # +x1
47 | self.terms[1] = oh(t.zeros(8, device=device), 1) # +x2
48 | self.terms[2] = oh(t.zeros(8, device=device), 4) # +c3
49 | self.terms[7] = oh(t.zeros(8, device=device), 2) # -c1
50 | self.terms[8] = oh(t.zeros(8, device=device), 3) # -c2
51 | elif rule == "sub":
52 | self.terms[0] = oh(t.zeros(8, device=device), 0) # +x1
53 | self.terms[6] = oh(t.zeros(8, device=device), 1) # -x2
54 | self.terms[2] = oh(t.zeros(8, device=device), 2) # +c1
55 | self.terms[7] = oh(t.zeros(8, device=device), 3) # -c2
56 | self.terms[8] = oh(t.zeros(8, device=device), 4) # -c3
57 | elif rule == "dist3":
58 | self.terms[0] = oh(t.zeros(8, device=device), 2) # +c1
59 | self.terms[1] = oh(t.zeros(8, device=device), 3) # +c2
60 | self.terms[2] = oh(t.zeros(8, device=device), 4) # +c3
61 | self.terms[6] = oh(t.zeros(8, device=device), 0) # -x1
62 | self.terms[7] = oh(t.zeros(8, device=device), 1) # -x2
63 | elif rule == "progr":
64 | self.terms[0] = oh(t.zeros(8, device=device), 1) # +x2
65 | self.terms[1] = oh(t.zeros(8, device=device), 1) # +x2
66 | self.terms[6] = oh(t.zeros(8, device=device), 0) # -x1
67 |
68 | def add_identity(self, x):
69 | identity = t.zeros_like(x[:, 0]).unsqueeze(1)
70 | identity[:, :, :, 0] = 1
71 | x_with_identity = t.cat((x, identity), dim=1)
72 | return x_with_identity
73 |
74 | def forward(self, x, ctx):
75 | x = t.cat([x, ctx], dim=1)
76 | x = x.reshape(x.shape[0], x.shape[1], self.k, -1)
77 | x = self.add_identity(x)
78 | x = x.view(x.shape[0], x.shape[1], -1)
79 |
80 | def wcomb(weights, input):
81 | attn_score = self.softmax(weights.unsqueeze(0).unsqueeze(0) / self.T)
82 | term = (
83 | t.matmul(attn_score.repeat(input.shape[0], 1, 1), input)
84 | .squeeze(1)
85 | .view(input.shape[0], self.k, -1)
86 | )
87 | return term
88 |
89 | def bind_seq(seq):
90 | seq_len = len(seq)
91 | res = seq[0]
92 | for i in range(1, seq_len):
93 | res = block_binding2(res, seq[i])
94 | return res
95 |
96 | n = bind_seq([wcomb(t, x) for t in self.terms[: len(self.terms) // 2]])
97 | d = bind_seq([wcomb(t, x) for t in self.terms[len(self.terms) // 2 :]])
98 | output = block_unbinding2(n, d)
99 | output = output.view(output.shape[0], -1)
100 | return output
101 |
102 | def __str__(self):
103 | tl = self.context_len + self.examples_len + 1
104 | cfp = nn.functional.one_hot(self.softmax(self.terms[0]).argmax(), tl)
105 | # add + terms
106 | for i in range(1, len(self.terms) // 2):
107 | cfp += nn.functional.one_hot(self.softmax(self.terms[i]).argmax(), tl)
108 |
109 | cfm = -nn.functional.one_hot(
110 | self.softmax(self.terms[len(self.terms) // 2]).argmax(), tl
111 | )
112 | # add - terms
113 | for i in range(len(self.terms) // 2 + 1, len(self.terms)):
114 | cfm -= nn.functional.one_hot(self.softmax(self.terms[i]).argmax(), tl)
115 |
116 | cf = cfp + cfm
117 |
118 | terms = [f"x{i+1}" for i in range(self.examples_len)] + [
119 | f"c{i+1}" for i in range(self.context_len)
120 | ]
121 | hr_rule = [
122 | f"{'+' if cf[i]>0 else ''}{cf[i]}{x} " for i, x in enumerate(terms) if cf[i]
123 | ]
124 | return "".join(hr_rule)
125 |
126 | def anneal_softmax(self):
127 | self.T = 0.01
128 |
129 |
130 | class IravenxGeneralLearnableFormula(nn.Module):
131 | def __init__(self, examples_len, context_len, k, num_terms=12, program_rule=None):
132 | super(IravenxGeneralLearnableFormula, self).__init__()
133 | self.k = k
134 | self.context_len = context_len
135 | self.examples_len = examples_len
136 | self.program_rule = program_rule
137 | self.num_terms = num_terms
138 | if program_rule:
139 | self.program_weights(rule=program_rule, num_terms=num_terms)
140 | else:
141 | self.init_terms(num_terms, examples_len + context_len + 1)
142 | self.softmax = nn.Softmax(dim=-1)
143 |
144 | def init_terms(self, num_terms, num_panels):
145 | terms = list()
146 | for _ in range(num_terms):
147 | terms.append(nn.Parameter(t.randn(num_panels)))
148 | # terms.append(nn.Parameter(oh(t.zeros(num_panels), -1, 1)))
149 | # for i in range(12, 13):
150 | # terms[i].data = oh(t.zeros(num_panels), i-12, 1) # -x1 -x2 ... -x9
151 | self.terms = nn.ParameterList(terms)
152 |
153 | def add_identity(self, x):
154 | identity = t.zeros_like(x[:, 0]).unsqueeze(1)
155 | identity[:, :, :, 0] = 1
156 | x_with_identity = t.cat((x, identity), dim=1)
157 | return x_with_identity
158 |
159 | def forward(self, x, ctx):
160 | x = t.cat([x, ctx], dim=1)
161 | x = x.reshape(x.shape[0], x.shape[1], self.k, -1)
162 | x = self.add_identity(x)
163 | x = x.view(x.shape[0], x.shape[1], -1)
164 |
165 | def wcomb(weights, input):
166 | attn_score = self.softmax(weights.unsqueeze(0).unsqueeze(0))
167 | term = (
168 | t.matmul(attn_score.repeat(input.shape[0], 1, 1), input)
169 | .squeeze(1)
170 | .view(input.shape[0], self.k, -1)
171 | )
172 | return term
173 |
174 | def bind_seq(seq):
175 | seq_len = len(seq)
176 | res = seq[0]
177 | for i in range(1, seq_len):
178 | res = block_binding2(res, seq[i])
179 | return res
180 |
181 | n = bind_seq([wcomb(t, x) for t in self.terms[: len(self.terms) // 2]])
182 | d = bind_seq([wcomb(t, x) for t in self.terms[len(self.terms) // 2 :]])
183 |
184 | output = block_unbinding2(n, d)
185 | output = output.view(output.shape[0], -1)
186 | return output
187 |
188 | def program_weights(self, rule, num_terms, device="cuda"):
189 | print(f"Programming {rule}")
190 |
191 | # init every term with the identity
192 | self.terms = []
193 | for i in range(num_terms):
194 | self.terms.append(oh(t.zeros(29, device=device), -1))
195 |
196 | if rule == "constant":
197 | self.terms[0] = oh(self.terms[0], 0) # +x1
198 | self.terms[1] = oh(self.terms[0], 0) # +x1
199 | self.terms[12] = oh(self.terms[0], 0) # -x1
200 |
201 | elif rule == "add":
202 | for i in range(9):
203 | self.terms[i] = oh(t.zeros(29, device=device), i) # +x1 +x2 ... +x9
204 | self.terms[9] = oh(t.zeros(29, device=device), 0) # +x1
205 | self.terms[12] = oh(t.zeros(29, device=device), 0) # -x1
206 |
207 | elif rule == "sub":
208 | for i in range(12, 21):
209 | self.terms[i] = oh(
210 | t.zeros(29, device=device), i - 12
211 | ) # -x1 -x2 ... -x9
212 | self.terms[0] = oh(t.zeros(29, device=device), 0) # +x1
213 | self.terms[1] = oh(t.zeros(29, device=device), 0) # +x1
214 |
215 | elif rule == "dist3":
216 | for i in range(10):
217 | self.terms[i] = oh(
218 | t.zeros(29, device=device), 9 + i
219 | ) # +c1 +c2 ... +c10
220 | for i in range(13, 22):
221 | self.terms[i] = oh(
222 | t.zeros(29, device=device), i - 13
223 | ) # -x1 -x2 ... -x9
224 |
225 | elif rule == "progr":
226 | self.terms[0] = oh(t.zeros(29, device=device), 1) # +x2
227 | self.terms[1] = oh(t.zeros(29, device=device), 8) # +x9
228 | self.terms[12] = oh(t.zeros(29, device=device), 0) # -x1
229 |
230 | def __str__(self):
231 | tl = self.context_len + self.examples_len + 1
232 | cfp = nn.functional.one_hot(self.softmax(self.terms[0]).argmax(), tl)
233 | # add + terms
234 | for i in range(1, len(self.terms) // 2):
235 | cfp += nn.functional.one_hot(self.softmax(self.terms[i]).argmax(), tl)
236 |
237 | cfm = -nn.functional.one_hot(
238 | self.softmax(self.terms[len(self.terms) // 2]).argmax(), tl
239 | )
240 | # add - terms
241 | for i in range(len(self.terms) // 2 + 1, len(self.terms)):
242 | cfm -= nn.functional.one_hot(self.softmax(self.terms[i]).argmax(), tl)
243 |
244 | cf = cfp + cfm
245 |
246 | terms = (
247 | [f"x{i+1}" for i in range(self.examples_len)]
248 | + [f"c{i+1}" for i in range(self.context_len)]
249 | + ["e"]
250 | )
251 | hr_rule = [
252 | f"{'+' if cf[i]>0 else ''}{cf[i]}{x} " for i, x in enumerate(terms) if cf[i]
253 | ]
254 | return "".join(
255 | hr_rule
256 | )
257 |
258 |
259 | class IravenVGeneralLearnableFormula(nn.Module):
260 | def __init__(self, examples_len, context_len, k, num_terms=12, program_rule=None):
261 | super(IravenVGeneralLearnableFormula, self).__init__()
262 | self.k = k
263 | self.context_len = context_len
264 | self.examples_len = examples_len
265 | self.program_rule = program_rule
266 | self.num_terms = num_terms
267 | if program_rule:
268 | self.program_weights(rule=program_rule, num_terms=num_terms)
269 | else:
270 | self.init_terms(num_terms, examples_len + context_len + 1)
271 | self.softmax = nn.Softmax(dim=-1)
272 |
273 | def init_terms(self, num_terms, num_panels):
274 | terms = list()
275 | for _ in range(num_terms):
276 | terms.append(nn.Parameter(t.randn(num_panels)))
277 | # terms.append(nn.Parameter(oh(t.zeros(num_panels), -1, 1)))
278 | # for i in range(12, 13):
279 | # terms[i].data = oh(t.zeros(num_panels), i-12, 1) # -x1 -x2 ... -x9
280 | self.terms = nn.ParameterList(terms)
281 |
282 | def add_identity(self, x):
283 | identity = t.zeros_like(x[:, 0]).unsqueeze(1)
284 | identity[:, :, :, 0] = 1
285 | x_with_identity = t.cat((x, identity), dim=1)
286 | return x_with_identity
287 |
288 | def forward(self, x, ctx):
289 | x = t.cat([x, ctx], dim=1)
290 | x = x.reshape(x.shape[0], x.shape[1], self.k, -1)
291 | x = self.add_identity(x)
292 | x = x.view(x.shape[0], x.shape[1], -1)
293 |
294 | def wcomb(weights, input):
295 | attn_score = self.softmax(weights.unsqueeze(0).unsqueeze(0))
296 | term = (
297 | t.matmul(attn_score.repeat(input.shape[0], 1, 1), input)
298 | .squeeze(1)
299 | .view(input.shape[0], self.k, -1)
300 | )
301 | return term
302 |
303 | def bind_seq(seq):
304 | seq_len = len(seq)
305 | res = seq[0]
306 | for i in range(1, seq_len):
307 | res = block_binding2(res, seq[i])
308 | return res
309 |
310 | n = bind_seq([wcomb(t, x) for t in self.terms[: len(self.terms) // 2]])
311 | d = bind_seq([wcomb(t, x) for t in self.terms[len(self.terms) // 2 :]])
312 |
313 | output = block_unbinding2(n, d)
314 | output = output.view(output.shape[0], -1)
315 | return output
316 |
317 | def program_weights(self, rule, num_terms, device="cuda"):
318 | print(f"Programming {rule}")
319 |
320 | # init every term with the identity
321 | self.terms = []
322 | for i in range(num_terms):
323 | self.terms.append(oh(t.zeros(14, device=device), -1))
324 |
325 | if rule == "constant":
326 | self.terms[0] = oh(self.terms[0], 0) # +x1
327 | self.terms[1] = oh(self.terms[0], 0) # +x1
328 | self.terms[11] = oh(self.terms[0], 0) # -x1
329 |
330 | elif rule == "add":
331 | for i in range(4):
332 | self.terms[i] = oh(t.zeros(14, device=device), i) # +x1 +x2 ... +x9
333 | self.terms[4] = oh(t.zeros(14, device=device), 0) # +x1
334 | self.terms[11] = oh(t.zeros(14, device=device), 0) # -x1
335 |
336 | elif rule == "sub":
337 | for i in range(10, 14):
338 | self.terms[i] = oh(
339 | t.zeros(14, device=device), i - 10
340 | ) # -x1 -x2 ... -x9
341 | self.terms[0] = oh(t.zeros(14, device=device), 0) # +x1
342 | self.terms[1] = oh(t.zeros(14, device=device), 0) # +x1
343 |
344 | elif rule == "dist3":
345 | for i in range(5):
346 | self.terms[i] = oh(
347 | t.zeros(14, device=device), 4 + i
348 | ) # +c1 +c2 ... +c10
349 | for i in range(10, 14):
350 | self.terms[i] = oh(
351 | t.zeros(14, device=device), i - 10
352 | ) # -x1 -x2 ... -x9
353 |
354 | elif rule == "progr":
355 | self.terms[0] = oh(t.zeros(14, device=device), 1) # +x2
356 | self.terms[1] = oh(t.zeros(14, device=device), 3) # +x3
357 | self.terms[10] = oh(t.zeros(14, device=device), 0) # -x1
358 |
359 | def __str__(self):
360 | tl = self.context_len + self.examples_len + 1
361 | cfp = nn.functional.one_hot(self.softmax(self.terms[0]).argmax(), tl)
362 | # add + terms
363 | for i in range(1, len(self.terms) // 2):
364 | cfp += nn.functional.one_hot(self.softmax(self.terms[i]).argmax(), tl)
365 | cfm = -nn.functional.one_hot(
366 | self.softmax(self.terms[len(self.terms) // 2]).argmax(), tl
367 | )
368 | # add - terms
369 | for i in range(len(self.terms) // 2 + 1, len(self.terms)):
370 | cfm -= nn.functional.one_hot(self.softmax(self.terms[i]).argmax(), tl)
371 | cf = cfp + cfm
372 | terms = [f"x{i+1}" for i in range(self.examples_len)] + [
373 | f"c{i+1}" for i in range(self.context_len)
374 | ]
375 | hr_rule = [
376 | f"{'+' if cf[i]>0 else ''}{cf[i]}{x} " for i, x in enumerate(terms) if cf[i]
377 | ]
378 | return "".join(
379 | hr_rule
380 | )
381 |
--------------------------------------------------------------------------------
/arlc/utils/raven/extraction.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import numpy as np
7 | import os
8 | import xml.etree.ElementTree as ET
9 | import argparse
10 | import tqdm
11 |
12 | parser = argparse.ArgumentParser(description="NVSA")
13 | parser.add_argument("--data_path", type=str, default="/dccstor/saentis/data/I-RAVEN")
14 |
15 | pos_num_rule_idx_map_four = {
16 | "Constant": 0,
17 | "Progression_One_Pos": 1,
18 | "Progression_Mone_Pos": 2,
19 | "Arithmetic_Plus_Pos": 3,
20 | "Arithmetic_Minus_Pos": 4,
21 | "Distribute_Three_Left_Pos": 5,
22 | "Distribute_Three_Right_Pos": 6,
23 | "Progression_One_Num": 7,
24 | "Progression_Mone_Num": 8,
25 | "Arithmetic_Plus_Num": 9,
26 | "Arithmetic_Minus_Num": 10,
27 | "Distribute_Three_Left_Num": 11,
28 | "Distribute_Three_Right_Num": 12,
29 | }
30 |
31 | pos_num_rule_idx_map_nine = {
32 | "Constant": 0,
33 | "Progression_One_Pos": 1,
34 | "Progression_Mone_Pos": 2,
35 | "Progression_Two_Pos": 3,
36 | "Progression_Mtwo_Pos": 4,
37 | "Arithmetic_Plus_Pos": 5,
38 | "Arithmetic_Minus_Pos": 6,
39 | "Distribute_Three_Left_Pos": 7,
40 | "Distribute_Three_Right_Pos": 8,
41 | "Progression_One_Num": 9,
42 | "Progression_Mone_Num": 10,
43 | "Progression_Two_Num": 11,
44 | "Progression_Mtwo_Num": 12,
45 | "Arithmetic_Plus_Num": 13,
46 | "Arithmetic_Minus_Num": 14,
47 | "Distribute_Three_Left_Num": 15,
48 | "Distribute_Three_Right_Num": 16,
49 | }
50 |
51 | type_rule_idx_map = {
52 | "Constant": 0,
53 | "Progression_One": 1,
54 | "Progression_Mone": 2,
55 | "Progression_Two": 3,
56 | "Progression_Mtwo": 4,
57 | "Distribute_Three_Left": 5,
58 | "Distribute_Three_Right": 6,
59 | }
60 |
61 | size_rule_idx_map = {
62 | "Constant": 0,
63 | "Progression_One": 1,
64 | "Progression_Mone": 2,
65 | "Progression_Two": 3,
66 | "Progression_Mtwo": 4,
67 | "Arithmetic_Plus": 5,
68 | "Arithmetic_Minus": 6,
69 | "Distribute_Three_Left": 7,
70 | "Distribute_Three_Right": 8,
71 | }
72 |
73 | color_rule_idx_map = {
74 | "Constant": 0,
75 | "Progression_One": 1,
76 | "Progression_Mone": 2,
77 | "Progression_Two": 3,
78 | "Progression_Mtwo": 4,
79 | "Arithmetic_Plus": 5,
80 | "Arithmetic_Minus": 6,
81 | "Distribute_Three_Left": 7,
82 | "Distribute_Three_Right": 8,
83 | }
84 |
85 | type_idx_rule_map = {
86 | 0: "Constant",
87 | 1: "Progression_One",
88 | 2: "Progression_Mone",
89 | 3: "Progression_Two",
90 | 4: "Progression_Mtwo",
91 | 5: "Distribute_Three_Left",
92 | 6: "Distribute_Three_Right",
93 | }
94 | size_idx_rule_map = {
95 | 0: "Constant",
96 | 1: "Progression_One",
97 | 2: "Progression_Mone",
98 | 3: "Progression_Two",
99 | 4: "Progression_Mtwo",
100 | 5: "Arithmetic_Plus",
101 | 6: "Arithmetic_Minus",
102 | 7: "Distribute_Three_Left",
103 | 8: "Distribute_Three_Right",
104 | }
105 | color_idx_rule_map = {
106 | 0: "Constant",
107 | 1: "Progression_One",
108 | 2: "Progression_Mone",
109 | 3: "Progression_Two",
110 | 4: "Progression_Mtwo",
111 | 5: "Arithmetic_Plus",
112 | 6: "Arithmetic_Minus",
113 | 7: "Distribute_Three_Left",
114 | 8: "Distribute_Three_Right",
115 | }
116 |
117 | pos_num_rule_four_idx_map = {
118 | 0: "Constant",
119 | 1: "Progression_One_Pos",
120 | 2: "Progression_Mone_Pos",
121 | 3: "Arithmetic_Plus_Pos",
122 | 4: "Arithmetic_Minus_Pos",
123 | 5: "Distribute_Three_Left_Pos",
124 | 6: "Distribute_Three_Right_Pos",
125 | 7: "Progression_One_Num",
126 | 8: "Progression_Mone_Num",
127 | 9: "Arithmetic_Plus_Num",
128 | 10: "Arithmetic_Minus_Num",
129 | 11: "Distribute_Three_Left_Num",
130 | 12: "Distribute_Three_Right_Num",
131 | }
132 |
133 | pos_num_rule_nine_idx_map = {
134 | 0: "Constant",
135 | 1: "Progression_One_Pos",
136 | 2: "Progression_Mone_Pos",
137 | 3: "Progression_Two_Pos",
138 | 4: "Progression_Mtwo_Pos",
139 | 5: "Arithmetic_Plus_Pos",
140 | 6: "Arithmetic_Minus_Pos",
141 | 7: "Distribute_Three_Left_Pos",
142 | 8: "Distribute_Three_Right_Pos",
143 | 9: "Progression_One_Num",
144 | 10: "Progression_Mone_Num",
145 | 11: "Progression_Two_Num",
146 | 12: "Progression_Mtwo_Num",
147 | 13: "Arithmetic_Plus_Num",
148 | 14: "Arithmetic_Minus_Num",
149 | 15: "Distribute_Three_Left_Num",
150 | 16: "Distribute_Three_Right_Num",
151 | }
152 |
153 |
154 | def get_pos_num_rule(
155 | rule_idx, comp_idx, num_elements, pos_num_rule_idx_map, xml_panels, xml_rules
156 | ):
157 | index_name = xml_rules[rule_idx][0].attrib["name"]
158 | attrib_name = xml_rules[rule_idx][0].attrib["attr"][:3]
159 | if index_name == "Progression":
160 | if attrib_name == "Num":
161 | first = int(xml_panels[0][0][comp_idx][0].attrib["Number"])
162 | second = int(xml_panels[1][0][comp_idx][0].attrib["Number"])
163 | if second == first + 1:
164 | index_name += "_One_Num"
165 | if second == first - 1:
166 | index_name += "_Mone_Num"
167 | if second == first + 2:
168 | index_name += "_Two_Num"
169 | if second == first - 2:
170 | index_name += "_Mtwo_Num"
171 | if attrib_name == "Pos":
172 | all_position = eval(xml_panels[0][0][comp_idx][0].attrib["Position"])
173 | first = []
174 | for entity in xml_panels[0][0][comp_idx][0]:
175 | first.append(all_position.index(eval(entity.attrib["bbox"])))
176 | second = []
177 | for entity in xml_panels[1][0][comp_idx][0]:
178 | second.append(all_position.index(eval(entity.attrib["bbox"])))
179 | third = []
180 | for entity in xml_panels[2][0][comp_idx][0]:
181 | third.append(all_position.index(eval(entity.attrib["bbox"])))
182 | fourth = []
183 | for entity in xml_panels[3][0][comp_idx][0]:
184 | fourth.append(all_position.index(eval(entity.attrib["bbox"])))
185 | fifth = []
186 | for entity in xml_panels[4][0][comp_idx][0]:
187 | fifth.append(all_position.index(eval(entity.attrib["bbox"])))
188 | sixth = []
189 | for entity in xml_panels[5][0][comp_idx][0]:
190 | sixth.append(all_position.index(eval(entity.attrib["bbox"])))
191 | seventh = []
192 | for entity in xml_panels[6][0][comp_idx][0]:
193 | seventh.append(all_position.index(eval(entity.attrib["bbox"])))
194 | eighth = []
195 | for entity in xml_panels[7][0][comp_idx][0]:
196 | eighth.append(all_position.index(eval(entity.attrib["bbox"])))
197 | if (
198 | len(
199 | set(map(lambda index: (index + 1) % num_elements, first))
200 | - set(second)
201 | )
202 | == 0
203 | and len(
204 | set(map(lambda index: (index + 1) % num_elements, second))
205 | - set(third)
206 | )
207 | == 0
208 | and len(
209 | set(map(lambda index: (index + 1) % num_elements, fourth))
210 | - set(fifth)
211 | )
212 | == 0
213 | and len(
214 | set(map(lambda index: (index + 1) % num_elements, fifth))
215 | - set(sixth)
216 | )
217 | == 0
218 | and len(
219 | set(map(lambda index: (index + 1) % num_elements, seventh))
220 | - set(eighth)
221 | )
222 | == 0
223 | ):
224 | index_name += "_One_Pos"
225 | if (
226 | len(
227 | set(map(lambda index: (index - 1) % num_elements, first))
228 | - set(second)
229 | )
230 | == 0
231 | and len(
232 | set(map(lambda index: (index - 1) % num_elements, second))
233 | - set(third)
234 | )
235 | == 0
236 | and len(
237 | set(map(lambda index: (index - 1) % num_elements, fourth))
238 | - set(fifth)
239 | )
240 | == 0
241 | and len(
242 | set(map(lambda index: (index - 1) % num_elements, fifth))
243 | - set(sixth)
244 | )
245 | == 0
246 | and len(
247 | set(map(lambda index: (index - 1) % num_elements, seventh))
248 | - set(eighth)
249 | )
250 | == 0
251 | ):
252 | index_name += "_Mone_Pos"
253 | if (
254 | len(
255 | set(map(lambda index: (index + 2) % num_elements, first))
256 | - set(second)
257 | )
258 | == 0
259 | and len(
260 | set(map(lambda index: (index + 2) % num_elements, second))
261 | - set(third)
262 | )
263 | == 0
264 | and len(
265 | set(map(lambda index: (index + 2) % num_elements, fourth))
266 | - set(fifth)
267 | )
268 | == 0
269 | and len(
270 | set(map(lambda index: (index + 2) % num_elements, fifth))
271 | - set(sixth)
272 | )
273 | == 0
274 | and len(
275 | set(map(lambda index: (index + 2) % num_elements, seventh))
276 | - set(eighth)
277 | )
278 | == 0
279 | ):
280 | index_name += "_Two_Pos"
281 | if (
282 | len(
283 | set(map(lambda index: (index - 2) % num_elements, first))
284 | - set(second)
285 | )
286 | == 0
287 | and len(
288 | set(map(lambda index: (index - 2) % num_elements, second))
289 | - set(third)
290 | )
291 | == 0
292 | and len(
293 | set(map(lambda index: (index - 2) % num_elements, fourth))
294 | - set(fifth)
295 | )
296 | == 0
297 | and len(
298 | set(map(lambda index: (index - 2) % num_elements, fifth))
299 | - set(sixth)
300 | )
301 | == 0
302 | and len(
303 | set(map(lambda index: (index - 2) % num_elements, seventh))
304 | - set(eighth)
305 | )
306 | == 0
307 | ):
308 | index_name += "_Mtwo_Pos"
309 | if index_name.endswith("_One_Pos_Mone_Pos"):
310 | if np.random.uniform() >= 0.5:
311 | index_name = "Progression_One_Pos"
312 | else:
313 | index_name = "Progression_Mone_Pos"
314 | if index_name == "Arithmetic":
315 | if attrib_name == "Num":
316 | first = int(xml_panels[0][0][comp_idx][0].attrib["Number"])
317 | second = int(xml_panels[1][0][comp_idx][0].attrib["Number"])
318 | third = int(xml_panels[2][0][comp_idx][0].attrib["Number"])
319 | if third == first + second + 1:
320 | index_name += "_Plus_Num"
321 | if third == first - second - 1:
322 | index_name += "_Minus_Num"
323 | if attrib_name == "Pos":
324 | all_position = eval(xml_panels[0][0][comp_idx][0].attrib["Position"])
325 | first = []
326 | for entity in xml_panels[0][0][comp_idx][0]:
327 | first.append(all_position.index(eval(entity.attrib["bbox"])))
328 | second = []
329 | for entity in xml_panels[1][0][comp_idx][0]:
330 | second.append(all_position.index(eval(entity.attrib["bbox"])))
331 | third = []
332 | for entity in xml_panels[2][0][comp_idx][0]:
333 | third.append(all_position.index(eval(entity.attrib["bbox"])))
334 | if set(third) == set(first).union(set(second)):
335 | index_name += "_Plus_Pos"
336 | if set(third) == set(first) - set(second):
337 | index_name += "_Minus_Pos"
338 | if index_name == "Distribute_Three":
339 | if attrib_name == "Num":
340 | first = int(xml_panels[0][0][comp_idx][0].attrib["Number"])
341 | second_left = int(xml_panels[5][0][comp_idx][0].attrib["Number"])
342 | second_right = int(xml_panels[4][0][comp_idx][0].attrib["Number"])
343 | if second_left == first:
344 | index_name += "_Left_Num"
345 | if second_right == first:
346 | index_name += "_Right_Num"
347 | if attrib_name == "Pos":
348 | all_position = eval(xml_panels[0][0][comp_idx][0].attrib["Position"])
349 | first = []
350 | for entity in xml_panels[0][0][comp_idx][0]:
351 | first.append(all_position.index(eval(entity.attrib["bbox"])))
352 | second_left = []
353 | for entity in xml_panels[5][0][comp_idx][0]:
354 | second_left.append(all_position.index(eval(entity.attrib["bbox"])))
355 | second_right = []
356 | for entity in xml_panels[4][0][comp_idx][0]:
357 | second_right.append(all_position.index(eval(entity.attrib["bbox"])))
358 | if set(second_left) == set(first):
359 | index_name += "_Left_Pos"
360 | if set(second_right) == set(first):
361 | index_name += "_Right_Pos"
362 | return pos_num_rule_idx_map[index_name]
363 |
364 |
365 | def get_type_rule(
366 | rule_idx, comp_idx, num_elements, pos_num_rule_idx_map, xml_panels, xml_rules
367 | ):
368 | index_name = xml_rules[rule_idx][1].attrib["name"]
369 | if index_name == "Progression":
370 | first = int(xml_panels[0][0][comp_idx][0][0].attrib["Type"])
371 | second = int(xml_panels[1][0][comp_idx][0][0].attrib["Type"])
372 | if second == first + 1:
373 | index_name += "_One"
374 | if second == first - 1:
375 | index_name += "_Mone"
376 | if second == first + 2:
377 | index_name += "_Two"
378 | if second == first - 2:
379 | index_name += "_Mtwo"
380 | if index_name == "Distribute_Three":
381 | first = int(xml_panels[0][0][comp_idx][0][0].attrib["Type"])
382 | second_left = int(xml_panels[5][0][comp_idx][0][0].attrib["Type"])
383 | second_right = int(xml_panels[4][0][comp_idx][0][0].attrib["Type"])
384 | if second_left == first:
385 | index_name += "_Left"
386 | if second_right == first:
387 | index_name += "_Right"
388 | return type_rule_idx_map[index_name]
389 |
390 |
391 | def get_size_rule(
392 | rule_idx, comp_idx, num_elements, pos_num_rule_idx_map, xml_panels, xml_rules
393 | ):
394 | index_name = xml_rules[rule_idx][2].attrib["name"]
395 | if index_name == "Progression":
396 | first = int(xml_panels[0][0][comp_idx][0][0].attrib["Size"])
397 | second = int(xml_panels[1][0][comp_idx][0][0].attrib["Size"])
398 | if second == first + 1:
399 | index_name += "_One"
400 | if second == first - 1:
401 | index_name += "_Mone"
402 | if second == first + 2:
403 | index_name += "_Two"
404 | if second == first - 2:
405 | index_name += "_Mtwo"
406 | if index_name == "Arithmetic":
407 | first = int(xml_panels[0][0][comp_idx][0][0].attrib["Size"])
408 | second = int(xml_panels[1][0][comp_idx][0][0].attrib["Size"])
409 | third = int(xml_panels[2][0][comp_idx][0][0].attrib["Size"])
410 | if third == first + second + 1:
411 | index_name += "_Plus"
412 | if third == first - second - 1:
413 | index_name += "_Minus"
414 | if index_name == "Distribute_Three":
415 | first = int(xml_panels[0][0][comp_idx][0][0].attrib["Size"])
416 | second_left = int(xml_panels[5][0][comp_idx][0][0].attrib["Size"])
417 | second_right = int(xml_panels[4][0][comp_idx][0][0].attrib["Size"])
418 | if second_left == first:
419 | index_name += "_Left"
420 | if second_right == first:
421 | index_name += "_Right"
422 | return size_rule_idx_map[index_name]
423 |
424 |
425 | def get_color_rule(
426 | rule_idx, comp_idx, num_elements, pos_num_rule_idx_map, xml_panels, xml_rules
427 | ):
428 | index_name = xml_rules[rule_idx][3].attrib["name"]
429 | if index_name == "Progression":
430 | first = int(xml_panels[0][0][comp_idx][0][0].attrib["Color"])
431 | second = int(xml_panels[1][0][comp_idx][0][0].attrib["Color"])
432 | if second == first + 1:
433 | index_name += "_One"
434 | if second == first - 1:
435 | index_name += "_Mone"
436 | if second == first + 2:
437 | index_name += "_Two"
438 | if second == first - 2:
439 | index_name += "_Mtwo"
440 | if index_name == "Arithmetic":
441 | first = int(xml_panels[0][0][comp_idx][0][0].attrib["Color"])
442 | second = int(xml_panels[1][0][comp_idx][0][0].attrib["Color"])
443 | third = int(xml_panels[2][0][comp_idx][0][0].attrib["Color"])
444 | fourth = int(xml_panels[3][0][comp_idx][0][0].attrib["Color"])
445 | fifth = int(xml_panels[4][0][comp_idx][0][0].attrib["Color"])
446 | sixth = int(xml_panels[5][0][comp_idx][0][0].attrib["Color"])
447 | if (third == first + second) and (sixth == fourth + fifth):
448 | index_name += "_Plus"
449 | if (third == first - second) and (sixth == fourth - fifth):
450 | index_name += "_Minus"
451 | if index_name == "Distribute_Three":
452 | first = int(xml_panels[0][0][comp_idx][0][0].attrib["Color"])
453 | second_left = int(xml_panels[5][0][comp_idx][0][0].attrib["Color"])
454 | second_right = int(xml_panels[4][0][comp_idx][0][0].attrib["Color"])
455 | if second_left == first:
456 | index_name += "_Left"
457 | if second_right == first:
458 | index_name += "_Right"
459 | return color_rule_idx_map[index_name]
460 |
461 |
462 | def main():
463 | args = parser.parse_args()
464 | DATA_PATH = args.data_path
465 | constellation_name_list = [
466 | "center_single",
467 | "distribute_four",
468 | "distribute_nine",
469 | "left_center_single_right_center_single",
470 | "up_center_single_down_center_single",
471 | "in_center_single_out_center_single",
472 | "in_distribute_four_out_center_single",
473 | ]
474 | save_name_list = [
475 | "center_single_extracted_with_rules",
476 | "distribute_four_extracted_with_rules",
477 | "distribute_nine_extracted_with_rules",
478 | "left_right_extracted_with_rules",
479 | "up_down_extracted_with_rules",
480 | "in_out_single_extracted_with_rules",
481 | "in_out_four_extracted_with_rules",
482 | ]
483 |
484 | obj_name_list = ["train", "val", "test"]
485 | my_bbox = {
486 | "in_out_four_extracted_with_rules": {
487 | "[0.5, 0.5, 1, 1]": 0,
488 | "[0.42, 0.42, 0.15, 0.15]": 1,
489 | "[0.42, 0.58, 0.15, 0.15]": 2,
490 | "[0.58, 0.42, 0.15, 0.15]": 3,
491 | "[0.58, 0.58, 0.15, 0.15]": 4,
492 | },
493 | "in_out_single_extracted_with_rules": {
494 | "[0.5, 0.5, 1, 1]": 0,
495 | "[0.5, 0.5, 0.33, 0.33]": 1,
496 | },
497 | "up_down_extracted_with_rules": {
498 | "[0.25, 0.5, 0.5, 0.5]": 0,
499 | "[0.75, 0.5, 0.5, 0.5]": 1,
500 | },
501 | "left_right_extracted_with_rules": {
502 | "[0.5, 0.25, 0.5, 0.5]": 0,
503 | "[0.5, 0.75, 0.5, 0.5]": 1,
504 | },
505 | "distribute_nine_extracted_with_rules": {
506 | "[0.16, 0.16, 0.33, 0.33]": 0,
507 | "[0.16, 0.5, 0.33, 0.33]": 1,
508 | "[0.16, 0.83, 0.33, 0.33]": 2,
509 | "[0.5, 0.16, 0.33, 0.33]": 3,
510 | "[0.5, 0.5, 0.33, 0.33]": 4,
511 | "[0.5, 0.83, 0.33, 0.33]": 5,
512 | "[0.83, 0.16, 0.33, 0.33]": 6,
513 | "[0.83, 0.5, 0.33, 0.33]": 7,
514 | "[0.83, 0.83, 0.33, 0.33]": 8,
515 | },
516 | "distribute_four_extracted_with_rules": {
517 | "[0.25, 0.25, 0.5, 0.5]": 0,
518 | "[0.25, 0.75, 0.5, 0.5]": 1,
519 | "[0.75, 0.25, 0.5, 0.5]": 2,
520 | "[0.75, 0.75, 0.5, 0.5]": 3,
521 | },
522 | }
523 | for w in range(len(constellation_name_list)):
524 | file_type = constellation_name_list[w]
525 | save_name = save_name_list[w]
526 | path = os.path.join(DATA_PATH, save_name)
527 | path_train, path_val, path_test = (
528 | os.path.join(path, "train"),
529 | os.path.join(path, "val"),
530 | os.path.join(path, "test"),
531 | )
532 | os.makedirs(path, exist_ok=True)
533 | os.makedirs(path_train, exist_ok=True)
534 | os.makedirs(path_val, exist_ok=True)
535 | os.makedirs(path_test, exist_ok=True)
536 | for n in range(len(obj_name_list)):
537 | count = 0
538 | obj_name = obj_name_list[n]
539 | for j in tqdm.tqdm(range(10001)):
540 | try:
541 | tree = ET.parse(
542 | "{0}/{1}/RAVEN_{2}_{3}.xml".format(
543 | DATA_PATH, file_type, j, obj_name
544 | )
545 | )
546 | except:
547 | continue
548 | root = tree.getroot()
549 | xml_panels = root[0]
550 | xml_rules = root[1]
551 | rule_idx = 0
552 | comp_idx = 0
553 | num_elements = 9
554 | pos_num_rule_idx_map = pos_num_rule_idx_map_four
555 | if file_type == "distribute_four":
556 | num_elements = 4
557 | pos_num_rule_idx_map = pos_num_rule_idx_map_four
558 | elif file_type == "distribute_nine":
559 | num_elements = 9
560 | pos_num_rule_idx_map = pos_num_rule_idx_map_nine
561 | args = [
562 | rule_idx,
563 | comp_idx,
564 | num_elements,
565 | pos_num_rule_idx_map,
566 | xml_panels,
567 | xml_rules,
568 | ]
569 | pos_num_rule = np.array(get_pos_num_rule(*args))
570 | color_rule = np.array(get_color_rule(*args))
571 | size_rule = np.array(get_size_rule(*args))
572 | type_rule = np.array(get_type_rule(*args))
573 | rules = np.array([pos_num_rule, color_rule, size_rule, type_rule])
574 | idx_panel = 0
575 | for panel in root[0]:
576 | idx = 0
577 | for component in panel[0]:
578 | for entity in component[0]:
579 | a = entity.attrib
580 | angle, color, size, typ, bbox = (
581 | int(a.get("Angle")),
582 | int(a.get("Color")),
583 | int(a.get("Size")),
584 | int(a.get("Type")) - 1,
585 | a.get("bbox"),
586 | )
587 | pos = (
588 | my_bbox[save_name][bbox]
589 | if save_name != "center_single_extracted_with_rules"
590 | else 0
591 | )
592 | ext_comp = [pos, angle, color, size, typ]
593 | ext_comp = np.expand_dims(ext_comp, axis=0)
594 | ext_panel = (
595 | ext_comp
596 | if (idx == 0)
597 | else np.concatenate((ext_panel, ext_comp), axis=0)
598 | )
599 | idx = idx + 1
600 | c = 9 - idx
601 | if c > 0:
602 | filler = np.ones((c, 5)) * (-1)
603 | ext_panel = np.concatenate((ext_panel, filler), axis=0)
604 | ext_panel = np.expand_dims(ext_panel, axis=0)
605 | ext_sample = (
606 | ext_panel
607 | if (idx_panel == 0)
608 | else np.concatenate((ext_sample, ext_panel), axis=0)
609 | )
610 | idx_panel = idx_panel + 1
611 | file = np.load(
612 | "{0}/{1}/RAVEN_{2}_{3}.npz".format(
613 | DATA_PATH, file_type, j, obj_name
614 | )
615 | )
616 | filename = "{0}/{1}/{2}/RAVEN_{3}_{2}.npz".format(
617 | DATA_PATH, save_name, obj_name, count
618 | )
619 | np.savez(
620 | filename,
621 | target=file["target"],
622 | predict=file["predict"],
623 | image=file["image"],
624 | meta_matrix=file["meta_matrix"],
625 | meta_structure=file["meta_structure"],
626 | meta_target=file["meta_target"],
627 | structure=file["structure"],
628 | extracted_meta=ext_sample,
629 | rules=rules,
630 | )
631 | count = count + 1
632 | print("finished with: ", save_name)
633 |
634 |
635 | if __name__ == "__main__":
636 | main()
637 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # *----------------------------------------------------------------------------*
2 | # * Copyright (C) 2024 IBM Inc. All rights reserved *
3 | # * SPDX-License-Identifier: GPL-3.0-only *
4 | # *----------------------------------------------------------------------------*
5 |
6 | import json
7 | import os
8 | import random
9 | from tqdm import tqdm
10 | import numpy as np
11 | import torch
12 | import torch.optim as optim
13 | from torch.utils.data import DataLoader
14 | from torch.utils.tensorboard import SummaryWriter
15 | from collections import defaultdict as dd
16 |
17 | import arlc.utils.raven.env as reasoning_env
18 | from arlc.utils.averagemeter import AverageMeter
19 | from arlc.utils.checkpath import check_paths, save_checkpoint
20 | from arlc.datasets import GeneralIRAVENDataset
21 | from arlc.execution import RuleLevelReasoner
22 | from arlc.selection import RuleSelector
23 | from arlc.utils.vsa import generate_nvsa_codebooks
24 | import arlc.losses as losses
25 | from arlc.utils.raven.raven_one_hot import create_one_hot
26 | from arlc.utils.parsing import parse_args
27 | from arlc.utils.general import iravenx_rule_map, iravenx_index_map
28 |
29 |
30 | def compute_loss_and_scores(
31 | outputs,
32 | tests,
33 | candidates,
34 | targets,
35 | loss_fn,
36 | distribute,
37 | params=None,
38 | confounders=False,
39 | attr_entropy=None,
40 | use_entropy=False,
41 | ):
42 | loss = 0
43 | scores = 0
44 | att_scores = dd(lambda: 0)
45 | for attr in outputs._fields:
46 | # in constellations without position, do not compute a loss on it
47 | if not distribute and (attr == "position" or attr == "number"):
48 | continue
49 | # if confounders is turned off, disregard confounding attributes
50 | if (attr == "angle" or "confounder" in attr) and not confounders:
51 | pass
52 | # add to the loss the contribution of attr
53 | else:
54 | # attribute entropy regularization
55 | entropy_reg = np.clip(attr_entropy[attr], 0.1, 10) if use_entropy else 1
56 | loss += (
57 | loss_fn(
58 | getattr(outputs, attr),
59 | torch.cat(
60 | (
61 | getattr(tests, attr),
62 | getattr(candidates, attr)[
63 | torch.arange(getattr(candidates, attr).shape[0]),
64 | targets,
65 | ].unsqueeze(1),
66 | ),
67 | dim=1,
68 | ),
69 | ).mean(dim=-1)
70 | / entropy_reg
71 | )
72 | # compute attribute scores
73 | att_scores[attr] = loss_fn.score(
74 | getattr(outputs, attr)[:, -1].unsqueeze(1).repeat(1, 8, 1),
75 | getattr(candidates, attr),
76 | )
77 | scores += att_scores[attr] / entropy_reg
78 |
79 | return loss, scores, [att_scores[att] for att in ["type", "color", "size"]]
80 |
81 |
82 | def train(args, env, device, confounders=False):
83 | """
84 | Training and validation of learnable NVSA backend
85 | """
86 |
87 | def inference_epoch(epoch, loader, train=True):
88 | if train:
89 | model.train()
90 | if args.config == "in_out_four":
91 | model2.train()
92 | rule_selector.train()
93 | else:
94 | model.eval()
95 | if args.config == "in_out_four":
96 | model2.eval()
97 | rule_selector.eval()
98 |
99 | # Define tracking meters
100 | loss_avg = AverageMeter("Loss", ":.3f")
101 | acc_avg = AverageMeter("Accuracy", ":.3f")
102 |
103 | for counter, (extracted, targets, all_action_rule) in enumerate(tqdm(loader)):
104 | extracted, targets, all_action_rule = (
105 | extracted.to(device),
106 | targets.to(device),
107 | all_action_rule.to(device),
108 | )
109 | att_logprob = create_one_hot(extracted, args.config, args.sigma)
110 | model_output = {k: v.to(device) for k, v in att_logprob.items()}
111 | scene_prob, _ = env.prepare(model_output)
112 | if args.config in ["center_single", "distribute_four", "distribute_nine"]:
113 | outputs, candidates, tests = model(
114 | scene_prob, targets, distribute=distribute
115 | )
116 | outputs, attr_entropy = rule_selector(
117 | outputs, tests, candidates, targets
118 | )
119 | loss, scores, _ = compute_loss_and_scores(
120 | outputs,
121 | tests,
122 | candidates,
123 | targets,
124 | loss_fn,
125 | "distribute" in args.config,
126 | [p.data for p in model.parameters() if p.requires_grad],
127 | confounders=confounders,
128 | attr_entropy=attr_entropy,
129 | use_entropy=args.entropy,
130 | )
131 | else:
132 | outputs1, candidates1, tests1 = model(
133 | scene_prob[0], distribute=distribute
134 | )
135 | outputs1, attr_entropy1 = rule_selector(
136 | outputs1, tests1, candidates1, targets
137 | )
138 | if args.config == "in_out_four":
139 | outputs2, candidates2, tests2 = model2(scene_prob[1])
140 | else:
141 | outputs2, candidates2, tests2 = model(scene_prob[1])
142 |
143 | outputs2, attr_entropy2 = rule_selector(
144 | outputs2, tests2, candidates2, targets
145 | )
146 | loss1, scores1, _ = compute_loss_and_scores(
147 | outputs1,
148 | tests1,
149 | candidates1,
150 | targets,
151 | loss_fn,
152 | distribute=args.config == "in_out_four",
153 | confounders=confounders,
154 | attr_entropy=attr_entropy1,
155 | use_entropy=args.entropy,
156 | )
157 | loss2, scores2, _ = compute_loss_and_scores(
158 | outputs2,
159 | tests2,
160 | candidates2,
161 | targets,
162 | loss_fn,
163 | distribute=False,
164 | confounders=confounders,
165 | attr_entropy=attr_entropy2,
166 | use_entropy=args.entropy,
167 | )
168 | loss = 0.8 * loss1 + 0.2 * loss2
169 | scores = 0.8 * scores1 + 0.2 * scores2
170 |
171 | predictions = torch.argmax(scores, dim=-1)
172 | accuracy = ((predictions == targets).sum() / len(targets)) * 100
173 | loss_avg.update(loss.item(), extracted.size(0))
174 | acc_avg.update(accuracy.item(), extracted.size(0))
175 | acc_avg.update(accuracy.item(), extracted.size(0))
176 |
177 | if train:
178 | optimizer.zero_grad()
179 | loss.backward()
180 | if args.clip:
181 | torch.nn.utils.clip_grad_norm_(
182 | parameters=train_param, max_norm=args.clip, norm_type=2.0
183 | )
184 | optimizer.step()
185 |
186 | if train:
187 | print(
188 | "Epoch {}, Total Iter: {}, Train Avg Loss: {:.6f}, Train Avg Accuracy: {:.6f}".format(
189 | epoch, counter, loss_avg.avg, acc_avg.avg
190 | )
191 | )
192 | writer.add_scalar("loss/training", loss_avg.avg, epoch)
193 | writer.add_scalar("accuracy/training", acc_avg.avg, epoch)
194 | else:
195 | print(
196 | "Epoch {}, Valid Avg Loss: {:.6f}, Valid Avg Acc: {:.4f}".format(
197 | epoch, loss_avg.avg, acc_avg.avg
198 | )
199 | )
200 | writer.add_scalar("loss/validation", loss_avg.avg, epoch)
201 | writer.add_scalar("accuracy/validation", acc_avg.avg, epoch)
202 | for r in model.rules_set.rules:
203 | print(str(r.rule))
204 | return acc_avg.avg
205 |
206 | # Set random seed
207 | np.random.seed(args.seed)
208 | torch.manual_seed(args.seed)
209 | if args.cuda:
210 | torch.cuda.manual_seed(args.seed)
211 | torch.backends.cudnn.benchmark = False
212 |
213 | writer = SummaryWriter(args.log_dir)
214 |
215 | # Init model
216 | model = RuleLevelReasoner(
217 | args.device,
218 | args.config,
219 | model=args.model,
220 | hidden_layers=args.hidden_layers,
221 | dictionary=args.backend_cb,
222 | vsa_conversion=args.vsa_conversion,
223 | vsa_selection=args.vsa_selection,
224 | context_superposition=args.context_superposition,
225 | num_rules=args.num_rules,
226 | shared_rules=args.shared_rules,
227 | program=args.program,
228 | num_terms=args.num_terms,
229 | n=args.n,
230 | )
231 | model.to(args.device)
232 | if args.config == "in_out_four":
233 | model2 = RuleLevelReasoner(
234 | args.device,
235 | "center_single",
236 | model=args.model,
237 | hidden_layers=args.hidden_layers,
238 | dictionary=args.backend_cb,
239 | vsa_conversion=args.vsa_conversion,
240 | vsa_selection=args.vsa_selection,
241 | context_superposition=args.context_superposition,
242 | num_rules=args.num_rules,
243 | shared_rules=args.shared_rules,
244 | program=args.program,
245 | num_terms=args.num_terms,
246 | n=args.n,
247 | )
248 | model2.to(args.device)
249 |
250 | distribute = "distribute" in args.config or "in_out_four" == args.config
251 | # Init loss
252 | loss_fn = getattr(losses, args.loss_fn)()
253 |
254 | rule_selector = RuleSelector(
255 | loss_fn, args.rule_selector_temperature, rule_selector=args.rule_selector
256 | )
257 |
258 | # Init optimizers
259 | train_param = list(model.parameters())
260 | if args.config == "in_out_four":
261 | train_param += list(model2.parameters())
262 | optimizer = optim.AdamW(train_param, args.lr, weight_decay=args.weight_decay)
263 |
264 | # Load all checkpoints
265 | rule_path = os.path.join(args.resume, "checkpoint.pth.tar")
266 | if os.path.isfile(rule_path):
267 | checkpoint = torch.load(rule_path)
268 | model.load_state_dict(checkpoint["state_dict_model"])
269 | if args.config == "in_out_four":
270 | model2.load_state_dict(checkpoint["state_dict_model2"])
271 | best_accuracy = checkpoint["best_accuracy"]
272 | start_epoch = checkpoint["epoch"]
273 | optimizer.load_state_dict(checkpoint["optimizer"])
274 | print(
275 | "=> loaded checkpoint '{}' at Epoch {:.3f}".format(
276 | rule_path, checkpoint["epoch"]
277 | )
278 | )
279 | else:
280 | best_accuracy = 0
281 | start_epoch = 0
282 |
283 | # Dataset loader
284 | train_set = GeneralIRAVENDataset(
285 | "train",
286 | args.data_dir,
287 | constellation_filter=args.config,
288 | rule_filter=args.gen_rule,
289 | attribute_filter=args.gen_attribute,
290 | n_train=args.n_train,
291 | maxval=args.dyn_range,
292 | partition=args.partition,
293 | n=args.n,
294 | n_confounders=args.orientation_confounder,
295 | )
296 | train_loader = DataLoader(
297 | train_set,
298 | batch_size=args.batch_size,
299 | shuffle=True,
300 | num_workers=args.num_workers,
301 | )
302 | val_set = GeneralIRAVENDataset(
303 | "val",
304 | args.data_dir,
305 | constellation_filter=args.config,
306 | rule_filter=args.gen_rule,
307 | attribute_filter=args.gen_attribute,
308 | n_train=args.n_train,
309 | maxval=args.dyn_range,
310 | partition=args.partition,
311 | n=args.n,
312 | )
313 | val_loader = DataLoader(
314 | val_set, batch_size=args.batch_size * 15, num_workers=args.num_workers
315 | )
316 |
317 | # training loop starts
318 | for epoch in range(start_epoch, args.epochs):
319 | inference_epoch(epoch, loader=train_loader, train=True)
320 | with torch.no_grad():
321 | accuracy = inference_epoch(epoch, loader=val_loader, train=False)
322 |
323 | # store model(s)
324 | is_best = accuracy > best_accuracy
325 | best_accuracy = max(accuracy, best_accuracy)
326 | if args.config == "in_out_four":
327 | save_checkpoint(
328 | {
329 | "epoch": epoch + 1,
330 | "state_dict_model": model.state_dict(),
331 | "state_dict_model2": model2.state_dict(),
332 | "best_accuracy": accuracy,
333 | "optimizer": optimizer.state_dict(),
334 | },
335 | is_best,
336 | savedir=args.checkpoint_dir,
337 | )
338 | else:
339 | save_checkpoint(
340 | {
341 | "epoch": epoch + 1,
342 | "state_dict_model": model.state_dict(),
343 | "best_accuracy": best_accuracy,
344 | "accuracy": accuracy,
345 | "optimizer": optimizer.state_dict(),
346 | },
347 | is_best,
348 | savedir=args.checkpoint_dir,
349 | )
350 | return writer
351 |
352 |
353 | def test(args, env, device, writer=None, dset="RAVEN", confounders=False):
354 | """
355 | Testing of NVSA backend
356 | """
357 |
358 | def test_epoch():
359 | model.eval()
360 | if args.config == "in_out_four":
361 | model2.eval()
362 | rule_selector.eval()
363 |
364 | loss_avg = AverageMeter("Loss", ":.3f")
365 | acc_avg = AverageMeter("Accuracy", ":.3f")
366 | rule_acc_avg = {
367 | rule: AverageMeter("Accuracy", ":.3f") for rule in iravenx_rule_map.keys()
368 | }
369 |
370 | for extracted, targets, all_action_rule in tqdm(test_loader):
371 | extracted, targets, all_action_rule = (
372 | extracted.to(device),
373 | targets.to(device),
374 | all_action_rule.to(device),
375 | )
376 | att_logprob = create_one_hot(extracted, args.config, args.sigma)
377 | model_output = {k: v.to(device) for k, v in att_logprob.items()}
378 | scene_prob, _ = env.prepare(model_output)
379 | if args.config in ["center_single", "distribute_four", "distribute_nine"]:
380 | outputs, candidates, tests = model(scene_prob, distribute=distribute)
381 | outputs, attr_entropy = rule_selector(outputs, tests)
382 | loss, scores, attscores = compute_loss_and_scores(
383 | outputs,
384 | tests,
385 | candidates,
386 | targets,
387 | loss_fn,
388 | "distribute" in args.config,
389 | [p.data for p in model.parameters() if p.requires_grad],
390 | confounders=confounders,
391 | attr_entropy=attr_entropy,
392 | use_entropy=args.entropy,
393 | )
394 |
395 | else:
396 | outputs1, candidates1, tests1 = model(
397 | scene_prob[0], distribute=distribute
398 | )
399 | outputs1, attr_entropy1 = rule_selector(outputs1, tests1)
400 | if args.config == "in_out_four":
401 | outputs2, candidates2, tests2 = model2(
402 | scene_prob[1], distribute=False
403 | )
404 | else:
405 | outputs2, candidates2, tests2 = model(
406 | scene_prob[1], distribute=False
407 | )
408 | outputs2, attr_entropy2 = rule_selector(outputs2, tests2)
409 | loss1, scores1, _ = compute_loss_and_scores(
410 | outputs1,
411 | tests1,
412 | candidates1,
413 | targets,
414 | loss_fn,
415 | distribute=args.config == "in_out_four",
416 | confounders=confounders,
417 | attr_entropy=attr_entropy1,
418 | use_entropy=args.entropy,
419 | )
420 | loss2, scores2, _ = compute_loss_and_scores(
421 | outputs2,
422 | tests2,
423 | candidates2,
424 | targets,
425 | loss_fn,
426 | distribute=False,
427 | confounders=confounders,
428 | attr_entropy=attr_entropy2,
429 | use_entropy=args.entropy,
430 | )
431 | loss = (loss1 + loss2) / 2
432 | scores = (scores1 + scores2) / 2
433 |
434 | # accuracy and loss computation
435 | predictions = torch.argmax(scores, dim=-1)
436 | accuracy = ((predictions == targets).sum() / len(targets)) * 100
437 | loss_avg.update(loss.item(), extracted.size(0))
438 | acc_avg.update(accuracy.item(), extracted.size(0))
439 | ##### <- rule accuracy computation
440 | if args.evaluate_rule:
441 | expanded_rules = all_action_rule[:, 1:]
442 | batch_size = all_action_rule.shape[0]
443 | for rule in iravenx_index_map.keys():
444 | total = 0
445 | correct = 0
446 | rule_mask = expanded_rules == rule
447 | if rule_mask.any():
448 | if args.sigma >= 0:
449 | for attribute_idx in range(0, 3):
450 | max_scores = attscores[attribute_idx].max(dim=-1).values
451 | not_max_entropy = ~torch.all(
452 | attscores[attribute_idx]
453 | == attscores[attribute_idx][:, 0:1],
454 | dim=1,
455 | )
456 | correct += (
457 | (
458 | torch.logical_and(
459 | attscores[attribute_idx][
460 | torch.arange(batch_size), targets
461 | ]
462 | == max_scores,
463 | torch.logical_and(
464 | rule_mask[:, attribute_idx],
465 | not_max_entropy,
466 | ),
467 | )
468 | )
469 | .sum()
470 | .item()
471 | )
472 | total += rule_mask[:, attribute_idx].sum().item()
473 | else:
474 | correct += (
475 | (rule_mask.sum(dim=-1) * (predictions == targets))
476 | .sum()
477 | .item()
478 | )
479 | total += rule_mask.sum().item()
480 |
481 | rule_acc = correct / total if total else 0
482 | rule_acc_avg[iravenx_index_map[rule]].update(
483 | rule_acc, extracted.size(0)
484 | )
485 | ##### -> rule accuracy computation
486 |
487 | # Save final result as npz (and potentially in Tensorboard)
488 | if args.resume == "":
489 | if writer is not None:
490 | writer.add_scalar("accuracy/testing-{}".format(dset), acc_avg.avg, 0)
491 | np.savez(
492 | args.save_dir + "result_{:}.npz".format(dset), loss=acc_avg.avg
493 | )
494 | else:
495 | args.save_dir = args.resume.replace("ckpt/", "save/")
496 | np.savez(
497 | args.save_dir + "result_{:}.npz".format(dset), loss=acc_avg.avg
498 | )
499 |
500 | print("Test Avg Accuracy: {:.4f}".format(acc_avg.avg))
501 | if args.evaluate_rule:
502 | for rule in iravenx_rule_map.keys():
503 | print(f"Rule {rule} Avg Accuracy: {rule_acc_avg[rule].avg * 100:.2f}")
504 | for r in model.rules_set.rules:
505 | print(str(r.rule))
506 | return {
507 | **{"acc": acc_avg.avg},
508 | **{rule: rule_acc_avg[rule].avg * 100 for rule in iravenx_rule_map.keys()},
509 | }
510 |
511 | # Load all checkpoint
512 | model_path = os.path.join(args.resume, "model_best.pth.tar")
513 | # model_path = os.path.join(args.resume, "checkpoint.pth.tar")
514 | print(model_path)
515 | if os.path.isfile(model_path):
516 | checkpoint = torch.load(model_path)
517 | print(
518 | "=> loaded checkpoint '{}', epoch {}, with accuracy {:.3f}".format(
519 | model_path, checkpoint["epoch"], checkpoint["best_accuracy"]
520 | )
521 | )
522 | else:
523 | print(
524 | f"Careful! The model is not loaded from checkpoint. Program is: {args.program}"
525 | )
526 | # raise ValueError("No checkpoint found at {:}".format(model_path))
527 | test_acc = dict()
528 | configs = [
529 | "center_single",
530 | # "distribute_four",
531 | # "distribute_nine",
532 | # "left_right",
533 | # "up_down",
534 | # "in_out_single",
535 | # "in_out_four",
536 | ]
537 | for config in configs:
538 | args.config = config
539 | env = reasoning_env.get_env(args.configs_map[args.config], device)
540 | # Init the model
541 | model = RuleLevelReasoner(
542 | args.device,
543 | config,
544 | model=args.model,
545 | hidden_layers=args.hidden_layers,
546 | dictionary=args.backend_cb,
547 | vsa_conversion=args.vsa_conversion,
548 | vsa_selection=args.vsa_selection,
549 | context_superposition=args.context_superposition,
550 | num_rules=args.num_rules,
551 | shared_rules=args.shared_rules,
552 | program=args.program,
553 | num_terms=args.num_terms,
554 | n=args.n,
555 | )
556 | model.to(device)
557 | if not args.program:
558 | model.load_state_dict(checkpoint["state_dict_model"])
559 | if args.data_dir == "/dccstor/saentis/data/I-RAVEN":
560 | model.anneal_softmax()
561 | if config == "in_out_four":
562 | model2 = RuleLevelReasoner(
563 | args.device,
564 | "center_single",
565 | model=args.model,
566 | hidden_layers=args.hidden_layers,
567 | dictionary=args.backend_cb,
568 | vsa_conversion=args.vsa_conversion,
569 | vsa_selection=args.vsa_selection,
570 | context_superposition=args.context_superposition,
571 | num_rules=args.num_rules,
572 | shared_rules=args.shared_rules,
573 | program=args.program,
574 | num_terms=args.num_terms,
575 | n=args.n,
576 | )
577 | model2.to(device)
578 | if not args.program:
579 | model2.load_state_dict(checkpoint["state_dict_model"])
580 | distribute = "distribute" in config or "in_out_four" == config
581 | # Init loss
582 | loss_fn = getattr(losses, args.loss_fn)()
583 |
584 | rule_selector = RuleSelector(
585 | loss_fn, args.rule_selector_temperature, rule_selector=args.rule_selector
586 | )
587 |
588 | # Dataset loader
589 | test_set = GeneralIRAVENDataset(
590 | "test",
591 | args.data_dir,
592 | constellation_filter=config,
593 | rule_filter=args.gen_rule,
594 | attribute_filter=args.gen_attribute,
595 | maxval=args.dyn_range,
596 | partition=args.partition,
597 | n=args.n,
598 | n_confounders=args.orientation_confounder,
599 | )
600 | test_loader = DataLoader(
601 | test_set, batch_size=args.batch_size, num_workers=args.num_workers
602 | )
603 | print("Evaluating on {}".format(config))
604 | with torch.no_grad():
605 | acc = test_epoch()
606 | test_acc[config] = acc
607 |
608 | with open(os.path.join(args.resume, f"eval.json"), "w") as fp:
609 | json.dump(test_acc, fp)
610 | return writer
611 |
612 |
613 | def main():
614 | args = parse_args()
615 |
616 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
617 | args.cuda = torch.cuda.is_available()
618 |
619 | # Use a rng for reproducible results
620 | rng = np.random.default_rng(seed=args.seed)
621 |
622 | torch.manual_seed(args.seed)
623 | np.random.seed(args.seed)
624 | random.seed(args.seed)
625 | os.environ["PYTHONHASHSEED"] = str(args.seed)
626 |
627 | # Load or define new codebooks
628 | backend_cb_cont, backend_cb_discrete = generate_nvsa_codebooks(args, rng)
629 |
630 | args.backend_cb_discrete = backend_cb_discrete
631 | args.backend_cb_cont = backend_cb_cont
632 |
633 | print(f"Sigma: {args.sigma}")
634 |
635 | if args.model == "LearnableFormula":
636 | args.backend_cb = backend_cb_cont
637 | else:
638 | args.backend_cb = backend_cb_discrete
639 |
640 | # backend for training/testing
641 | input_configs = [
642 | "center_single",
643 | "left_right",
644 | "up_down",
645 | "in_out_single",
646 | "distribute_four",
647 | "in_out_four",
648 | "distribute_nine",
649 | ]
650 | output_configs = [
651 | "center_single",
652 | "left_center_single_right_center_single",
653 | "up_center_single_down_center_single",
654 | "in_center_single_out_center_single",
655 | "distribute_four",
656 | "in_distribute_four_out_center_single",
657 | "distribute_nine",
658 | ]
659 | args.configs_map = dict(zip(input_configs, output_configs))
660 |
661 | env = reasoning_env.get_env(args.configs_map[args.config], args.device)
662 |
663 | if args.mode == "train":
664 | args.exp_dir = os.path.join(args.exp_dir, args.run_name, str(args.seed))
665 | args.checkpoint_dir = os.path.join(args.exp_dir, "ckpt")
666 | args.save_dir = os.path.join(args.exp_dir, "save")
667 | args.log_dir = os.path.join(args.exp_dir, "log")
668 | check_paths(args)
669 |
670 | # Run the actual training
671 | writer = train(args, env, args.device, confounders=args.orientation_confounder)
672 |
673 | # Do final testing
674 | args.resume = args.checkpoint_dir
675 | writer = test(
676 | args, env, args.device, writer, confounders=args.orientation_confounder
677 | )
678 |
679 | writer.close()
680 |
681 | elif args.mode == "test":
682 | test(args, env, args.device, confounders=args.orientation_confounder)
683 |
684 |
685 | if __name__ == "__main__":
686 | main()
687 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
--------------------------------------------------------------------------------