├── src
├── pdp
│ ├── __init__.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── pdp_predict.py
│ │ ├── pdp_propagate.py
│ │ ├── pdp_decimate.py
│ │ ├── util.py
│ │ └── solver.py
│ ├── factorgraph
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── base.py
│ ├── trainer.py
│ └── generator.py
├── .DS_Store
├── satyr.py
├── dimacs2json.py
└── satyr-train-test.py
├── .DS_Store
├── config
├── Predict
│ ├── PDP-p-d-p-walksat-pytorch.yaml
│ ├── PDP-p-d-p-sp-pytorch.yaml
│ ├── PDP-p-d-p-reinforce-pytorch.yaml
│ └── PDP-np-nd-np-gcnf-10-100-pytorch.yaml
└── Train
│ ├── p-prodec2-ws-cnf-pytorch.yaml
│ ├── p-prodec2-gcnf-4SAT-pytorch.yaml
│ ├── p-prodec2-sp-cnf-3-10-pytorch.yaml
│ ├── p-prodec2-reinforce-cnf-pytorch.yaml
│ ├── p-prodec2-nsp-cnf-3-10-pytorch.yaml
│ ├── p-prodec2-gcnf-10-100-pytorch.yaml
│ ├── p-prodec2-modular-variable-pytorch-2.yaml
│ ├── p-prodec2-ndec-cnf-3-10-pytorch.yaml
│ ├── p-prodec2-modular-variable-pytorch.yaml
│ └── p-prodec2-modular-4SAT-pytorch.yaml
├── .pylintrc
├── LICENSE
├── setup.py
├── .gitignore
├── SECURITY.md
└── README.md
/src/pdp/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | PDP Solver.
3 | """
4 |
--------------------------------------------------------------------------------
/src/pdp/nn/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | NN layers.
3 | """
4 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PDP-Solver/HEAD/.DS_Store
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PDP-Solver/HEAD/src/.DS_Store
--------------------------------------------------------------------------------
/config/Predict/PDP-p-d-p-walksat-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_type: "walk-sat"
2 | model_name: "p-prodec2-walksat-pytorch"
3 |
--------------------------------------------------------------------------------
/config/Predict/PDP-p-d-p-sp-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_type: "p-d-p"
2 | model_name: "p-prodec2-sp-pytorch"
3 | tolerance: 0.02
4 | t_max: 100
5 |
--------------------------------------------------------------------------------
/config/Predict/PDP-p-d-p-reinforce-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_type: "reinforce"
2 | model_name: "p-prodec2-reinforce-pytorch"
3 | pi: 0.01
4 | decimation_probability: 0.5
5 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MESSAGES CONTROL]
2 | disable=not-callable
3 |
4 | [FORMAT]
5 | max-line-length=100
6 | max-module-lines=1000
7 |
8 | [TYPECHECK]
9 | generated-members=numpy,np,torch
10 |
--------------------------------------------------------------------------------
/src/pdp/factorgraph/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Factor Graph trainer base funstionality.
3 | """
4 |
5 | from pdp.factorgraph.base import FactorGraphTrainerBase
6 |
7 |
8 | __all__ = ["FactorGraphTrainerBase"]
9 |
--------------------------------------------------------------------------------
/config/Predict/PDP-np-nd-np-gcnf-10-100-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_type: "np-nd-np"
2 | has_meta_data: false
3 | model_name: "p-prodec2-gcnf-10-100-pytorch"
4 | model_path: "../../Trained-models/SAT/p-prodec2-gcnf-10-100-pytorch/2.0g/best"
5 | label_dim: 1
6 | edge_feature_dim: 1
7 | meta_feature_dim: 0
8 | prediction_dim: 1
9 | hidden_dim: 150
10 | mem_hidden_dim: 100
11 | agg_hidden_dim: 100
12 | mem_agg_hidden_dim: 50
13 | classifier_dim: 50
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-ws-cnf-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-ws-cnf-pytorch"
2 | model_type: "walk-sat"
3 | version: "0.0"
4 | has_meta_data: true
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/cnf-10.json"]
7 | test_path: ["../../datasets/SAT/4SAT-test"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/4SAT-100"]
8 | model_path: "../../Trained-models/SAT"
9 | repetition_num: 1
10 | train_epoch_size: 200000
11 | epoch_num: 500
12 | label_dim: 1
13 | edge_feature_dim: 1
14 | meta_feature_dim: 1
15 | error_dim: 3
16 | metric_index: 0
17 | prediction_dim: 1
18 | hidden_dim: 157 # 110
19 | mem_hidden_dim: 50
20 | agg_hidden_dim: 50 # 135
21 | mem_agg_hidden_dim: 50 # 50
22 | classifier_dim: 50 # 30 # 100
23 | batch_size: 5000
24 | learning_rate: 0.0001
25 | exploration: 0.1
26 | verbose: true
27 | randomized: true
28 | train_inner_recurrence_num: 1
29 | train_outer_recurrence_num: 20
30 | test_recurrence_num: 1000
31 | max_cache_size: 100000
32 | dropout: 0.2
33 | clip_norm: 0.65
34 | weight_decay: 0.0000000001
35 | loss_sharpness: 5
36 | train_batch_limit: 4000000
37 | test_batch_limit: 40000000
38 | min_n: 10
39 | max_n: 100
40 | min_alpha: 2
41 | max_alpha: 10
42 | min_k: 2
43 | max_k: 10
44 | local_search_iteration: 1000
45 | epsilon: 0.5
46 | lambda: 1
47 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-gcnf-4SAT-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-gcnf-4SAT-pytorch"
2 | model_type: "np-nd-np"
3 | version: "0.0g"
4 | has_meta_data: false # true
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/cnf-10.json"]
7 | test_path: ["../../datasets/SAT/M0-4SAT-100"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/4SAT-100"]
8 | model_path: "../../Trained-models/SAT"
9 | repetition_num: 1
10 | train_epoch_size: 40000
11 | epoch_num: 500
12 | label_dim: 1
13 | edge_feature_dim: 1
14 | meta_feature_dim: 0 # 1
15 | error_dim: 3
16 | metric_index: 0
17 | prediction_dim: 1
18 | hidden_dim: 150 # 110
19 | mem_hidden_dim: 100
20 | agg_hidden_dim: 100 # 135
21 | mem_agg_hidden_dim: 50 # 50
22 | classifier_dim: 50 # 100
23 | batch_size: 5000
24 | learning_rate: 0.0001
25 | exploration: 0.1
26 | verbose: true
27 | randomized: true
28 | train_inner_recurrence_num: 1
29 | train_outer_recurrence_num: 10
30 | test_recurrence_num: 1000
31 | max_cache_size: 100000
32 | dropout: 0.2
33 | clip_norm: 0.65
34 | weight_decay: 0.0000000001
35 | loss_sharpness: 5
36 | train_batch_limit: 4000000
37 | test_batch_limit: 40000000
38 | generator: "uniform"
39 | min_n: 5
40 | max_n: 100
41 | min_alpha: 7
42 | max_alpha: 10
43 | min_k: 4
44 | max_k: 4
45 | local_search_iteration: 100
46 | epsilon: 0.05
47 | lambda: 1
48 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-sp-cnf-3-10-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-sp-cnf-3-10-pytorch"
2 | model_type: "p-d-p"
3 | version: "2.0"
4 | has_meta_data: true
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/cnf-10.json"]
7 | test_path: ["../../datasets/SAT/3SAT-100"] #["../../datasets/SAT/4SAT-test"] # ["../../datasets/SAT/cnf-10.json"] #
8 | model_path: "../../Trained-models/SAT"
9 | repetition_num: 1
10 | train_epoch_size: 200000
11 | epoch_num: 500
12 | label_dim: 1
13 | edge_feature_dim: 1
14 | meta_feature_dim: 1
15 | error_dim: 3
16 | metric_index: 0
17 | prediction_dim: 1
18 | hidden_dim: 157 # 110
19 | mem_hidden_dim: 50
20 | agg_hidden_dim: 50 # 135
21 | mem_agg_hidden_dim: 50 # 50
22 | classifier_dim: 50 # 30 # 100
23 | batch_size: 5000
24 | learning_rate: 0.0001
25 | exploration: 0.1
26 | verbose: true
27 | randomized: true
28 | train_inner_recurrence_num: 1
29 | train_outer_recurrence_num: 20
30 | test_recurrence_num: 1000
31 | max_cache_size: 100000
32 | dropout: 0.2
33 | clip_norm: 0.65
34 | weight_decay: 0.0000000001
35 | loss_sharpness: 5
36 | train_batch_limit: 4000000
37 | test_batch_limit: 40000000
38 | min_n: 10
39 | max_n: 100
40 | min_alpha: 2
41 | max_alpha: 10
42 | min_k: 2
43 | max_k: 10
44 | local_search_iteration: 1000
45 | epsilon: 0.5
46 | tolerance: 0.02
47 | t_max: 100
48 | lambda: 1
49 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-reinforce-cnf-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-reinforce-cnf-pytorch"
2 | model_type: "reinforce"
3 | version: "1.0"
4 | has_meta_data: true
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/cnf-10.json"]
7 | test_path: ["../../datasets/SAT/3SAT-100"] # ["../../datasets/SAT/3SAT-100"] # ["../../datasets/SAT/4SAT-100"] # ["../../datasets/SAT/cnf-10.json"] #
8 | model_path: "../../Trained-models/SAT"
9 | repetition_num: 1
10 | train_epoch_size: 200000
11 | epoch_num: 500
12 | label_dim: 1
13 | edge_feature_dim: 1
14 | meta_feature_dim: 1
15 | error_dim: 3
16 | metric_index: 0
17 | prediction_dim: 1
18 | hidden_dim: 157 # 110
19 | mem_hidden_dim: 50
20 | agg_hidden_dim: 50 # 135
21 | mem_agg_hidden_dim: 50 # 50
22 | classifier_dim: 50 # 30 # 100
23 | batch_size: 5000
24 | learning_rate: 0.0001
25 | exploration: 0.1
26 | verbose: true
27 | randomized: true
28 | train_inner_recurrence_num: 1
29 | train_outer_recurrence_num: 20
30 | test_recurrence_num: 1000
31 | max_cache_size: 100000
32 | dropout: 0.2
33 | clip_norm: 0.65
34 | weight_decay: 0.0000000001
35 | loss_sharpness: 5
36 | train_batch_limit: 4000000
37 | test_batch_limit: 40000000
38 | min_n: 10
39 | max_n: 100
40 | min_alpha: 2
41 | max_alpha: 10
42 | min_k: 2
43 | max_k: 10
44 | local_search_iteration: 1000
45 | epsilon: 0.5
46 | pi: 0.01
47 | decimation_probability: 0.5
48 | lambda: 1
49 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-nsp-cnf-3-10-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-nsp-cnf-3-10-pytorch"
2 | model_type: "p-nd-np"
3 | version: "2.0"
4 | has_meta_data: true
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/cnf-10.json"]
7 | test_path: [["../../datasets/SAT/cnf-gc-6-10-test.json"],
8 | ["../../datasets/SAT/cnf-10.json"],
9 | ["../../datasets/SAT/cnf-20.json"],
10 | ["../../datasets/SAT/cnf-40.json"],
11 | ["../../datasets/SAT/cnf-60.json"],
12 | ["../../datasets/SAT/cnf-80.json"]]
13 | model_path: "../../Trained-models/SAT"
14 | repetition_num: 1
15 | train_epoch_size: 200000
16 | epoch_num: 500
17 | label_dim: 1
18 | edge_feature_dim: 1
19 | meta_feature_dim: 1
20 | error_dim: 3
21 | metric_index: 0
22 | prediction_dim: 1
23 | hidden_dim: 150 # 110
24 | mem_hidden_dim: 50
25 | agg_hidden_dim: 50 # 135
26 | mem_agg_hidden_dim: 50 # 50
27 | classifier_dim: 50 # 30 # 100
28 | batch_size: 5000
29 | learning_rate: 0.0001
30 | exploration: 0.1
31 | verbose: true
32 | randomized: true
33 | train_inner_recurrence_num: 1
34 | train_outer_recurrence_num: 10
35 | test_recurrence_num: 20
36 | max_cache_size: 100000
37 | dropout: 0.2
38 | clip_norm: 0.65
39 | weight_decay: 0.0000000001
40 | loss_sharpness: 5
41 | train_batch_limit: 4000000
42 | test_batch_limit: 40000000
43 | min_n: 10
44 | max_n: 100
45 | min_alpha: 2
46 | max_alpha: 10
47 | min_k: 2
48 | max_k: 10
49 | local_search_iteration: 10
50 | epsilon: 0.05
51 | lambda: 1
52 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-gcnf-10-100-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-gcnf-10-100-pytorch"
2 | model_type: "np-nd-np"
3 | version: "2.0g"
4 | has_meta_data: false # true
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/cnf-10.json"]
7 | test_path: ["../../datasets/SAT/sat-race-2015.json"] #["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/4SAT-subsample"] #["../../datasets/SAT/M0-4SAT-100"] # ["../../datasets/SAT/3SAT-100"] # ["../../datasets/SAT/4SAT-100"] # ["../../datasets/SAT/cnf-10.json"] #
8 | model_path: "../../Trained-models/SAT"
9 | repetition_num: 1
10 | train_epoch_size: 40000
11 | epoch_num: 500
12 | label_dim: 1
13 | edge_feature_dim: 1
14 | meta_feature_dim: 0 # 1
15 | error_dim: 3
16 | metric_index: 0
17 | prediction_dim: 1
18 | hidden_dim: 150 # 110
19 | mem_hidden_dim: 100
20 | agg_hidden_dim: 100 # 135
21 | mem_agg_hidden_dim: 50 # 50
22 | classifier_dim: 50 # 100
23 | batch_size: 5000
24 | learning_rate: 0.0001
25 | exploration: 0.1
26 | verbose: true
27 | randomized: true
28 | train_inner_recurrence_num: 1
29 | train_outer_recurrence_num: 10
30 | test_recurrence_num: 8800
31 | max_cache_size: 100000
32 | dropout: 0.2
33 | clip_norm: 0.65
34 | weight_decay: 0.0000000001
35 | loss_sharpness: 5
36 | train_batch_limit: 4000000
37 | test_batch_limit: 40000000
38 | generator: "uniform"
39 | min_n: 4
40 | max_n: 100
41 | min_alpha: 2
42 | max_alpha: 10
43 | min_k: 2
44 | max_k: 10
45 | local_search_iteration: 1000
46 | epsilon: 0.5
47 | lambda: 1
48 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-modular-variable-pytorch-2.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-modular-variable-pytorch-2"
2 | model_type: "np-nd-np"
3 | version: "2.0g"
4 | has_meta_data: false # true
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"]
7 | test_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/M0-4SAT-100"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/sat-race-test.json"] # ["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/3SAT-100"] #
8 | model_path: "../../Trained-models/SAT"
9 | repetition_num: 1
10 | train_epoch_size: 40000
11 | epoch_num: 500
12 | label_dim: 1
13 | edge_feature_dim: 1
14 | meta_feature_dim: 0 # 1
15 | error_dim: 3
16 | metric_index: 0
17 | prediction_dim: 1
18 | hidden_dim: 200 # 110
19 | mem_hidden_dim: 150
20 | agg_hidden_dim: 150 # 135
21 | mem_agg_hidden_dim: 100 # 50
22 | classifier_dim: 100 # 100
23 | batch_size: 5000
24 | learning_rate: 0.0001
25 | exploration: 0.1
26 | verbose: true
27 | randomized: true
28 | train_inner_recurrence_num: 1
29 | train_outer_recurrence_num: 10
30 | test_recurrence_num: 200
31 | max_cache_size: 100000
32 | dropout: 0.2
33 | clip_norm: 0.65
34 | weight_decay: 0.0000000001
35 | loss_sharpness: 5
36 | train_batch_limit: 4000000
37 | test_batch_limit: 40000000
38 | generator: "v-modular"
39 | min_n: 5
40 | max_n: 100
41 | min_alpha: 5
42 | max_alpha: 10
43 | min_k: 2
44 | max_k: 10
45 | min_q: 0.8
46 | max_q: 0.9
47 | min_c: 10
48 | max_c: 20
49 | local_search_iteration: 0
50 | epsilon: 0.5
51 | lambda: 0.95
52 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-ndec-cnf-3-10-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-ndec-cnf-3-10-pytorch"
2 | model_type: "np-d-np"
3 | version: "0.0"
4 | has_meta_data: false
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/cnf-10.json"]
7 | test_path: [["../../datasets/SAT/cnf-gc-6-10-test.json"],
8 | ["../../datasets/SAT/cnf-10.json"],
9 | ["../../datasets/SAT/cnf-20.json"],
10 | ["../../datasets/SAT/cnf-40.json"],
11 | ["../../datasets/SAT/cnf-60.json"],
12 | ["../../datasets/SAT/cnf-80.json"]]
13 | model_path: "../../Trained-models/SAT"
14 | repetition_num: 1
15 | train_epoch_size: 200000
16 | epoch_num: 500
17 | label_dim: 1
18 | edge_feature_dim: 1
19 | meta_feature_dim: 0
20 | error_dim: 3
21 | metric_index: 0
22 | prediction_dim: 1
23 | hidden_dim: 150 # 110
24 | mem_hidden_dim: 100
25 | agg_hidden_dim: 100 # 135
26 | mem_agg_hidden_dim: 50 # 50
27 | classifier_dim: 50 # 100
28 | batch_size: 5000
29 | learning_rate: 0.0001
30 | exploration: 0.1
31 | verbose: true
32 | randomized: true
33 | train_inner_recurrence_num: 1
34 | train_outer_recurrence_num: 10
35 | test_recurrence_num: 20
36 | max_cache_size: 100000
37 | dropout: 0.2
38 | clip_norm: 0.65
39 | weight_decay: 0.0000000001
40 | loss_sharpness: 5
41 | train_batch_limit: 4000000
42 | test_batch_limit: 40000000
43 | generator: "uniform"
44 | min_n: 10
45 | max_n: 100
46 | min_alpha: 2
47 | max_alpha: 10
48 | min_k: 2
49 | max_k: 10
50 | local_search_iteration: 0
51 | epsilon: 0.05
52 | tolerance: 0.02
53 | t_max: 10
54 | lambda: 1
55 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-modular-variable-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-modular-variable-pytorch"
2 | model_type: "np-nd-np"
3 | version: "1.0g"
4 | has_meta_data: false # true
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"]
7 | test_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/M0-4SAT-100"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/sat-race-test.json"] # ["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/3SAT-100"] #
8 | model_path: "../../Trained-models/SAT"
9 | repetition_num: 1
10 | train_epoch_size: 40000
11 | epoch_num: 500
12 | label_dim: 1
13 | edge_feature_dim: 1
14 | meta_feature_dim: 0 # 1
15 | error_dim: 3
16 | metric_index: 0
17 | prediction_dim: 1
18 | hidden_dim: 150 # 110
19 | mem_hidden_dim: 100
20 | agg_hidden_dim: 100 # 135
21 | mem_agg_hidden_dim: 50 # 50
22 | classifier_dim: 50 # 100
23 | batch_size: 5000
24 | learning_rate: 0.0001
25 | exploration: 0.1
26 | verbose: true
27 | randomized: true
28 | train_inner_recurrence_num: 1
29 | train_outer_recurrence_num: 10
30 | test_recurrence_num: 200
31 | max_cache_size: 100000
32 | dropout: 0.2
33 | clip_norm: 0.65
34 | weight_decay: 0.0000000001
35 | loss_sharpness: 5
36 | train_batch_limit: 3500000
37 | test_batch_limit: 40000000
38 | generator: "v-modular"
39 | min_n: 5
40 | max_n: 100
41 | min_alpha: 5
42 | max_alpha: 10
43 | min_k: 2
44 | max_k: 10
45 | min_q: 0.8
46 | max_q: 0.9
47 | min_c: 10
48 | max_c: 20
49 | local_search_iteration: 0
50 | epsilon: 0.5
51 | lambda: 1
52 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # To make sure pylint is happy with the tests, run:
4 | # python3 setup.py develop
5 | # every time we add or remove new top-level packages to the project.
6 |
7 | import os
8 | from setuptools import setup
9 |
10 |
11 | def _read(fname):
12 | with open(os.path.join(os.path.dirname(__file__), fname)) as stream:
13 | return stream.read()
14 |
15 |
16 | setup(
17 | name='PDP_Solver',
18 | author="Saeed Amizadeh",
19 | author_email="saamizad@microsoft.com",
20 | description="PDP Framework for Neural Constraint Satisfaction Solving",
21 | long_description=_read('./README.md'),
22 | long_description_content_type='text/markdown',
23 | keywords="pdp sat solver pytorch neurosymbolic",
24 | license="MIT",
25 | classifiers=[
26 | # Trove classifiers
27 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
28 | 'License :: OSI Approved :: MIT License',
29 | 'Programming Language :: Python',
30 | 'Programming Language :: Python :: 3',
31 | 'Programming Language :: Python :: 3.5',
32 | 'Programming Language :: Python :: Implementation :: CPython',
33 | ],
34 | url="https://github/Microsoft/PDP-Solver",
35 | version='0.1',
36 | python_requires=">=3.5",
37 | install_requires=[
38 | "numpy >= 1.10",
39 | "torch >= 0.4",
40 | ],
41 | package_dir={"": "src"},
42 | packages=[
43 | "pdp",
44 | "pdp.factorgraph",
45 | "pdp.nn"
46 | ],
47 | scripts=[
48 | "src/satyr.py",
49 | "src/satyr-train-test.py",
50 | "src/dimacs2json.py",
51 | "src/pdp/generator.py"
52 | ]
53 | )
54 |
--------------------------------------------------------------------------------
/config/Train/p-prodec2-modular-4SAT-pytorch.yaml:
--------------------------------------------------------------------------------
1 | model_name: "p-prodec2-modular-4SAT-pytorch"
2 | model_type: "np-nd-np"
3 | version: "5.0g"
4 | has_meta_data: false # true
5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"]
6 | validation_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"]
7 | test_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/M0-4SAT-100"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/sat-race-test.json"] # ["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/3SAT-100"] #
8 | model_path: "../../Trained-models/SAT"
9 | repetition_num: 1
10 | train_epoch_size: 40000
11 | epoch_num: 500
12 | label_dim: 1
13 | edge_feature_dim: 1
14 | meta_feature_dim: 0 # 1
15 | error_dim: 3
16 | metric_index: 0
17 | prediction_dim: 1
18 | hidden_dim: 150 # 110
19 | mem_hidden_dim: 100
20 | agg_hidden_dim: 100 # 135
21 | mem_agg_hidden_dim: 50 # 50
22 | classifier_dim: 50 # 100
23 | batch_size: 5000
24 | learning_rate: 0.00001 # 0.0001
25 | exploration: 0.05 # 0.1
26 | verbose: true
27 | randomized: true
28 | train_inner_recurrence_num: 1
29 | train_outer_recurrence_num: 10
30 | test_recurrence_num: 20
31 | max_cache_size: 100000
32 | dropout: 0.2
33 | clip_norm: 0.65
34 | weight_decay: 0.0000000001
35 | loss_sharpness: 5
36 | train_batch_limit: 4000000
37 | test_batch_limit: 40000000
38 | generator: "modular"
39 | min_n: 5
40 | max_n: 100
41 | min_alpha: 5
42 | max_alpha: 10
43 | min_k: 4
44 | min_q: 0.8
45 | max_q: 0.9
46 | min_c: 10
47 | max_c: 20
48 | local_search_iteration: 0
49 | epsilon: 0.5
50 | lambda: 1
51 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | /.vscode/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # celery beat schedule file
82 | celerybeat-schedule
83 |
84 | # SageMath parsed files
85 | *.sage.py
86 |
87 | # Environments
88 | .env
89 | .venv
90 | env/
91 | venv/
92 | ENV/
93 | env.bak/
94 | venv.bak/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/src/satyr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Main script to run a trained PDP solver against a test dataset.
4 | """
5 |
6 | # Copyright (c) Microsoft. All rights reserved.
7 | # Licensed under the MIT license. See LICENSE.md file
8 | # in the project root for full license information.
9 |
10 | import argparse
11 | import yaml, os, logging, sys
12 | import numpy as np
13 | import torch
14 | from datetime import datetime
15 |
16 | import dimacs2json
17 |
18 | from pdp.trainer import SatFactorGraphTrainer
19 |
20 |
21 | def run(config, logger, output):
22 | "Runs the prediction engine."
23 |
24 | np.random.seed(config['random_seed'])
25 | torch.manual_seed(config['random_seed'])
26 |
27 | if config['verbose']:
28 | logger.info("Building the computational graph...")
29 |
30 | predicter = SatFactorGraphTrainer(config=config, use_cuda=not config['cpu_mode'], logger=logger)
31 |
32 | if config['verbose']:
33 | logger.info("Starting the prediction phase...")
34 |
35 | predicter._counter = 0
36 | if output == '':
37 | predicter.predict(test_list=config['test_path'], out_file=sys.stdout, import_path_base=config['model_path'],
38 | post_processor=predicter._post_process_predictions, batch_replication=config['batch_replication'])
39 | else:
40 | with open(output, 'w') as file:
41 | predicter.predict(test_list=config['test_path'], out_file=file, import_path_base=config['model_path'],
42 | post_processor=predicter._post_process_predictions, batch_replication=config['batch_replication'])
43 |
44 |
45 | if __name__ == '__main__':
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('model_config', help='The model configuration yaml file')
48 | parser.add_argument('test_path', help='The input test path')
49 | parser.add_argument('test_recurrence_num', help='The number of iterations for the PDP', type=int)
50 | parser.add_argument('-b', '--batch_replication', help='Batch replication factor', type=int, default=1)
51 | parser.add_argument('-z', '--batch_size', help='Batch size', type=int, default=5000)
52 | parser.add_argument('-m', '--max_cache_size', help='Maximum cache size', type=int, default=100000)
53 | parser.add_argument('-l', '--test_batch_limit', help='Memory limit for mini-batches', type=int, default=40000000)
54 | parser.add_argument('-w', '--local_search_iteration', help='Number of iterations for post-processing local search', type=int, default=100)
55 | parser.add_argument('-e', '--epsilon', help='Epsilon probablity for post-processing local search', type=float, default=0.5)
56 | parser.add_argument('-v', '--verbose', help='Verbose', action='store_true')
57 | parser.add_argument('-c', '--cpu_mode', help='Run on CPU', action='store_true')
58 | parser.add_argument('-d', '--dimacs', help='The input folder contains DIMACS files', action='store_true')
59 | parser.add_argument('-s', '--random_seed', help='Random seed', type=int, default=int(datetime.now().microsecond))
60 | parser.add_argument('-o', '--output', help='The JSON output file', default='')
61 |
62 | args = vars(parser.parse_args())
63 |
64 | # Load the model config
65 | with open(args['model_config'], 'r') as f:
66 | model_config = yaml.load(f)
67 |
68 | # Set the logger
69 | format = '[%(levelname)s] %(asctime)s - %(name)s: %(message)s'
70 | logging.basicConfig(level=logging.DEBUG, format=format)
71 | logger = logging.getLogger(model_config['model_name'])
72 |
73 | # Convert DIMACS input files into JSON
74 | if args['dimacs']:
75 | if args['verbose']:
76 | logger.info("Converting DIMACS files into JSON...")
77 | temp_file_name = 'temp_problem_file.json'
78 |
79 | if os.path.isfile(args['test_path']):
80 | head, _ = os.path.split(args['test_path'])
81 | temp_file_name = os.path.join(head, temp_file_name)
82 | dimacs2json.convert_file(args['test_path'], temp_file_name, False)
83 | else:
84 | temp_file_name = os.path.join(args['test_path'], temp_file_name)
85 | dimacs2json.convert_directory(args['test_path'], temp_file_name, False)
86 |
87 | args['test_path'] = temp_file_name
88 |
89 | # Merge model config and other arguments into one config dict
90 | config = {**model_config, **args}
91 |
92 | if config['model_type'] == 'p-d-p' or config['model_type'] == 'walk-sat' or config['model_type'] == 'reinforce':
93 | config['model_path'] = None
94 | config['hidden_dim'] = 3
95 |
96 | if config['model_type'] == 'walk-sat':
97 | config['local_search_iteration'] = config['test_recurrence_num']
98 |
99 | config['dropout'] = 0
100 | config['error_dim'] = 1
101 | config['exploration'] = 0
102 |
103 | # Run the prediction engine
104 | run(config, logger, config['output'])
105 |
106 | if args['dimacs']:
107 | os.remove(temp_file_name)
108 |
109 | print('')
110 |
--------------------------------------------------------------------------------
/src/dimacs2json.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Auxiliary script for converting sets of DIMACS files into PDP's compact JSON format.
4 | """
5 |
6 | # Copyright (c) Microsoft. All rights reserved.
7 | #
8 | # Licensed under the MIT license. See LICENSE.md file
9 | # in the project root for full license information.
10 |
11 | import sys
12 | import argparse
13 |
14 | from os import listdir
15 | from os.path import isfile, join, split, splitext
16 |
17 | import numpy as np
18 |
19 |
20 | class CompactDimacs:
21 | "Encapsulates a CNF file given in the DIMACS format."
22 |
23 | def __init__(self, dimacs_file, output, propagate):
24 |
25 | self.propagate = propagate
26 | self.file_name = split(dimacs_file)[1]
27 |
28 | with open(dimacs_file, 'r') as f:
29 | j = 0
30 | for line in f:
31 | seg = line.split(" ")
32 | if seg[0] == 'c':
33 | continue
34 |
35 | if seg[0] == 'p':
36 | var_num = int(seg[2])
37 | clause_num = int(seg[3])
38 | self._clause_mat = np.zeros((clause_num, var_num), dtype=np.int32)
39 |
40 | elif len(seg) <= 1:
41 | continue
42 | else:
43 | temp = np.array(seg[:-1], dtype=np.int32)
44 | self._clause_mat[j, np.abs(temp) - 1] = np.sign(temp)
45 | j += 1
46 |
47 | ind = np.where(np.sum(np.abs(self._clause_mat), 1) > 0)[0]
48 | self._clause_mat = self._clause_mat[ind, :]
49 |
50 | ind = np.where(np.sum(np.abs(self._clause_mat), 0) > 0)[0]
51 | self._clause_mat = self._clause_mat[:, ind]
52 |
53 | if propagate:
54 | self._clause_mat = self._propagate_constraints(self._clause_mat)
55 |
56 | self._output = output
57 |
58 | def _propagate_constraints(self, clause_mat):
59 | n = clause_mat.shape[0]
60 | if n < 2:
61 | return clause_mat
62 |
63 | length = np.tile(np.sum(np.abs(clause_mat), 1), (n, 1))
64 | intersection_len = np.matmul(clause_mat, np.transpose(clause_mat))
65 |
66 | temp = intersection_len == np.transpose(length)
67 | temp *= np.tri(*temp.shape, k=-1, dtype=bool)
68 | flags = np.logical_not(np.any(temp, 0))
69 |
70 | clause_mat = clause_mat[flags, :]
71 |
72 | n = clause_mat.shape[0]
73 | if n < 2:
74 | return clause_mat
75 |
76 | length = np.tile(np.sum(np.abs(clause_mat), 1), (n, 1))
77 | intersection_len = np.matmul(clause_mat, np.transpose(clause_mat))
78 |
79 | temp = intersection_len == length
80 | temp *= np.tri(*temp.shape, k=-1, dtype=bool)
81 | flags = np.logical_not(np.any(temp, 1))
82 |
83 | return clause_mat[flags, :]
84 |
85 | def to_json(self):
86 | clause_list = []
87 | clause_num, var_num = self._clause_mat.shape
88 |
89 | ind = np.nonzero(self._clause_mat)
90 | return [[var_num, clause_num], list((ind[1] + 1) * self._clause_mat[ind]),
91 | list(ind[0] + 1), self._output, [self.file_name]]
92 |
93 |
94 | def convert_directory(dimacs_dir, output_file, propagate, only_positive=False):
95 | file_list = [join(dimacs_dir, f) for f in listdir(dimacs_dir) if isfile(join(dimacs_dir, f))]
96 |
97 | with open(output_file, 'w') as f:
98 | for i in range(len(file_list)):
99 | name, ext = splitext(file_list[i])
100 | ext = ext.lower()
101 |
102 | if ext != '.dimacs' and ext != '.cnf':
103 | continue
104 |
105 | label = float(name[-1]) if name[-1].isdigit() else -1
106 |
107 | if only_positive and label == 0:
108 | continue
109 |
110 | bc = CompactDimacs(file_list[i], label, propagate)
111 | f.write(str(bc.to_json()).replace("'", '"') + '\n')
112 | print("Generating JSON input file: %6.2f%% complete..." % (
113 | (i + 1) * 100.0 / len(file_list)), end='\r', file=sys.stderr)
114 |
115 |
116 | def convert_file(file_name, output_file, propagate):
117 | with open(output_file, 'w') as f:
118 | if len(file_name) < 8:
119 | label = -1
120 | else:
121 | temp = file_name[-8]
122 | label = float(temp) if temp.isdigit() else -1
123 |
124 | bc = CompactDimacs(file_name, label, propagate)
125 | f.write(str(bc.to_json()).replace("'", '"') + '\n')
126 |
127 |
128 | if __name__ == '__main__':
129 | parser = argparse.ArgumentParser()
130 | parser.add_argument('in_dir', action='store', type=str)
131 | parser.add_argument('out_file', action='store', type=str)
132 | parser.add_argument('-s', '--simplify', help='Propagate binary constraints', required=False, action='store_true', default=False)
133 | parser.add_argument('-p', '--positive', help='Output only positive examples', required=False, action='store_true', default=False)
134 | args = vars(parser.parse_args())
135 |
136 | convert_directory(args['in_dir'], args['out_file'], args['simplify'], args['positive'])
137 |
--------------------------------------------------------------------------------
/src/satyr-train-test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """The main entry point to the PDP trainer/tester/predictor."""
3 |
4 | # Copyright (c) Microsoft. All rights reserved.
5 | # Licensed under the MIT license. See LICENSE.md file
6 | # in the project root for full license information.
7 |
8 | import numpy as np
9 | import torch
10 | import torch.optim as optim
11 | import logging
12 | import argparse, os, yaml, csv
13 |
14 | from pdp.generator import *
15 | from pdp.trainer import SatFactorGraphTrainer
16 |
17 |
18 | ##########################################################################################################################
19 |
20 | def write_to_csv(result_list, file_path):
21 | with open(file_path, mode='w', newline='') as f:
22 | writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
23 |
24 | for row in result_list:
25 | writer.writerow([row[0], row[1][1, 0]])
26 |
27 | def write_to_csv_time(result_list, file_path):
28 | with open(file_path, mode='w', newline='') as f:
29 | writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
30 |
31 | for row in result_list:
32 | writer.writerow([row[0], row[2]])
33 |
34 | def run(random_seed, config_file, is_training, load_model, cpu, reset_step, use_generator, batch_replication):
35 | "Runs the train/test/predict procedures."
36 |
37 | if not use_generator:
38 | np.random.seed(random_seed)
39 | torch.manual_seed(random_seed)
40 |
41 | # Set the configurations (from either JSON or YAML file)
42 | with open(config_file, 'r') as f:
43 | config = yaml.load(f)
44 |
45 | # Set the logger
46 | format = '[%(levelname)s] %(asctime)s - %(name)s: %(message)s'
47 | logging.basicConfig(level=logging.DEBUG, format=format)
48 | logger = logging.getLogger(config['model_name'] + ' (' + config['version'] + ')')
49 |
50 | # Check if the input path is a list or on
51 | if not isinstance(config['train_path'], list):
52 | config['train_path'] = [os.path.join(config['train_path'], f) \
53 | for f in os.listdir(config['train_path']) if os.path.isfile(os.path.join(config['train_path'], f)) and f.endswith('.json')]
54 |
55 | if not isinstance(config['validation_path'], list):
56 | config['validation_path'] = [os.path.join(config['validation_path'], f) \
57 | for f in os.listdir(config['validation_path']) if os.path.isfile(os.path.join(config['validation_path'], f)) and f.endswith('.json')]
58 |
59 | if config['verbose']:
60 | if use_generator:
61 | logger.info("Generating training examples via %s generator." % config['generator'])
62 | else:
63 | logger.info("Training file(s): %s" % config['train_path'])
64 | logger.info("Validation file(s): %s" % config['validation_path'])
65 |
66 | best_model_path_base = os.path.join(os.path.relpath(config['model_path']),
67 | config['model_name'], config['version'], "best")
68 |
69 | last_model_path_base = os.path.join(os.path.relpath(config['model_path']),
70 | config['model_name'], config['version'], "last")
71 |
72 | if not os.path.exists(best_model_path_base):
73 | os.makedirs(best_model_path_base)
74 |
75 | if not os.path.exists(last_model_path_base):
76 | os.makedirs(last_model_path_base)
77 |
78 | trainer = SatFactorGraphTrainer(config=config, use_cuda=not cpu, logger=logger)
79 |
80 | # Training
81 | if is_training:
82 | if config['verbose']:
83 | logger.info("Starting the training phase...")
84 |
85 | generator = None
86 |
87 | if use_generator:
88 | if config['generator'] == 'modular':
89 | generator = ModularCNFGenerator(config['min_k'], config['min_n'], config['max_n'], config['min_q'],
90 | config['max_q'], config['min_c'], config['max_c'], config['min_alpha'], config['max_alpha'])
91 | elif config['generator'] == 'v-modular':
92 | generator = VariableModularCNFGenerator(config['min_k'], config['max_k'], config['min_n'], config['max_n'], config['min_q'],
93 | config['max_q'], config['min_c'], config['max_c'], config['min_alpha'], config['max_alpha'])
94 | else:
95 | generator = UniformCNFGenerator(config['min_n'], config['max_n'], config['min_k'], config['max_k'], config['min_alpha'], config['max_alpha'])
96 |
97 | model_list, errors, losses = trainer.train(
98 | train_list=config['train_path'], validation_list=config['validation_path'],
99 | optimizer=optim.Adam(trainer.get_parameter_list(), lr=config['learning_rate'],
100 | weight_decay=config['weight_decay']), last_export_path_base=last_model_path_base,
101 | best_export_path_base=best_model_path_base, metric_index=config['metric_index'],
102 | load_model=load_model, reset_step=reset_step, generator=generator,
103 | train_epoch_size=config['train_epoch_size'])
104 |
105 | if config['verbose']:
106 | logger.info("Starting the test phase...")
107 |
108 | for test_files in config['test_path']:
109 | if config['verbose']:
110 | logger.info("Testing " + test_files)
111 |
112 | if load_model == "last":
113 | import_path_base = last_model_path_base
114 | elif load_model == "best":
115 | import_path_base = best_model_path_base
116 | else:
117 | import_path_base = None
118 |
119 | result = trainer.test(test_list=test_files, import_path_base=import_path_base,
120 | batch_replication=batch_replication)
121 |
122 | if config['verbose']:
123 | for row in result:
124 | filename, errors, _ = row
125 | print('Dataset: ' + filename)
126 | print("Accuracy: \t%s" % (1 - errors[0]))
127 | print("Recall: \t%s" % (1 - errors[1]))
128 |
129 | if os.path.isdir(test_files):
130 | write_to_csv(result, os.path.join(test_files, config['model_type'] + '_' + config['model_name'] + '_' + config['version'] + '-results.csv'))
131 | write_to_csv_time(result, os.path.join(test_files, config['model_type'] + '_' + config['model_name'] + '_' + config['version'] + '-results-time.csv'))
132 |
133 |
134 | if __name__ == '__main__':
135 | parser = argparse.ArgumentParser()
136 | parser.add_argument('config', help='The configuration JSON file')
137 | parser.add_argument('-t', '--test', help='The test mode', action='store_true')
138 | parser.add_argument('-l', '--load_model', help='Load the previous model')
139 | parser.add_argument('-c', '--cpu_mode', help='Run on CPU', action='store_true')
140 | parser.add_argument('-r', '--reset', help='Reset the global step', action='store_true')
141 | parser.add_argument('-g', '--use_generator', help='Reset the global step', action='store_true')
142 | parser.add_argument('-b', '--batch_replication', help='Batch replication factor', type=int, default=1)
143 |
144 | args = parser.parse_args()
145 | run(0, args.config, not args.test, args.load_model,
146 | args.cpu_mode, args.reset, args.use_generator, args.batch_replication)
147 |
--------------------------------------------------------------------------------
/src/pdp/factorgraph/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
3 |
4 | # factor_graph_input_pipeline.py : Defines the input pipeline for the PDP framework.
5 |
6 | import linecache, json
7 | import collections
8 | import numpy as np
9 |
10 | import torch
11 | import torch.utils.data as data
12 |
13 | from os import listdir
14 | from os.path import isfile, join
15 |
16 |
17 | class DynamicBatchDivider(object):
18 | "Implements the dynamic batching process."
19 |
20 | def __init__(self, limit, hidden_dim):
21 | self.limit = limit
22 | self.hidden_dim = hidden_dim
23 |
24 | def divide(self, variable_num, function_num, graph_map, edge_feature, graph_feature, label, misc_data):
25 | batch_size = len(variable_num)
26 | edge_num = [len(n) for n in edge_feature]
27 |
28 | graph_map_list = []
29 | edge_feature_list = []
30 | graph_feature_list = []
31 | variable_num_list = []
32 | function_num_list = []
33 | label_list = []
34 | misc_data_list = []
35 |
36 | if (self.limit // (max(edge_num) * self.hidden_dim)) >= batch_size:
37 | if graph_feature[0] is None:
38 | graph_feature_list = [[None]]
39 | else:
40 | graph_feature_list = [graph_feature]
41 |
42 | graph_map_list = [graph_map]
43 | edge_feature_list = [edge_feature]
44 | variable_num_list = [variable_num]
45 | function_num_list = [function_num]
46 | label_list = [label]
47 | misc_data_list = [misc_data]
48 |
49 | else:
50 |
51 | indices = sorted(range(len(edge_num)), reverse=True, key=lambda k: edge_num[k])
52 | sorted_edge_num = sorted(edge_num, reverse=True)
53 |
54 | i = 0
55 |
56 | while i < batch_size:
57 | allowed_batch_size = self.limit // (sorted_edge_num[i] * self.hidden_dim)
58 | ind = indices[i:min(i + allowed_batch_size, batch_size)]
59 |
60 | if graph_feature[0] is None:
61 | graph_feature_list += [[None]]
62 | else:
63 | graph_feature_list += [[graph_feature[j] for j in ind]]
64 |
65 | edge_feature_list += [[edge_feature[j] for j in ind]]
66 | variable_num_list += [[variable_num[j] for j in ind]]
67 | function_num_list += [[function_num[j] for j in ind]]
68 | graph_map_list += [[graph_map[j] for j in ind]]
69 | label_list += [[label[j] for j in ind]]
70 | misc_data_list += [[misc_data[j] for j in ind]]
71 |
72 | i += allowed_batch_size
73 |
74 | return variable_num_list, function_num_list, graph_map_list, edge_feature_list, graph_feature_list, label_list, misc_data_list
75 |
76 |
77 | ###############################################################
78 |
79 |
80 | class FactorGraphDataset(data.Dataset):
81 | "Implements a PyTorch Dataset class for reading and parsing CNFs in the JSON format from disk."
82 |
83 | def __init__(self, input_file, limit, hidden_dim, max_cache_size=100000, generator=None, epoch_size=0, batch_replication=1):
84 |
85 | self._cache = collections.OrderedDict()
86 | self._generator = generator
87 | self._epoch_size = epoch_size
88 | self._input_file = input_file
89 | self._max_cache_size = max_cache_size
90 |
91 | if self._generator is None:
92 | with open(self._input_file, 'r') as fh_input:
93 | self._row_num = len(fh_input.readlines())
94 |
95 | self.batch_divider = DynamicBatchDivider(limit // batch_replication, hidden_dim)
96 |
97 | def __len__(self):
98 | if self._generator is not None:
99 | return self._epoch_size
100 | else:
101 | return self._row_num
102 |
103 | def __getitem__(self, idx):
104 | if self._generator is not None:
105 | return self._generator.generate()
106 |
107 | else:
108 | if idx in self._cache:
109 | return self._cache[idx]
110 |
111 | line = linecache.getline(self._input_file, idx + 1)
112 | result = self._convert_line(line)
113 |
114 | if len(self._cache) >= self._max_cache_size:
115 | self._cache.popitem(last=False)
116 |
117 | self._cache[idx] = result
118 | return result
119 |
120 | def _convert_line(self, json_str):
121 |
122 | input_data = json.loads(json_str)
123 | variable_num, function_num = input_data[0]
124 |
125 | variable_ind = np.abs(np.array(input_data[1], dtype=np.int32)) - 1
126 | function_ind = np.abs(np.array(input_data[2], dtype=np.int32)) - 1
127 | edge_feature = np.sign(np.array(input_data[1], dtype=np.float32))
128 |
129 | graph_map = np.stack((variable_ind, function_ind))
130 | alpha = float(function_num) / variable_num
131 |
132 | misc_data = []
133 | if len(input_data) > 4:
134 | misc_data = input_data[4]
135 |
136 | return (variable_num, function_num, graph_map, edge_feature, None, float(input_data[3]), misc_data)
137 |
138 | def dag_collate_fn(self, input_data):
139 | "Torch dataset loader collation function for factor graph input."
140 |
141 | vn, fn, gm, ef, gf, l, md = zip(*input_data)
142 |
143 | variable_num, function_num, graph_map, edge_feature, graph_feat, label, misc_data = \
144 | self.batch_divider.divide(vn, fn, gm, ef, gf, l, md)
145 | segment_num = len(variable_num)
146 |
147 | graph_feat_batch = []
148 | graph_map_batch = []
149 | batch_variable_map_batch = []
150 | batch_function_map_batch = []
151 | edge_feature_batch = []
152 | label_batch = []
153 |
154 | for i in range(segment_num):
155 |
156 | # Create the graph features batch
157 | graph_feat_batch += [None if graph_feat[i][0] is None else torch.from_numpy(np.stack(graph_feat[i])).float()]
158 |
159 | # Create the edge feature batch
160 | edge_feature_batch += [torch.from_numpy(np.expand_dims(np.concatenate(edge_feature[i]), 1)).float()]
161 |
162 | # Create the label batch
163 | label_batch += [torch.from_numpy(np.expand_dims(np.array(label[i]), 1)).float()]
164 |
165 | # Create the graph map, variable map and function map batches
166 | g_map_b = np.zeros((2, 0), dtype=np.int32)
167 | v_map_b = np.zeros(0, dtype=np.int32)
168 | f_map_b = np.zeros(0, dtype=np.int32)
169 | variable_ind = 0
170 | function_ind = 0
171 |
172 | for j in range(len(graph_map[i])):
173 | graph_map[i][j][0, :] += variable_ind
174 | graph_map[i][j][1, :] += function_ind
175 | g_map_b = np.concatenate((g_map_b, graph_map[i][j]), axis=1)
176 |
177 | v_map_b = np.concatenate((v_map_b, np.tile(j, variable_num[i][j])))
178 | f_map_b = np.concatenate((f_map_b, np.tile(j, function_num[i][j])))
179 |
180 | variable_ind += variable_num[i][j]
181 | function_ind += function_num[i][j]
182 |
183 | graph_map_batch += [torch.from_numpy(g_map_b).int()]
184 | batch_variable_map_batch += [torch.from_numpy(v_map_b).int()]
185 | batch_function_map_batch += [torch.from_numpy(f_map_b).int()]
186 |
187 | return graph_map_batch, batch_variable_map_batch, batch_function_map_batch, edge_feature_batch, graph_feat_batch, label_batch, misc_data
188 |
189 | @staticmethod
190 | def get_loader(input_file, limit, hidden_dim, batch_size, shuffle, num_workers,
191 | max_cache_size=100000, use_cuda=True, generator=None, epoch_size=0, batch_replication=1):
192 | "Return the torch dataset loader object for the input."
193 |
194 | dataset = FactorGraphDataset(
195 | input_file=input_file,
196 | limit=limit,
197 | hidden_dim=hidden_dim,
198 | max_cache_size=max_cache_size,
199 | generator=generator,
200 | epoch_size=epoch_size,
201 | batch_replication=batch_replication)
202 |
203 | data_loader = torch.utils.data.DataLoader(
204 | dataset=dataset,
205 | batch_size=batch_size,
206 | shuffle=shuffle,
207 | num_workers=num_workers,
208 | collate_fn=dataset.dag_collate_fn,
209 | pin_memory=use_cuda)
210 |
211 | return data_loader
212 |
213 |
214 |
215 |
216 |
217 |
--------------------------------------------------------------------------------
/src/pdp/trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
3 |
4 | # PDP_solver_trainer.py : Implements a factor graph trainer for various types of PDP SAT solvers.
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import sys
12 |
13 | from pdp.factorgraph import FactorGraphTrainerBase
14 | from pdp.nn import solver, util
15 |
16 |
17 | ##########################################################################################################################
18 |
19 |
20 | class Perceptron(nn.Module):
21 | "Implements a 1-layer perceptron."
22 |
23 | def __init__(self, input_dimension, hidden_dimension, output_dimension):
24 | super(Perceptron, self).__init__()
25 | self._layer1 = nn.Linear(input_dimension, hidden_dimension)
26 | self._layer2 = nn.Linear(hidden_dimension, output_dimension, bias=False)
27 |
28 | def forward(self, inp):
29 | return F.sigmoid(self._layer2(F.relu(self._layer1(inp))))
30 |
31 |
32 | ##########################################################################################################################
33 |
34 | class SatFactorGraphTrainer(FactorGraphTrainerBase):
35 | "Implements a factor graph trainer for various types of PDP SAT solvers."
36 |
37 | def __init__(self, config, use_cuda, logger):
38 | super(SatFactorGraphTrainer, self).__init__(config=config,
39 | has_meta_data=False, error_dim=config['error_dim'], loss=None,
40 | evaluator=nn.L1Loss(), use_cuda=use_cuda, logger=logger)
41 |
42 | self._eps = 1e-8 * torch.ones(1, device=self._device)
43 | self._loss_evaluator = util.SatLossEvaluator(alpha = self._config['exploration'], device = self._device)
44 | self._cnf_evaluator = util.SatCNFEvaluator(device = self._device)
45 | self._counter = 0
46 | self._max_coeff = 10.0
47 |
48 | def _build_graph(self, config):
49 | model_list = []
50 |
51 | if config['model_type'] == 'np-nd-np':
52 | model_list += [solver.NeuralPropagatorDecimatorSolver(device=self._device, name=config['model_name'],
53 | edge_dimension=config['edge_feature_dim'], meta_data_dimension=config['meta_feature_dim'],
54 | propagator_dimension=config['hidden_dim'], decimator_dimension=config['hidden_dim'],
55 | mem_hidden_dimension=config['mem_hidden_dim'],
56 | agg_hidden_dimension=config['agg_hidden_dim'], mem_agg_hidden_dimension=config['mem_agg_hidden_dim'],
57 | prediction_dimension=config['prediction_dim'],
58 | variable_classifier=Perceptron(config['hidden_dim'], config['classifier_dim'], config['prediction_dim']),
59 | function_classifier=None, dropout=config['dropout'],
60 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])]
61 |
62 | elif config['model_type'] == 'p-nd-np':
63 | model_list += [solver.NeuralSurveyPropagatorSolver(device=self._device, name=config['model_name'],
64 | edge_dimension=config['edge_feature_dim'], meta_data_dimension=config['meta_feature_dim'],
65 | decimator_dimension=config['hidden_dim'],
66 | mem_hidden_dimension=config['mem_hidden_dim'],
67 | agg_hidden_dimension=config['agg_hidden_dim'], mem_agg_hidden_dimension=config['mem_agg_hidden_dim'],
68 | prediction_dimension=config['prediction_dim'],
69 | variable_classifier=Perceptron(config['hidden_dim'], config['classifier_dim'], config['prediction_dim']),
70 | function_classifier=None, dropout=config['dropout'],
71 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])]
72 |
73 | elif config['model_type'] == 'np-d-np':
74 | model_list += [solver.NeuralSequentialDecimatorSolver(device=self._device, name=config['model_name'],
75 | edge_dimension=config['edge_feature_dim'], meta_data_dimension=config['meta_feature_dim'],
76 | propagator_dimension=config['hidden_dim'], decimator_dimension=config['hidden_dim'],
77 | mem_hidden_dimension=config['mem_hidden_dim'],
78 | agg_hidden_dimension=config['agg_hidden_dim'], mem_agg_hidden_dimension=config['mem_agg_hidden_dim'],
79 | classifier_dimension=config['classifier_dim'],
80 | dropout=config['dropout'], tolerance=config['tolerance'], t_max=config['t_max'],
81 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])]
82 |
83 | elif config['model_type'] == 'p-d-p':
84 | model_list += [solver.SurveyPropagatorSolver(device=self._device, name=config['model_name'],
85 | tolerance=config['tolerance'], t_max=config['t_max'],
86 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])]
87 |
88 | elif config['model_type'] == 'walk-sat':
89 | model_list += [solver.WalkSATSolver(device=self._device, name=config['model_name'],
90 | iteration_num=config['local_search_iteration'], epsilon=config['epsilon'])]
91 |
92 | elif config['model_type'] == 'reinforce':
93 | model_list += [solver.ReinforceSurveyPropagatorSolver(device=self._device, name=config['model_name'],
94 | pi=config['pi'], decimation_probability=config['decimation_probability'],
95 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])]
96 |
97 | if config['verbose']:
98 | self._logger.info("The model parameter count is %d." % model_list[0].parameter_count())
99 | return model_list
100 |
101 | def _compute_loss(self, model, loss, prediction, label, graph_map, batch_variable_map,
102 | batch_function_map, edge_feature, meta_data):
103 |
104 | return self._loss_evaluator(variable_prediction=prediction[0], label=label, graph_map=graph_map,
105 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map,
106 | edge_feature=edge_feature, meta_data=meta_data, global_step=model._global_step,
107 | eps=self._eps, max_coeff=self._max_coeff, loss_sharpness=self._config['loss_sharpness'])
108 |
109 | def _compute_evaluation_metrics(self, model, evaluator, prediction, label, graph_map,
110 | batch_variable_map, batch_function_map, edge_feature, meta_data):
111 |
112 | output, _ = self._cnf_evaluator(variable_prediction=prediction[0], graph_map=graph_map,
113 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map,
114 | edge_feature=edge_feature, meta_data=meta_data)
115 |
116 | recall = torch.sum(label * ((output > 0.5).float() - label).abs()) / torch.max(torch.sum(label), self._eps)
117 | accuracy = evaluator((output > 0.5).float(), label).unsqueeze(0)
118 | loss_value = self._loss_evaluator(variable_prediction=prediction[0], label=label, graph_map=graph_map,
119 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map,
120 | edge_feature=edge_feature, meta_data=meta_data, global_step=model._global_step,
121 | eps=self._eps, max_coeff=self._max_coeff, loss_sharpness=self._config['loss_sharpness']).unsqueeze(0)
122 |
123 | return torch.cat([accuracy, recall, loss_value], 0)
124 |
125 | def _post_process_predictions(self, model, prediction, graph_map,
126 | batch_variable_map, batch_function_map, edge_feature, graph_feat, label, misc_data):
127 | "Formats the prediction and the output solution into JSON format."
128 |
129 | message = ""
130 | labs = label.detach().cpu().numpy()
131 |
132 | res = self._cnf_evaluator(variable_prediction=prediction[0], graph_map=graph_map,
133 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map,
134 | edge_feature=edge_feature, meta_data=graph_feat)
135 | output, unsat_clause_num = [a.detach().cpu().numpy() for a in res]
136 |
137 | for i in range(output.shape[0]):
138 | instance = {
139 | 'ID': misc_data[i][0] if len(misc_data[i]) > 0 else "",
140 | 'label': int(labs[i, 0]),
141 | 'solved': int(output[i].flatten()[0] == 1),
142 | 'unsat_clauses': int(unsat_clause_num[i].flatten()[0]),
143 | 'solution': (prediction[0][batch_variable_map == i, 0].detach().cpu().numpy().flatten() > 0.5).astype(int).tolist()
144 | }
145 | message += (str(instance).replace("'", '"') + "\n")
146 | self._counter += 1
147 |
148 | return message
149 |
150 | def _check_recurrence_termination(self, active, prediction, sat_problem):
151 | "De-actives the CNF examples which the model has already found a SAT solution for."
152 |
153 | output, _ = self._cnf_evaluator(variable_prediction=prediction[0], graph_map=sat_problem._graph_map,
154 | batch_variable_map=sat_problem._batch_variable_map, batch_function_map=sat_problem._batch_function_map,
155 | edge_feature=sat_problem._edge_feature, meta_data=sat_problem._meta_data)#.detach().cpu().numpy()
156 |
157 | if sat_problem._batch_replication > 1:
158 | real_batch = torch.mm(sat_problem._replication_mask_tuple[1], (output > 0.5).float())
159 | dup_batch = torch.mm(sat_problem._replication_mask_tuple[0], (real_batch == 0).float())
160 | active[active[:, 0], 0] = (dup_batch[active[:, 0], 0] > 0)
161 | else:
162 | active[active[:, 0], 0] = (output[active[:, 0], 0] <= 0.5)
163 |
--------------------------------------------------------------------------------
/src/pdp/nn/pdp_predict.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
3 |
4 | # pdp_predict.py : Defines various predictors and scorers for the PDP framework.
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from pdp.nn import util
11 |
12 |
13 | ###############################################################
14 | ### The Predictor Classes
15 | ###############################################################
16 |
17 |
18 | class NeuralPredictor(nn.Module):
19 | "Implements the neural predictor."
20 |
21 | def __init__(self, device, decimator_dimension, prediction_dimension,
22 | edge_dimension, meta_data_dimension, mem_hidden_dimension, agg_hidden_dimension, mem_agg_hidden_dimension,
23 | variable_classifier=None, function_classifier=None):
24 |
25 | super(NeuralPredictor, self).__init__()
26 | self._device = device
27 | self._module_list = nn.ModuleList()
28 |
29 | self._variable_classifier = variable_classifier
30 | self._function_classifier = function_classifier
31 | self._hidden_dimension = decimator_dimension
32 |
33 | if variable_classifier is not None:
34 | self._variable_aggregator = util.MessageAggregator(device, decimator_dimension + edge_dimension + meta_data_dimension,
35 | decimator_dimension, mem_hidden_dimension,
36 | mem_agg_hidden_dimension, agg_hidden_dimension, 0, include_self_message=True)
37 |
38 | self._module_list.append(self._variable_aggregator)
39 | self._module_list.append(self._variable_classifier)
40 |
41 | if function_classifier is not None:
42 | self._function_aggregator = util.MessageAggregator(device, decimator_dimension + edge_dimension + meta_data_dimension,
43 | decimator_dimension, mem_hidden_dimension,
44 | mem_agg_hidden_dimension, agg_hidden_dimension, 0, include_self_message=True)
45 |
46 | self._module_list.append(self._function_aggregator)
47 | self._module_list.append(self._function_classifier)
48 |
49 | def forward(self, decimator_state, sat_problem, last_call=False):
50 |
51 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple
52 | b_variable_mask, b_variable_mask_transpose, b_function_mask, b_function_mask_transpose = sat_problem._batch_mask_tuple
53 |
54 | variable_prediction = None
55 | function_prediction = None
56 |
57 | if sat_problem._meta_data is not None:
58 | graph_feat = torch.mm(b_variable_mask, sat_problem._meta_data)
59 | graph_feat = torch.mm(variable_mask_transpose, graph_feat)
60 |
61 | if len(decimator_state) == 3:
62 | decimator_variable_state, decimator_function_state, edge_mask = decimator_state
63 | else:
64 | decimator_variable_state, decimator_function_state = decimator_state
65 | edge_mask = None
66 |
67 | if self._variable_classifier is not None:
68 |
69 | aggregated_variable_state = torch.cat((decimator_variable_state, sat_problem._edge_feature), 1)
70 |
71 | if sat_problem._meta_data is not None:
72 | aggregated_variable_state = torch.cat((aggregated_variable_state, graph_feat), 1)
73 |
74 | aggregated_variable_state = self._variable_aggregator(
75 | aggregated_variable_state, None, variable_mask, variable_mask_transpose, edge_mask)
76 |
77 | variable_prediction = self._variable_classifier(aggregated_variable_state)
78 |
79 | if self._function_classifier is not None:
80 |
81 | aggregated_function_state = torch.cat((decimator_function_state, sat_problem._edge_feature), 1)
82 |
83 | if sat_problem._meta_data is not None:
84 | aggregated_function_state = torch.cat((aggregated_function_state, graph_feat), 1)
85 |
86 | aggregated_function_state = self._function_aggregator(
87 | aggregated_function_state, None, function_mask, function_mask_transpose, edge_mask)
88 |
89 | function_prediction = self._function_classifier(aggregated_function_state)
90 |
91 | return variable_prediction, function_prediction
92 |
93 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication):
94 |
95 | edge_num = graph_map.size(1) * batch_replication
96 |
97 | if randomized:
98 | variable_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0
99 | function_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0
100 | else:
101 | variable_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device)
102 | function_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device)
103 |
104 | return (variable_state, function_state)
105 |
106 |
107 | ###############################################################
108 |
109 |
110 | class IdentityPredictor(nn.Module):
111 | "Implements the Identity predictor (prediction based on the assignments to the solution property of the SAT problem)."
112 |
113 | def __init__(self, device, random_fill=False):
114 | super(IdentityPredictor, self).__init__()
115 | self._random_fill = random_fill
116 | self._device = device
117 |
118 | def forward(self, decimator_state, sat_problem, last_call=False):
119 | pred = sat_problem._solution.unsqueeze(1)
120 |
121 | if self._random_fill and last_call:
122 | active_var_num = (sat_problem._active_variables[:, 0] > 0).long().sum()
123 |
124 | if active_var_num > 0:
125 | pred[sat_problem._active_variables[:, 0] > 0, 0] = \
126 | torch.rand(active_var_num.item(), device=self._device)
127 |
128 | return pred, None
129 |
130 |
131 | ###############################################################
132 |
133 |
134 | class SurveyScorer(nn.Module):
135 | "Implements the varaible scoring mechanism for SP-guided decimation."
136 |
137 | def __init__(self, device, message_dimension, include_adaptors=False, pi=0.0):
138 | super(SurveyScorer, self).__init__()
139 | self._device = device
140 | self._include_adaptors = include_adaptors
141 | self._eps = torch.tensor([1e-10], device=self._device)
142 | self._max_logit = torch.tensor([30.0], device=self._device)
143 | self._pi = torch.tensor([pi], dtype=torch.float32, device=device)
144 |
145 | if self._include_adaptors:
146 | self._projector = nn.Linear(message_dimension, 2, bias=False)
147 | self._module_list = nn.ModuleList([self._projector])
148 |
149 | def safe_log(self, x):
150 | return torch.max(x, self._eps).log()
151 |
152 | def safe_exp(self, x):
153 | return torch.min(x, self._max_logit).exp()
154 |
155 | def forward(self, message_state, sat_problem, last_call=False):
156 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple
157 | b_variable_mask, _, _, _ = sat_problem._batch_mask_tuple
158 | p_variable_mask, _, _, _ = sat_problem._pos_mask_tuple
159 | n_variable_mask, _, _, _ = sat_problem._neg_mask_tuple
160 |
161 | if self._include_adaptors:
162 | function_message = self._projector(message_state[1])
163 | function_message[:, 0] = F.sigmoid(function_message[:, 0])
164 | function_message[:, 1] = torch.sign(function_message[:, 1])
165 | else:
166 | function_message = message_state[1]
167 |
168 | external_force = torch.sign(torch.mm(variable_mask, function_message[:, 1].unsqueeze(1)))
169 | function_message = self.safe_log(1 - function_message[:, 0]).unsqueeze(1)
170 |
171 | edge_mask = torch.mm(function_mask_transpose, sat_problem._active_functions)
172 | function_message = function_message * edge_mask
173 |
174 | pos = torch.mm(p_variable_mask, function_message) + self.safe_log(1.0 - self._pi * (external_force == 1).float())
175 | neg = torch.mm(n_variable_mask, function_message) + self.safe_log(1.0 - self._pi * (external_force == -1).float())
176 |
177 | pos_neg_sum = pos + neg
178 |
179 | dont_care = torch.mm(variable_mask, function_message) + self.safe_log(1.0 - self._pi)
180 |
181 | bias = (2 * pos_neg_sum + dont_care) / 4.0
182 | pos = pos - bias
183 | neg = neg - bias
184 | pos_neg_sum = pos_neg_sum - bias
185 | dont_care = self.safe_exp(dont_care - bias)
186 |
187 | q_0 = self.safe_exp(pos) - self.safe_exp(pos_neg_sum)
188 | q_1 = self.safe_exp(neg) - self.safe_exp(pos_neg_sum)
189 |
190 | total = self.safe_log(q_0 + q_1 + dont_care)
191 |
192 | return self.safe_exp(self.safe_log(q_1) - total) - self.safe_exp(self.safe_log(q_0) - total), None
193 |
194 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication):
195 |
196 | edge_num = graph_map.size(1) * batch_replication
197 |
198 | if randomized:
199 | variable_state = torch.rand(edge_num, 3, dtype=torch.float32, device=self._device)
200 | # variable_state = variable_state / torch.sum(variable_state, 1).unsqueeze(1)
201 | function_state = torch.rand(edge_num, 2, dtype=torch.float32, device=self._device)
202 | function_state[:, 1] = 0
203 | else:
204 | variable_state = torch.ones(edge_num, 3, dtype=torch.float32, device=self._device) / 3.0
205 | function_state = 0.5 * torch.ones(edge_num, 2, dtype=torch.float32, device=self._device)
206 | function_state[:, 1] = 0
207 |
208 | return (variable_state, function_state)
209 |
210 |
211 | ###############################################################
212 |
213 |
214 | class ReinforcePredictor(nn.Module):
215 | "Implements the prediction mechanism for the Reinforce Algorithm."
216 |
217 | def __init__(self, device):
218 | super(ReinforcePredictor, self).__init__()
219 | self._device = device
220 |
221 | def forward(self, decimator_state, sat_problem, last_call=False):
222 |
223 | pred = decimator_state[1][:, 1].unsqueeze(1)
224 | pred = (torch.mm(sat_problem._graph_mask_tuple[0], pred) > 0).float()
225 |
226 | return pred, None
227 |
--------------------------------------------------------------------------------
/src/pdp/nn/pdp_propagate.py:
--------------------------------------------------------------------------------
1 | """
2 | Define various propagators for the PDP framework.
3 | """
4 |
5 | # Copyright (c) Microsoft. All rights reserved.
6 | # Licensed under the MIT license. See LICENSE.md file
7 | # in the project root for full license information.
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from pdp.nn import util
14 |
15 |
16 | ###############################################################
17 | ### The Propagator Classes
18 | ###############################################################
19 |
20 |
21 | class NeuralMessagePasser(nn.Module):
22 | "Implements the neural propagator."
23 |
24 | def __init__(self, device, edge_dimension, decimator_dimension, meta_data_dimension, hidden_dimension, mem_hidden_dimension,
25 | mem_agg_hidden_dimension, agg_hidden_dimension, dropout):
26 |
27 | super(NeuralMessagePasser, self).__init__()
28 | self._device = device
29 | self._module_list = nn.ModuleList()
30 | self._drop_out = dropout
31 |
32 | self._variable_aggregator = util.MessageAggregator(device, decimator_dimension + edge_dimension + meta_data_dimension,
33 | hidden_dimension, mem_hidden_dimension,
34 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, include_self_message=False)
35 | self._function_aggregator = util.MessageAggregator(device, decimator_dimension + edge_dimension + meta_data_dimension,
36 | hidden_dimension, mem_hidden_dimension,
37 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, include_self_message=False)
38 |
39 | self._module_list.append(self._variable_aggregator)
40 | self._module_list.append(self._function_aggregator)
41 |
42 | self._hidden_dimension = hidden_dimension
43 | self._mem_hidden_dimension = mem_hidden_dimension
44 | self._agg_hidden_dimension = agg_hidden_dimension
45 | self._mem_agg_hidden_dimension = mem_agg_hidden_dimension
46 |
47 | def forward(self, init_state, decimator_state, sat_problem, is_training, active_mask=None):
48 |
49 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple
50 | b_variable_mask, _, _, _ = sat_problem._batch_mask_tuple
51 |
52 | if active_mask is not None:
53 | mask = torch.mm(b_variable_mask, active_mask.float())
54 | mask = torch.mm(variable_mask_transpose, mask)
55 | else:
56 | edge_num = init_state[0].size(0)
57 | mask = torch.ones(edge_num, 1, device=self._device)
58 |
59 | if sat_problem._meta_data is not None:
60 | graph_feat = torch.mm(b_variable_mask, sat_problem._meta_data)
61 | graph_feat = torch.mm(variable_mask_transpose, graph_feat)
62 |
63 | if len(decimator_state) == 3:
64 | decimator_variable_state, decimator_function_state, edge_mask = decimator_state
65 | else:
66 | decimator_variable_state, decimator_function_state = decimator_state
67 | edge_mask = None
68 |
69 | variable_state, function_state = init_state
70 |
71 | ## variables --> functions
72 | decimator_variable_state = torch.cat((decimator_variable_state, sat_problem._edge_feature), 1)
73 |
74 | if sat_problem._meta_data is not None:
75 | decimator_variable_state = torch.cat((decimator_variable_state, graph_feat), 1)
76 |
77 | function_state = mask * self._variable_aggregator(
78 | decimator_variable_state, sat_problem._edge_feature, variable_mask, variable_mask_transpose, edge_mask) + (1 - mask) * function_state
79 |
80 | function_state = F.dropout(function_state, p=self._drop_out, training=is_training)
81 |
82 | ## functions --> variables
83 | decimator_function_state = torch.cat((decimator_function_state, sat_problem._edge_feature), 1)
84 |
85 | if sat_problem._meta_data is not None:
86 | decimator_function_state = torch.cat((decimator_function_state, graph_feat), 1)
87 |
88 | variable_state = mask * self._function_aggregator(
89 | decimator_function_state, sat_problem._edge_feature, function_mask, function_mask_transpose, edge_mask) + (1 - mask) * variable_state
90 |
91 | variable_state = F.dropout(variable_state, p=self._drop_out, training=is_training)
92 |
93 | del mask
94 |
95 | return variable_state, function_state
96 |
97 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication):
98 |
99 | edge_num = graph_map.size(1) * batch_replication
100 |
101 | if randomized:
102 | variable_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0
103 | function_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0
104 | else:
105 | variable_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device)
106 | function_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device)
107 |
108 | return (variable_state, function_state)
109 |
110 |
111 | ###############################################################
112 |
113 |
114 | class SurveyPropagator(nn.Module):
115 | "Implements the Survey Propagator (SP)."
116 |
117 | def __init__(self, device, decimator_dimension, include_adaptors=False, pi=0.0):
118 |
119 | super(SurveyPropagator, self).__init__()
120 | self._device = device
121 | self._function_message_dim = 3
122 | self._variable_message_dim = 2
123 | self._include_adaptors = include_adaptors
124 | self._eps = torch.tensor([1e-40], device=self._device)
125 | self._max_logit = torch.tensor([30.0], device=self._device)
126 | self._pi = torch.tensor([pi], dtype=torch.float32, device=device)
127 |
128 | if self._include_adaptors:
129 | self._variable_input_projector = nn.Linear(decimator_dimension, self._variable_message_dim, bias=False)
130 | self._function_input_projector = nn.Linear(decimator_dimension, 1, bias=False)
131 | self._module_list = nn.ModuleList([self._variable_input_projector, self._function_input_projector])
132 |
133 | def safe_log(self, x):
134 | return torch.max(x, self._eps).log()
135 |
136 | def safe_exp(self, x):
137 | return torch.min(x, self._max_logit).exp()
138 |
139 | def forward(self, init_state, decimator_state, sat_problem, is_training, active_mask=None):
140 |
141 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple
142 | b_variable_mask, _, _, _ = sat_problem._batch_mask_tuple
143 | p_variable_mask, _, _, _ = sat_problem._pos_mask_tuple
144 | n_variable_mask, _, _, _ = sat_problem._neg_mask_tuple
145 |
146 | if active_mask is not None:
147 | mask = torch.mm(b_variable_mask, active_mask.float())
148 | mask = torch.mm(variable_mask_transpose, mask)
149 | else:
150 | edge_num = init_state[0].size(0)
151 | mask = torch.ones(edge_num, 1, device=self._device)
152 |
153 | if len(decimator_state) == 3:
154 | decimator_variable_state, decimator_function_state, edge_mask = decimator_state
155 | else:
156 | decimator_variable_state, decimator_function_state = decimator_state
157 | edge_mask = None
158 |
159 | variable_state, function_state = init_state
160 |
161 | ## functions --> variables
162 |
163 | if self._include_adaptors:
164 | decimator_variable_state = F.logsigmoid(self._function_input_projector(decimator_variable_state))
165 | else:
166 | decimator_variable_state = self.safe_log(decimator_variable_state[:, 0]).unsqueeze(1)
167 |
168 | if edge_mask is not None:
169 | decimator_variable_state = decimator_variable_state * edge_mask
170 |
171 | aggregated_variable_state = torch.mm(function_mask, decimator_variable_state)
172 | aggregated_variable_state = torch.mm(function_mask_transpose, aggregated_variable_state)
173 | aggregated_variable_state = aggregated_variable_state - decimator_variable_state
174 |
175 | function_state = mask * self.safe_exp(aggregated_variable_state) + (1 - mask) * function_state[:, 0].unsqueeze(1)
176 |
177 | ## functions --> variables
178 |
179 | if self._include_adaptors:
180 | decimator_function_state = self._variable_input_projector(decimator_function_state)
181 | decimator_function_state[:, 0] = F.sigmoid(decimator_function_state[:, 0])
182 | decimator_function_state[:, 1] = torch.sign(decimator_function_state[:, 1])
183 |
184 | external_force = decimator_function_state[:, 1].unsqueeze(1)
185 | decimator_function_state = self.safe_log(1 - decimator_function_state[:, 0]).unsqueeze(1)
186 |
187 | if edge_mask is not None:
188 | decimator_function_state = decimator_function_state * edge_mask
189 |
190 | pos = torch.mm(p_variable_mask, decimator_function_state)
191 | pos = torch.mm(variable_mask_transpose, pos)
192 | neg = torch.mm(n_variable_mask, decimator_function_state)
193 | neg = torch.mm(variable_mask_transpose, neg)
194 |
195 | same_sign = 0.5 * (1 + sat_problem._edge_feature) * pos + 0.5 * (1 - sat_problem._edge_feature) * neg
196 | same_sign = same_sign - decimator_function_state
197 | same_sign += self.safe_log(1.0 - self._pi * (external_force == sat_problem._edge_feature).float())
198 |
199 | opposite_sign = 0.5 * (1 - sat_problem._edge_feature) * pos + 0.5 * (1 + sat_problem._edge_feature) * neg
200 | # The opposite sign edge aggregation does not include the current edge by definition, therefore no need for subtraction.
201 | opposite_sign += self.safe_log(1.0 - self._pi * (external_force == -sat_problem._edge_feature).float())
202 |
203 | dont_care = same_sign + opposite_sign
204 |
205 | bias = 0 #(2 * dont_care) / 3.0
206 | same_sign = same_sign - bias
207 | opposite_sign = opposite_sign - bias
208 | dont_care = self.safe_exp(dont_care - bias)
209 |
210 | same_sign = self.safe_exp(same_sign)
211 | opposite_sign = self.safe_exp(opposite_sign)
212 | q_u = same_sign * (1 - opposite_sign)
213 | q_s = opposite_sign * (1 - same_sign)
214 |
215 | total = q_u + q_s + dont_care
216 | temp = torch.cat((q_u, q_s, dont_care), 1) / total
217 |
218 | variable_state = mask * temp + (1 - mask) * variable_state
219 |
220 | del mask
221 | return variable_state, torch.cat((function_state, external_force), 1)
222 |
223 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication):
224 |
225 | edge_num = graph_map.size(1) * batch_replication
226 |
227 | if randomized:
228 | variable_state = torch.rand(edge_num, self._function_message_dim, dtype=torch.float32, device=self._device)
229 | variable_state = variable_state / torch.sum(variable_state, 1).unsqueeze(1)
230 | function_state = torch.rand(edge_num, self._variable_message_dim, dtype=torch.float32, device=self._device)
231 | function_state[:, 1] = 0
232 | else:
233 | variable_state = torch.ones(edge_num, self._function_message_dim, dtype=torch.float32, device=self._device) / self._function_message_dim
234 | function_state = 0.5 * torch.ones(edge_num, self._variable_message_dim, dtype=torch.float32, device=self._device)
235 | function_state[:, 1] = 0
236 |
237 | return (variable_state, function_state)
238 |
239 |
240 | ###############################################################
241 |
--------------------------------------------------------------------------------
/src/pdp/nn/pdp_decimate.py:
--------------------------------------------------------------------------------
1 | """
2 | Define various decimators for the PDP framework.
3 | """
4 |
5 | # Copyright (c) Microsoft. All rights reserved.
6 | # Licensed under the MIT license. See LICENSE.md file
7 | # in the project root for full license information.
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from pdp.nn import util
14 |
15 |
16 | ###############################################################
17 | ### The Decimator Classes
18 | ###############################################################
19 |
20 |
21 | class NeuralDecimator(nn.Module):
22 | "Implements a neural decimator."
23 |
24 | def __init__(self, device, message_dimension, meta_data_dimension, hidden_dimension, mem_hidden_dimension,
25 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, dropout):
26 |
27 | super(NeuralDecimator, self).__init__()
28 | self._device = device
29 | self._module_list = nn.ModuleList()
30 | self._drop_out = dropout
31 |
32 | if isinstance(message_dimension, tuple):
33 | variable_message_dim, function_message_dim = message_dimension
34 | else:
35 | variable_message_dim = message_dimension
36 | function_message_dim = message_dimension
37 |
38 | self._variable_rnn_cell = nn.GRUCell(
39 | variable_message_dim + edge_dimension + meta_data_dimension, hidden_dimension, bias=True)
40 | self._function_rnn_cell = nn.GRUCell(
41 | function_message_dim + edge_dimension + meta_data_dimension, hidden_dimension, bias=True)
42 |
43 | self._module_list.append(self._variable_rnn_cell)
44 | self._module_list.append(self._function_rnn_cell)
45 |
46 | self._hidden_dimension = hidden_dimension
47 | self._mem_hidden_dimension = mem_hidden_dimension
48 | self._agg_hidden_dimension = agg_hidden_dimension
49 | self._mem_agg_hidden_dimension = mem_agg_hidden_dimension
50 |
51 | def forward(self, init_state, message_state, sat_problem, is_training, active_mask=None):
52 |
53 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple
54 | b_variable_mask, b_variable_mask_transpose, b_function_mask, b_function_mask_transpose = sat_problem._batch_mask_tuple
55 |
56 | if active_mask is not None:
57 | mask = torch.mm(b_variable_mask, active_mask.float())
58 | mask = torch.mm(variable_mask_transpose, mask)
59 | else:
60 | edge_num = init_state[0].size(0)
61 | mask = torch.ones(edge_num, 1, device=self._device)
62 |
63 | if sat_problem._meta_data is not None:
64 | graph_feat = torch.mm(b_variable_mask, sat_problem._meta_data)
65 | graph_feat = torch.mm(variable_mask_transpose, graph_feat)
66 |
67 | variable_state, function_state = message_state
68 |
69 | # Variable states
70 | variable_state = torch.cat((variable_state, sat_problem._edge_feature), 1)
71 |
72 | if sat_problem._meta_data is not None:
73 | variable_state = torch.cat((variable_state, graph_feat), 1)
74 |
75 | variable_state = mask * self._variable_rnn_cell(variable_state, init_state[0]) + (1 - mask) * init_state[0]
76 |
77 | # Function states
78 | function_state = torch.cat((function_state, sat_problem._edge_feature), 1)
79 |
80 | if sat_problem._meta_data is not None:
81 | function_state = torch.cat((function_state, graph_feat), 1)
82 |
83 | function_state = mask * self._function_rnn_cell(function_state, init_state[1]) + (1 - mask) * init_state[1]
84 |
85 | del mask
86 |
87 | return variable_state, function_state
88 |
89 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication):
90 |
91 | edge_num = graph_map.size(1) * batch_replication
92 |
93 | if randomized:
94 | variable_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0
95 | function_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0
96 | else:
97 | variable_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device)
98 | function_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device)
99 |
100 | return (variable_state, function_state)
101 |
102 |
103 | ###############################################################
104 |
105 |
106 | class SequentialDecimator(nn.Module):
107 | "Implements the general (greedy) sequential decimator."
108 |
109 | def __init__(self, device, message_dimension, scorer, tolerance, t_max):
110 | super(SequentialDecimator, self).__init__()
111 |
112 | self._device = device
113 | self._tolerance = tolerance
114 | self._scorer = scorer
115 | self._previous_function_state = None
116 | self._message_dimension = message_dimension
117 | self._constant = torch.ones(1, 1, device=self._device)
118 | self._t_max = t_max
119 | self._counters = None
120 | self._module_list = nn.ModuleList([self._scorer])
121 |
122 | def forward(self, init_state, message_state, sat_problem, is_training, active_mask=None):
123 |
124 | if self._counters is None:
125 | self._counters = torch.zeros(sat_problem._batch_size, 1, device=self._device)
126 |
127 | if active_mask is not None:
128 | survey = message_state[1][:, 0].unsqueeze(1)
129 | survey = util.sparse_smooth_max(survey, sat_problem._graph_mask_tuple[0], self._device)
130 | survey = survey * sat_problem._active_variables
131 | survey = util.sparse_max(survey.squeeze(1), sat_problem._batch_mask_tuple[0], self._device).unsqueeze(1)
132 |
133 | active_mask[survey <= 1e-10] = 0
134 |
135 | if self._previous_function_state is not None and sat_problem._active_variables.sum() > 0:
136 | function_diff = (self._previous_function_state - message_state[1][:, 0]).abs().unsqueeze(1)
137 |
138 | if sat_problem._edge_mask is not None:
139 | function_diff = function_diff * sat_problem._edge_mask
140 |
141 | sum_diff = util.sparse_smooth_max(function_diff, sat_problem._graph_mask_tuple[0], self._device)
142 | sum_diff = sum_diff * sat_problem._active_variables
143 | sum_diff = util.sparse_max(sum_diff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device).unsqueeze(1)
144 |
145 | self._counters[sum_diff[:, 0] < self._tolerance, 0] = 0
146 | sum_diff = (sum_diff < self._tolerance).float()
147 | sum_diff[self._counters[:, 0] >= self._t_max, 0] = 1
148 | self._counters[self._counters[:, 0] >= self._t_max, 0] = 0
149 |
150 | sum_diff = torch.mm(sat_problem._batch_mask_tuple[0], sum_diff)
151 |
152 | if sum_diff.sum() > 0:
153 | score, _ = self._scorer(message_state, sat_problem)
154 |
155 | # Find the variable index with max score for each instance in the batch
156 | coeff = score.abs() * sat_problem._active_variables * sum_diff
157 |
158 | if coeff.sum() > 0:
159 | max_ind = util.sparse_argmax(coeff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device)
160 | norm = torch.mm(sat_problem._batch_mask_tuple[1], coeff)
161 |
162 | if active_mask is not None:
163 | max_ind = max_ind[(active_mask * (norm != 0)).squeeze(1)]
164 | else:
165 | max_ind = max_ind[norm.squeeze(1) != 0]
166 |
167 | if max_ind.size()[0] > 0:
168 | assignment = torch.zeros(sat_problem._variable_num, 1, device=self._device)
169 | assignment[max_ind, 0] = score.sign()[max_ind, 0]
170 |
171 | sat_problem.set_variables(assignment)
172 |
173 | self._counters = self._counters + 1
174 |
175 | self._previous_function_state = message_state[1][:, 0]
176 |
177 | return message_state
178 |
179 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication):
180 | self._previous_function_state = None
181 | self._counters = None
182 |
183 | return self._scorer.get_init_state(graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication)
184 |
185 |
186 | ###############################################################
187 |
188 |
189 | class ReinforceDecimator(nn.Module):
190 | "Implements the (distributed) Reinforce decimator."
191 |
192 | def __init__(self, device, scorer, decimation_probability=0.5):
193 | super(ReinforceDecimator, self).__init__()
194 |
195 | self._device = device
196 | self._scorer = scorer
197 | self._decimation_probability = decimation_probability
198 | self._function_message_dim = 3
199 | self._variable_message_dim = 2
200 | self._previous_function_state = None
201 |
202 | def forward(self, init_state, message_state, sat_problem, is_training, active_mask=None):
203 | variable_state, function_state = message_state
204 |
205 | if active_mask is not None and self._previous_function_state is not None and sat_problem._active_variables.sum() > 0:
206 | function_diff = (self._previous_function_state - message_state[1][:, 0]).abs().unsqueeze(1)
207 |
208 | if sat_problem._edge_mask is not None:
209 | function_diff = function_diff * sat_problem._edge_mask
210 |
211 | sum_diff = util.sparse_smooth_max(function_diff, sat_problem._graph_mask_tuple[0], self._device)
212 | sum_diff = sum_diff * sat_problem._active_variables
213 | sum_diff = util.sparse_max(sum_diff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device)
214 | active_mask[sum_diff <= 0.01, 0] = 0
215 |
216 | self._previous_function_state = message_state[1][:, 0]
217 |
218 | if torch.rand(1, device=self._device) < self._decimation_probability:
219 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple
220 | b_variable_mask, b_variable_mask_transpose, b_function_mask, b_function_mask_transpose = sat_problem._batch_mask_tuple
221 |
222 | if active_mask is not None:
223 | mask = torch.mm(b_variable_mask, active_mask.float())
224 | mask = torch.mm(variable_mask_transpose, mask)
225 | else:
226 | mask = torch.ones(sat_problem._edge_num, 1, device=self._device)
227 |
228 | mask = mask.squeeze(1)
229 | score, _ = self._scorer(message_state, sat_problem)
230 | score = torch.mm(variable_mask_transpose, torch.sign(score)).squeeze(1)
231 |
232 | function_state[:, 1] = mask * score + (1 - mask) * function_state[:, 1]
233 |
234 | return variable_state, function_state
235 |
236 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication):
237 |
238 | edge_num = graph_map.size(1) * batch_replication
239 | self._previous_function_state = None
240 |
241 | if randomized:
242 | variable_state = torch.rand(edge_num, self._function_message_dim, dtype=torch.float32, device=self._device)
243 | function_state = torch.rand(edge_num, self._variable_message_dim, dtype=torch.float32, device=self._device)
244 | function_state[:, 1] = 0
245 | else:
246 | variable_state = torch.ones(edge_num, self._function_message_dim, dtype=torch.float32, device=self._device) / self._function_message_dim
247 | function_state = 0.5 * torch.ones(edge_num, self._variable_message_dim, dtype=torch.float32, device=self._device)
248 | function_state[:, 1] = 0
249 |
250 | return (variable_state, function_state)
251 |
--------------------------------------------------------------------------------
/src/pdp/nn/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
3 |
4 | # util.py : Defines the utility functionalities for the PDP framework.
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class MessageAggregator(nn.Module):
12 | "Implements a deep set function for message aggregation at variable and function nodes."
13 |
14 | def __init__(self, device, input_dimension, output_dimension, mem_hidden_dimension,
15 | mem_agg_hidden_dimension, agg_hidden_dimension, feature_dimension, include_self_message):
16 |
17 | super(MessageAggregator, self).__init__()
18 | self._device = device
19 | self._include_self_message = include_self_message
20 | self._module_list = nn.ModuleList()
21 |
22 | if mem_hidden_dimension > 0 and mem_agg_hidden_dimension > 0:
23 |
24 | self._W1_m = nn.Linear(
25 | input_dimension, mem_hidden_dimension, bias=True) # .to(self._device)
26 |
27 | self._W2_m = nn.Linear(
28 | mem_hidden_dimension, mem_agg_hidden_dimension, bias=False) # .to(self._device)
29 |
30 | self._module_list.append(self._W1_m)
31 | self._module_list.append(self._W2_m)
32 |
33 | if agg_hidden_dimension > 0 and mem_agg_hidden_dimension > 0:
34 |
35 | if mem_hidden_dimension <= 0:
36 | mem_agg_hidden_dimension = input_dimension
37 |
38 | self._W1_a = nn.Linear(
39 | mem_agg_hidden_dimension + feature_dimension, agg_hidden_dimension, bias=True) # .to(self._device)
40 |
41 | self._W2_a = nn.Linear(
42 | agg_hidden_dimension, output_dimension, bias=False) # .to(self._device)
43 |
44 | self._module_list.append(self._W1_a)
45 | self._module_list.append(self._W2_a)
46 |
47 | self._agg_hidden_dimension = agg_hidden_dimension
48 | self._mem_hidden_dimension = mem_hidden_dimension
49 | self._mem_agg_hidden_dimension = mem_agg_hidden_dimension
50 |
51 | def forward(self, state, feature, mask, mask_transpose, edge_mask=None):
52 |
53 | # Apply the pre-aggregation transform
54 | if self._mem_hidden_dimension > 0 and self._mem_agg_hidden_dimension > 0:
55 | state = F.logsigmoid(self._W2_m(F.logsigmoid(self._W1_m(state))))
56 |
57 | if edge_mask is not None:
58 | state = state * edge_mask
59 |
60 | aggregated_state = torch.mm(mask, state)
61 |
62 | if not self._include_self_message:
63 | aggregated_state = torch.mm(mask_transpose, aggregated_state)
64 |
65 | if edge_mask is not None:
66 | aggregated_state = aggregated_state - state * edge_mask
67 | else:
68 | aggregated_state = aggregated_state - state
69 |
70 | if feature is not None:
71 | aggregated_state = torch.cat((aggregated_state, feature), 1)
72 |
73 | # Apply the post-aggregation transform
74 | if self._agg_hidden_dimension > 0 and self._mem_agg_hidden_dimension > 0:
75 | aggregated_state = F.logsigmoid(self._W2_a(F.logsigmoid(self._W1_a(aggregated_state))))
76 |
77 | return aggregated_state
78 |
79 |
80 | ###############################################################
81 |
82 |
83 | class MultiLayerPerceptron(nn.Module):
84 | "Implements a standard fully-connected, multi-layer perceptron."
85 |
86 | def __init__(self, device, layer_dims):
87 |
88 | super(MultiLayerPerceptron, self).__init__()
89 | self._device = device
90 | self._module_list = nn.ModuleList()
91 | self._layer_num = len(layer_dims) - 1
92 |
93 | self._inner_layers = []
94 | for i in range(self._layer_num - 1):
95 | self._inner_layers += [nn.Linear(layer_dims[i], layer_dims[i + 1])]
96 | self._module_list.append(self._inner_layers[i])
97 |
98 | self._output_layer = nn.Linear(layer_dims[self._layer_num - 1], layer_dims[self._layer_num], bias=False)
99 | self._module_list.append(self._output_layer)
100 |
101 | def forward(self, inp):
102 | x = inp
103 |
104 | for layer in self._inner_layers:
105 | x = F.relu(layer(x))
106 |
107 | return F.sigmoid(self._output_layer(x))
108 |
109 |
110 | ##########################################################################################################################
111 |
112 |
113 | class SatLossEvaluator(nn.Module):
114 | "Implements a module to calculate the energy (i.e. the loss) for the current prediction."
115 |
116 | def __init__(self, alpha, device):
117 | super(SatLossEvaluator, self).__init__()
118 | self._alpha = alpha
119 | self._device = device
120 |
121 | @staticmethod
122 | def safe_log(x, eps):
123 | return torch.max(x, eps).log()
124 |
125 | @staticmethod
126 | def compute_masks(graph_map, batch_variable_map, batch_function_map, edge_feature, device):
127 | edge_num = graph_map.size(1)
128 | variable_num = batch_variable_map.size(0)
129 | function_num = batch_function_map.size(0)
130 | all_ones = torch.ones(edge_num, device=device)
131 | edge_num_range = torch.arange(edge_num, dtype=torch.int64, device=device)
132 |
133 | variable_sparse_ind = torch.stack([edge_num_range, graph_map[0, :].long()])
134 | function_sparse_ind = torch.stack([graph_map[1, :].long(), edge_num_range])
135 |
136 | if device.type == 'cuda':
137 | variable_mask = torch.cuda.sparse.FloatTensor(variable_sparse_ind, edge_feature.squeeze(1),
138 | torch.Size([edge_num, variable_num]), device=device)
139 | function_mask = torch.cuda.sparse.FloatTensor(function_sparse_ind, all_ones,
140 | torch.Size([function_num, edge_num]), device=device)
141 | else:
142 | variable_mask = torch.sparse.FloatTensor(variable_sparse_ind, edge_feature.squeeze(1),
143 | torch.Size([edge_num, variable_num]), device=device)
144 | function_mask = torch.sparse.FloatTensor(function_sparse_ind, all_ones,
145 | torch.Size([function_num, edge_num]), device=device)
146 |
147 | return variable_mask, function_mask
148 |
149 | @staticmethod
150 | def compute_batch_mask(batch_variable_map, batch_function_map, device):
151 | variable_num = batch_variable_map.size()[0]
152 | function_num = batch_function_map.size()[0]
153 | variable_all_ones = torch.ones(variable_num, device=device)
154 | function_all_ones = torch.ones(function_num, device=device)
155 | variable_range = torch.arange(variable_num, dtype=torch.int64, device=device)
156 | function_range = torch.arange(function_num, dtype=torch.int64, device=device)
157 | batch_size = (batch_variable_map.max() + 1).long().item()
158 |
159 | variable_sparse_ind = torch.stack([variable_range, batch_variable_map.long()])
160 | function_sparse_ind = torch.stack([function_range, batch_function_map.long()])
161 |
162 | if device.type == 'cuda':
163 | variable_mask = torch.cuda.sparse.FloatTensor(variable_sparse_ind, variable_all_ones,
164 | torch.Size([variable_num, batch_size]), device=device)
165 | function_mask = torch.cuda.sparse.FloatTensor(function_sparse_ind, function_all_ones,
166 | torch.Size([function_num, batch_size]), device=device)
167 | else:
168 | variable_mask = torch.sparse.FloatTensor(variable_sparse_ind, variable_all_ones,
169 | torch.Size([variable_num, batch_size]), device=device)
170 | function_mask = torch.sparse.FloatTensor(function_sparse_ind, function_all_ones,
171 | torch.Size([function_num, batch_size]), device=device)
172 |
173 | variable_mask_transpose = variable_mask.transpose(0, 1)
174 | function_mask_transpose = function_mask.transpose(0, 1)
175 |
176 | return (variable_mask, variable_mask_transpose, function_mask, function_mask_transpose)
177 |
178 | def forward(self, variable_prediction, label, graph_map, batch_variable_map,
179 | batch_function_map, edge_feature, meta_data, global_step, eps, max_coeff, loss_sharpness):
180 |
181 | coeff = torch.min(global_step.pow(self._alpha), torch.tensor([max_coeff], device=self._device))
182 |
183 | signed_variable_mask_transpose, function_mask = \
184 | SatLossEvaluator.compute_masks(graph_map, batch_variable_map, batch_function_map,
185 | edge_feature, self._device)
186 |
187 | edge_values = torch.mm(signed_variable_mask_transpose, variable_prediction)
188 | edge_values = edge_values + (1 - edge_feature) / 2
189 |
190 | weights = (coeff * edge_values).exp()
191 |
192 | nominator = torch.mm(function_mask, weights * edge_values)
193 | denominator = torch.mm(function_mask, weights)
194 |
195 | clause_value = denominator / torch.max(nominator, eps)
196 | clause_value = 1 + (clause_value - 1).pow(loss_sharpness)
197 | return torch.mean(SatLossEvaluator.safe_log(clause_value, eps))
198 |
199 |
200 | ##########################################################################################################################
201 |
202 |
203 | class SatCNFEvaluator(nn.Module):
204 | "Implements a module to evaluate the current prediction."
205 |
206 | def __init__(self, device):
207 | super(SatCNFEvaluator, self).__init__()
208 | self._device = device
209 |
210 | def forward(self, variable_prediction, graph_map, batch_variable_map,
211 | batch_function_map, edge_feature, meta_data):
212 |
213 | variable_num = batch_variable_map.size(0)
214 | function_num = batch_function_map.size(0)
215 | batch_size = (batch_variable_map.max() + 1).item()
216 | all_ones = torch.ones(function_num, 1, device=self._device)
217 |
218 | signed_variable_mask_transpose, function_mask = \
219 | SatLossEvaluator.compute_masks(graph_map, batch_variable_map, batch_function_map,
220 | edge_feature, self._device)
221 |
222 | b_variable_mask, b_variable_mask_transpose, b_function_mask, b_function_mask_transpose = \
223 | SatLossEvaluator.compute_batch_mask(
224 | batch_variable_map, batch_function_map, self._device)
225 |
226 | edge_values = torch.mm(signed_variable_mask_transpose, variable_prediction)
227 | edge_values = edge_values + (1 - edge_feature) / 2
228 | edge_values = (edge_values > 0.5).float()
229 |
230 | clause_values = torch.mm(function_mask, edge_values)
231 | clause_values = (clause_values > 0).float()
232 |
233 | max_sat = torch.mm(b_function_mask_transpose, all_ones)
234 | batch_values = torch.mm(b_function_mask_transpose, clause_values)
235 |
236 | return (max_sat == batch_values).float(), max_sat - batch_values
237 |
238 |
239 | ##########################################################################################################################
240 |
241 |
242 | class PerceptronTanh(nn.Module):
243 | "Implements a 1-layer perceptron with Tanh activaton."
244 |
245 | def __init__(self, input_dimension, hidden_dimension, output_dimension):
246 | super(PerceptronTanh, self).__init__()
247 | self._layer1 = nn.Linear(input_dimension, hidden_dimension)
248 | self._layer2 = nn.Linear(hidden_dimension, output_dimension, bias=False)
249 |
250 | def forward(self, inp):
251 | return F.tanh(self._layer2(F.relu(self._layer1(inp))))
252 |
253 |
254 | ##########################################################################################################################
255 |
256 |
257 | def sparse_argmax(x, mask, device):
258 | "Implements the exact, memory-inefficient argmax operation for a row vector input."
259 |
260 | if device.type == 'cuda':
261 | dense_mat = torch.cuda.sparse.FloatTensor(mask._indices(), x - x.min() + 1, mask.size(), device=device).to_dense()
262 | else:
263 | dense_mat = torch.sparse.FloatTensor(mask._indices(), x - x.min() + 1, mask.size(), device=device).to_dense()
264 |
265 | return torch.argmax(dense_mat, 0)
266 |
267 | def sparse_max(x, mask, device):
268 | "Implements the exact, memory-inefficient max operation for a row vector input."
269 |
270 | if device.type == 'cuda':
271 | dense_mat = torch.cuda.sparse.FloatTensor(mask._indices(), x - x.min() + 1, mask.size(), device=device).to_dense()
272 | else:
273 | dense_mat = torch.sparse.FloatTensor(mask._indices(), x - x.min() + 1, mask.size(), device=device).to_dense()
274 |
275 | return torch.max(dense_mat, 0)[0] + x.min() - 1
276 |
277 | def safe_exp(x, device):
278 | "Implements safe exp operation."
279 |
280 | return torch.min(x, torch.tensor([30.0], device=device)).exp()
281 |
282 | def sparse_smooth_max(x, mask, device, alpha=30):
283 | "Implements the approximate, memory-efficient max operation for a row vector input."
284 |
285 | coeff = safe_exp(alpha * x, device)
286 | return torch.mm(mask, x * coeff) / torch.max(torch.mm(mask, coeff), torch.ones(1, device=device))
287 |
--------------------------------------------------------------------------------
/src/pdp/generator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Various types of CNF generators for generating real-time CNF instances.
4 | """
5 |
6 | # Copyright (c) Microsoft. All rights reserved.
7 | # Licensed under the MIT license. See LICENSE.md file
8 | # in the project root for full license information.
9 |
10 | import numpy as np
11 | import argparse
12 | import os, sys
13 |
14 |
15 | def is_sat(_var_num, iclause_list):
16 | ## Note: Invoke your SAT solver of choice here for generating labeled data.
17 | return False
18 |
19 | ##########################################################################
20 |
21 |
22 | class CNFGeneratorBase(object):
23 | "The base class for all CNF generators."
24 |
25 | def __init__(self, min_n, max_n, min_alpha, max_alpha, alpha_resolution=10):
26 | self._min_n = min_n
27 | self._max_n = max_n
28 | self._min_alpha = min_alpha
29 | self._max_alpha = max_alpha
30 | self._alpha = min_alpha
31 | self._alpha_inc = (max_alpha - min_alpha) / alpha_resolution
32 | self._alpha_resolution = alpha_resolution
33 |
34 | def generate(self):
35 | "Generates unlabeled CNF instances."
36 | pass
37 |
38 | def generate_complete(self):
39 | "Generates labeled CNF instances."
40 | pass
41 |
42 | def _to_json(self, n, m, graph_map, edge_feature, label):
43 | return [[n, m], list(((graph_map[0, :] + 1) * edge_feature).astype(int)), list(graph_map[1, :] + 1), label]
44 |
45 | def _to_dimacs(self, n, m, clause_list):
46 | body = ''
47 |
48 | for clause in clause_list:
49 | body += (str(clause)[1:-1].replace(',', '')) + ' 0\n'
50 |
51 | return 'p cnf ' + str(n) + ' ' + str(m) + '\n' + body
52 |
53 | def generate_dataset(self, size, output_dimacs_path, json_output, name, sat_only=True):
54 |
55 | max_trial = 50
56 |
57 | if not os.path.exists(output_dimacs_path):
58 | os.makedirs(output_dimacs_path)
59 |
60 | if not os.path.exists(json_output):
61 | os.makedirs(json_output)
62 |
63 | output_dimacs_path = os.path.join(output_dimacs_path, name)
64 | json_output = os.path.join(json_output, name)
65 |
66 | for j in range(self._alpha_resolution):
67 | postfix = '_' + str(j) + '_' + str(self._alpha) + '_' + str(self._alpha + self._alpha_inc)
68 |
69 | if not os.path.exists(output_dimacs_path + postfix):
70 | os.makedirs(output_dimacs_path + postfix)
71 |
72 | with open(json_output + postfix + ".json", 'w') as f:
73 | for i in range(size):
74 | flag = False
75 | for _ in range(max_trial):
76 | n, m, graph_map, edge_feature, _, label, clause_list = self.generate_complete()
77 |
78 | if (not sat_only) or (label == 1):
79 | flag = True
80 | break
81 |
82 | if flag:
83 | f.write(str(self._to_json(n, m, graph_map, edge_feature, label)).replace("'", '"') + '\n')
84 |
85 | dimacs_file_name = 'dimacs_' + str(i) + '_sat=' + str(label) + '.DIMACS'
86 | with open(os.path.join(output_dimacs_path + postfix, dimacs_file_name), 'w') as g:
87 | g.write(self._to_dimacs(n, m, clause_list) + '\n')
88 |
89 | sys.stdout.write("Dataset {:2d}/{:2d}: {:.2f} % complete \r".format(j + 1, self._alpha_resolution, 100*float(i+1) / size))
90 | sys.stdout.flush()
91 |
92 | self._alpha += self._alpha_inc
93 |
94 |
95 | ###############################################################################################
96 |
97 |
98 | class UniformCNFGenerator(CNFGeneratorBase):
99 | "Implements the uniformly random CNF generator."
100 |
101 | def __init__(self, min_n, max_n, min_k, max_k, min_alpha, max_alpha, alpha_resolution=10):
102 |
103 | super(UniformCNFGenerator, self).__init__(min_n, max_n, min_alpha, max_alpha, alpha_resolution)
104 | self._min_k = min_k
105 | self._max_k = max_k
106 |
107 | def generate(self):
108 | n = np.random.randint(self._min_n, self._max_n + 1)
109 | alpha = np.random.uniform(self._min_alpha, self._max_alpha)
110 | m = int(n * alpha)
111 |
112 | clause_length = [np.random.randint(self._min_k, min(self._max_k, n-1) + 1) for _ in range(m)]
113 | edge_num = np.sum(clause_length)
114 |
115 | graph_map = np.zeros((2, edge_num), dtype=np.int32)
116 |
117 | ind = 0
118 | for i in range(m):
119 | graph_map[0, ind:(ind+clause_length[i])] = np.random.choice(n, clause_length[i], replace=False)
120 | graph_map[1, ind:(ind+clause_length[i])] = i
121 | ind += clause_length[i]
122 |
123 | edge_feature = 2.0 * np.random.choice(2, edge_num) - 1
124 |
125 | return n, m, graph_map, edge_feature, None, -1.0, []
126 |
127 | def generate_complete(self):
128 | n = np.random.randint(self._min_n, self._max_n + 1)
129 | alpha = np.random.uniform(self._alpha, self._alpha + self._alpha_inc)
130 | m = int(n * alpha)
131 | max_trial = 10
132 |
133 | clause_set = set()
134 | clause_list = []
135 | graph_map = np.zeros((2, 0), dtype=np.int32)
136 | edge_features = np.zeros(0)
137 |
138 | i = -1
139 | for _ in range(m):
140 | for _ in range(max_trial):
141 | clause_length = np.random.randint(self._min_k, min(self._max_k, n-1) + 1)
142 | literals = np.sort(np.random.choice(n, clause_length, replace=False))
143 | edge_feature = 2.0 * np.random.choice(2, clause_length) - 1
144 | iclause = list(((literals + 1) * edge_feature).astype(int))
145 |
146 | if str(iclause) not in clause_set:
147 | i += 1
148 | break
149 |
150 | clause_set.add(str(iclause))
151 | clause_list += [iclause]
152 |
153 | graph_map = np.concatenate((graph_map, np.stack((literals, i * np.ones(clause_length, dtype=np.int32)))), 1)
154 | edge_features = np.concatenate((edge_features, edge_feature))
155 |
156 | label = is_sat(n, clause_list)
157 | return n, m, graph_map, edge_features, None, label, clause_list
158 |
159 |
160 | ###############################################################################################
161 |
162 |
163 | class ModularCNFGenerator(CNFGeneratorBase):
164 | "Implements the modular random CNF generator according to the Community Attachment model (https://www.iiia.csic.es/sites/default/files/aij16.pdf)"
165 |
166 | def __init__(self, k, min_n, max_n, min_q, max_q, min_c, max_c, min_alpha, max_alpha, alpha_resolution=10):
167 |
168 | super(ModularCNFGenerator, self).__init__(min_n, max_n, min_alpha, max_alpha, alpha_resolution)
169 | self._k = k
170 | self._min_c = min_c
171 | self._max_c = max_c
172 | self._min_q = min_q
173 | self._max_q = max_q
174 |
175 | def generate(self):
176 | n = np.random.randint(self._min_n, self._max_n + 1)
177 | alpha = np.random.uniform(self._min_alpha, self._max_alpha)
178 | m = int(n * alpha)
179 |
180 | q = np.random.uniform(self._min_q, self._max_q)
181 | c = np.random.randint(self._min_c, self._max_c + 1)
182 | c = max(1, min(c, int(n / self._k) - 1))
183 | size = int(n / c)
184 | community_size = size * np.ones(c, dtype=np.int32)
185 | community_size[c - 1] += (n - np.sum(community_size))
186 |
187 | p = q + 1.0 / c
188 | edge_num = m * self._k
189 |
190 | graph_map = np.zeros((2, edge_num), dtype=np.int32)
191 | index = np.random.permutation(n)
192 |
193 | ind = 0
194 | for i in range(m):
195 | coin = np.random.uniform()
196 | if coin <= p: # Pick from the same community
197 | community = np.random.randint(0, c)
198 | graph_map[0, ind:(ind + self._k)] = index[np.random.choice(range(size*community, size*community + community_size[community]), self._k, replace=False)]
199 | else: # Pick from different communities
200 | if c >= self._k:
201 | communities = np.random.choice(c, self._k, replace=False)
202 | temp = np.random.uniform(size = self._k)
203 | inner_offset = (temp * community_size[communities]).astype(int)
204 | graph_map[0, ind:(ind + self._k)] = index[size*communities + inner_offset]
205 | else:
206 | graph_map[0, ind:(ind + self._k)] = np.random.choice(n, self._k, replace=False)
207 |
208 | graph_map[1, ind:(ind+self._k)] = i
209 | ind += self._k
210 |
211 | edge_feature = 2.0 * np.random.choice(2, edge_num) - 1
212 |
213 | return n, m, graph_map, edge_feature, None, -1.0, []
214 |
215 | def generate_complete(self):
216 | n = np.random.randint(self._min_n, self._max_n + 1)
217 | alpha = np.random.uniform(self._alpha, self._alpha + self._alpha_inc)
218 | m = int(n * alpha)
219 | max_trial = 10
220 |
221 | q = np.random.uniform(self._min_q, self._max_q)
222 | c = np.random.randint(self._min_c, self._max_c + 1)
223 | c = max(self._k + 1, min(c, int(n / self._k) - 1))
224 | size = int(n / c)
225 | community_size = size * np.ones(c, dtype=np.int32)
226 | community_size[c - 1] += (n - np.sum(community_size))
227 |
228 | p = q + 1.0 / c
229 | edge_num = m * self._k
230 |
231 | index = np.random.permutation(n)
232 | clause_set = set()
233 | clause_list = []
234 | graph_map = np.zeros((2, 0), dtype=np.int32)
235 | edge_features = np.zeros(0)
236 |
237 | i = -1
238 | for _ in range(m):
239 | for _ in range(max_trial):
240 | coin = np.random.uniform()
241 | if coin <= p: # Pick from the same community
242 | community = np.random.randint(0, c)
243 | literals = np.sort(index[np.random.choice(range(size*community, size*community + community_size[community]), self._k, replace=False)])
244 | else: # Pick from different communities
245 | communities = np.random.choice(c, self._k, replace=False)
246 | temp = np.random.uniform(size = self._k)
247 | inner_offset = (temp * community_size[communities]).astype(int)
248 | literals = np.sort(index[size*communities + inner_offset])
249 |
250 | edge_feature = 2.0 * np.random.choice(2, self._k) - 1
251 | iclause = list(((literals + 1) * edge_feature).astype(int))
252 |
253 | if str(iclause) not in clause_set:
254 | i += 1
255 | break
256 |
257 | clause_set.add(str(iclause))
258 | clause_list += [iclause]
259 |
260 | graph_map = np.concatenate((graph_map, np.stack((literals, i * np.ones(self._k, dtype=np.int32)))), 1)
261 | edge_features = np.concatenate((edge_features, edge_feature))
262 |
263 | label = is_sat(n, clause_list)
264 | return n, m, graph_map, edge_features, None, label, clause_list
265 |
266 |
267 | ###############################################################################################
268 |
269 |
270 | class VariableModularCNFGenerator(CNFGeneratorBase):
271 | "Implements a variation of the Community Attachment model with variable sized clauses."
272 |
273 | def __init__(self, min_k, max_k, min_n, max_n, min_q, max_q, min_c, max_c, min_alpha, max_alpha, alpha_resolution=10):
274 |
275 | super(VariableModularCNFGenerator, self).__init__(min_n, max_n, min_alpha, max_alpha, alpha_resolution)
276 | self._min_k = min_k
277 | self._max_k = max_k
278 | self._min_c = min_c
279 | self._max_c = max_c
280 | self._min_q = min_q
281 | self._max_q = max_q
282 |
283 | def generate(self):
284 | n = np.random.randint(self._min_n, self._max_n + 1)
285 | alpha = np.random.uniform(self._min_alpha, self._max_alpha)
286 | m = int(n * alpha)
287 |
288 | q = np.random.uniform(self._min_q, self._max_q)
289 | c = np.random.randint(self._min_c, self._max_c + 1)
290 | c = max(1, min(c, n))
291 | size = int(n / c)
292 | community_size = size * np.ones(c, dtype=np.int32)
293 | community_size[c - 1] += (n - np.sum(community_size))
294 |
295 | p = q + 1.0 / c
296 | clause_length = [np.random.randint(min(self._min_k, size), min(self._max_k, n-1, size) + 1) for _ in range(m)]
297 | edge_num = np.sum(clause_length)
298 |
299 | graph_map = np.zeros((2, edge_num), dtype=np.int32)
300 | index = np.random.permutation(n)
301 |
302 | ind = 0
303 | for i in range(m):
304 | coin = np.random.uniform()
305 | if coin <= p: # Pick from the same community
306 | community = np.random.randint(0, c)
307 | graph_map[0, ind:(ind + clause_length[i])] = index[np.random.choice(range(size*community, size*community + community_size[community]), clause_length[i], replace=False)]
308 | else: # Pick from different communities
309 | if c >= clause_length[i]:
310 | communities = np.random.choice(c, clause_length[i], replace=False)
311 | temp = np.random.uniform(size = clause_length[i])
312 | inner_offset = (temp * community_size[communities]).astype(int)
313 | graph_map[0, ind:(ind + clause_length[i])] = index[size*communities + inner_offset]
314 | else:
315 | graph_map[0, ind:(ind + clause_length[i])] = np.random.choice(n, clause_length[i], replace=False)
316 |
317 | graph_map[1, ind:(ind+clause_length[i])] = i
318 | ind += clause_length[i]
319 |
320 | edge_feature = 2.0 * np.random.choice(2, edge_num) - 1
321 |
322 | return n, m, graph_map, edge_feature, None, -1.0, []
323 |
324 | def generate_complete(self):
325 | n = np.random.randint(self._min_n, self._max_n + 1)
326 | alpha = np.random.uniform(self._alpha, self._alpha + self._alpha_inc)
327 | m = int(n * alpha)
328 | max_trial = 10
329 |
330 | q = np.random.uniform(self._min_q, self._max_q)
331 | c = np.random.randint(self._min_c, self._max_c + 1)
332 | c = max(self._k + 1, min(c, int(n / self._k) - 1))
333 | size = int(n / c)
334 | community_size = size * np.ones(c, dtype=np.int32)
335 | community_size[c - 1] += (n - np.sum(community_size))
336 |
337 | p = q + 1.0 / c
338 | edge_num = m * self._k
339 |
340 | index = np.random.permutation(n)
341 | clause_set = set()
342 | clause_list = []
343 | graph_map = np.zeros((2, 0), dtype=np.int32)
344 | edge_features = np.zeros(0)
345 |
346 | i = -1
347 | for _ in range(m):
348 | for _ in range(max_trial):
349 | clause_length = np.random.randint(min(self._min_k, size), min(self._max_k, n-1, size) + 1)
350 | coin = np.random.uniform()
351 | if coin <= p: # Pick from the same community
352 | community = np.random.randint(0, c)
353 | literals = np.sort(index[np.random.choice(range(size*community, size*community + community_size[community]), clause_length, replace=False)])
354 | else: # Pick from different communities
355 | if c >= clause_length:
356 | communities = np.random.choice(c, clause_length, replace=False)
357 | temp = np.random.uniform(size = clause_length)
358 | inner_offset = (temp * community_size[communities]).astype(int)
359 | literals = np.sort(index[size*communities + inner_offset])
360 | else:
361 | literals = np.random.choice(n, clause_length, replace=False)
362 |
363 | edge_feature = 2.0 * np.random.choice(2, clause_length) - 1
364 | iclause = list(((literals + 1) * edge_feature).astype(int))
365 |
366 | if str(iclause) not in clause_set:
367 | i += 1
368 | break
369 |
370 | clause_set.add(str(iclause))
371 | clause_list += [iclause]
372 |
373 | graph_map = np.concatenate((graph_map, np.stack((literals, i * np.ones(clause_length, dtype=np.int32)))), 1)
374 | edge_features = np.concatenate((edge_features, edge_feature))
375 |
376 | label = is_sat(n, clause_list)
377 | return n, m, graph_map, edge_features, None, label, clause_list
378 |
379 |
380 | ###############################################################################################
381 |
382 |
383 | if __name__ == '__main__':
384 |
385 | parser = argparse.ArgumentParser()
386 | parser.add_argument('out_dir', action='store', type=str)
387 | parser.add_argument('out_json', action='store', type=str)
388 | parser.add_argument('name', action='store', type=str)
389 | parser.add_argument('size', action='store', type=int)
390 | parser.add_argument('method', action='store', type=str)
391 |
392 | parser.add_argument('--min_n', action='store', dest='min_n', type=int, default=40)
393 | parser.add_argument('--max_n', action='store', dest='max_n', type=int, default=40)
394 |
395 | parser.add_argument('--min_c', action='store', dest='min_c', type=int, default=10)
396 | parser.add_argument('--max_c', action='store', dest='max_c', type=int, default=40)
397 |
398 | parser.add_argument('--min_q', action='store', dest='min_q', type=float, default=0.3)
399 | parser.add_argument('--max_q', action='store', dest='max_q', type=float, default=0.9)
400 |
401 | parser.add_argument('--min_k', action='store', dest='min_k', type=int, default=3)
402 | parser.add_argument('--max_k', action='store', dest='max_k', type=int, default=5)
403 |
404 | parser.add_argument('--min_a', action='store', dest='min_a', type=float, default=2)
405 | parser.add_argument('--max_a', action='store', dest='max_a', type=float, default=10)
406 | parser.add_argument('--res', action='store', dest='res', type=int, default=5)
407 |
408 | parser.add_argument('-s', '--sat_only', help='Include SAT examples only', required=False, action='store_true', default=False)
409 |
410 | args = vars(parser.parse_args())
411 |
412 | if args['method'] == 'modular':
413 | generator = ModularCNFGenerator(k=args['min_k'], min_n=args['min_n'], max_n=args['max_n'], min_q=args['min_q'],
414 | max_q=args['max_q'], min_c=args['min_c'], max_c=args['max_c'], min_alpha=args['min_a'], max_alpha=args['max_a'], alpha_resolution=args['res'])
415 | elif args['method'] == 'v-modular':
416 | generator = VariableModularCNFGenerator(min_k=args['min_k'], max_k=args['max_k'], min_n=args['min_n'], max_n=args['max_n'], min_q=args['min_q'],
417 | max_q=args['max_q'], min_c=args['min_c'], max_c=args['max_c'], min_alpha=args['min_a'], max_alpha=args['max_a'], alpha_resolution=args['res'])
418 | else:
419 | generator = UniformCNFGenerator(min_n=args['min_n'], max_n=args['max_n'], min_k=args['min_k'],
420 | max_k=args['max_k'], min_alpha=args['min_a'], max_alpha=args['max_a'], alpha_resolution=args['res'])
421 |
422 | generator.generate_dataset(args['size'], args['out_dir'], args['out_json'], args['name'], args['sat_only'])
--------------------------------------------------------------------------------
/src/pdp/factorgraph/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
3 |
4 | # factor_graph_trainer.py : Defines the trainer base class for the PDP framework.
5 |
6 | import os
7 | import time
8 | import math
9 | import multiprocessing
10 |
11 | import numpy as np
12 |
13 | import torch
14 | import torch.nn as nn
15 |
16 | from pdp.factorgraph.dataset import FactorGraphDataset
17 |
18 |
19 | def _module(model):
20 | return model.module if isinstance(model, nn.DataParallel) else model
21 |
22 |
23 | # pylint: disable=protected-access
24 | class FactorGraphTrainerBase:
25 | "Base class of the Factor Graph trainer pipeline (abstract)."
26 |
27 | # pylint: disable=unused-argument
28 | def __init__(self, config, has_meta_data, error_dim, loss, evaluator, use_cuda, logger):
29 |
30 | self._config = config
31 | self._logger = logger
32 | self._use_cuda = use_cuda and torch.cuda.is_available()
33 |
34 | if config['verbose']:
35 | if self._use_cuda:
36 | self._logger.info('Using GPU...')
37 | else:
38 | self._logger.info('Using CPU...')
39 |
40 | self._device = torch.device("cuda" if self._use_cuda else "cpu")
41 |
42 | self._error_dim = error_dim
43 | self._num_cores = multiprocessing.cpu_count()
44 | self._loss = loss
45 | self._evaluator = evaluator
46 |
47 | if config['verbose']:
48 | self._logger.info("The number of CPU cores is %s." % self._num_cores)
49 |
50 | torch.set_num_threads(self._num_cores)
51 |
52 | # Build the network
53 | self._model_list = [self._set_device(model) for model in self._build_graph(self._config)]
54 |
55 | def _build_graph(self, config):
56 | "Builds the forward computational graph."
57 |
58 | raise NotImplementedError("Subclass must implement abstract method")
59 |
60 | # pylint: disable=unused-argument
61 | def _compute_loss(self, model, loss, prediction, label, graph_map, batch_variable_map,
62 | batch_function_map, edge_feature, meta_data):
63 | "Computes the loss function."
64 |
65 | return loss(prediction, label)
66 |
67 | # pylint: disable=unused-argument
68 | def _compute_evaluation_metrics(self, model, evaluator, prediction, label, graph_map,
69 | batch_variable_map, batch_function_map, edge_feature, meta_data):
70 | "Computes the evaluation function."
71 |
72 | return evaluator(prediction, label)
73 |
74 | def _load(self, import_path_base):
75 | "Loads the model(s) from file."
76 |
77 | for model in self._model_list:
78 | _module(model).load(import_path_base)
79 |
80 | def _save(self, export_path_base):
81 | "Saves the model(s) to file."
82 |
83 | for model in self._model_list:
84 | _module(model).save(export_path_base)
85 |
86 | def _reset_global_step(self):
87 | "Resets the global step counter."
88 |
89 | for model in self._model_list:
90 | _module(model)._global_step.data = torch.tensor(
91 | [0], dtype=torch.float, device=self._device)
92 |
93 | def _set_device(self, model):
94 | "Sets the CPU/GPU device."
95 |
96 | if self._use_cuda:
97 | return nn.DataParallel(model).cuda(self._device)
98 | return model.cpu()
99 |
100 | def _to_cuda(self, data):
101 | if isinstance(data, list):
102 | return data
103 |
104 | if data is not None and self._use_cuda:
105 | return data.cuda(self._device, non_blocking=True)
106 | return data
107 |
108 | def get_parameter_list(self):
109 | "Returns list of dictionaries with models' parameters."
110 | return [{'params': filter(lambda p: p.requires_grad, model.parameters())}
111 | for model in self._model_list]
112 |
113 | def _train_epoch(self, train_loader, optimizer):
114 |
115 | train_batch_num = math.ceil(len(train_loader.dataset) / self._config['batch_size'])
116 |
117 | total_loss = np.zeros(len(self._model_list), dtype=np.float32)
118 | total_example_num = 0
119 |
120 | for (j, data) in enumerate(train_loader, 1):
121 | segment_num = len(data[0])
122 |
123 | for i in range(segment_num):
124 |
125 | (graph_map, batch_variable_map, batch_function_map,
126 | edge_feature, graph_feat, label, _) = [self._to_cuda(d[i]) for d in data]
127 | total_example_num += (batch_variable_map.max() + 1)
128 |
129 | self._train_batch(total_loss, optimizer, graph_map, batch_variable_map, batch_function_map,
130 | edge_feature, graph_feat, label)
131 |
132 | if self._config['verbose']:
133 | print("Training epoch with batch of size {:4d} ({:4d}/{:4d}): {:3d}% complete...".format(
134 | batch_variable_map.max().item(), total_example_num % self._config['batch_size'], self._config['batch_size'],
135 | int(j * 100.0 / train_batch_num)), end='\r')
136 |
137 | del graph_map
138 | del batch_variable_map
139 | del batch_function_map
140 | del edge_feature
141 | del graph_feat
142 | del label
143 |
144 | for model in self._model_list:
145 | _module(model)._global_step += 1
146 |
147 | return total_loss / total_example_num # max(1, len(train_loader))
148 |
149 | def _train_batch(self, total_loss, optimizer, graph_map, batch_variable_map, batch_function_map,
150 | edge_feature, graph_feat, label):
151 |
152 | optimizer.zero_grad()
153 | lambda_value = torch.tensor([self._config['lambda']], dtype=torch.float32, device=self._device)
154 |
155 | for (i, model) in enumerate(self._model_list):
156 |
157 | state = _module(model).get_init_state(graph_map, batch_variable_map, batch_function_map,
158 | edge_feature, graph_feat, self._config['randomized'])
159 |
160 | loss = torch.zeros(1, device=self._device)
161 |
162 | for t in torch.arange(self._config['train_outer_recurrence_num'], dtype=torch.int32, device=self._device):
163 |
164 | prediction, state = model(
165 | init_state=state, graph_map=graph_map, batch_variable_map=batch_variable_map,
166 | batch_function_map=batch_function_map, edge_feature=edge_feature,
167 | meta_data=graph_feat, is_training=True, iteration_num=self._config['train_inner_recurrence_num'])
168 |
169 | loss += self._compute_loss(
170 | model=_module(model), loss=self._loss, prediction=prediction,
171 | label=label, graph_map=graph_map, batch_variable_map=batch_variable_map,
172 | batch_function_map=batch_function_map, edge_feature=edge_feature, meta_data=graph_feat) * \
173 | lambda_value.pow((self._config['train_outer_recurrence_num'] - t - 1).float())
174 |
175 | loss.backward()
176 | nn.utils.clip_grad_norm_(model.parameters(), self._config['clip_norm'])
177 | total_loss[i] += loss.detach().cpu().numpy()
178 |
179 | for s in state:
180 | del s
181 |
182 | optimizer.step()
183 |
184 | def _test_epoch(self, validation_loader, batch_replication):
185 |
186 | test_batch_num = math.ceil(len(validation_loader.dataset) / self._config['batch_size'])
187 |
188 | with torch.no_grad():
189 |
190 | error = np.zeros(
191 | (self._error_dim, len(self._model_list)), dtype=np.float32)
192 | total_example_num = 0
193 |
194 | for (j, data) in enumerate(validation_loader, 1):
195 | segment_num = len(data[0])
196 |
197 | for i in range(segment_num):
198 |
199 | (graph_map, batch_variable_map, batch_function_map,
200 | edge_feature, graph_feat, label, _) = [self._to_cuda(d[i]) for d in data]
201 | total_example_num += (batch_variable_map.max() + 1).detach().cpu().numpy()
202 |
203 | self._test_batch(error, graph_map, batch_variable_map, batch_function_map,
204 | edge_feature, graph_feat, label, batch_replication)
205 |
206 | if self._config['verbose']:
207 | print("Testing epoch with batch of size {:4d} ({:4d}/{:4d}): {:3d}% complete...".format(
208 | batch_variable_map.max().item(), total_example_num % self._config['batch_size'], self._config['batch_size'],
209 | int(j * 100.0 / test_batch_num)), end='\r')
210 |
211 | del graph_map
212 | del batch_variable_map
213 | del batch_function_map
214 | del edge_feature
215 | del graph_feat
216 | del label
217 |
218 | # if self._use_cuda:
219 | # torch.cuda.empty_cache()
220 |
221 | return error / total_example_num
222 |
223 | def _test_batch(self, error, graph_map, batch_variable_map, batch_function_map,
224 | edge_feature, graph_feat, label, batch_replication):
225 |
226 | this_batch_size = batch_variable_map.max() + 1
227 | edge_num = graph_map.size(1)
228 |
229 | for (i, model) in enumerate(self._model_list):
230 |
231 | state = _module(model).get_init_state(graph_map, batch_variable_map, batch_function_map,
232 | edge_feature, graph_feat, randomized=True, batch_replication=batch_replication)
233 |
234 | prediction, _ = model(
235 | init_state=state, graph_map=graph_map, batch_variable_map=batch_variable_map,
236 | batch_function_map=batch_function_map, edge_feature=edge_feature,
237 | meta_data=graph_feat, is_training=False, iteration_num=self._config['test_recurrence_num'],
238 | check_termination=self._check_recurrence_termination, batch_replication=batch_replication)
239 |
240 | error[:, i] += (this_batch_size.float() * self._compute_evaluation_metrics(
241 | model=_module(model), evaluator=self._evaluator,
242 | prediction=prediction, label=label, graph_map=graph_map,
243 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map,
244 | edge_feature=edge_feature, meta_data=graph_feat)).detach().cpu().numpy()
245 |
246 | for p in prediction:
247 | del p
248 |
249 | for s in state:
250 | del s
251 |
252 | def _predict_epoch(self, validation_loader, post_processor, batch_replication, file):
253 |
254 | test_batch_num = math.ceil(len(validation_loader.dataset) / self._config['batch_size'])
255 |
256 | with torch.no_grad():
257 |
258 | for (j, data) in enumerate(validation_loader, 1):
259 | segment_num = len(data[0])
260 |
261 | for i in range(segment_num):
262 |
263 | (graph_map, batch_variable_map, batch_function_map,
264 | edge_feature, graph_feat, label, misc_data) = [self._to_cuda(d[i]) for d in data]
265 |
266 | self._predict_batch(graph_map, batch_variable_map, batch_function_map,
267 | edge_feature, graph_feat, label, misc_data, post_processor, batch_replication, file)
268 |
269 | del graph_map
270 | del batch_variable_map
271 | del batch_function_map
272 | del edge_feature
273 | del graph_feat
274 | del label
275 |
276 | # if self._config['verbose']:
277 | # print("Predicting epoch: %3d%% complete..."
278 | # % (j * 100.0 / test_batch_num), end='\r')
279 |
280 | def _predict_batch(self, graph_map, batch_variable_map, batch_function_map,
281 | edge_feature, graph_feat, label, misc_data, post_processor, batch_replication, file):
282 |
283 | edge_num = graph_map.size(1)
284 |
285 | for (i, model) in enumerate(self._model_list):
286 |
287 | state = _module(model).get_init_state(graph_map, batch_variable_map, batch_function_map,
288 | edge_feature, graph_feat, randomized=False, batch_replication=batch_replication)
289 |
290 | prediction, _ = model(
291 | init_state=state, graph_map=graph_map, batch_variable_map=batch_variable_map,
292 | batch_function_map=batch_function_map, edge_feature=edge_feature,
293 | meta_data=graph_feat, is_training=False, iteration_num=self._config['test_recurrence_num'],
294 | check_termination=self._check_recurrence_termination, batch_replication=batch_replication)
295 |
296 | if post_processor is not None and callable(post_processor):
297 | message = post_processor(_module(model), prediction, graph_map,
298 | batch_variable_map, batch_function_map, edge_feature, graph_feat, label, misc_data)
299 | print(message, file=file)
300 |
301 | for p in prediction:
302 | del p
303 |
304 | for s in state:
305 | del s
306 |
307 | def _check_recurrence_termination(self, active, prediction, sat_problem):
308 | "De-actives the CNF examples which the model has already found a SAT solution for."
309 | pass
310 |
311 | def train(self, train_list, validation_list, optimizer, last_export_path_base=None,
312 | best_export_path_base=None, metric_index=0, load_model=None, reset_step=False,
313 | generator=None, train_epoch_size=0):
314 | "Trains the PDP model."
315 |
316 | # Build the input pipeline
317 | train_loader = FactorGraphDataset.get_loader(
318 | input_file=train_list[0], limit=self._config['train_batch_limit'],
319 | hidden_dim=self._config['hidden_dim'], batch_size=self._config['batch_size'], shuffle=True,
320 | num_workers=self._num_cores, max_cache_size=self._config['max_cache_size'], generator=generator,
321 | epoch_size=train_epoch_size)
322 |
323 | validation_loader = FactorGraphDataset.get_loader(
324 | input_file=validation_list[0], limit=self._config['test_batch_limit'],
325 | hidden_dim=self._config['hidden_dim'], batch_size=self._config['batch_size'], shuffle=False,
326 | num_workers=self._num_cores, max_cache_size=self._config['max_cache_size'])
327 |
328 | model_num = len(self._model_list)
329 |
330 | errors = np.zeros(
331 | (self._error_dim, model_num, self._config['epoch_num'],
332 | self._config['repetition_num']), dtype=np.float32)
333 |
334 | losses = np.zeros(
335 | (model_num, self._config['epoch_num'], self._config['repetition_num']),
336 | dtype=np.float32)
337 |
338 | best_errors = np.repeat(np.inf, model_num)
339 |
340 | if self._use_cuda:
341 | torch.backends.cudnn.benchmark = True
342 |
343 | for rep in range(self._config['repetition_num']):
344 |
345 | if load_model == "best" and best_export_path_base is not None:
346 | self._load(best_export_path_base)
347 | elif load_model == "last" and last_export_path_base is not None:
348 | self._load(last_export_path_base)
349 |
350 | if reset_step:
351 | self._reset_global_step()
352 |
353 | for epoch in range(self._config['epoch_num']):
354 |
355 | # Training
356 | start_time = time.time()
357 | losses[:, epoch, rep] = self._train_epoch(train_loader, optimizer)
358 |
359 | if self._use_cuda:
360 | torch.cuda.empty_cache()
361 |
362 | # Validation
363 | errors[:, :, epoch, rep] = self._test_epoch(validation_loader, 1)
364 | duration = time.time() - start_time
365 |
366 | # Checkpoint the best models so far
367 | if last_export_path_base is not None:
368 | for (i, model) in enumerate(self._model_list):
369 | _module(model).save(last_export_path_base)
370 |
371 | if best_export_path_base is not None:
372 | for (i, model) in enumerate(self._model_list):
373 | if errors[metric_index, i, epoch, rep] < best_errors[i]:
374 | best_errors[i] = errors[metric_index, i, epoch, rep]
375 | _module(model).save(best_export_path_base)
376 |
377 | if self._use_cuda:
378 | torch.cuda.empty_cache()
379 |
380 | if self._config['verbose']:
381 | message = ''
382 | for (i, model) in enumerate(self._model_list):
383 | name = _module(model)._name
384 | message += 'Step {:d}: {:s} error={:s}, {:s} loss={:5.5f} |'.format(
385 | _module(model)._global_step.int()[0], name,
386 | np.array_str(errors[:, i, epoch, rep].flatten()),
387 | name, losses[i, epoch, rep])
388 |
389 | self._logger.info('Rep {:2d}, Epoch {:2d}: {:s}'.format(rep + 1, epoch + 1, message))
390 | self._logger.info('Time spent: %s seconds' % duration)
391 |
392 | if self._use_cuda:
393 | torch.backends.cudnn.benchmark = False
394 |
395 | if best_export_path_base is not None:
396 | # Save losses and errors
397 | base = os.path.relpath(best_export_path_base)
398 | np.save(os.path.join(base, "losses"), losses, allow_pickle=False)
399 | np.save(os.path.join(base, "errors"), errors, allow_pickle=False)
400 |
401 | # Save the model
402 | self._save(best_export_path_base)
403 |
404 | return self._model_list, errors, losses
405 |
406 | def test(self, test_list, import_path_base=None, batch_replication=1):
407 | "Tests the PDP model and generates test stats."
408 |
409 | if isinstance(test_list, list):
410 | test_files = test_list
411 | elif os.path.isdir(test_list):
412 | test_files = [os.path.join(test_list, f) for f in os.listdir(test_list) \
413 | if os.path.isfile(os.path.join(test_list, f)) and f[-5:].lower() == '.json' ]
414 | elif isinstance(test_list, str):
415 | test_files = [test_list]
416 | else:
417 | return None
418 |
419 | result = []
420 |
421 | for file in test_files:
422 | # Build the input pipeline
423 | test_loader = FactorGraphDataset.get_loader(
424 | input_file=file, limit=self._config['test_batch_limit'],
425 | hidden_dim=self._config['hidden_dim'], batch_size=self._config['batch_size'], shuffle=False,
426 | num_workers=self._num_cores, max_cache_size=self._config['max_cache_size'], batch_replication=batch_replication)
427 |
428 | if import_path_base is not None:
429 | self._load(import_path_base)
430 |
431 | start_time = time.time()
432 | error = self._test_epoch(test_loader, batch_replication)
433 | duration = time.time() - start_time
434 |
435 | if self._use_cuda:
436 | torch.cuda.empty_cache()
437 |
438 | if self._config['verbose']:
439 | message = ''
440 | for (i, model) in enumerate(self._model_list):
441 | message += '{:s}, dataset:{:s} error={:s}|'.format(
442 | _module(model)._name, file, np.array_str(error[:, i].flatten()))
443 |
444 | self._logger.info(message)
445 | self._logger.info('Time spent: %s seconds' % duration)
446 |
447 | result += [[file, error, duration]]
448 |
449 | return result
450 |
451 | def predict(self, test_list, out_file, import_path_base=None, post_processor=None, batch_replication=1):
452 | "Produces predictions for the trained PDP model."
453 |
454 | # Build the input pipeline
455 | test_loader = FactorGraphDataset.get_loader(
456 | input_file=test_list, limit=self._config['test_batch_limit'],
457 | hidden_dim=self._config['hidden_dim'], batch_size=self._config['batch_size'], shuffle=False,
458 | num_workers=self._num_cores, max_cache_size=self._config['max_cache_size'], batch_replication=batch_replication)
459 |
460 | if import_path_base is not None:
461 | self._load(import_path_base)
462 |
463 | start_time = time.time()
464 | self._predict_epoch(test_loader, post_processor, batch_replication, out_file)
465 |
466 | duration = time.time() - start_time
467 |
468 | if self._use_cuda:
469 | torch.cuda.empty_cache()
470 |
471 | if self._config['verbose']:
472 | self._logger.info('Time spent: %s seconds' % duration)
473 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PDP Framework for Neural Constraint Satisfaction Solving
2 |
3 | The PDP framework is a generic framework based on the idea of Propagation, Decimation and Prediction (PDP) for learning and implementing message passing-based solvers for constraint satisfaction problems (CSP). In particular, it provides an elegant unsupervised framework for training neural solvers based on the idea of energy minimization. Our SAT solver adaptation of the PDP framework, referred as **SATYR**, supports a wide spectrum of solvers from fully neural architectures to classical inference-based techniques (such as Survey Propagation) with hybrid methods in between. For further theoretical details of the framework, please refer to our paper:
4 |
5 | **Saeed Amizadeh, Sergiy Matusevych and Markus Weimer, [PDP: A General Neural Framework for Learning Constraint Satisfaction Solvers](https://arxiv.org/abs/1903.01969), arXiv preprint arXiv:1903.01969, 2019.** [**[Video]**](https://www.youtube.com/watch?v=PwuCc7Lylww&ab_channel=CP2020)
6 |
7 | ```
8 | @article{amizadeh2019pdp,
9 | title={PDP: A General Neural Framework for Learning Constraint Satisfaction Solvers},
10 | author={Amizadeh, Saeed and Matusevych, Sergiy and Weimer, Markus},
11 | journal={arXiv preprint arXiv:1903.01969},
12 | year={2019}
13 | }
14 | ```
15 |
16 | We also note the present work is still far away from competing with modern industrial solvers; nevertheless, we believe it is a significant step in the right direction for machine learning-based methods. Hence, we are glad to open source our code to the researchers in related fields including the neuro-symbolic community as well as the classical SAT community.
17 |
18 | # SATYR
19 |
20 | SATYR is the adaptation of the PDP framework for training and deploying neural Boolean Satisfiability solvers. In particular, SATYR implements:
21 |
22 | 1. Fully or partially neural SAT solvers that can be trained toward solving SAT for a specific distribution of problem instances. The training is based on unsupervised energy minimization and can be performed on an infinite stream of unlabeled, random instances sampled from the target distribution.
23 |
24 | 2. Non-learnable classical solvers based on message passing in graphical models (e.g. Survey Propagation). Even though, these solvers are non-learnable, they still benefit from the embarrassingly parallel implementation via the PDP framework on GPUs.
25 |
26 | It should be noted that all the SATYR solvers try to find a satisfying assignment for input SAT formulas. However, if the SATYR solvers cannot find a satisfying solution for a given problem within their iteration number budget, it does NOT necessarily mean that the input problem in UNSAT. In other words, none of the SATYR solvers provide the proof of unsatisfiability.
27 |
28 | # Setup
29 |
30 | ## Prerequisites
31 |
32 | * Python 3.5 or higher.
33 | * PyTorch 0.4.0 or higher.
34 |
35 | Run:
36 |
37 | ```
38 | > python setup.py install
39 | ```
40 |
41 | # Usage
42 |
43 | The SATYR solvers can be used in two main modes: (1) apply an already-trained model or non-ML algorithm to test data, and (2) train/test new models.
44 |
45 | ## Running a (Trained) SATYR Solver
46 |
47 | The usage for running a SATYR solver againts a set of SAT problems (represented as Conjunctive Normal Form (CNF)) is:
48 |
49 | ```
50 | > python satyr.py [-h] [-b BATCH_REPLICATION] [-z BATCH_SIZE]
51 | [-m MAX_CACHE_SIZE] [-l TEST_BATCH_LIMIT]
52 | [-w LOCAL_SEARCH_ITERATION] [-e EPSILON] [-v] [-c] [-d]
53 | [-s RANDOM_SEED] [-o OUTPUT]
54 | model_config test_path test_recurrence_num
55 | ```
56 |
57 | The commandline arguments are:
58 |
59 | + **-h** or **--help**: Shows the commandline options.
60 |
61 | + **-b BATCH_REPLICATION** or **--batch_replication BATCH_REPLICATION**: BATCH_REPLICATION is the replication factor for input problems to further benefit from parallelization (default: 1).
62 |
63 | + **-z BATCH_SIZE** or **--batch_size BATCH_SIZE**: BATCH_SIZE is the batch size (default: 5000).
64 |
65 | + **-m MAX_CACHE_SIZE** or **--max_cache_size MAX_CACHE_SIZE**: MAX_CACHE_SIZE is the maximum size of the cache containing the parsed CNFs loaded from disk (mostly useful for iterative training) (default: 100000).
66 |
67 | + **-l TEST_BATCH_LIMIT** or **--test_batch_limit TEST_BATCH_LIMIT**: TEST_BATCH_LIMIT is the memory limit used for dynamic batching. It must be empirically tuned by the user depending on the available GPU memory (default: 40000000).
68 |
69 | + **-w LOCAL_SEARCH_ITERATION** or **--local_search_iteration LOCAL_SEARCH_ITERATION**: LOCAL_SEARCH_ITERATION is the maximum number of local search (i.e. Walk-SAT) iterations that can be optionally applied as a post-processing step after the main solver terminates (default: 100).
70 |
71 | + **-e EPSILON** or **--epsilon EPSILON**: EPSILON is the probability with which the post-processing local search picks a random variable instead of the best option for flipping (default: 0.5).
72 |
73 | + **-v** or **--verbose**: Prints the log messages to STDOUT.
74 |
75 | + **-c** or **--cpu_mode**: Forces the solver to run on CPU.
76 |
77 | + **-d** or **--dimacs**: Notifies the solver that the input path is a directory of DIMACS files.
78 |
79 | + **-s RANDOM_SEED** or **--random_seed RANDOM_SEED**: RANDOM_SEED is the random seed directly affecting the randomized initial values of the messages in the PDP solvers.
80 |
81 | + **-o OUTPUT** or **--output OUTPUT**: OUTPUT is the path to the output JSON file that would contain the solutions for the input CNFs. If not specified, the output is directed to STDOUT.
82 |
83 | + **model_config**: The path to YAML config file specifying the model used for SAT solving. A few example config files are provided [here](https://github.com/Microsoft/PDP-Solver/tree/master/config/Predict). A model config file specifies the following properties:
84 |
85 | * **model_type**: The type of solver. So far, we have implemented six different types of solvers in SATYR:
86 |
87 | * *'np-nd-np'*: A fully neural PDP solver.
88 |
89 | * *'p-d-p'*: A PDP solver that implements the classical Survey Propagation with greedy sequential decimation.[[1]](#reference-1)
90 |
91 | * *'p-nd-np'*: A PDP solver that implements the classical Survey Propagation except with neural decimation.
92 |
93 | * *'np-d-np'*: A PDP solver that implements neural propagation with a greedy sequential decimation.
94 |
95 | * *'reinforce'*: A PDP solver that implements the classical Survey Propagation with concurrent distributed decimation (The REINFORCE algorithm).[[2]](#reference-2)
96 |
97 | * *'walk-sat'*: A PDP solver that implements the classical local search Walk-SAT algorithm.[[3]](#reference-3)
98 |
99 | * **has_meta_data**: whether the input problem instances contain meta features other than the CNF itself (Note: loading of such features is not supported by the current input pipeline).
100 |
101 | * **model_name**: The name picked for the model by the user.
102 |
103 | * **model_path**: The path to the saved weights for the trained model.
104 |
105 | * **label_dim**: Always set to 1 for SATYR.
106 |
107 | * **edge_feature_dim**: Always set to 1 for SATYR.
108 |
109 | * **meta_feature_dim**: The dimensionality of the meta features (0 for now).
110 |
111 | * **prediction_dim**: Always set to 1 for SATYR.
112 |
113 | If **model_type** is *'np-nd-np'*, *'p-nd-np'* or *'np-d-np'*:
114 |
115 | * **hidden_dim**: The dimensionality of messages between the propagator and the decimator.
116 |
117 | If **model_type** is *'np-nd-np'* or *'np-d-np'*:
118 |
119 | * **mem_hidden_dim**: The dimensionality of the hidden layer for the perceptron that is applied to messages *before* the aggregation step in the propagator.
120 |
121 | * **agg_hidden_dim**: The dimensionality of the hidden layer for the perceptron that is applied to messages *after* the aggregation step in the propagator.
122 |
123 | * **mem_agg_hidden_dim**: The output dimensionality of the perceptron that is applied to messages *before* the aggregation step in the propagator.
124 |
125 | If **model_type** is *'np-nd-np'*, *'p-nd-np'* or *'np-d-np'*:
126 |
127 | * **classifier_dim**: The dimensionality of the hidden layer for the perceptron that is used as the final predictor.
128 |
129 | If **model_type** is *'p-d-p'* or *'np-d-np'*:
130 |
131 | * **tolerance**: The convergence tolerance for the propagator before sequential decimator is invoked.
132 |
133 | * **t_max**: The maximum iteration number for the propagator before sequential decimator is invoked.
134 |
135 | If **model_type** is *'reinforce'*:
136 |
137 | * **pi**: The external force magnitude parameter for the REINFORCE algorithm.
138 |
139 | * **decimation_probability**: The probability with which the distributed decimation is invoked in the REINFORCE algorithm.
140 |
141 | + **test_path**: The path to the input JSON file containing the test CNFs (in the case of '-d' option, the path to the directory containing the test DIMACS files.)
142 |
143 | + **test_recurrence_num**: The maximum number of iterations the solver is allowed to run before termination.
144 |
145 |
146 | ## Training/Testing a SATYR Solver
147 |
148 | The usage for training/testing new SATYR models is:
149 |
150 | ```
151 | > python satyr-train-test.py [-h] [-t] [-l LOAD_MODEL] [-c] [-r] [-g]
152 | [-b BATCH_REPLICATION]
153 | config
154 | ```
155 |
156 | The commandline arguments are:
157 |
158 | + **-h** or **--help**: Shows the commandline options.
159 |
160 | + **-t** or **--test**: Skips the training stage directly to the testing stage.
161 |
162 | + **-l LOAD_MODEL** or **-load LOAD_MODEL**: LOAD_MODEL is:
163 |
164 | * *best*: The model is initialized by the best model (according to the validation metric) saved from the previous run.
165 |
166 | * *last*: The model is initialized by the last model saved from the previous run.
167 |
168 | * Otherwise, the model is initialized by random weights.
169 |
170 | + **-c** or **--cpu_mode**: Forces the training/testing to run on CPU.
171 |
172 | + **-r** or **--reset**: Resets the global time parameter to 0 (used for annealing the temperature).
173 |
174 | + **-g** or **--use_generator**: Makes the training process use one of the provided CNF generators to generate unlabeled training CNF instances on the fly.
175 |
176 | + **-b BATCH_REPLICATION** or **--batch_replication BATCH_REPLICATION**: BATCH_REPLICATION is the replication factor for input problems to further benefit from parallelization (default: 1).
177 |
178 | + **config**: The path to YAML config file specifying the model as well as the training parameters. A few example training config files are provided [here](https://github.com/Microsoft/PDP-Solver/tree/master/config/Train). A training config file specifies the following properties:
179 |
180 | * **model_name**: The name picked for the model by the user.
181 |
182 | * **model_type**: The model type explained above.
183 |
184 | * **version**: The model version.
185 |
186 | * **has_meta_data**: whether the input problem instances contain meta features other than the CNF itself (Note: loading of such features is not supported by the current input pipeline).
187 |
188 | * **train_path**: A one-element list containing the path to the training JSON file. Will be ignored in the case of using option -g.
189 |
190 | * **validation_path**: A one-element list containing the path to the validation JSON file. Validation set is used for picking the best model during each training run.
191 |
192 | * **test_path**: A list containing the path(s) to test JSON files.
193 |
194 | * **model_path**: The parent directory of the location where the best and the last models are saved.
195 |
196 | * **repetition_num**: Number of repetitions for the training process (for regular scenarios: 1).
197 |
198 | * **train_epoch_size**: The size of one epoch in the case of using CNF generators via option -g.
199 |
200 | * **epoch_num**: The number of epochs for training.
201 |
202 | * **label_dim**: Always set to 1 for SATYR.
203 |
204 | * **edge_feature_dim**: Always set to 1 for SATYR.
205 |
206 | * **meta_feature_dim**: The dimensionality of the meta features (0 for now).
207 |
208 | * **error_dim**: The number error metrics the model reports on the validation/test sets (3 for now: accuracy, recall and test loss).
209 |
210 | * **metric_index**: The 0-based index of the error metric used to pick the best model.
211 |
212 | * **prediction_dim**: Always set to 1 for SATYR.
213 |
214 | * **batch_size**: The batch size used for training/testing.
215 |
216 | * **learning_rate**: The learning rate for ADAM optimization algorithm used for training.
217 |
218 | * **exploration**: The exploration factor used for annealing the temperature.
219 |
220 | * **verbose**: If TRUE prints log messages to STDOUT.
221 |
222 | * **randomized**: If TRUE initializes the propagator and the decimator messages with random values; otherwise with zeros.
223 |
224 | * **train_inner_recurrence_num**: The number of inner loop iterations before the loss function is computed (typically is set to 1).
225 |
226 | * **train_outer_recurrence_num**: The number of outer loop iterations (T in the paper) used during training.
227 |
228 | * **test_recurrence_num**: The number of outer loop iterations (T in the paper) used during testing.
229 |
230 | * **max_cache_size**: The maximum size of the cache containing the parsed CNFs loaded from disk during training.
231 |
232 | * **dropout**: The dropout factor during training.
233 |
234 | * **clip_norm**: The clip norm ratio used for gradient clipping during training.
235 |
236 | * **weight_decay**: The weight decay coefficient for ADAM optimizer used for training.
237 |
238 | * **loss_sharpness**: The sharpness of the step function used for calculating loss (the kappa parameter in the paper).
239 |
240 | + **train_batch_limit**: The memory limit used for dynamic batching during training. It must be empirically tuned by the user depending on the available GPU memory.
241 |
242 | + **test_batch_limit**: The memory limit used for dynamic batching during testing. It must be empirically tuned by the user depending on the available GPU memory.
243 |
244 | * **generator**: The type of CNF generator incorporated in the case -g option is deployed:
245 |
246 | * *'uniform'*: The uniform random k-SAT generator.
247 |
248 | * *'modular'*: The modular random k-SAT generator with fixed k (specified by **min_k**) according to the Community Attachment model[[4]](#reference-4).
249 |
250 | * *'v-modular'*: The modular random k-SAT generator with variable size k according to the Community Attachment model.
251 |
252 | * **min_n**: The minimum number of variables for a random training CNF instance in the case -g option is deployed.
253 |
254 | * **max_n**: The maximum number of variables for a random training CNF instance in the case -g option is deployed.
255 |
256 | * **min_alpha**: The minimum clause/variable ratio for a random training CNF instance in the case -g option is deployed.
257 |
258 | * **max_alpha**: The maximum clause/variable ratio for a random training CNF instance in the case -g option is deployed.
259 |
260 | * **min_k**: The minimum clause size for a random training CNF instance in the case -g option is deployed.
261 |
262 | * **max_k**: The maximum clause size for a random training CNF instance in the case -g option is deployed (not supported for *'v-modular'* generator).
263 |
264 | * **min_q**: The minimum modularity value for a random training CNF instance generated according to the Community Attachment model in the case -g option is deployed (not supported for *'uniform'* generator).
265 |
266 | * **max_q**: The maximum modularity value for a random training CNF instance generated according to the Community Attachment model in the case -g option is deployed (not supported for *'uniform'* generator).
267 |
268 | * **min_c**: The minimum number of communities for a random training CNF instance generated according to the Community Attachment model in the case -g option is deployed (not supported for *'uniform'* generator).
269 |
270 | * **max_c**: The maximum number of communities for a random training CNF instance generated according to the Community Attachment model in the case -g option is deployed (not supported for *'uniform'* generator).
271 |
272 | * **local_search_iteration**: The maximum number of local search (i.e. Walk-SAT) iterations that can be optionally applied as a post-processing step after the main solver terminates during testing.
273 |
274 | * **epsilon**: The probability with which the optional post-processing local search picks a random variable instead of the best option for flipping.
275 |
276 | * **lambda**: The discounting factor in (0, 1] used for loss calculation (the lambda parameter in the paper).
277 |
278 | If **model_type** is *'np-nd-np'*, *'p-nd-np'* or *'np-d-np'*:
279 |
280 | * **hidden_dim**: The dimensionality of messages between the propagator and the decimator.
281 |
282 | If **model_type** is *'np-nd-np'* or *'np-d-np'*:
283 |
284 | * **mem_hidden_dim**: The dimensionality of the hidden layer for the perceptron that is applied to messages *before* the aggregation step in the propagator.
285 |
286 | * **agg_hidden_dim**: The dimensionality of the hidden layer for the perceptron that is applied to messages *after* the aggregation step in the propagator.
287 |
288 | * **mem_agg_hidden_dim**: The output dimensionality of the perceptron that is applied to messages *before* the aggregation step in the propagator.
289 |
290 | If **model_type** is *'np-nd-np'*, *'p-nd-np'* or *'np-d-np'*:
291 |
292 | * **classifier_dim**: The dimensionality of the hidden layer for the perceptron that is used as the final predictor.
293 |
294 | If **model_type** is *'p-d-p'* or *'np-d-np'*:
295 |
296 | * **tolerance**: The convergence tolerance for the propagator before sequential decimator is invoked.
297 |
298 | * **t_max**: The maximum iteration number for the propagator before sequential decimator is invoked.
299 |
300 | If **model_type** is *'reinforce'*:
301 |
302 | * **pi**: The external force magnitude parameter for the REINFORCE algorithm.
303 |
304 | * **decimation_probability**: The probability with which the distributed decimation is invoked in the REINFORCE algorithm.
305 |
306 | # Input/Output Formats
307 |
308 | ## Input
309 |
310 | SATYR effectively works with the standard DIMACS format for representing CNF formulas. However, in order to increase the ingressing efficiency, the actual solvers work directly with an intermediate JSON format instead of the DIMACS representation for consuming input CNF data. A key feature of the intermediate JSON format is that an entire set of DIMACS files can be represented by a single JSON file where each row in the JSON file associates with one DIMACS file.
311 |
312 | The train/test script assumes the train/validation/test sets are already in the JSON format. In order to convert a set of DIMACS files into a single JSON file, we have provided the following script:
313 |
314 | ```
315 | > python dimacs2json.py [-h] [-s] [-p] in_dir out_file
316 | ```
317 |
318 | where the commandline arguments are:
319 |
320 | + **-h** or **--help**: Shows the commandline options.
321 |
322 | + **-s** or **--simplify**: Performs elementary clause propagation simplification on the CNF formulas before converting them to JSON format. This option is not recommended for large formulas as it takes quadratic memory and time in terms of the number of clauses.
323 |
324 | + **-p** or **--positives**: Writes only the satisfiable examples in the output JSON file. This option is specially useful for creating all SAT validation/test sets. Note that this option does NOT invoke any external solver to find out whether an example is SAT or not. Instead, it only works if the SAT/UNSAT labels are already provided in the names of the DIMACS files. In particular, if the name of an input DIMACS file ends in '1', it will be regarded as a SAT (positive) example.
325 |
326 | + **in_dir**: The path to the parent directory of the input DIMACS files.
327 |
328 | + **out_file**: The path of the output JSON file.
329 |
330 | The solver script, however, does not require the input problems to be in the JSON format; they can be in the DIMACS format as long as the -d option is deployed. Nevertheless, for repetitive applications of the solver script on the same input set, we would recommend externally converting the input DIMACS files into the JSON format once and only consume the JSON file afterwards.
331 |
332 | ## Output
333 |
334 | The output of the solver script is a JSON file where each line corresponds to one input CNF instance and is a dictionary with the following key:value pairs:
335 |
336 | + **"ID"**: The DIMACS file name associated with a CNF example.
337 |
338 | + **"label"**: The binary SAT/UNSAT (0/1) label associated with a CNF example (only if it is already provided in the DIMACS filename).
339 |
340 | + **"solved"**: The binary flag showing whether the provided solution satisfies the CNF.
341 |
342 | + **"unsat_clauses"**: The number of clauses in the CNF that are not satisfied by the provided solution (0 if the CNF is satisfied by the solution).
343 |
344 | + **"solution"**: The provided solution by the solver. The variable assignments in the list are ordered based on the increasing variable indices in the original DIMACS file.
345 |
346 | # Main Contributors
347 |
348 | + [Saeed Amizadeh](mailto:saeed.amizadeh@gmail.com), Microsoft Inc.
349 | + [Sergiy Matusevych](mailto:sergiy.matusevych@gmail.com), Microsoft Inc.
350 |
351 | # Contributing
352 |
353 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
354 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
355 | the rights to use your contribution. For details, visit https://cla.microsoft.com.
356 |
357 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide
358 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions
359 | provided by the bot. You will only need to do this once across all repos using our CLA.
360 |
361 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
362 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
363 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
364 |
365 | ## Extending The PDP Framework
366 |
367 | The PDP framework supports a wide range of solvers from fully neural solvers to hybrid, neuro-symbolic models all the way to classical, non-learnable algorithms. So far, we have only implemented six different types, but there is definitely room for more. Therefore, we highly encourage contributions in the form of other types of PDP-based solvers. Furthermore, we welcome contributions with PDP-based adaptations for other types of constraint satisfaction problems beyond SAT.
368 |
369 | # References
370 |
371 | 1. Mezard, M. and Montanari, A. Information, physics, and computation. Oxford University Press, 2009.
372 | 2. Chavas, J., Furtlehner, C., Mezard, M., and Zecchina, R. Survey-propagation decimation through distributed local computations. Journal of Statistical Mechanics: Theory and Experiment, 2005(11):P11016, 2005.
373 | 3. Hoos, Holger H. On the Run-time Behaviour of Stochastic Local Search Algorithms for SAT. In AAAI/IAAI, pp. 661-666. 1999.
374 | 4. Giraldez-Cru, J. and Levy, J. Generating sat instances with community structure. Artificial Intelligence, 238:119–134, 2016.
375 |
--------------------------------------------------------------------------------
/src/pdp/nn/solver.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
3 |
4 | # solver.py : Defines the base class for all PDP Solvers as well as the various inherited solvers.
5 |
6 | import os
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from pdp.nn import pdp_propagate, pdp_decimate, pdp_predict, util
13 |
14 |
15 | ###############################################################
16 | ### The Problem Class
17 | ###############################################################
18 |
19 | class SATProblem(object):
20 | "The class that encapsulates a batch of CNF problem instances."
21 |
22 | def __init__(self, data_batch, device, batch_replication=1):
23 | self._device = device
24 | self._batch_replication = batch_replication
25 | self.setup_problem(data_batch, batch_replication)
26 | self._edge_mask = None
27 |
28 | def setup_problem(self, data_batch, batch_replication):
29 | "Setup the problem properties as well as the relevant sparse matrices."
30 |
31 | if batch_replication > 1:
32 | self._replication_mask_tuple = self._compute_batch_replication_map(data_batch[1], batch_replication)
33 | self._graph_map, self._batch_variable_map, self._batch_function_map, self._edge_feature, self._meta_data, _ = self._replicate_batch(data_batch, batch_replication)
34 | else:
35 | self._graph_map, self._batch_variable_map, self._batch_function_map, self._edge_feature, self._meta_data, _ = data_batch
36 |
37 | self._variable_num = self._batch_variable_map.size()[0]
38 | self._function_num = self._batch_function_map.size()[0]
39 | self._edge_num = self._graph_map.size()[1]
40 |
41 | self._vf_mask_tuple = self._compute_variable_function_map(self._graph_map, self._batch_variable_map,
42 | self._batch_function_map, self._edge_feature)
43 | self._batch_mask_tuple = self._compute_batch_map(self._batch_variable_map, self._batch_function_map)
44 | self._graph_mask_tuple = self._compute_graph_mask(self._graph_map, self._batch_variable_map, self._batch_function_map)
45 | self._pos_mask_tuple = self._compute_graph_mask(self._graph_map, self._batch_variable_map, self._batch_function_map, (self._edge_feature == 1).squeeze(1).float())
46 | self._neg_mask_tuple = self._compute_graph_mask(self._graph_map, self._batch_variable_map, self._batch_function_map, (self._edge_feature == -1).squeeze(1).float())
47 | self._signed_mask_tuple = self._compute_graph_mask(self._graph_map, self._batch_variable_map, self._batch_function_map, self._edge_feature.squeeze(1))
48 |
49 | self._active_variables = torch.ones(self._variable_num, 1, device=self._device)
50 | self._active_functions = torch.ones(self._function_num, 1, device=self._device)
51 | self._solution = 0.5 * torch.ones(self._variable_num, device=self._device)
52 |
53 | self._batch_size = (self._batch_variable_map.max() + 1).long().item()
54 | self._is_sat = 0.5 * torch.ones(self._batch_size, device=self._device)
55 |
56 | def _replicate_batch(self, data_batch, batch_replication):
57 | "Implements the batch replication."
58 |
59 | graph_map, batch_variable_map, batch_function_map, edge_feature, meta_data, label = data_batch
60 | edge_num = graph_map.size()[1]
61 | batch_size = (batch_variable_map.max() + 1).long().item()
62 | variable_num = batch_variable_map.size()[0]
63 | function_num = batch_function_map.size()[0]
64 |
65 | ind = torch.arange(batch_replication, dtype=torch.int32, device=self._device).unsqueeze(1).repeat(1, edge_num).view(1, -1)
66 | graph_map = graph_map.repeat(1, batch_replication) + ind.repeat(2, 1) * torch.tensor([[variable_num], [function_num]], dtype=torch.int32, device=self._device)
67 |
68 | ind = torch.arange(batch_replication, dtype=torch.int32, device=self._device).unsqueeze(1).repeat(1, variable_num).view(1, -1)
69 | batch_variable_map = batch_variable_map.repeat(batch_replication) + ind * batch_size
70 |
71 | ind = torch.arange(batch_replication, dtype=torch.int32, device=self._device).unsqueeze(1).repeat(1, function_num).view(1, -1)
72 | batch_function_map = batch_function_map.repeat(batch_replication) + ind * batch_size
73 |
74 | edge_feature = edge_feature.repeat(batch_replication, 1)
75 |
76 | if meta_data is not None:
77 | meta_data = meta_data.repeat(batch_replication, 1)
78 |
79 | if label is not None:
80 | label = label.repeat(batch_replication, 1)
81 |
82 | return graph_map, batch_variable_map.squeeze(0), batch_function_map.squeeze(0), edge_feature, meta_data, label
83 |
84 | def _compute_batch_replication_map(self, batch_variable_map, batch_replication):
85 | batch_size = (batch_variable_map.max() + 1).long().item()
86 | x_ind = torch.arange(batch_size * batch_replication, dtype=torch.int64, device=self._device)
87 | y_ind = torch.arange(batch_size, dtype=torch.int64, device=self._device).repeat(batch_replication)
88 | ind = torch.stack([x_ind, y_ind])
89 | all_ones = torch.ones(batch_size * batch_replication, device=self._device)
90 |
91 | if self._device.type == 'cuda':
92 | mask = torch.cuda.sparse.FloatTensor(ind, all_ones,
93 | torch.Size([batch_size * batch_replication, batch_size]), device=self._device)
94 | else:
95 | mask = torch.sparse.FloatTensor(ind, all_ones,
96 | torch.Size([batch_size * batch_replication, batch_size]), device=self._device)
97 |
98 | mask_transpose = mask.transpose(0, 1)
99 | return (mask, mask_transpose)
100 |
101 | def _compute_variable_function_map(self, graph_map, batch_variable_map, batch_function_map, edge_feature):
102 | edge_num = graph_map.size()[1]
103 | variable_num = batch_variable_map.size()[0]
104 | function_num = batch_function_map.size()[0]
105 | all_ones = torch.ones(edge_num, device=self._device)
106 |
107 | if self._device.type == 'cuda':
108 | mask = torch.cuda.sparse.FloatTensor(graph_map.long(), all_ones,
109 | torch.Size([variable_num, function_num]), device=self._device)
110 | signed_mask = torch.cuda.sparse.FloatTensor(graph_map.long(), edge_feature.squeeze(1),
111 | torch.Size([variable_num, function_num]), device=self._device)
112 | else:
113 | mask = torch.sparse.FloatTensor(graph_map.long(), all_ones,
114 | torch.Size([variable_num, function_num]), device=self._device)
115 | signed_mask = torch.sparse.FloatTensor(graph_map.long(), edge_feature.squeeze(1),
116 | torch.Size([variable_num, function_num]), device=self._device)
117 |
118 | mask_transpose = mask.transpose(0, 1)
119 | signed_mask_transpose = signed_mask.transpose(0, 1)
120 |
121 | return (mask, mask_transpose, signed_mask, signed_mask_transpose)
122 |
123 | def _compute_batch_map(self, batch_variable_map, batch_function_map):
124 | variable_num = batch_variable_map.size()[0]
125 | function_num = batch_function_map.size()[0]
126 | variable_all_ones = torch.ones(variable_num, device=self._device)
127 | function_all_ones = torch.ones(function_num, device=self._device)
128 | variable_range = torch.arange(variable_num, dtype=torch.int64, device=self._device)
129 | function_range = torch.arange(function_num, dtype=torch.int64, device=self._device)
130 | batch_size = (batch_variable_map.max() + 1).long().item()
131 |
132 | variable_sparse_ind = torch.stack([variable_range, batch_variable_map.long()])
133 | function_sparse_ind = torch.stack([function_range, batch_function_map.long()])
134 |
135 | if self._device.type == 'cuda':
136 | variable_mask = torch.cuda.sparse.FloatTensor(variable_sparse_ind, variable_all_ones,
137 | torch.Size([variable_num, batch_size]), device=self._device)
138 | function_mask = torch.cuda.sparse.FloatTensor(function_sparse_ind, function_all_ones,
139 | torch.Size([function_num, batch_size]), device=self._device)
140 | else:
141 | variable_mask = torch.sparse.FloatTensor(variable_sparse_ind, variable_all_ones,
142 | torch.Size([variable_num, batch_size]), device=self._device)
143 | function_mask = torch.sparse.FloatTensor(function_sparse_ind, function_all_ones,
144 | torch.Size([function_num, batch_size]), device=self._device)
145 |
146 | variable_mask_transpose = variable_mask.transpose(0, 1)
147 | function_mask_transpose = function_mask.transpose(0, 1)
148 |
149 | return (variable_mask, variable_mask_transpose, function_mask, function_mask_transpose)
150 |
151 | def _compute_graph_mask(self, graph_map, batch_variable_map, batch_function_map, edge_values=None):
152 | edge_num = graph_map.size()[1]
153 | variable_num = batch_variable_map.size()[0]
154 | function_num = batch_function_map.size()[0]
155 |
156 | if edge_values is None:
157 | edge_values = torch.ones(edge_num, device=self._device)
158 |
159 | edge_num_range = torch.arange(edge_num, dtype=torch.int64, device=self._device)
160 |
161 | variable_sparse_ind = torch.stack([graph_map[0, :].long(), edge_num_range])
162 | function_sparse_ind = torch.stack([graph_map[1, :].long(), edge_num_range])
163 |
164 | if self._device.type == 'cuda':
165 | variable_mask = torch.cuda.sparse.FloatTensor(variable_sparse_ind, edge_values,
166 | torch.Size([variable_num, edge_num]), device=self._device)
167 | function_mask = torch.cuda.sparse.FloatTensor(function_sparse_ind, edge_values,
168 | torch.Size([function_num, edge_num]), device=self._device)
169 | else:
170 | variable_mask = torch.sparse.FloatTensor(variable_sparse_ind, edge_values,
171 | torch.Size([variable_num, edge_num]), device=self._device)
172 | function_mask = torch.sparse.FloatTensor(function_sparse_ind, edge_values,
173 | torch.Size([function_num, edge_num]), device=self._device)
174 |
175 | variable_mask_transpose = variable_mask.transpose(0, 1)
176 | function_mask_transpose = function_mask.transpose(0, 1)
177 |
178 | return (variable_mask, variable_mask_transpose, function_mask, function_mask_transpose)
179 |
180 | def _peel(self):
181 | "Implements the peeling algorithm."
182 |
183 | vf_map, vf_map_transpose, signed_vf_map, _ = self._vf_mask_tuple
184 |
185 | variable_degree = torch.mm(vf_map, self._active_functions)
186 | signed_variable_degree = torch.mm(signed_vf_map, self._active_functions)
187 |
188 | while True:
189 | single_variables = (variable_degree == signed_variable_degree.abs()).float() * self._active_variables
190 |
191 | if torch.sum(single_variables) <= 0:
192 | break
193 |
194 | single_functions = (torch.mm(vf_map_transpose, single_variables) > 0).float() * self._active_functions
195 | degree_delta = torch.mm(vf_map, single_functions) * self._active_variables
196 | signed_degree_delta = torch.mm(signed_vf_map, single_functions) * self._active_variables
197 | self._solution[single_variables[:, 0] == 1] = (signed_variable_degree[single_variables[:, 0] == 1, 0].sign() + 1) / 2.0
198 |
199 | variable_degree -= degree_delta
200 | signed_variable_degree -= signed_degree_delta
201 |
202 | self._active_variables[single_variables[:, 0] == 1, 0] = 0
203 | self._active_functions[single_functions[:, 0] == 1, 0] = 0
204 |
205 | def _set_variable_core(self, assignment):
206 | "Fixes variables to certain binary values."
207 |
208 | _, vf_map_transpose, _, signed_vf_map_transpose = self._vf_mask_tuple
209 |
210 | assignment *= self._active_variables
211 |
212 | # Compute the number of inputs for each function node
213 | input_num = torch.mm(vf_map_transpose, assignment.abs())
214 |
215 | # Compute the signed evaluation for each function node
216 | function_eval = torch.mm(signed_vf_map_transpose, assignment)
217 |
218 | # Compute the de-activated functions
219 | deactivated_functions = (function_eval > -input_num).float() * self._active_functions
220 |
221 | # De-activate functions and variables
222 | self._active_variables[assignment[:, 0].abs() == 1, 0] = 0
223 | self._active_functions[deactivated_functions[:, 0] == 1, 0] = 0
224 |
225 | # Update the solution
226 | self._solution[assignment[:, 0].abs() == 1] = (assignment[assignment[:, 0].abs() == 1, 0] + 1) / 2.0
227 |
228 | def _propagate_single_clauses(self):
229 | "Implements unit clause propagation algorithm."
230 |
231 | vf_map, vf_map_transpose, signed_vf_map, _ = self._vf_mask_tuple
232 | b_variable_mask, b_variable_mask_transpose, b_function_mask, _ = self._batch_mask_tuple
233 |
234 | while True:
235 | function_degree = torch.mm(vf_map_transpose, self._active_variables)
236 | single_functions = (function_degree == 1).float() * self._active_functions
237 |
238 | if torch.sum(single_functions) <= 0:
239 | break
240 |
241 | # Compute the number of inputs for each variable node
242 | input_num = torch.mm(vf_map, single_functions)
243 |
244 | # Compute the signed evaluation for each variable node
245 | variable_eval = torch.mm(signed_vf_map, single_functions)
246 |
247 | # Detect and de-activate the UNSAT examples
248 | conflict_variables = (variable_eval.abs() != input_num).float() * self._active_variables
249 | if torch.sum(conflict_variables) > 0:
250 |
251 | # Detect the UNSAT examples
252 | unsat_examples = torch.mm(b_variable_mask_transpose, conflict_variables)
253 | self._is_sat[unsat_examples[:, 0] >= 1] = 0
254 |
255 | # De-activate the function nodes related to unsat examples
256 | unsat_functions = torch.mm(b_function_mask, unsat_examples) * self._active_functions
257 | self._active_functions[unsat_functions[:, 0] == 1, 0] = 0
258 |
259 | # De-activate the variable nodes related to unsat examples
260 | unsat_variables = torch.mm(b_variable_mask, unsat_examples) * self._active_variables
261 | self._active_variables[unsat_variables[:, 0] == 1, 0] = 0
262 |
263 | # Compute the assigned variables
264 | assigned_variables = (variable_eval.abs() == input_num).float() * self._active_variables
265 |
266 | # Compute the variable assignment
267 | assignment = torch.sign(variable_eval) * assigned_variables
268 |
269 | # De-activate single functions
270 | self._active_functions[single_functions[:, 0] == 1, 0] = 0
271 |
272 | # Set the corresponding variables
273 | self._set_variable_core(assignment)
274 |
275 | def set_variables(self, assignment):
276 | "Fixes variables to certain binary values and simplifies the CNF accordingly."
277 |
278 | self._set_variable_core(assignment)
279 | self.simplify()
280 |
281 | def simplify(self):
282 | "Simplifies the CNF."
283 |
284 | self._propagate_single_clauses()
285 | self._peel()
286 |
287 |
288 | ###############################################################
289 | ### The Solver Classes
290 | ###############################################################
291 |
292 |
293 | class PropagatorDecimatorSolverBase(nn.Module):
294 | "The base class for all PDP SAT solvers."
295 |
296 | def __init__(self, device, name, propagator, decimator, predictor, local_search_iterations=0, epsilon=0.05):
297 |
298 | super(PropagatorDecimatorSolverBase, self).__init__()
299 | self._device = device
300 | self._module_list = nn.ModuleList()
301 |
302 | self._propagator = propagator
303 | self._decimator = decimator
304 | self._predictor = predictor
305 |
306 | self._module_list.append(self._propagator)
307 | self._module_list.append(self._decimator)
308 | self._module_list.append(self._predictor)
309 |
310 | self._global_step = nn.Parameter(torch.tensor([0], dtype=torch.float, device=self._device), requires_grad=False)
311 | self._name = name
312 | self._local_search_iterations = local_search_iterations
313 | self._epsilon = epsilon
314 |
315 | def parameter_count(self):
316 | return sum(p.numel() for p in self.parameters() if p.requires_grad)
317 |
318 | def save(self, export_path_base):
319 | torch.save(self.state_dict(), os.path.join(export_path_base, self._name))
320 |
321 | def load(self, import_path_base):
322 | self.load_state_dict(torch.load(os.path.join(import_path_base, self._name)))
323 |
324 | def forward(self, init_state,
325 | graph_map, batch_variable_map, batch_function_map, edge_feature,
326 | meta_data, is_training=True, iteration_num=1, check_termination=None, simplify=True, batch_replication=1):
327 |
328 | init_propagator_state, init_decimator_state = init_state
329 | batch_replication = 1 if is_training else batch_replication
330 | sat_problem = SATProblem((graph_map, batch_variable_map, batch_function_map, edge_feature, meta_data, None), self._device, batch_replication)
331 |
332 | if simplify and not is_training:
333 | sat_problem.simplify()
334 |
335 | if self._propagator is not None and self._decimator is not None:
336 | propagator_state, decimator_state = self._forward_core(init_propagator_state, init_decimator_state,
337 | sat_problem, iteration_num, is_training, check_termination)
338 | else:
339 | decimator_state = None
340 | propagator_state = None
341 |
342 | prediction = self._predictor(decimator_state, sat_problem, True)
343 |
344 | # Post-processing local search
345 | if not is_training:
346 | prediction = self._local_search(prediction, sat_problem, batch_replication)
347 |
348 | prediction = self._update_solution(prediction, sat_problem)
349 |
350 | if batch_replication > 1:
351 | prediction, propagator_state, decimator_state = self._deduplicate(prediction, propagator_state, decimator_state, sat_problem)
352 |
353 | return (prediction, (propagator_state, decimator_state))
354 |
355 | def _forward_core(self, init_propagator_state, init_decimator_state, sat_problem, iteration_num, is_training, check_termination):
356 |
357 | propagator_state = init_propagator_state
358 | decimator_state = init_decimator_state
359 |
360 | if check_termination is None:
361 | active_mask = None
362 | else:
363 | active_mask = torch.ones(sat_problem._batch_size, 1, dtype=torch.uint8, device=self._device)
364 |
365 | for _ in torch.arange(iteration_num, dtype=torch.int32, device=self._device):
366 |
367 | propagator_state = self._propagator(propagator_state, decimator_state, sat_problem, is_training, active_mask)
368 | decimator_state = self._decimator(decimator_state, propagator_state, sat_problem, is_training, active_mask)
369 |
370 | sat_problem._edge_mask = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables) * \
371 | torch.mm(sat_problem._graph_mask_tuple[3], sat_problem._active_functions)
372 |
373 | if sat_problem._edge_mask.sum() < sat_problem._edge_num:
374 | decimator_state += (sat_problem._edge_mask,)
375 |
376 | if check_termination is not None:
377 | prediction = self._predictor(decimator_state, sat_problem)
378 | prediction = self._update_solution(prediction, sat_problem)
379 |
380 | check_termination(active_mask, prediction, sat_problem)
381 | num_active = active_mask.sum()
382 |
383 | if num_active <= 0:
384 | break
385 |
386 | return propagator_state, decimator_state
387 |
388 | def _update_solution(self, prediction, sat_problem):
389 | "Updates the the SAT problem object's solution according to the cuerrent prediction."
390 |
391 | if prediction[0] is not None:
392 | variable_solution = sat_problem._active_variables * prediction[0] + \
393 | (1.0 - sat_problem._active_variables) * sat_problem._solution.unsqueeze(1)
394 | sat_problem._solution[sat_problem._active_variables[:, 0] == 1] = \
395 | variable_solution[sat_problem._active_variables[:, 0] == 1, 0]
396 | else:
397 | variable_solution = None
398 |
399 | return variable_solution, prediction[1]
400 |
401 | def _deduplicate(self, prediction, propagator_state, decimator_state, sat_problem):
402 | "De-duplicates the current batch (to neutralize the batch replication) by finding the replica with minimum energy for each problem instance. "
403 |
404 | if sat_problem._batch_replication <= 1 or sat_problem._replication_mask_tuple is None:
405 | return None, None, None
406 |
407 | assignment = 2 * prediction[0] - 1.0
408 | energy, _ = self._compute_energy(assignment, sat_problem)
409 | max_ind = util.sparse_argmax(-energy.squeeze(1), sat_problem._replication_mask_tuple[0], device=self._device)
410 |
411 | batch_flag = torch.zeros(sat_problem._batch_size, 1, device=self._device)
412 | batch_flag[max_ind, 0] = 1
413 |
414 | flag = torch.mm(sat_problem._batch_mask_tuple[0], batch_flag)
415 | variable_prediction = (flag * prediction[0]).view(sat_problem._batch_replication, -1).sum(dim=0).unsqueeze(1)
416 |
417 | flag = torch.mm(sat_problem._graph_mask_tuple[1], flag)
418 | new_propagator_state = ()
419 | for x in propagator_state:
420 | new_propagator_state += ((flag * x).view(sat_problem._batch_replication, sat_problem._edge_num / sat_problem._batch_replication, -1).sum(dim=0),)
421 |
422 | new_decimator_state = ()
423 | for x in decimator_state:
424 | new_decimator_state += ((flag * x).view(sat_problem._batch_replication, sat_problem._edge_num / sat_problem._batch_replication, -1).sum(dim=0),)
425 |
426 | function_prediction = None
427 | if prediction[1] is not None:
428 | flag = torch.mm(sat_problem._batch_mask_tuple[2], batch_flag)
429 | function_prediction = (flag * prediction[1]).view(sat_problem._batch_replication, -1).sum(dim=0).unsqueeze(1)
430 |
431 | return (variable_prediction, function_prediction), new_propagator_state, new_decimator_state
432 |
433 | def _local_search(self, prediction, sat_problem, batch_replication):
434 | "Implements the Walk-SAT algorithm for post-processing."
435 |
436 | assignment = (prediction[0] > 0.5).float()
437 | assignment = sat_problem._active_variables * (2*assignment - 1.0)
438 |
439 | sat_problem._edge_mask = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables) * \
440 | torch.mm(sat_problem._graph_mask_tuple[3], sat_problem._active_functions)
441 |
442 | for _ in range(self._local_search_iterations):
443 | unsat_examples, unsat_functions = self._compute_energy(assignment, sat_problem)
444 | unsat_examples = (unsat_examples > 0).float()
445 |
446 | if batch_replication > 1:
447 | compact_unsat_examples = 1 - (torch.mm(sat_problem._replication_mask_tuple[1], 1 - unsat_examples) > 0).float()
448 | if compact_unsat_examples.sum() == 0:
449 | break
450 | elif unsat_examples.sum() == 0:
451 | break
452 |
453 | delta_energy = self._compute_energy_diff(assignment, sat_problem)
454 | max_delta_ind = util.sparse_argmax(-delta_energy.squeeze(1), sat_problem._batch_mask_tuple[0], device=self._device)
455 |
456 | unsat_variables = torch.mm(sat_problem._vf_mask_tuple[0], unsat_functions) * sat_problem._active_variables
457 | unsat_variables = (unsat_variables > 0).float() * torch.rand([sat_problem._variable_num, 1], device=self._device)
458 | random_ind = util.sparse_argmax(unsat_variables.squeeze(1), sat_problem._batch_mask_tuple[0], device=self._device)
459 |
460 | coin = (torch.rand(sat_problem._batch_size, device=self._device) > self._epsilon).long()
461 | max_ind = coin * max_delta_ind + (1 - coin) * random_ind
462 | max_ind = max_ind[unsat_examples[:, 0] > 0]
463 |
464 | # Flipping the selected variables
465 | assignment[max_ind, 0] = -assignment[max_ind, 0]
466 |
467 | return (assignment + 1) / 2.0, prediction[1]
468 |
469 | def _compute_energy_diff(self, assignment, sat_problem):
470 | "Computes the delta energy if each variable to be flipped during the local search."
471 |
472 | distributed_assignment = torch.mm(sat_problem._signed_mask_tuple[1], assignment * sat_problem._active_variables)
473 | aggregated_assignment = torch.mm(sat_problem._graph_mask_tuple[2], distributed_assignment)
474 | aggregated_assignment = torch.mm(sat_problem._graph_mask_tuple[3], aggregated_assignment)
475 | aggregated_assignment = aggregated_assignment - distributed_assignment
476 |
477 | function_degree = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables)
478 | function_degree = torch.mm(sat_problem._graph_mask_tuple[2], function_degree)
479 | function_degree = torch.mm(sat_problem._graph_mask_tuple[3], function_degree)
480 |
481 | critical_edges = (aggregated_assignment == (1 - function_degree)).float() * sat_problem._edge_mask
482 | delta = torch.mm(sat_problem._graph_mask_tuple[0], critical_edges * distributed_assignment)
483 |
484 | return delta
485 |
486 | def _compute_energy(self, assignment, sat_problem):
487 | "Computes the energy of each CNF instance present in the batch."
488 |
489 | aggregated_assignment = torch.mm(sat_problem._signed_mask_tuple[1], assignment * sat_problem._active_variables)
490 | aggregated_assignment = torch.mm(sat_problem._graph_mask_tuple[2], aggregated_assignment)
491 |
492 | function_degree = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables)
493 | function_degree = torch.mm(sat_problem._graph_mask_tuple[2], function_degree)
494 |
495 | unsat_functions = (aggregated_assignment == -function_degree).float() * sat_problem._active_functions
496 | return torch.mm(sat_problem._batch_mask_tuple[3], unsat_functions), unsat_functions
497 |
498 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication=1):
499 | "Initializes the propgator and the decimator messages in each direction."
500 |
501 | if self._propagator is None:
502 | init_propagator_state = None
503 | else:
504 | init_propagator_state = self._propagator.get_init_state(graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication)
505 |
506 | if self._decimator is None:
507 | init_decimator_state = None
508 | else:
509 | init_decimator_state = self._decimator.get_init_state(graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication)
510 |
511 | return init_propagator_state, init_decimator_state
512 |
513 |
514 | ###############################################################
515 |
516 |
517 | class NeuralPropagatorDecimatorSolver(PropagatorDecimatorSolverBase):
518 | "Implements a fully neural PDP SAT solver with both the propagator and the decimator being neural."
519 |
520 | def __init__(self, device, name, edge_dimension, meta_data_dimension,
521 | propagator_dimension, decimator_dimension,
522 | mem_hidden_dimension, agg_hidden_dimension, mem_agg_hidden_dimension, prediction_dimension,
523 | variable_classifier=None, function_classifier=None, dropout=0,
524 | local_search_iterations=0, epsilon=0.05):
525 |
526 | super(NeuralPropagatorDecimatorSolver, self).__init__(
527 | device=device, name=name,
528 | propagator=pdp_propagate.NeuralMessagePasser(device, edge_dimension, decimator_dimension,
529 | meta_data_dimension, propagator_dimension, mem_hidden_dimension,
530 | mem_agg_hidden_dimension, agg_hidden_dimension, dropout),
531 | decimator=pdp_decimate.NeuralDecimator(device, propagator_dimension, meta_data_dimension,
532 | decimator_dimension, mem_hidden_dimension,
533 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, dropout),
534 | predictor=pdp_predict.NeuralPredictor(device, decimator_dimension, prediction_dimension,
535 | edge_dimension, meta_data_dimension, mem_hidden_dimension, agg_hidden_dimension,
536 | mem_agg_hidden_dimension, variable_classifier, function_classifier),
537 | local_search_iterations=local_search_iterations, epsilon=epsilon)
538 |
539 |
540 | ###############################################################
541 |
542 |
543 | class NeuralSurveyPropagatorSolver(PropagatorDecimatorSolverBase):
544 | "Implements a PDP solver with the SP propgator and a neural decimator."
545 |
546 | def __init__(self, device, name, edge_dimension, meta_data_dimension,
547 | decimator_dimension,
548 | mem_hidden_dimension, agg_hidden_dimension, mem_agg_hidden_dimension, prediction_dimension,
549 | variable_classifier=None, function_classifier=None, dropout=0,
550 | local_search_iterations=0, epsilon=0.05):
551 |
552 | super(NeuralSurveyPropagatorSolver, self).__init__(
553 | device=device, name=name,
554 | propagator=pdp_propagate.SurveyPropagator(device, decimator_dimension, include_adaptors=True),
555 | decimator=pdp_decimate.NeuralDecimator(device, (3, 1), meta_data_dimension,
556 | decimator_dimension, mem_hidden_dimension,
557 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, dropout),
558 | predictor=pdp_predict.NeuralPredictor(device, decimator_dimension, prediction_dimension,
559 | edge_dimension, meta_data_dimension, mem_hidden_dimension, agg_hidden_dimension,
560 | mem_agg_hidden_dimension, variable_classifier, function_classifier),
561 | local_search_iterations=local_search_iterations, epsilon=epsilon)
562 |
563 |
564 | ###############################################################
565 |
566 |
567 | class SurveyPropagatorSolver(PropagatorDecimatorSolverBase):
568 | "Implements the classical SP-guided decimation solver via the PDP framework."
569 |
570 | def __init__(self, device, name, tolerance, t_max, local_search_iterations=0, epsilon=0.05):
571 |
572 | super(SurveyPropagatorSolver, self).__init__(
573 | device=device, name=name,
574 | propagator=pdp_propagate.SurveyPropagator(device, decimator_dimension=1, include_adaptors=False),
575 | decimator=pdp_decimate.SequentialDecimator(device, message_dimension=(3, 1),
576 | scorer=pdp_predict.SurveyScorer(device, message_dimension=1, include_adaptors=False), tolerance=tolerance, t_max=t_max),
577 | predictor=pdp_predict.IdentityPredictor(device=device, random_fill=True),
578 | local_search_iterations=local_search_iterations, epsilon=epsilon)
579 |
580 |
581 | ###############################################################
582 |
583 |
584 | class WalkSATSolver(PropagatorDecimatorSolverBase):
585 | "Implements the classical Walk-SAT solver via the PDP framework."
586 |
587 | def __init__(self, device, name, iteration_num, epsilon=0.05):
588 |
589 | super(WalkSATSolver, self).__init__(
590 | device=device, name=name, propagator=None, decimator=None,
591 | predictor=pdp_predict.IdentityPredictor(device=device, random_fill=True),
592 | local_search_iterations=iteration_num, epsilon=epsilon)
593 |
594 |
595 | ###############################################################
596 |
597 |
598 | class ReinforceSurveyPropagatorSolver(PropagatorDecimatorSolverBase):
599 | "Implements the classical Reinforce solver via the PDP framework."
600 |
601 | def __init__(self, device, name, pi=0.1, decimation_probability=0.5, local_search_iterations=0, epsilon=0.05):
602 |
603 | super(ReinforceSurveyPropagatorSolver, self).__init__(
604 | device=device, name=name,
605 | propagator=pdp_propagate.SurveyPropagator(device, decimator_dimension=1, include_adaptors=False, pi=pi),
606 | decimator=pdp_decimate.ReinforceDecimator(device,
607 | scorer=pdp_predict.SurveyScorer(device, message_dimension=1, include_adaptors=False, pi=pi),
608 | decimation_probability=decimation_probability),
609 | predictor=pdp_predict.ReinforcePredictor(device=device),
610 | local_search_iterations=local_search_iterations, epsilon=epsilon)
611 |
612 |
613 | ###############################################################
614 |
615 |
616 | class NeuralSequentialDecimatorSolver(PropagatorDecimatorSolverBase):
617 | "Implements a PDP solver with a neural propgator and the sequential decimator."
618 |
619 | def __init__(self, device, name, edge_dimension, meta_data_dimension,
620 | propagator_dimension, decimator_dimension,
621 | mem_hidden_dimension, agg_hidden_dimension, mem_agg_hidden_dimension,
622 | classifier_dimension, dropout, tolerance, t_max,
623 | local_search_iterations=0, epsilon=0.05):
624 |
625 | super(NeuralSequentialDecimatorSolver, self).__init__(
626 | device=device, name=name,
627 | propagator=pdp_propagate.NeuralMessagePasser(device, edge_dimension, decimator_dimension,
628 | meta_data_dimension, propagator_dimension, mem_hidden_dimension,
629 | mem_agg_hidden_dimension, agg_hidden_dimension, dropout),
630 | decimator=pdp_decimate.SequentialDecimator(device, message_dimension=(3, 1),
631 | scorer=pdp_predict.NeuralPredictor(device, decimator_dimension, 1,
632 | edge_dimension, meta_data_dimension, mem_hidden_dimension, agg_hidden_dimension,
633 | mem_agg_hidden_dimension, variable_classifier=util.PerceptronTanh(decimator_dimension,
634 | classifier_dimension, 1), function_classifier=None),
635 | tolerance=tolerance, t_max=t_max),
636 | predictor=pdp_predict.IdentityPredictor(device=device, random_fill=True),
637 | local_search_iterations=local_search_iterations, epsilon=epsilon)
638 |
--------------------------------------------------------------------------------