├── src ├── pdp │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── pdp_predict.py │ │ ├── pdp_propagate.py │ │ ├── pdp_decimate.py │ │ ├── util.py │ │ └── solver.py │ ├── factorgraph │ │ ├── __init__.py │ │ ├── dataset.py │ │ └── base.py │ ├── trainer.py │ └── generator.py ├── .DS_Store ├── satyr.py ├── dimacs2json.py └── satyr-train-test.py ├── .DS_Store ├── config ├── Predict │ ├── PDP-p-d-p-walksat-pytorch.yaml │ ├── PDP-p-d-p-sp-pytorch.yaml │ ├── PDP-p-d-p-reinforce-pytorch.yaml │ └── PDP-np-nd-np-gcnf-10-100-pytorch.yaml └── Train │ ├── p-prodec2-ws-cnf-pytorch.yaml │ ├── p-prodec2-gcnf-4SAT-pytorch.yaml │ ├── p-prodec2-sp-cnf-3-10-pytorch.yaml │ ├── p-prodec2-reinforce-cnf-pytorch.yaml │ ├── p-prodec2-nsp-cnf-3-10-pytorch.yaml │ ├── p-prodec2-gcnf-10-100-pytorch.yaml │ ├── p-prodec2-modular-variable-pytorch-2.yaml │ ├── p-prodec2-ndec-cnf-3-10-pytorch.yaml │ ├── p-prodec2-modular-variable-pytorch.yaml │ └── p-prodec2-modular-4SAT-pytorch.yaml ├── .pylintrc ├── LICENSE ├── setup.py ├── .gitignore ├── SECURITY.md └── README.md /src/pdp/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | PDP Solver. 3 | """ 4 | -------------------------------------------------------------------------------- /src/pdp/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | NN layers. 3 | """ 4 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PDP-Solver/HEAD/.DS_Store -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PDP-Solver/HEAD/src/.DS_Store -------------------------------------------------------------------------------- /config/Predict/PDP-p-d-p-walksat-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_type: "walk-sat" 2 | model_name: "p-prodec2-walksat-pytorch" 3 | -------------------------------------------------------------------------------- /config/Predict/PDP-p-d-p-sp-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_type: "p-d-p" 2 | model_name: "p-prodec2-sp-pytorch" 3 | tolerance: 0.02 4 | t_max: 100 5 | -------------------------------------------------------------------------------- /config/Predict/PDP-p-d-p-reinforce-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_type: "reinforce" 2 | model_name: "p-prodec2-reinforce-pytorch" 3 | pi: 0.01 4 | decimation_probability: 0.5 5 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | disable=not-callable 3 | 4 | [FORMAT] 5 | max-line-length=100 6 | max-module-lines=1000 7 | 8 | [TYPECHECK] 9 | generated-members=numpy,np,torch 10 | -------------------------------------------------------------------------------- /src/pdp/factorgraph/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factor Graph trainer base funstionality. 3 | """ 4 | 5 | from pdp.factorgraph.base import FactorGraphTrainerBase 6 | 7 | 8 | __all__ = ["FactorGraphTrainerBase"] 9 | -------------------------------------------------------------------------------- /config/Predict/PDP-np-nd-np-gcnf-10-100-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_type: "np-nd-np" 2 | has_meta_data: false 3 | model_name: "p-prodec2-gcnf-10-100-pytorch" 4 | model_path: "../../Trained-models/SAT/p-prodec2-gcnf-10-100-pytorch/2.0g/best" 5 | label_dim: 1 6 | edge_feature_dim: 1 7 | meta_feature_dim: 0 8 | prediction_dim: 1 9 | hidden_dim: 150 10 | mem_hidden_dim: 100 11 | agg_hidden_dim: 100 12 | mem_agg_hidden_dim: 50 13 | classifier_dim: 50 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-ws-cnf-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-ws-cnf-pytorch" 2 | model_type: "walk-sat" 3 | version: "0.0" 4 | has_meta_data: true 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/cnf-10.json"] 7 | test_path: ["../../datasets/SAT/4SAT-test"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/4SAT-100"] 8 | model_path: "../../Trained-models/SAT" 9 | repetition_num: 1 10 | train_epoch_size: 200000 11 | epoch_num: 500 12 | label_dim: 1 13 | edge_feature_dim: 1 14 | meta_feature_dim: 1 15 | error_dim: 3 16 | metric_index: 0 17 | prediction_dim: 1 18 | hidden_dim: 157 # 110 19 | mem_hidden_dim: 50 20 | agg_hidden_dim: 50 # 135 21 | mem_agg_hidden_dim: 50 # 50 22 | classifier_dim: 50 # 30 # 100 23 | batch_size: 5000 24 | learning_rate: 0.0001 25 | exploration: 0.1 26 | verbose: true 27 | randomized: true 28 | train_inner_recurrence_num: 1 29 | train_outer_recurrence_num: 20 30 | test_recurrence_num: 1000 31 | max_cache_size: 100000 32 | dropout: 0.2 33 | clip_norm: 0.65 34 | weight_decay: 0.0000000001 35 | loss_sharpness: 5 36 | train_batch_limit: 4000000 37 | test_batch_limit: 40000000 38 | min_n: 10 39 | max_n: 100 40 | min_alpha: 2 41 | max_alpha: 10 42 | min_k: 2 43 | max_k: 10 44 | local_search_iteration: 1000 45 | epsilon: 0.5 46 | lambda: 1 47 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-gcnf-4SAT-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-gcnf-4SAT-pytorch" 2 | model_type: "np-nd-np" 3 | version: "0.0g" 4 | has_meta_data: false # true 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/cnf-10.json"] 7 | test_path: ["../../datasets/SAT/M0-4SAT-100"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/4SAT-100"] 8 | model_path: "../../Trained-models/SAT" 9 | repetition_num: 1 10 | train_epoch_size: 40000 11 | epoch_num: 500 12 | label_dim: 1 13 | edge_feature_dim: 1 14 | meta_feature_dim: 0 # 1 15 | error_dim: 3 16 | metric_index: 0 17 | prediction_dim: 1 18 | hidden_dim: 150 # 110 19 | mem_hidden_dim: 100 20 | agg_hidden_dim: 100 # 135 21 | mem_agg_hidden_dim: 50 # 50 22 | classifier_dim: 50 # 100 23 | batch_size: 5000 24 | learning_rate: 0.0001 25 | exploration: 0.1 26 | verbose: true 27 | randomized: true 28 | train_inner_recurrence_num: 1 29 | train_outer_recurrence_num: 10 30 | test_recurrence_num: 1000 31 | max_cache_size: 100000 32 | dropout: 0.2 33 | clip_norm: 0.65 34 | weight_decay: 0.0000000001 35 | loss_sharpness: 5 36 | train_batch_limit: 4000000 37 | test_batch_limit: 40000000 38 | generator: "uniform" 39 | min_n: 5 40 | max_n: 100 41 | min_alpha: 7 42 | max_alpha: 10 43 | min_k: 4 44 | max_k: 4 45 | local_search_iteration: 100 46 | epsilon: 0.05 47 | lambda: 1 48 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-sp-cnf-3-10-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-sp-cnf-3-10-pytorch" 2 | model_type: "p-d-p" 3 | version: "2.0" 4 | has_meta_data: true 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/cnf-10.json"] 7 | test_path: ["../../datasets/SAT/3SAT-100"] #["../../datasets/SAT/4SAT-test"] # ["../../datasets/SAT/cnf-10.json"] # 8 | model_path: "../../Trained-models/SAT" 9 | repetition_num: 1 10 | train_epoch_size: 200000 11 | epoch_num: 500 12 | label_dim: 1 13 | edge_feature_dim: 1 14 | meta_feature_dim: 1 15 | error_dim: 3 16 | metric_index: 0 17 | prediction_dim: 1 18 | hidden_dim: 157 # 110 19 | mem_hidden_dim: 50 20 | agg_hidden_dim: 50 # 135 21 | mem_agg_hidden_dim: 50 # 50 22 | classifier_dim: 50 # 30 # 100 23 | batch_size: 5000 24 | learning_rate: 0.0001 25 | exploration: 0.1 26 | verbose: true 27 | randomized: true 28 | train_inner_recurrence_num: 1 29 | train_outer_recurrence_num: 20 30 | test_recurrence_num: 1000 31 | max_cache_size: 100000 32 | dropout: 0.2 33 | clip_norm: 0.65 34 | weight_decay: 0.0000000001 35 | loss_sharpness: 5 36 | train_batch_limit: 4000000 37 | test_batch_limit: 40000000 38 | min_n: 10 39 | max_n: 100 40 | min_alpha: 2 41 | max_alpha: 10 42 | min_k: 2 43 | max_k: 10 44 | local_search_iteration: 1000 45 | epsilon: 0.5 46 | tolerance: 0.02 47 | t_max: 100 48 | lambda: 1 49 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-reinforce-cnf-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-reinforce-cnf-pytorch" 2 | model_type: "reinforce" 3 | version: "1.0" 4 | has_meta_data: true 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/cnf-10.json"] 7 | test_path: ["../../datasets/SAT/3SAT-100"] # ["../../datasets/SAT/3SAT-100"] # ["../../datasets/SAT/4SAT-100"] # ["../../datasets/SAT/cnf-10.json"] # 8 | model_path: "../../Trained-models/SAT" 9 | repetition_num: 1 10 | train_epoch_size: 200000 11 | epoch_num: 500 12 | label_dim: 1 13 | edge_feature_dim: 1 14 | meta_feature_dim: 1 15 | error_dim: 3 16 | metric_index: 0 17 | prediction_dim: 1 18 | hidden_dim: 157 # 110 19 | mem_hidden_dim: 50 20 | agg_hidden_dim: 50 # 135 21 | mem_agg_hidden_dim: 50 # 50 22 | classifier_dim: 50 # 30 # 100 23 | batch_size: 5000 24 | learning_rate: 0.0001 25 | exploration: 0.1 26 | verbose: true 27 | randomized: true 28 | train_inner_recurrence_num: 1 29 | train_outer_recurrence_num: 20 30 | test_recurrence_num: 1000 31 | max_cache_size: 100000 32 | dropout: 0.2 33 | clip_norm: 0.65 34 | weight_decay: 0.0000000001 35 | loss_sharpness: 5 36 | train_batch_limit: 4000000 37 | test_batch_limit: 40000000 38 | min_n: 10 39 | max_n: 100 40 | min_alpha: 2 41 | max_alpha: 10 42 | min_k: 2 43 | max_k: 10 44 | local_search_iteration: 1000 45 | epsilon: 0.5 46 | pi: 0.01 47 | decimation_probability: 0.5 48 | lambda: 1 49 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-nsp-cnf-3-10-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-nsp-cnf-3-10-pytorch" 2 | model_type: "p-nd-np" 3 | version: "2.0" 4 | has_meta_data: true 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/cnf-10.json"] 7 | test_path: [["../../datasets/SAT/cnf-gc-6-10-test.json"], 8 | ["../../datasets/SAT/cnf-10.json"], 9 | ["../../datasets/SAT/cnf-20.json"], 10 | ["../../datasets/SAT/cnf-40.json"], 11 | ["../../datasets/SAT/cnf-60.json"], 12 | ["../../datasets/SAT/cnf-80.json"]] 13 | model_path: "../../Trained-models/SAT" 14 | repetition_num: 1 15 | train_epoch_size: 200000 16 | epoch_num: 500 17 | label_dim: 1 18 | edge_feature_dim: 1 19 | meta_feature_dim: 1 20 | error_dim: 3 21 | metric_index: 0 22 | prediction_dim: 1 23 | hidden_dim: 150 # 110 24 | mem_hidden_dim: 50 25 | agg_hidden_dim: 50 # 135 26 | mem_agg_hidden_dim: 50 # 50 27 | classifier_dim: 50 # 30 # 100 28 | batch_size: 5000 29 | learning_rate: 0.0001 30 | exploration: 0.1 31 | verbose: true 32 | randomized: true 33 | train_inner_recurrence_num: 1 34 | train_outer_recurrence_num: 10 35 | test_recurrence_num: 20 36 | max_cache_size: 100000 37 | dropout: 0.2 38 | clip_norm: 0.65 39 | weight_decay: 0.0000000001 40 | loss_sharpness: 5 41 | train_batch_limit: 4000000 42 | test_batch_limit: 40000000 43 | min_n: 10 44 | max_n: 100 45 | min_alpha: 2 46 | max_alpha: 10 47 | min_k: 2 48 | max_k: 10 49 | local_search_iteration: 10 50 | epsilon: 0.05 51 | lambda: 1 52 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-gcnf-10-100-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-gcnf-10-100-pytorch" 2 | model_type: "np-nd-np" 3 | version: "2.0g" 4 | has_meta_data: false # true 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/cnf-10.json"] 7 | test_path: ["../../datasets/SAT/sat-race-2015.json"] #["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/4SAT-subsample"] #["../../datasets/SAT/M0-4SAT-100"] # ["../../datasets/SAT/3SAT-100"] # ["../../datasets/SAT/4SAT-100"] # ["../../datasets/SAT/cnf-10.json"] # 8 | model_path: "../../Trained-models/SAT" 9 | repetition_num: 1 10 | train_epoch_size: 40000 11 | epoch_num: 500 12 | label_dim: 1 13 | edge_feature_dim: 1 14 | meta_feature_dim: 0 # 1 15 | error_dim: 3 16 | metric_index: 0 17 | prediction_dim: 1 18 | hidden_dim: 150 # 110 19 | mem_hidden_dim: 100 20 | agg_hidden_dim: 100 # 135 21 | mem_agg_hidden_dim: 50 # 50 22 | classifier_dim: 50 # 100 23 | batch_size: 5000 24 | learning_rate: 0.0001 25 | exploration: 0.1 26 | verbose: true 27 | randomized: true 28 | train_inner_recurrence_num: 1 29 | train_outer_recurrence_num: 10 30 | test_recurrence_num: 8800 31 | max_cache_size: 100000 32 | dropout: 0.2 33 | clip_norm: 0.65 34 | weight_decay: 0.0000000001 35 | loss_sharpness: 5 36 | train_batch_limit: 4000000 37 | test_batch_limit: 40000000 38 | generator: "uniform" 39 | min_n: 4 40 | max_n: 100 41 | min_alpha: 2 42 | max_alpha: 10 43 | min_k: 2 44 | max_k: 10 45 | local_search_iteration: 1000 46 | epsilon: 0.5 47 | lambda: 1 48 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-modular-variable-pytorch-2.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-modular-variable-pytorch-2" 2 | model_type: "np-nd-np" 3 | version: "2.0g" 4 | has_meta_data: false # true 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] 7 | test_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/M0-4SAT-100"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/sat-race-test.json"] # ["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/3SAT-100"] # 8 | model_path: "../../Trained-models/SAT" 9 | repetition_num: 1 10 | train_epoch_size: 40000 11 | epoch_num: 500 12 | label_dim: 1 13 | edge_feature_dim: 1 14 | meta_feature_dim: 0 # 1 15 | error_dim: 3 16 | metric_index: 0 17 | prediction_dim: 1 18 | hidden_dim: 200 # 110 19 | mem_hidden_dim: 150 20 | agg_hidden_dim: 150 # 135 21 | mem_agg_hidden_dim: 100 # 50 22 | classifier_dim: 100 # 100 23 | batch_size: 5000 24 | learning_rate: 0.0001 25 | exploration: 0.1 26 | verbose: true 27 | randomized: true 28 | train_inner_recurrence_num: 1 29 | train_outer_recurrence_num: 10 30 | test_recurrence_num: 200 31 | max_cache_size: 100000 32 | dropout: 0.2 33 | clip_norm: 0.65 34 | weight_decay: 0.0000000001 35 | loss_sharpness: 5 36 | train_batch_limit: 4000000 37 | test_batch_limit: 40000000 38 | generator: "v-modular" 39 | min_n: 5 40 | max_n: 100 41 | min_alpha: 5 42 | max_alpha: 10 43 | min_k: 2 44 | max_k: 10 45 | min_q: 0.8 46 | max_q: 0.9 47 | min_c: 10 48 | max_c: 20 49 | local_search_iteration: 0 50 | epsilon: 0.5 51 | lambda: 0.95 52 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-ndec-cnf-3-10-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-ndec-cnf-3-10-pytorch" 2 | model_type: "np-d-np" 3 | version: "0.0" 4 | has_meta_data: false 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/cnf-10.json"] 7 | test_path: [["../../datasets/SAT/cnf-gc-6-10-test.json"], 8 | ["../../datasets/SAT/cnf-10.json"], 9 | ["../../datasets/SAT/cnf-20.json"], 10 | ["../../datasets/SAT/cnf-40.json"], 11 | ["../../datasets/SAT/cnf-60.json"], 12 | ["../../datasets/SAT/cnf-80.json"]] 13 | model_path: "../../Trained-models/SAT" 14 | repetition_num: 1 15 | train_epoch_size: 200000 16 | epoch_num: 500 17 | label_dim: 1 18 | edge_feature_dim: 1 19 | meta_feature_dim: 0 20 | error_dim: 3 21 | metric_index: 0 22 | prediction_dim: 1 23 | hidden_dim: 150 # 110 24 | mem_hidden_dim: 100 25 | agg_hidden_dim: 100 # 135 26 | mem_agg_hidden_dim: 50 # 50 27 | classifier_dim: 50 # 100 28 | batch_size: 5000 29 | learning_rate: 0.0001 30 | exploration: 0.1 31 | verbose: true 32 | randomized: true 33 | train_inner_recurrence_num: 1 34 | train_outer_recurrence_num: 10 35 | test_recurrence_num: 20 36 | max_cache_size: 100000 37 | dropout: 0.2 38 | clip_norm: 0.65 39 | weight_decay: 0.0000000001 40 | loss_sharpness: 5 41 | train_batch_limit: 4000000 42 | test_batch_limit: 40000000 43 | generator: "uniform" 44 | min_n: 10 45 | max_n: 100 46 | min_alpha: 2 47 | max_alpha: 10 48 | min_k: 2 49 | max_k: 10 50 | local_search_iteration: 0 51 | epsilon: 0.05 52 | tolerance: 0.02 53 | t_max: 10 54 | lambda: 1 55 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-modular-variable-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-modular-variable-pytorch" 2 | model_type: "np-nd-np" 3 | version: "1.0g" 4 | has_meta_data: false # true 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] 7 | test_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/M0-4SAT-100"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/sat-race-test.json"] # ["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/3SAT-100"] # 8 | model_path: "../../Trained-models/SAT" 9 | repetition_num: 1 10 | train_epoch_size: 40000 11 | epoch_num: 500 12 | label_dim: 1 13 | edge_feature_dim: 1 14 | meta_feature_dim: 0 # 1 15 | error_dim: 3 16 | metric_index: 0 17 | prediction_dim: 1 18 | hidden_dim: 150 # 110 19 | mem_hidden_dim: 100 20 | agg_hidden_dim: 100 # 135 21 | mem_agg_hidden_dim: 50 # 50 22 | classifier_dim: 50 # 100 23 | batch_size: 5000 24 | learning_rate: 0.0001 25 | exploration: 0.1 26 | verbose: true 27 | randomized: true 28 | train_inner_recurrence_num: 1 29 | train_outer_recurrence_num: 10 30 | test_recurrence_num: 200 31 | max_cache_size: 100000 32 | dropout: 0.2 33 | clip_norm: 0.65 34 | weight_decay: 0.0000000001 35 | loss_sharpness: 5 36 | train_batch_limit: 3500000 37 | test_batch_limit: 40000000 38 | generator: "v-modular" 39 | min_n: 5 40 | max_n: 100 41 | min_alpha: 5 42 | max_alpha: 10 43 | min_k: 2 44 | max_k: 10 45 | min_q: 0.8 46 | max_q: 0.9 47 | min_c: 10 48 | max_c: 20 49 | local_search_iteration: 0 50 | epsilon: 0.5 51 | lambda: 1 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # To make sure pylint is happy with the tests, run: 4 | # python3 setup.py develop 5 | # every time we add or remove new top-level packages to the project. 6 | 7 | import os 8 | from setuptools import setup 9 | 10 | 11 | def _read(fname): 12 | with open(os.path.join(os.path.dirname(__file__), fname)) as stream: 13 | return stream.read() 14 | 15 | 16 | setup( 17 | name='PDP_Solver', 18 | author="Saeed Amizadeh", 19 | author_email="saamizad@microsoft.com", 20 | description="PDP Framework for Neural Constraint Satisfaction Solving", 21 | long_description=_read('./README.md'), 22 | long_description_content_type='text/markdown', 23 | keywords="pdp sat solver pytorch neurosymbolic", 24 | license="MIT", 25 | classifiers=[ 26 | # Trove classifiers 27 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python', 30 | 'Programming Language :: Python :: 3', 31 | 'Programming Language :: Python :: 3.5', 32 | 'Programming Language :: Python :: Implementation :: CPython', 33 | ], 34 | url="https://github/Microsoft/PDP-Solver", 35 | version='0.1', 36 | python_requires=">=3.5", 37 | install_requires=[ 38 | "numpy >= 1.10", 39 | "torch >= 0.4", 40 | ], 41 | package_dir={"": "src"}, 42 | packages=[ 43 | "pdp", 44 | "pdp.factorgraph", 45 | "pdp.nn" 46 | ], 47 | scripts=[ 48 | "src/satyr.py", 49 | "src/satyr-train-test.py", 50 | "src/dimacs2json.py", 51 | "src/pdp/generator.py" 52 | ] 53 | ) 54 | -------------------------------------------------------------------------------- /config/Train/p-prodec2-modular-4SAT-pytorch.yaml: -------------------------------------------------------------------------------- 1 | model_name: "p-prodec2-modular-4SAT-pytorch" 2 | model_type: "np-nd-np" 3 | version: "5.0g" 4 | has_meta_data: false # true 5 | train_path: ["../../datasets/SAT/p-cnf-3-10-large.json"] 6 | validation_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] 7 | test_path: ["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/M-4SAT-validation-2/M-4SAT-validation-2_0_7.0_10.0.json"] #["../../datasets/SAT/M0-4SAT-100"] #["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/sat-race-test.json"] # ["../../datasets/SAT/cnf-10.json"] #["../../datasets/SAT/3SAT-100"] # 8 | model_path: "../../Trained-models/SAT" 9 | repetition_num: 1 10 | train_epoch_size: 40000 11 | epoch_num: 500 12 | label_dim: 1 13 | edge_feature_dim: 1 14 | meta_feature_dim: 0 # 1 15 | error_dim: 3 16 | metric_index: 0 17 | prediction_dim: 1 18 | hidden_dim: 150 # 110 19 | mem_hidden_dim: 100 20 | agg_hidden_dim: 100 # 135 21 | mem_agg_hidden_dim: 50 # 50 22 | classifier_dim: 50 # 100 23 | batch_size: 5000 24 | learning_rate: 0.00001 # 0.0001 25 | exploration: 0.05 # 0.1 26 | verbose: true 27 | randomized: true 28 | train_inner_recurrence_num: 1 29 | train_outer_recurrence_num: 10 30 | test_recurrence_num: 20 31 | max_cache_size: 100000 32 | dropout: 0.2 33 | clip_norm: 0.65 34 | weight_decay: 0.0000000001 35 | loss_sharpness: 5 36 | train_batch_limit: 4000000 37 | test_batch_limit: 40000000 38 | generator: "modular" 39 | min_n: 5 40 | max_n: 100 41 | min_alpha: 5 42 | max_alpha: 10 43 | min_k: 4 44 | min_q: 0.8 45 | max_q: 0.9 46 | min_c: 10 47 | max_c: 20 48 | local_search_iteration: 0 49 | epsilon: 0.5 50 | lambda: 1 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | /.vscode/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/satyr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Main script to run a trained PDP solver against a test dataset. 4 | """ 5 | 6 | # Copyright (c) Microsoft. All rights reserved. 7 | # Licensed under the MIT license. See LICENSE.md file 8 | # in the project root for full license information. 9 | 10 | import argparse 11 | import yaml, os, logging, sys 12 | import numpy as np 13 | import torch 14 | from datetime import datetime 15 | 16 | import dimacs2json 17 | 18 | from pdp.trainer import SatFactorGraphTrainer 19 | 20 | 21 | def run(config, logger, output): 22 | "Runs the prediction engine." 23 | 24 | np.random.seed(config['random_seed']) 25 | torch.manual_seed(config['random_seed']) 26 | 27 | if config['verbose']: 28 | logger.info("Building the computational graph...") 29 | 30 | predicter = SatFactorGraphTrainer(config=config, use_cuda=not config['cpu_mode'], logger=logger) 31 | 32 | if config['verbose']: 33 | logger.info("Starting the prediction phase...") 34 | 35 | predicter._counter = 0 36 | if output == '': 37 | predicter.predict(test_list=config['test_path'], out_file=sys.stdout, import_path_base=config['model_path'], 38 | post_processor=predicter._post_process_predictions, batch_replication=config['batch_replication']) 39 | else: 40 | with open(output, 'w') as file: 41 | predicter.predict(test_list=config['test_path'], out_file=file, import_path_base=config['model_path'], 42 | post_processor=predicter._post_process_predictions, batch_replication=config['batch_replication']) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('model_config', help='The model configuration yaml file') 48 | parser.add_argument('test_path', help='The input test path') 49 | parser.add_argument('test_recurrence_num', help='The number of iterations for the PDP', type=int) 50 | parser.add_argument('-b', '--batch_replication', help='Batch replication factor', type=int, default=1) 51 | parser.add_argument('-z', '--batch_size', help='Batch size', type=int, default=5000) 52 | parser.add_argument('-m', '--max_cache_size', help='Maximum cache size', type=int, default=100000) 53 | parser.add_argument('-l', '--test_batch_limit', help='Memory limit for mini-batches', type=int, default=40000000) 54 | parser.add_argument('-w', '--local_search_iteration', help='Number of iterations for post-processing local search', type=int, default=100) 55 | parser.add_argument('-e', '--epsilon', help='Epsilon probablity for post-processing local search', type=float, default=0.5) 56 | parser.add_argument('-v', '--verbose', help='Verbose', action='store_true') 57 | parser.add_argument('-c', '--cpu_mode', help='Run on CPU', action='store_true') 58 | parser.add_argument('-d', '--dimacs', help='The input folder contains DIMACS files', action='store_true') 59 | parser.add_argument('-s', '--random_seed', help='Random seed', type=int, default=int(datetime.now().microsecond)) 60 | parser.add_argument('-o', '--output', help='The JSON output file', default='') 61 | 62 | args = vars(parser.parse_args()) 63 | 64 | # Load the model config 65 | with open(args['model_config'], 'r') as f: 66 | model_config = yaml.load(f) 67 | 68 | # Set the logger 69 | format = '[%(levelname)s] %(asctime)s - %(name)s: %(message)s' 70 | logging.basicConfig(level=logging.DEBUG, format=format) 71 | logger = logging.getLogger(model_config['model_name']) 72 | 73 | # Convert DIMACS input files into JSON 74 | if args['dimacs']: 75 | if args['verbose']: 76 | logger.info("Converting DIMACS files into JSON...") 77 | temp_file_name = 'temp_problem_file.json' 78 | 79 | if os.path.isfile(args['test_path']): 80 | head, _ = os.path.split(args['test_path']) 81 | temp_file_name = os.path.join(head, temp_file_name) 82 | dimacs2json.convert_file(args['test_path'], temp_file_name, False) 83 | else: 84 | temp_file_name = os.path.join(args['test_path'], temp_file_name) 85 | dimacs2json.convert_directory(args['test_path'], temp_file_name, False) 86 | 87 | args['test_path'] = temp_file_name 88 | 89 | # Merge model config and other arguments into one config dict 90 | config = {**model_config, **args} 91 | 92 | if config['model_type'] == 'p-d-p' or config['model_type'] == 'walk-sat' or config['model_type'] == 'reinforce': 93 | config['model_path'] = None 94 | config['hidden_dim'] = 3 95 | 96 | if config['model_type'] == 'walk-sat': 97 | config['local_search_iteration'] = config['test_recurrence_num'] 98 | 99 | config['dropout'] = 0 100 | config['error_dim'] = 1 101 | config['exploration'] = 0 102 | 103 | # Run the prediction engine 104 | run(config, logger, config['output']) 105 | 106 | if args['dimacs']: 107 | os.remove(temp_file_name) 108 | 109 | print('') 110 | -------------------------------------------------------------------------------- /src/dimacs2json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Auxiliary script for converting sets of DIMACS files into PDP's compact JSON format. 4 | """ 5 | 6 | # Copyright (c) Microsoft. All rights reserved. 7 | # 8 | # Licensed under the MIT license. See LICENSE.md file 9 | # in the project root for full license information. 10 | 11 | import sys 12 | import argparse 13 | 14 | from os import listdir 15 | from os.path import isfile, join, split, splitext 16 | 17 | import numpy as np 18 | 19 | 20 | class CompactDimacs: 21 | "Encapsulates a CNF file given in the DIMACS format." 22 | 23 | def __init__(self, dimacs_file, output, propagate): 24 | 25 | self.propagate = propagate 26 | self.file_name = split(dimacs_file)[1] 27 | 28 | with open(dimacs_file, 'r') as f: 29 | j = 0 30 | for line in f: 31 | seg = line.split(" ") 32 | if seg[0] == 'c': 33 | continue 34 | 35 | if seg[0] == 'p': 36 | var_num = int(seg[2]) 37 | clause_num = int(seg[3]) 38 | self._clause_mat = np.zeros((clause_num, var_num), dtype=np.int32) 39 | 40 | elif len(seg) <= 1: 41 | continue 42 | else: 43 | temp = np.array(seg[:-1], dtype=np.int32) 44 | self._clause_mat[j, np.abs(temp) - 1] = np.sign(temp) 45 | j += 1 46 | 47 | ind = np.where(np.sum(np.abs(self._clause_mat), 1) > 0)[0] 48 | self._clause_mat = self._clause_mat[ind, :] 49 | 50 | ind = np.where(np.sum(np.abs(self._clause_mat), 0) > 0)[0] 51 | self._clause_mat = self._clause_mat[:, ind] 52 | 53 | if propagate: 54 | self._clause_mat = self._propagate_constraints(self._clause_mat) 55 | 56 | self._output = output 57 | 58 | def _propagate_constraints(self, clause_mat): 59 | n = clause_mat.shape[0] 60 | if n < 2: 61 | return clause_mat 62 | 63 | length = np.tile(np.sum(np.abs(clause_mat), 1), (n, 1)) 64 | intersection_len = np.matmul(clause_mat, np.transpose(clause_mat)) 65 | 66 | temp = intersection_len == np.transpose(length) 67 | temp *= np.tri(*temp.shape, k=-1, dtype=bool) 68 | flags = np.logical_not(np.any(temp, 0)) 69 | 70 | clause_mat = clause_mat[flags, :] 71 | 72 | n = clause_mat.shape[0] 73 | if n < 2: 74 | return clause_mat 75 | 76 | length = np.tile(np.sum(np.abs(clause_mat), 1), (n, 1)) 77 | intersection_len = np.matmul(clause_mat, np.transpose(clause_mat)) 78 | 79 | temp = intersection_len == length 80 | temp *= np.tri(*temp.shape, k=-1, dtype=bool) 81 | flags = np.logical_not(np.any(temp, 1)) 82 | 83 | return clause_mat[flags, :] 84 | 85 | def to_json(self): 86 | clause_list = [] 87 | clause_num, var_num = self._clause_mat.shape 88 | 89 | ind = np.nonzero(self._clause_mat) 90 | return [[var_num, clause_num], list((ind[1] + 1) * self._clause_mat[ind]), 91 | list(ind[0] + 1), self._output, [self.file_name]] 92 | 93 | 94 | def convert_directory(dimacs_dir, output_file, propagate, only_positive=False): 95 | file_list = [join(dimacs_dir, f) for f in listdir(dimacs_dir) if isfile(join(dimacs_dir, f))] 96 | 97 | with open(output_file, 'w') as f: 98 | for i in range(len(file_list)): 99 | name, ext = splitext(file_list[i]) 100 | ext = ext.lower() 101 | 102 | if ext != '.dimacs' and ext != '.cnf': 103 | continue 104 | 105 | label = float(name[-1]) if name[-1].isdigit() else -1 106 | 107 | if only_positive and label == 0: 108 | continue 109 | 110 | bc = CompactDimacs(file_list[i], label, propagate) 111 | f.write(str(bc.to_json()).replace("'", '"') + '\n') 112 | print("Generating JSON input file: %6.2f%% complete..." % ( 113 | (i + 1) * 100.0 / len(file_list)), end='\r', file=sys.stderr) 114 | 115 | 116 | def convert_file(file_name, output_file, propagate): 117 | with open(output_file, 'w') as f: 118 | if len(file_name) < 8: 119 | label = -1 120 | else: 121 | temp = file_name[-8] 122 | label = float(temp) if temp.isdigit() else -1 123 | 124 | bc = CompactDimacs(file_name, label, propagate) 125 | f.write(str(bc.to_json()).replace("'", '"') + '\n') 126 | 127 | 128 | if __name__ == '__main__': 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument('in_dir', action='store', type=str) 131 | parser.add_argument('out_file', action='store', type=str) 132 | parser.add_argument('-s', '--simplify', help='Propagate binary constraints', required=False, action='store_true', default=False) 133 | parser.add_argument('-p', '--positive', help='Output only positive examples', required=False, action='store_true', default=False) 134 | args = vars(parser.parse_args()) 135 | 136 | convert_directory(args['in_dir'], args['out_file'], args['simplify'], args['positive']) 137 | -------------------------------------------------------------------------------- /src/satyr-train-test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """The main entry point to the PDP trainer/tester/predictor.""" 3 | 4 | # Copyright (c) Microsoft. All rights reserved. 5 | # Licensed under the MIT license. See LICENSE.md file 6 | # in the project root for full license information. 7 | 8 | import numpy as np 9 | import torch 10 | import torch.optim as optim 11 | import logging 12 | import argparse, os, yaml, csv 13 | 14 | from pdp.generator import * 15 | from pdp.trainer import SatFactorGraphTrainer 16 | 17 | 18 | ########################################################################################################################## 19 | 20 | def write_to_csv(result_list, file_path): 21 | with open(file_path, mode='w', newline='') as f: 22 | writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 23 | 24 | for row in result_list: 25 | writer.writerow([row[0], row[1][1, 0]]) 26 | 27 | def write_to_csv_time(result_list, file_path): 28 | with open(file_path, mode='w', newline='') as f: 29 | writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 30 | 31 | for row in result_list: 32 | writer.writerow([row[0], row[2]]) 33 | 34 | def run(random_seed, config_file, is_training, load_model, cpu, reset_step, use_generator, batch_replication): 35 | "Runs the train/test/predict procedures." 36 | 37 | if not use_generator: 38 | np.random.seed(random_seed) 39 | torch.manual_seed(random_seed) 40 | 41 | # Set the configurations (from either JSON or YAML file) 42 | with open(config_file, 'r') as f: 43 | config = yaml.load(f) 44 | 45 | # Set the logger 46 | format = '[%(levelname)s] %(asctime)s - %(name)s: %(message)s' 47 | logging.basicConfig(level=logging.DEBUG, format=format) 48 | logger = logging.getLogger(config['model_name'] + ' (' + config['version'] + ')') 49 | 50 | # Check if the input path is a list or on 51 | if not isinstance(config['train_path'], list): 52 | config['train_path'] = [os.path.join(config['train_path'], f) \ 53 | for f in os.listdir(config['train_path']) if os.path.isfile(os.path.join(config['train_path'], f)) and f.endswith('.json')] 54 | 55 | if not isinstance(config['validation_path'], list): 56 | config['validation_path'] = [os.path.join(config['validation_path'], f) \ 57 | for f in os.listdir(config['validation_path']) if os.path.isfile(os.path.join(config['validation_path'], f)) and f.endswith('.json')] 58 | 59 | if config['verbose']: 60 | if use_generator: 61 | logger.info("Generating training examples via %s generator." % config['generator']) 62 | else: 63 | logger.info("Training file(s): %s" % config['train_path']) 64 | logger.info("Validation file(s): %s" % config['validation_path']) 65 | 66 | best_model_path_base = os.path.join(os.path.relpath(config['model_path']), 67 | config['model_name'], config['version'], "best") 68 | 69 | last_model_path_base = os.path.join(os.path.relpath(config['model_path']), 70 | config['model_name'], config['version'], "last") 71 | 72 | if not os.path.exists(best_model_path_base): 73 | os.makedirs(best_model_path_base) 74 | 75 | if not os.path.exists(last_model_path_base): 76 | os.makedirs(last_model_path_base) 77 | 78 | trainer = SatFactorGraphTrainer(config=config, use_cuda=not cpu, logger=logger) 79 | 80 | # Training 81 | if is_training: 82 | if config['verbose']: 83 | logger.info("Starting the training phase...") 84 | 85 | generator = None 86 | 87 | if use_generator: 88 | if config['generator'] == 'modular': 89 | generator = ModularCNFGenerator(config['min_k'], config['min_n'], config['max_n'], config['min_q'], 90 | config['max_q'], config['min_c'], config['max_c'], config['min_alpha'], config['max_alpha']) 91 | elif config['generator'] == 'v-modular': 92 | generator = VariableModularCNFGenerator(config['min_k'], config['max_k'], config['min_n'], config['max_n'], config['min_q'], 93 | config['max_q'], config['min_c'], config['max_c'], config['min_alpha'], config['max_alpha']) 94 | else: 95 | generator = UniformCNFGenerator(config['min_n'], config['max_n'], config['min_k'], config['max_k'], config['min_alpha'], config['max_alpha']) 96 | 97 | model_list, errors, losses = trainer.train( 98 | train_list=config['train_path'], validation_list=config['validation_path'], 99 | optimizer=optim.Adam(trainer.get_parameter_list(), lr=config['learning_rate'], 100 | weight_decay=config['weight_decay']), last_export_path_base=last_model_path_base, 101 | best_export_path_base=best_model_path_base, metric_index=config['metric_index'], 102 | load_model=load_model, reset_step=reset_step, generator=generator, 103 | train_epoch_size=config['train_epoch_size']) 104 | 105 | if config['verbose']: 106 | logger.info("Starting the test phase...") 107 | 108 | for test_files in config['test_path']: 109 | if config['verbose']: 110 | logger.info("Testing " + test_files) 111 | 112 | if load_model == "last": 113 | import_path_base = last_model_path_base 114 | elif load_model == "best": 115 | import_path_base = best_model_path_base 116 | else: 117 | import_path_base = None 118 | 119 | result = trainer.test(test_list=test_files, import_path_base=import_path_base, 120 | batch_replication=batch_replication) 121 | 122 | if config['verbose']: 123 | for row in result: 124 | filename, errors, _ = row 125 | print('Dataset: ' + filename) 126 | print("Accuracy: \t%s" % (1 - errors[0])) 127 | print("Recall: \t%s" % (1 - errors[1])) 128 | 129 | if os.path.isdir(test_files): 130 | write_to_csv(result, os.path.join(test_files, config['model_type'] + '_' + config['model_name'] + '_' + config['version'] + '-results.csv')) 131 | write_to_csv_time(result, os.path.join(test_files, config['model_type'] + '_' + config['model_name'] + '_' + config['version'] + '-results-time.csv')) 132 | 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('config', help='The configuration JSON file') 137 | parser.add_argument('-t', '--test', help='The test mode', action='store_true') 138 | parser.add_argument('-l', '--load_model', help='Load the previous model') 139 | parser.add_argument('-c', '--cpu_mode', help='Run on CPU', action='store_true') 140 | parser.add_argument('-r', '--reset', help='Reset the global step', action='store_true') 141 | parser.add_argument('-g', '--use_generator', help='Reset the global step', action='store_true') 142 | parser.add_argument('-b', '--batch_replication', help='Batch replication factor', type=int, default=1) 143 | 144 | args = parser.parse_args() 145 | run(0, args.config, not args.test, args.load_model, 146 | args.cpu_mode, args.reset, args.use_generator, args.batch_replication) 147 | -------------------------------------------------------------------------------- /src/pdp/factorgraph/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 3 | 4 | # factor_graph_input_pipeline.py : Defines the input pipeline for the PDP framework. 5 | 6 | import linecache, json 7 | import collections 8 | import numpy as np 9 | 10 | import torch 11 | import torch.utils.data as data 12 | 13 | from os import listdir 14 | from os.path import isfile, join 15 | 16 | 17 | class DynamicBatchDivider(object): 18 | "Implements the dynamic batching process." 19 | 20 | def __init__(self, limit, hidden_dim): 21 | self.limit = limit 22 | self.hidden_dim = hidden_dim 23 | 24 | def divide(self, variable_num, function_num, graph_map, edge_feature, graph_feature, label, misc_data): 25 | batch_size = len(variable_num) 26 | edge_num = [len(n) for n in edge_feature] 27 | 28 | graph_map_list = [] 29 | edge_feature_list = [] 30 | graph_feature_list = [] 31 | variable_num_list = [] 32 | function_num_list = [] 33 | label_list = [] 34 | misc_data_list = [] 35 | 36 | if (self.limit // (max(edge_num) * self.hidden_dim)) >= batch_size: 37 | if graph_feature[0] is None: 38 | graph_feature_list = [[None]] 39 | else: 40 | graph_feature_list = [graph_feature] 41 | 42 | graph_map_list = [graph_map] 43 | edge_feature_list = [edge_feature] 44 | variable_num_list = [variable_num] 45 | function_num_list = [function_num] 46 | label_list = [label] 47 | misc_data_list = [misc_data] 48 | 49 | else: 50 | 51 | indices = sorted(range(len(edge_num)), reverse=True, key=lambda k: edge_num[k]) 52 | sorted_edge_num = sorted(edge_num, reverse=True) 53 | 54 | i = 0 55 | 56 | while i < batch_size: 57 | allowed_batch_size = self.limit // (sorted_edge_num[i] * self.hidden_dim) 58 | ind = indices[i:min(i + allowed_batch_size, batch_size)] 59 | 60 | if graph_feature[0] is None: 61 | graph_feature_list += [[None]] 62 | else: 63 | graph_feature_list += [[graph_feature[j] for j in ind]] 64 | 65 | edge_feature_list += [[edge_feature[j] for j in ind]] 66 | variable_num_list += [[variable_num[j] for j in ind]] 67 | function_num_list += [[function_num[j] for j in ind]] 68 | graph_map_list += [[graph_map[j] for j in ind]] 69 | label_list += [[label[j] for j in ind]] 70 | misc_data_list += [[misc_data[j] for j in ind]] 71 | 72 | i += allowed_batch_size 73 | 74 | return variable_num_list, function_num_list, graph_map_list, edge_feature_list, graph_feature_list, label_list, misc_data_list 75 | 76 | 77 | ############################################################### 78 | 79 | 80 | class FactorGraphDataset(data.Dataset): 81 | "Implements a PyTorch Dataset class for reading and parsing CNFs in the JSON format from disk." 82 | 83 | def __init__(self, input_file, limit, hidden_dim, max_cache_size=100000, generator=None, epoch_size=0, batch_replication=1): 84 | 85 | self._cache = collections.OrderedDict() 86 | self._generator = generator 87 | self._epoch_size = epoch_size 88 | self._input_file = input_file 89 | self._max_cache_size = max_cache_size 90 | 91 | if self._generator is None: 92 | with open(self._input_file, 'r') as fh_input: 93 | self._row_num = len(fh_input.readlines()) 94 | 95 | self.batch_divider = DynamicBatchDivider(limit // batch_replication, hidden_dim) 96 | 97 | def __len__(self): 98 | if self._generator is not None: 99 | return self._epoch_size 100 | else: 101 | return self._row_num 102 | 103 | def __getitem__(self, idx): 104 | if self._generator is not None: 105 | return self._generator.generate() 106 | 107 | else: 108 | if idx in self._cache: 109 | return self._cache[idx] 110 | 111 | line = linecache.getline(self._input_file, idx + 1) 112 | result = self._convert_line(line) 113 | 114 | if len(self._cache) >= self._max_cache_size: 115 | self._cache.popitem(last=False) 116 | 117 | self._cache[idx] = result 118 | return result 119 | 120 | def _convert_line(self, json_str): 121 | 122 | input_data = json.loads(json_str) 123 | variable_num, function_num = input_data[0] 124 | 125 | variable_ind = np.abs(np.array(input_data[1], dtype=np.int32)) - 1 126 | function_ind = np.abs(np.array(input_data[2], dtype=np.int32)) - 1 127 | edge_feature = np.sign(np.array(input_data[1], dtype=np.float32)) 128 | 129 | graph_map = np.stack((variable_ind, function_ind)) 130 | alpha = float(function_num) / variable_num 131 | 132 | misc_data = [] 133 | if len(input_data) > 4: 134 | misc_data = input_data[4] 135 | 136 | return (variable_num, function_num, graph_map, edge_feature, None, float(input_data[3]), misc_data) 137 | 138 | def dag_collate_fn(self, input_data): 139 | "Torch dataset loader collation function for factor graph input." 140 | 141 | vn, fn, gm, ef, gf, l, md = zip(*input_data) 142 | 143 | variable_num, function_num, graph_map, edge_feature, graph_feat, label, misc_data = \ 144 | self.batch_divider.divide(vn, fn, gm, ef, gf, l, md) 145 | segment_num = len(variable_num) 146 | 147 | graph_feat_batch = [] 148 | graph_map_batch = [] 149 | batch_variable_map_batch = [] 150 | batch_function_map_batch = [] 151 | edge_feature_batch = [] 152 | label_batch = [] 153 | 154 | for i in range(segment_num): 155 | 156 | # Create the graph features batch 157 | graph_feat_batch += [None if graph_feat[i][0] is None else torch.from_numpy(np.stack(graph_feat[i])).float()] 158 | 159 | # Create the edge feature batch 160 | edge_feature_batch += [torch.from_numpy(np.expand_dims(np.concatenate(edge_feature[i]), 1)).float()] 161 | 162 | # Create the label batch 163 | label_batch += [torch.from_numpy(np.expand_dims(np.array(label[i]), 1)).float()] 164 | 165 | # Create the graph map, variable map and function map batches 166 | g_map_b = np.zeros((2, 0), dtype=np.int32) 167 | v_map_b = np.zeros(0, dtype=np.int32) 168 | f_map_b = np.zeros(0, dtype=np.int32) 169 | variable_ind = 0 170 | function_ind = 0 171 | 172 | for j in range(len(graph_map[i])): 173 | graph_map[i][j][0, :] += variable_ind 174 | graph_map[i][j][1, :] += function_ind 175 | g_map_b = np.concatenate((g_map_b, graph_map[i][j]), axis=1) 176 | 177 | v_map_b = np.concatenate((v_map_b, np.tile(j, variable_num[i][j]))) 178 | f_map_b = np.concatenate((f_map_b, np.tile(j, function_num[i][j]))) 179 | 180 | variable_ind += variable_num[i][j] 181 | function_ind += function_num[i][j] 182 | 183 | graph_map_batch += [torch.from_numpy(g_map_b).int()] 184 | batch_variable_map_batch += [torch.from_numpy(v_map_b).int()] 185 | batch_function_map_batch += [torch.from_numpy(f_map_b).int()] 186 | 187 | return graph_map_batch, batch_variable_map_batch, batch_function_map_batch, edge_feature_batch, graph_feat_batch, label_batch, misc_data 188 | 189 | @staticmethod 190 | def get_loader(input_file, limit, hidden_dim, batch_size, shuffle, num_workers, 191 | max_cache_size=100000, use_cuda=True, generator=None, epoch_size=0, batch_replication=1): 192 | "Return the torch dataset loader object for the input." 193 | 194 | dataset = FactorGraphDataset( 195 | input_file=input_file, 196 | limit=limit, 197 | hidden_dim=hidden_dim, 198 | max_cache_size=max_cache_size, 199 | generator=generator, 200 | epoch_size=epoch_size, 201 | batch_replication=batch_replication) 202 | 203 | data_loader = torch.utils.data.DataLoader( 204 | dataset=dataset, 205 | batch_size=batch_size, 206 | shuffle=shuffle, 207 | num_workers=num_workers, 208 | collate_fn=dataset.dag_collate_fn, 209 | pin_memory=use_cuda) 210 | 211 | return data_loader 212 | 213 | 214 | 215 | 216 | 217 | -------------------------------------------------------------------------------- /src/pdp/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 3 | 4 | # PDP_solver_trainer.py : Implements a factor graph trainer for various types of PDP SAT solvers. 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import sys 12 | 13 | from pdp.factorgraph import FactorGraphTrainerBase 14 | from pdp.nn import solver, util 15 | 16 | 17 | ########################################################################################################################## 18 | 19 | 20 | class Perceptron(nn.Module): 21 | "Implements a 1-layer perceptron." 22 | 23 | def __init__(self, input_dimension, hidden_dimension, output_dimension): 24 | super(Perceptron, self).__init__() 25 | self._layer1 = nn.Linear(input_dimension, hidden_dimension) 26 | self._layer2 = nn.Linear(hidden_dimension, output_dimension, bias=False) 27 | 28 | def forward(self, inp): 29 | return F.sigmoid(self._layer2(F.relu(self._layer1(inp)))) 30 | 31 | 32 | ########################################################################################################################## 33 | 34 | class SatFactorGraphTrainer(FactorGraphTrainerBase): 35 | "Implements a factor graph trainer for various types of PDP SAT solvers." 36 | 37 | def __init__(self, config, use_cuda, logger): 38 | super(SatFactorGraphTrainer, self).__init__(config=config, 39 | has_meta_data=False, error_dim=config['error_dim'], loss=None, 40 | evaluator=nn.L1Loss(), use_cuda=use_cuda, logger=logger) 41 | 42 | self._eps = 1e-8 * torch.ones(1, device=self._device) 43 | self._loss_evaluator = util.SatLossEvaluator(alpha = self._config['exploration'], device = self._device) 44 | self._cnf_evaluator = util.SatCNFEvaluator(device = self._device) 45 | self._counter = 0 46 | self._max_coeff = 10.0 47 | 48 | def _build_graph(self, config): 49 | model_list = [] 50 | 51 | if config['model_type'] == 'np-nd-np': 52 | model_list += [solver.NeuralPropagatorDecimatorSolver(device=self._device, name=config['model_name'], 53 | edge_dimension=config['edge_feature_dim'], meta_data_dimension=config['meta_feature_dim'], 54 | propagator_dimension=config['hidden_dim'], decimator_dimension=config['hidden_dim'], 55 | mem_hidden_dimension=config['mem_hidden_dim'], 56 | agg_hidden_dimension=config['agg_hidden_dim'], mem_agg_hidden_dimension=config['mem_agg_hidden_dim'], 57 | prediction_dimension=config['prediction_dim'], 58 | variable_classifier=Perceptron(config['hidden_dim'], config['classifier_dim'], config['prediction_dim']), 59 | function_classifier=None, dropout=config['dropout'], 60 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])] 61 | 62 | elif config['model_type'] == 'p-nd-np': 63 | model_list += [solver.NeuralSurveyPropagatorSolver(device=self._device, name=config['model_name'], 64 | edge_dimension=config['edge_feature_dim'], meta_data_dimension=config['meta_feature_dim'], 65 | decimator_dimension=config['hidden_dim'], 66 | mem_hidden_dimension=config['mem_hidden_dim'], 67 | agg_hidden_dimension=config['agg_hidden_dim'], mem_agg_hidden_dimension=config['mem_agg_hidden_dim'], 68 | prediction_dimension=config['prediction_dim'], 69 | variable_classifier=Perceptron(config['hidden_dim'], config['classifier_dim'], config['prediction_dim']), 70 | function_classifier=None, dropout=config['dropout'], 71 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])] 72 | 73 | elif config['model_type'] == 'np-d-np': 74 | model_list += [solver.NeuralSequentialDecimatorSolver(device=self._device, name=config['model_name'], 75 | edge_dimension=config['edge_feature_dim'], meta_data_dimension=config['meta_feature_dim'], 76 | propagator_dimension=config['hidden_dim'], decimator_dimension=config['hidden_dim'], 77 | mem_hidden_dimension=config['mem_hidden_dim'], 78 | agg_hidden_dimension=config['agg_hidden_dim'], mem_agg_hidden_dimension=config['mem_agg_hidden_dim'], 79 | classifier_dimension=config['classifier_dim'], 80 | dropout=config['dropout'], tolerance=config['tolerance'], t_max=config['t_max'], 81 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])] 82 | 83 | elif config['model_type'] == 'p-d-p': 84 | model_list += [solver.SurveyPropagatorSolver(device=self._device, name=config['model_name'], 85 | tolerance=config['tolerance'], t_max=config['t_max'], 86 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])] 87 | 88 | elif config['model_type'] == 'walk-sat': 89 | model_list += [solver.WalkSATSolver(device=self._device, name=config['model_name'], 90 | iteration_num=config['local_search_iteration'], epsilon=config['epsilon'])] 91 | 92 | elif config['model_type'] == 'reinforce': 93 | model_list += [solver.ReinforceSurveyPropagatorSolver(device=self._device, name=config['model_name'], 94 | pi=config['pi'], decimation_probability=config['decimation_probability'], 95 | local_search_iterations=config['local_search_iteration'], epsilon=config['epsilon'])] 96 | 97 | if config['verbose']: 98 | self._logger.info("The model parameter count is %d." % model_list[0].parameter_count()) 99 | return model_list 100 | 101 | def _compute_loss(self, model, loss, prediction, label, graph_map, batch_variable_map, 102 | batch_function_map, edge_feature, meta_data): 103 | 104 | return self._loss_evaluator(variable_prediction=prediction[0], label=label, graph_map=graph_map, 105 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map, 106 | edge_feature=edge_feature, meta_data=meta_data, global_step=model._global_step, 107 | eps=self._eps, max_coeff=self._max_coeff, loss_sharpness=self._config['loss_sharpness']) 108 | 109 | def _compute_evaluation_metrics(self, model, evaluator, prediction, label, graph_map, 110 | batch_variable_map, batch_function_map, edge_feature, meta_data): 111 | 112 | output, _ = self._cnf_evaluator(variable_prediction=prediction[0], graph_map=graph_map, 113 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map, 114 | edge_feature=edge_feature, meta_data=meta_data) 115 | 116 | recall = torch.sum(label * ((output > 0.5).float() - label).abs()) / torch.max(torch.sum(label), self._eps) 117 | accuracy = evaluator((output > 0.5).float(), label).unsqueeze(0) 118 | loss_value = self._loss_evaluator(variable_prediction=prediction[0], label=label, graph_map=graph_map, 119 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map, 120 | edge_feature=edge_feature, meta_data=meta_data, global_step=model._global_step, 121 | eps=self._eps, max_coeff=self._max_coeff, loss_sharpness=self._config['loss_sharpness']).unsqueeze(0) 122 | 123 | return torch.cat([accuracy, recall, loss_value], 0) 124 | 125 | def _post_process_predictions(self, model, prediction, graph_map, 126 | batch_variable_map, batch_function_map, edge_feature, graph_feat, label, misc_data): 127 | "Formats the prediction and the output solution into JSON format." 128 | 129 | message = "" 130 | labs = label.detach().cpu().numpy() 131 | 132 | res = self._cnf_evaluator(variable_prediction=prediction[0], graph_map=graph_map, 133 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map, 134 | edge_feature=edge_feature, meta_data=graph_feat) 135 | output, unsat_clause_num = [a.detach().cpu().numpy() for a in res] 136 | 137 | for i in range(output.shape[0]): 138 | instance = { 139 | 'ID': misc_data[i][0] if len(misc_data[i]) > 0 else "", 140 | 'label': int(labs[i, 0]), 141 | 'solved': int(output[i].flatten()[0] == 1), 142 | 'unsat_clauses': int(unsat_clause_num[i].flatten()[0]), 143 | 'solution': (prediction[0][batch_variable_map == i, 0].detach().cpu().numpy().flatten() > 0.5).astype(int).tolist() 144 | } 145 | message += (str(instance).replace("'", '"') + "\n") 146 | self._counter += 1 147 | 148 | return message 149 | 150 | def _check_recurrence_termination(self, active, prediction, sat_problem): 151 | "De-actives the CNF examples which the model has already found a SAT solution for." 152 | 153 | output, _ = self._cnf_evaluator(variable_prediction=prediction[0], graph_map=sat_problem._graph_map, 154 | batch_variable_map=sat_problem._batch_variable_map, batch_function_map=sat_problem._batch_function_map, 155 | edge_feature=sat_problem._edge_feature, meta_data=sat_problem._meta_data)#.detach().cpu().numpy() 156 | 157 | if sat_problem._batch_replication > 1: 158 | real_batch = torch.mm(sat_problem._replication_mask_tuple[1], (output > 0.5).float()) 159 | dup_batch = torch.mm(sat_problem._replication_mask_tuple[0], (real_batch == 0).float()) 160 | active[active[:, 0], 0] = (dup_batch[active[:, 0], 0] > 0) 161 | else: 162 | active[active[:, 0], 0] = (output[active[:, 0], 0] <= 0.5) 163 | -------------------------------------------------------------------------------- /src/pdp/nn/pdp_predict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 3 | 4 | # pdp_predict.py : Defines various predictors and scorers for the PDP framework. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from pdp.nn import util 11 | 12 | 13 | ############################################################### 14 | ### The Predictor Classes 15 | ############################################################### 16 | 17 | 18 | class NeuralPredictor(nn.Module): 19 | "Implements the neural predictor." 20 | 21 | def __init__(self, device, decimator_dimension, prediction_dimension, 22 | edge_dimension, meta_data_dimension, mem_hidden_dimension, agg_hidden_dimension, mem_agg_hidden_dimension, 23 | variable_classifier=None, function_classifier=None): 24 | 25 | super(NeuralPredictor, self).__init__() 26 | self._device = device 27 | self._module_list = nn.ModuleList() 28 | 29 | self._variable_classifier = variable_classifier 30 | self._function_classifier = function_classifier 31 | self._hidden_dimension = decimator_dimension 32 | 33 | if variable_classifier is not None: 34 | self._variable_aggregator = util.MessageAggregator(device, decimator_dimension + edge_dimension + meta_data_dimension, 35 | decimator_dimension, mem_hidden_dimension, 36 | mem_agg_hidden_dimension, agg_hidden_dimension, 0, include_self_message=True) 37 | 38 | self._module_list.append(self._variable_aggregator) 39 | self._module_list.append(self._variable_classifier) 40 | 41 | if function_classifier is not None: 42 | self._function_aggregator = util.MessageAggregator(device, decimator_dimension + edge_dimension + meta_data_dimension, 43 | decimator_dimension, mem_hidden_dimension, 44 | mem_agg_hidden_dimension, agg_hidden_dimension, 0, include_self_message=True) 45 | 46 | self._module_list.append(self._function_aggregator) 47 | self._module_list.append(self._function_classifier) 48 | 49 | def forward(self, decimator_state, sat_problem, last_call=False): 50 | 51 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple 52 | b_variable_mask, b_variable_mask_transpose, b_function_mask, b_function_mask_transpose = sat_problem._batch_mask_tuple 53 | 54 | variable_prediction = None 55 | function_prediction = None 56 | 57 | if sat_problem._meta_data is not None: 58 | graph_feat = torch.mm(b_variable_mask, sat_problem._meta_data) 59 | graph_feat = torch.mm(variable_mask_transpose, graph_feat) 60 | 61 | if len(decimator_state) == 3: 62 | decimator_variable_state, decimator_function_state, edge_mask = decimator_state 63 | else: 64 | decimator_variable_state, decimator_function_state = decimator_state 65 | edge_mask = None 66 | 67 | if self._variable_classifier is not None: 68 | 69 | aggregated_variable_state = torch.cat((decimator_variable_state, sat_problem._edge_feature), 1) 70 | 71 | if sat_problem._meta_data is not None: 72 | aggregated_variable_state = torch.cat((aggregated_variable_state, graph_feat), 1) 73 | 74 | aggregated_variable_state = self._variable_aggregator( 75 | aggregated_variable_state, None, variable_mask, variable_mask_transpose, edge_mask) 76 | 77 | variable_prediction = self._variable_classifier(aggregated_variable_state) 78 | 79 | if self._function_classifier is not None: 80 | 81 | aggregated_function_state = torch.cat((decimator_function_state, sat_problem._edge_feature), 1) 82 | 83 | if sat_problem._meta_data is not None: 84 | aggregated_function_state = torch.cat((aggregated_function_state, graph_feat), 1) 85 | 86 | aggregated_function_state = self._function_aggregator( 87 | aggregated_function_state, None, function_mask, function_mask_transpose, edge_mask) 88 | 89 | function_prediction = self._function_classifier(aggregated_function_state) 90 | 91 | return variable_prediction, function_prediction 92 | 93 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication): 94 | 95 | edge_num = graph_map.size(1) * batch_replication 96 | 97 | if randomized: 98 | variable_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0 99 | function_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0 100 | else: 101 | variable_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) 102 | function_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) 103 | 104 | return (variable_state, function_state) 105 | 106 | 107 | ############################################################### 108 | 109 | 110 | class IdentityPredictor(nn.Module): 111 | "Implements the Identity predictor (prediction based on the assignments to the solution property of the SAT problem)." 112 | 113 | def __init__(self, device, random_fill=False): 114 | super(IdentityPredictor, self).__init__() 115 | self._random_fill = random_fill 116 | self._device = device 117 | 118 | def forward(self, decimator_state, sat_problem, last_call=False): 119 | pred = sat_problem._solution.unsqueeze(1) 120 | 121 | if self._random_fill and last_call: 122 | active_var_num = (sat_problem._active_variables[:, 0] > 0).long().sum() 123 | 124 | if active_var_num > 0: 125 | pred[sat_problem._active_variables[:, 0] > 0, 0] = \ 126 | torch.rand(active_var_num.item(), device=self._device) 127 | 128 | return pred, None 129 | 130 | 131 | ############################################################### 132 | 133 | 134 | class SurveyScorer(nn.Module): 135 | "Implements the varaible scoring mechanism for SP-guided decimation." 136 | 137 | def __init__(self, device, message_dimension, include_adaptors=False, pi=0.0): 138 | super(SurveyScorer, self).__init__() 139 | self._device = device 140 | self._include_adaptors = include_adaptors 141 | self._eps = torch.tensor([1e-10], device=self._device) 142 | self._max_logit = torch.tensor([30.0], device=self._device) 143 | self._pi = torch.tensor([pi], dtype=torch.float32, device=device) 144 | 145 | if self._include_adaptors: 146 | self._projector = nn.Linear(message_dimension, 2, bias=False) 147 | self._module_list = nn.ModuleList([self._projector]) 148 | 149 | def safe_log(self, x): 150 | return torch.max(x, self._eps).log() 151 | 152 | def safe_exp(self, x): 153 | return torch.min(x, self._max_logit).exp() 154 | 155 | def forward(self, message_state, sat_problem, last_call=False): 156 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple 157 | b_variable_mask, _, _, _ = sat_problem._batch_mask_tuple 158 | p_variable_mask, _, _, _ = sat_problem._pos_mask_tuple 159 | n_variable_mask, _, _, _ = sat_problem._neg_mask_tuple 160 | 161 | if self._include_adaptors: 162 | function_message = self._projector(message_state[1]) 163 | function_message[:, 0] = F.sigmoid(function_message[:, 0]) 164 | function_message[:, 1] = torch.sign(function_message[:, 1]) 165 | else: 166 | function_message = message_state[1] 167 | 168 | external_force = torch.sign(torch.mm(variable_mask, function_message[:, 1].unsqueeze(1))) 169 | function_message = self.safe_log(1 - function_message[:, 0]).unsqueeze(1) 170 | 171 | edge_mask = torch.mm(function_mask_transpose, sat_problem._active_functions) 172 | function_message = function_message * edge_mask 173 | 174 | pos = torch.mm(p_variable_mask, function_message) + self.safe_log(1.0 - self._pi * (external_force == 1).float()) 175 | neg = torch.mm(n_variable_mask, function_message) + self.safe_log(1.0 - self._pi * (external_force == -1).float()) 176 | 177 | pos_neg_sum = pos + neg 178 | 179 | dont_care = torch.mm(variable_mask, function_message) + self.safe_log(1.0 - self._pi) 180 | 181 | bias = (2 * pos_neg_sum + dont_care) / 4.0 182 | pos = pos - bias 183 | neg = neg - bias 184 | pos_neg_sum = pos_neg_sum - bias 185 | dont_care = self.safe_exp(dont_care - bias) 186 | 187 | q_0 = self.safe_exp(pos) - self.safe_exp(pos_neg_sum) 188 | q_1 = self.safe_exp(neg) - self.safe_exp(pos_neg_sum) 189 | 190 | total = self.safe_log(q_0 + q_1 + dont_care) 191 | 192 | return self.safe_exp(self.safe_log(q_1) - total) - self.safe_exp(self.safe_log(q_0) - total), None 193 | 194 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication): 195 | 196 | edge_num = graph_map.size(1) * batch_replication 197 | 198 | if randomized: 199 | variable_state = torch.rand(edge_num, 3, dtype=torch.float32, device=self._device) 200 | # variable_state = variable_state / torch.sum(variable_state, 1).unsqueeze(1) 201 | function_state = torch.rand(edge_num, 2, dtype=torch.float32, device=self._device) 202 | function_state[:, 1] = 0 203 | else: 204 | variable_state = torch.ones(edge_num, 3, dtype=torch.float32, device=self._device) / 3.0 205 | function_state = 0.5 * torch.ones(edge_num, 2, dtype=torch.float32, device=self._device) 206 | function_state[:, 1] = 0 207 | 208 | return (variable_state, function_state) 209 | 210 | 211 | ############################################################### 212 | 213 | 214 | class ReinforcePredictor(nn.Module): 215 | "Implements the prediction mechanism for the Reinforce Algorithm." 216 | 217 | def __init__(self, device): 218 | super(ReinforcePredictor, self).__init__() 219 | self._device = device 220 | 221 | def forward(self, decimator_state, sat_problem, last_call=False): 222 | 223 | pred = decimator_state[1][:, 1].unsqueeze(1) 224 | pred = (torch.mm(sat_problem._graph_mask_tuple[0], pred) > 0).float() 225 | 226 | return pred, None 227 | -------------------------------------------------------------------------------- /src/pdp/nn/pdp_propagate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define various propagators for the PDP framework. 3 | """ 4 | 5 | # Copyright (c) Microsoft. All rights reserved. 6 | # Licensed under the MIT license. See LICENSE.md file 7 | # in the project root for full license information. 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from pdp.nn import util 14 | 15 | 16 | ############################################################### 17 | ### The Propagator Classes 18 | ############################################################### 19 | 20 | 21 | class NeuralMessagePasser(nn.Module): 22 | "Implements the neural propagator." 23 | 24 | def __init__(self, device, edge_dimension, decimator_dimension, meta_data_dimension, hidden_dimension, mem_hidden_dimension, 25 | mem_agg_hidden_dimension, agg_hidden_dimension, dropout): 26 | 27 | super(NeuralMessagePasser, self).__init__() 28 | self._device = device 29 | self._module_list = nn.ModuleList() 30 | self._drop_out = dropout 31 | 32 | self._variable_aggregator = util.MessageAggregator(device, decimator_dimension + edge_dimension + meta_data_dimension, 33 | hidden_dimension, mem_hidden_dimension, 34 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, include_self_message=False) 35 | self._function_aggregator = util.MessageAggregator(device, decimator_dimension + edge_dimension + meta_data_dimension, 36 | hidden_dimension, mem_hidden_dimension, 37 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, include_self_message=False) 38 | 39 | self._module_list.append(self._variable_aggregator) 40 | self._module_list.append(self._function_aggregator) 41 | 42 | self._hidden_dimension = hidden_dimension 43 | self._mem_hidden_dimension = mem_hidden_dimension 44 | self._agg_hidden_dimension = agg_hidden_dimension 45 | self._mem_agg_hidden_dimension = mem_agg_hidden_dimension 46 | 47 | def forward(self, init_state, decimator_state, sat_problem, is_training, active_mask=None): 48 | 49 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple 50 | b_variable_mask, _, _, _ = sat_problem._batch_mask_tuple 51 | 52 | if active_mask is not None: 53 | mask = torch.mm(b_variable_mask, active_mask.float()) 54 | mask = torch.mm(variable_mask_transpose, mask) 55 | else: 56 | edge_num = init_state[0].size(0) 57 | mask = torch.ones(edge_num, 1, device=self._device) 58 | 59 | if sat_problem._meta_data is not None: 60 | graph_feat = torch.mm(b_variable_mask, sat_problem._meta_data) 61 | graph_feat = torch.mm(variable_mask_transpose, graph_feat) 62 | 63 | if len(decimator_state) == 3: 64 | decimator_variable_state, decimator_function_state, edge_mask = decimator_state 65 | else: 66 | decimator_variable_state, decimator_function_state = decimator_state 67 | edge_mask = None 68 | 69 | variable_state, function_state = init_state 70 | 71 | ## variables --> functions 72 | decimator_variable_state = torch.cat((decimator_variable_state, sat_problem._edge_feature), 1) 73 | 74 | if sat_problem._meta_data is not None: 75 | decimator_variable_state = torch.cat((decimator_variable_state, graph_feat), 1) 76 | 77 | function_state = mask * self._variable_aggregator( 78 | decimator_variable_state, sat_problem._edge_feature, variable_mask, variable_mask_transpose, edge_mask) + (1 - mask) * function_state 79 | 80 | function_state = F.dropout(function_state, p=self._drop_out, training=is_training) 81 | 82 | ## functions --> variables 83 | decimator_function_state = torch.cat((decimator_function_state, sat_problem._edge_feature), 1) 84 | 85 | if sat_problem._meta_data is not None: 86 | decimator_function_state = torch.cat((decimator_function_state, graph_feat), 1) 87 | 88 | variable_state = mask * self._function_aggregator( 89 | decimator_function_state, sat_problem._edge_feature, function_mask, function_mask_transpose, edge_mask) + (1 - mask) * variable_state 90 | 91 | variable_state = F.dropout(variable_state, p=self._drop_out, training=is_training) 92 | 93 | del mask 94 | 95 | return variable_state, function_state 96 | 97 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication): 98 | 99 | edge_num = graph_map.size(1) * batch_replication 100 | 101 | if randomized: 102 | variable_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0 103 | function_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0 104 | else: 105 | variable_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) 106 | function_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) 107 | 108 | return (variable_state, function_state) 109 | 110 | 111 | ############################################################### 112 | 113 | 114 | class SurveyPropagator(nn.Module): 115 | "Implements the Survey Propagator (SP)." 116 | 117 | def __init__(self, device, decimator_dimension, include_adaptors=False, pi=0.0): 118 | 119 | super(SurveyPropagator, self).__init__() 120 | self._device = device 121 | self._function_message_dim = 3 122 | self._variable_message_dim = 2 123 | self._include_adaptors = include_adaptors 124 | self._eps = torch.tensor([1e-40], device=self._device) 125 | self._max_logit = torch.tensor([30.0], device=self._device) 126 | self._pi = torch.tensor([pi], dtype=torch.float32, device=device) 127 | 128 | if self._include_adaptors: 129 | self._variable_input_projector = nn.Linear(decimator_dimension, self._variable_message_dim, bias=False) 130 | self._function_input_projector = nn.Linear(decimator_dimension, 1, bias=False) 131 | self._module_list = nn.ModuleList([self._variable_input_projector, self._function_input_projector]) 132 | 133 | def safe_log(self, x): 134 | return torch.max(x, self._eps).log() 135 | 136 | def safe_exp(self, x): 137 | return torch.min(x, self._max_logit).exp() 138 | 139 | def forward(self, init_state, decimator_state, sat_problem, is_training, active_mask=None): 140 | 141 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple 142 | b_variable_mask, _, _, _ = sat_problem._batch_mask_tuple 143 | p_variable_mask, _, _, _ = sat_problem._pos_mask_tuple 144 | n_variable_mask, _, _, _ = sat_problem._neg_mask_tuple 145 | 146 | if active_mask is not None: 147 | mask = torch.mm(b_variable_mask, active_mask.float()) 148 | mask = torch.mm(variable_mask_transpose, mask) 149 | else: 150 | edge_num = init_state[0].size(0) 151 | mask = torch.ones(edge_num, 1, device=self._device) 152 | 153 | if len(decimator_state) == 3: 154 | decimator_variable_state, decimator_function_state, edge_mask = decimator_state 155 | else: 156 | decimator_variable_state, decimator_function_state = decimator_state 157 | edge_mask = None 158 | 159 | variable_state, function_state = init_state 160 | 161 | ## functions --> variables 162 | 163 | if self._include_adaptors: 164 | decimator_variable_state = F.logsigmoid(self._function_input_projector(decimator_variable_state)) 165 | else: 166 | decimator_variable_state = self.safe_log(decimator_variable_state[:, 0]).unsqueeze(1) 167 | 168 | if edge_mask is not None: 169 | decimator_variable_state = decimator_variable_state * edge_mask 170 | 171 | aggregated_variable_state = torch.mm(function_mask, decimator_variable_state) 172 | aggregated_variable_state = torch.mm(function_mask_transpose, aggregated_variable_state) 173 | aggregated_variable_state = aggregated_variable_state - decimator_variable_state 174 | 175 | function_state = mask * self.safe_exp(aggregated_variable_state) + (1 - mask) * function_state[:, 0].unsqueeze(1) 176 | 177 | ## functions --> variables 178 | 179 | if self._include_adaptors: 180 | decimator_function_state = self._variable_input_projector(decimator_function_state) 181 | decimator_function_state[:, 0] = F.sigmoid(decimator_function_state[:, 0]) 182 | decimator_function_state[:, 1] = torch.sign(decimator_function_state[:, 1]) 183 | 184 | external_force = decimator_function_state[:, 1].unsqueeze(1) 185 | decimator_function_state = self.safe_log(1 - decimator_function_state[:, 0]).unsqueeze(1) 186 | 187 | if edge_mask is not None: 188 | decimator_function_state = decimator_function_state * edge_mask 189 | 190 | pos = torch.mm(p_variable_mask, decimator_function_state) 191 | pos = torch.mm(variable_mask_transpose, pos) 192 | neg = torch.mm(n_variable_mask, decimator_function_state) 193 | neg = torch.mm(variable_mask_transpose, neg) 194 | 195 | same_sign = 0.5 * (1 + sat_problem._edge_feature) * pos + 0.5 * (1 - sat_problem._edge_feature) * neg 196 | same_sign = same_sign - decimator_function_state 197 | same_sign += self.safe_log(1.0 - self._pi * (external_force == sat_problem._edge_feature).float()) 198 | 199 | opposite_sign = 0.5 * (1 - sat_problem._edge_feature) * pos + 0.5 * (1 + sat_problem._edge_feature) * neg 200 | # The opposite sign edge aggregation does not include the current edge by definition, therefore no need for subtraction. 201 | opposite_sign += self.safe_log(1.0 - self._pi * (external_force == -sat_problem._edge_feature).float()) 202 | 203 | dont_care = same_sign + opposite_sign 204 | 205 | bias = 0 #(2 * dont_care) / 3.0 206 | same_sign = same_sign - bias 207 | opposite_sign = opposite_sign - bias 208 | dont_care = self.safe_exp(dont_care - bias) 209 | 210 | same_sign = self.safe_exp(same_sign) 211 | opposite_sign = self.safe_exp(opposite_sign) 212 | q_u = same_sign * (1 - opposite_sign) 213 | q_s = opposite_sign * (1 - same_sign) 214 | 215 | total = q_u + q_s + dont_care 216 | temp = torch.cat((q_u, q_s, dont_care), 1) / total 217 | 218 | variable_state = mask * temp + (1 - mask) * variable_state 219 | 220 | del mask 221 | return variable_state, torch.cat((function_state, external_force), 1) 222 | 223 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication): 224 | 225 | edge_num = graph_map.size(1) * batch_replication 226 | 227 | if randomized: 228 | variable_state = torch.rand(edge_num, self._function_message_dim, dtype=torch.float32, device=self._device) 229 | variable_state = variable_state / torch.sum(variable_state, 1).unsqueeze(1) 230 | function_state = torch.rand(edge_num, self._variable_message_dim, dtype=torch.float32, device=self._device) 231 | function_state[:, 1] = 0 232 | else: 233 | variable_state = torch.ones(edge_num, self._function_message_dim, dtype=torch.float32, device=self._device) / self._function_message_dim 234 | function_state = 0.5 * torch.ones(edge_num, self._variable_message_dim, dtype=torch.float32, device=self._device) 235 | function_state[:, 1] = 0 236 | 237 | return (variable_state, function_state) 238 | 239 | 240 | ############################################################### 241 | -------------------------------------------------------------------------------- /src/pdp/nn/pdp_decimate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define various decimators for the PDP framework. 3 | """ 4 | 5 | # Copyright (c) Microsoft. All rights reserved. 6 | # Licensed under the MIT license. See LICENSE.md file 7 | # in the project root for full license information. 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from pdp.nn import util 14 | 15 | 16 | ############################################################### 17 | ### The Decimator Classes 18 | ############################################################### 19 | 20 | 21 | class NeuralDecimator(nn.Module): 22 | "Implements a neural decimator." 23 | 24 | def __init__(self, device, message_dimension, meta_data_dimension, hidden_dimension, mem_hidden_dimension, 25 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, dropout): 26 | 27 | super(NeuralDecimator, self).__init__() 28 | self._device = device 29 | self._module_list = nn.ModuleList() 30 | self._drop_out = dropout 31 | 32 | if isinstance(message_dimension, tuple): 33 | variable_message_dim, function_message_dim = message_dimension 34 | else: 35 | variable_message_dim = message_dimension 36 | function_message_dim = message_dimension 37 | 38 | self._variable_rnn_cell = nn.GRUCell( 39 | variable_message_dim + edge_dimension + meta_data_dimension, hidden_dimension, bias=True) 40 | self._function_rnn_cell = nn.GRUCell( 41 | function_message_dim + edge_dimension + meta_data_dimension, hidden_dimension, bias=True) 42 | 43 | self._module_list.append(self._variable_rnn_cell) 44 | self._module_list.append(self._function_rnn_cell) 45 | 46 | self._hidden_dimension = hidden_dimension 47 | self._mem_hidden_dimension = mem_hidden_dimension 48 | self._agg_hidden_dimension = agg_hidden_dimension 49 | self._mem_agg_hidden_dimension = mem_agg_hidden_dimension 50 | 51 | def forward(self, init_state, message_state, sat_problem, is_training, active_mask=None): 52 | 53 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple 54 | b_variable_mask, b_variable_mask_transpose, b_function_mask, b_function_mask_transpose = sat_problem._batch_mask_tuple 55 | 56 | if active_mask is not None: 57 | mask = torch.mm(b_variable_mask, active_mask.float()) 58 | mask = torch.mm(variable_mask_transpose, mask) 59 | else: 60 | edge_num = init_state[0].size(0) 61 | mask = torch.ones(edge_num, 1, device=self._device) 62 | 63 | if sat_problem._meta_data is not None: 64 | graph_feat = torch.mm(b_variable_mask, sat_problem._meta_data) 65 | graph_feat = torch.mm(variable_mask_transpose, graph_feat) 66 | 67 | variable_state, function_state = message_state 68 | 69 | # Variable states 70 | variable_state = torch.cat((variable_state, sat_problem._edge_feature), 1) 71 | 72 | if sat_problem._meta_data is not None: 73 | variable_state = torch.cat((variable_state, graph_feat), 1) 74 | 75 | variable_state = mask * self._variable_rnn_cell(variable_state, init_state[0]) + (1 - mask) * init_state[0] 76 | 77 | # Function states 78 | function_state = torch.cat((function_state, sat_problem._edge_feature), 1) 79 | 80 | if sat_problem._meta_data is not None: 81 | function_state = torch.cat((function_state, graph_feat), 1) 82 | 83 | function_state = mask * self._function_rnn_cell(function_state, init_state[1]) + (1 - mask) * init_state[1] 84 | 85 | del mask 86 | 87 | return variable_state, function_state 88 | 89 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication): 90 | 91 | edge_num = graph_map.size(1) * batch_replication 92 | 93 | if randomized: 94 | variable_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0 95 | function_state = 2.0*torch.rand(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) - 1.0 96 | else: 97 | variable_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) 98 | function_state = torch.zeros(edge_num, self._hidden_dimension, dtype=torch.float32, device=self._device) 99 | 100 | return (variable_state, function_state) 101 | 102 | 103 | ############################################################### 104 | 105 | 106 | class SequentialDecimator(nn.Module): 107 | "Implements the general (greedy) sequential decimator." 108 | 109 | def __init__(self, device, message_dimension, scorer, tolerance, t_max): 110 | super(SequentialDecimator, self).__init__() 111 | 112 | self._device = device 113 | self._tolerance = tolerance 114 | self._scorer = scorer 115 | self._previous_function_state = None 116 | self._message_dimension = message_dimension 117 | self._constant = torch.ones(1, 1, device=self._device) 118 | self._t_max = t_max 119 | self._counters = None 120 | self._module_list = nn.ModuleList([self._scorer]) 121 | 122 | def forward(self, init_state, message_state, sat_problem, is_training, active_mask=None): 123 | 124 | if self._counters is None: 125 | self._counters = torch.zeros(sat_problem._batch_size, 1, device=self._device) 126 | 127 | if active_mask is not None: 128 | survey = message_state[1][:, 0].unsqueeze(1) 129 | survey = util.sparse_smooth_max(survey, sat_problem._graph_mask_tuple[0], self._device) 130 | survey = survey * sat_problem._active_variables 131 | survey = util.sparse_max(survey.squeeze(1), sat_problem._batch_mask_tuple[0], self._device).unsqueeze(1) 132 | 133 | active_mask[survey <= 1e-10] = 0 134 | 135 | if self._previous_function_state is not None and sat_problem._active_variables.sum() > 0: 136 | function_diff = (self._previous_function_state - message_state[1][:, 0]).abs().unsqueeze(1) 137 | 138 | if sat_problem._edge_mask is not None: 139 | function_diff = function_diff * sat_problem._edge_mask 140 | 141 | sum_diff = util.sparse_smooth_max(function_diff, sat_problem._graph_mask_tuple[0], self._device) 142 | sum_diff = sum_diff * sat_problem._active_variables 143 | sum_diff = util.sparse_max(sum_diff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device).unsqueeze(1) 144 | 145 | self._counters[sum_diff[:, 0] < self._tolerance, 0] = 0 146 | sum_diff = (sum_diff < self._tolerance).float() 147 | sum_diff[self._counters[:, 0] >= self._t_max, 0] = 1 148 | self._counters[self._counters[:, 0] >= self._t_max, 0] = 0 149 | 150 | sum_diff = torch.mm(sat_problem._batch_mask_tuple[0], sum_diff) 151 | 152 | if sum_diff.sum() > 0: 153 | score, _ = self._scorer(message_state, sat_problem) 154 | 155 | # Find the variable index with max score for each instance in the batch 156 | coeff = score.abs() * sat_problem._active_variables * sum_diff 157 | 158 | if coeff.sum() > 0: 159 | max_ind = util.sparse_argmax(coeff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device) 160 | norm = torch.mm(sat_problem._batch_mask_tuple[1], coeff) 161 | 162 | if active_mask is not None: 163 | max_ind = max_ind[(active_mask * (norm != 0)).squeeze(1)] 164 | else: 165 | max_ind = max_ind[norm.squeeze(1) != 0] 166 | 167 | if max_ind.size()[0] > 0: 168 | assignment = torch.zeros(sat_problem._variable_num, 1, device=self._device) 169 | assignment[max_ind, 0] = score.sign()[max_ind, 0] 170 | 171 | sat_problem.set_variables(assignment) 172 | 173 | self._counters = self._counters + 1 174 | 175 | self._previous_function_state = message_state[1][:, 0] 176 | 177 | return message_state 178 | 179 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication): 180 | self._previous_function_state = None 181 | self._counters = None 182 | 183 | return self._scorer.get_init_state(graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication) 184 | 185 | 186 | ############################################################### 187 | 188 | 189 | class ReinforceDecimator(nn.Module): 190 | "Implements the (distributed) Reinforce decimator." 191 | 192 | def __init__(self, device, scorer, decimation_probability=0.5): 193 | super(ReinforceDecimator, self).__init__() 194 | 195 | self._device = device 196 | self._scorer = scorer 197 | self._decimation_probability = decimation_probability 198 | self._function_message_dim = 3 199 | self._variable_message_dim = 2 200 | self._previous_function_state = None 201 | 202 | def forward(self, init_state, message_state, sat_problem, is_training, active_mask=None): 203 | variable_state, function_state = message_state 204 | 205 | if active_mask is not None and self._previous_function_state is not None and sat_problem._active_variables.sum() > 0: 206 | function_diff = (self._previous_function_state - message_state[1][:, 0]).abs().unsqueeze(1) 207 | 208 | if sat_problem._edge_mask is not None: 209 | function_diff = function_diff * sat_problem._edge_mask 210 | 211 | sum_diff = util.sparse_smooth_max(function_diff, sat_problem._graph_mask_tuple[0], self._device) 212 | sum_diff = sum_diff * sat_problem._active_variables 213 | sum_diff = util.sparse_max(sum_diff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device) 214 | active_mask[sum_diff <= 0.01, 0] = 0 215 | 216 | self._previous_function_state = message_state[1][:, 0] 217 | 218 | if torch.rand(1, device=self._device) < self._decimation_probability: 219 | variable_mask, variable_mask_transpose, function_mask, function_mask_transpose = sat_problem._graph_mask_tuple 220 | b_variable_mask, b_variable_mask_transpose, b_function_mask, b_function_mask_transpose = sat_problem._batch_mask_tuple 221 | 222 | if active_mask is not None: 223 | mask = torch.mm(b_variable_mask, active_mask.float()) 224 | mask = torch.mm(variable_mask_transpose, mask) 225 | else: 226 | mask = torch.ones(sat_problem._edge_num, 1, device=self._device) 227 | 228 | mask = mask.squeeze(1) 229 | score, _ = self._scorer(message_state, sat_problem) 230 | score = torch.mm(variable_mask_transpose, torch.sign(score)).squeeze(1) 231 | 232 | function_state[:, 1] = mask * score + (1 - mask) * function_state[:, 1] 233 | 234 | return variable_state, function_state 235 | 236 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication): 237 | 238 | edge_num = graph_map.size(1) * batch_replication 239 | self._previous_function_state = None 240 | 241 | if randomized: 242 | variable_state = torch.rand(edge_num, self._function_message_dim, dtype=torch.float32, device=self._device) 243 | function_state = torch.rand(edge_num, self._variable_message_dim, dtype=torch.float32, device=self._device) 244 | function_state[:, 1] = 0 245 | else: 246 | variable_state = torch.ones(edge_num, self._function_message_dim, dtype=torch.float32, device=self._device) / self._function_message_dim 247 | function_state = 0.5 * torch.ones(edge_num, self._variable_message_dim, dtype=torch.float32, device=self._device) 248 | function_state[:, 1] = 0 249 | 250 | return (variable_state, function_state) 251 | -------------------------------------------------------------------------------- /src/pdp/nn/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 3 | 4 | # util.py : Defines the utility functionalities for the PDP framework. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class MessageAggregator(nn.Module): 12 | "Implements a deep set function for message aggregation at variable and function nodes." 13 | 14 | def __init__(self, device, input_dimension, output_dimension, mem_hidden_dimension, 15 | mem_agg_hidden_dimension, agg_hidden_dimension, feature_dimension, include_self_message): 16 | 17 | super(MessageAggregator, self).__init__() 18 | self._device = device 19 | self._include_self_message = include_self_message 20 | self._module_list = nn.ModuleList() 21 | 22 | if mem_hidden_dimension > 0 and mem_agg_hidden_dimension > 0: 23 | 24 | self._W1_m = nn.Linear( 25 | input_dimension, mem_hidden_dimension, bias=True) # .to(self._device) 26 | 27 | self._W2_m = nn.Linear( 28 | mem_hidden_dimension, mem_agg_hidden_dimension, bias=False) # .to(self._device) 29 | 30 | self._module_list.append(self._W1_m) 31 | self._module_list.append(self._W2_m) 32 | 33 | if agg_hidden_dimension > 0 and mem_agg_hidden_dimension > 0: 34 | 35 | if mem_hidden_dimension <= 0: 36 | mem_agg_hidden_dimension = input_dimension 37 | 38 | self._W1_a = nn.Linear( 39 | mem_agg_hidden_dimension + feature_dimension, agg_hidden_dimension, bias=True) # .to(self._device) 40 | 41 | self._W2_a = nn.Linear( 42 | agg_hidden_dimension, output_dimension, bias=False) # .to(self._device) 43 | 44 | self._module_list.append(self._W1_a) 45 | self._module_list.append(self._W2_a) 46 | 47 | self._agg_hidden_dimension = agg_hidden_dimension 48 | self._mem_hidden_dimension = mem_hidden_dimension 49 | self._mem_agg_hidden_dimension = mem_agg_hidden_dimension 50 | 51 | def forward(self, state, feature, mask, mask_transpose, edge_mask=None): 52 | 53 | # Apply the pre-aggregation transform 54 | if self._mem_hidden_dimension > 0 and self._mem_agg_hidden_dimension > 0: 55 | state = F.logsigmoid(self._W2_m(F.logsigmoid(self._W1_m(state)))) 56 | 57 | if edge_mask is not None: 58 | state = state * edge_mask 59 | 60 | aggregated_state = torch.mm(mask, state) 61 | 62 | if not self._include_self_message: 63 | aggregated_state = torch.mm(mask_transpose, aggregated_state) 64 | 65 | if edge_mask is not None: 66 | aggregated_state = aggregated_state - state * edge_mask 67 | else: 68 | aggregated_state = aggregated_state - state 69 | 70 | if feature is not None: 71 | aggregated_state = torch.cat((aggregated_state, feature), 1) 72 | 73 | # Apply the post-aggregation transform 74 | if self._agg_hidden_dimension > 0 and self._mem_agg_hidden_dimension > 0: 75 | aggregated_state = F.logsigmoid(self._W2_a(F.logsigmoid(self._W1_a(aggregated_state)))) 76 | 77 | return aggregated_state 78 | 79 | 80 | ############################################################### 81 | 82 | 83 | class MultiLayerPerceptron(nn.Module): 84 | "Implements a standard fully-connected, multi-layer perceptron." 85 | 86 | def __init__(self, device, layer_dims): 87 | 88 | super(MultiLayerPerceptron, self).__init__() 89 | self._device = device 90 | self._module_list = nn.ModuleList() 91 | self._layer_num = len(layer_dims) - 1 92 | 93 | self._inner_layers = [] 94 | for i in range(self._layer_num - 1): 95 | self._inner_layers += [nn.Linear(layer_dims[i], layer_dims[i + 1])] 96 | self._module_list.append(self._inner_layers[i]) 97 | 98 | self._output_layer = nn.Linear(layer_dims[self._layer_num - 1], layer_dims[self._layer_num], bias=False) 99 | self._module_list.append(self._output_layer) 100 | 101 | def forward(self, inp): 102 | x = inp 103 | 104 | for layer in self._inner_layers: 105 | x = F.relu(layer(x)) 106 | 107 | return F.sigmoid(self._output_layer(x)) 108 | 109 | 110 | ########################################################################################################################## 111 | 112 | 113 | class SatLossEvaluator(nn.Module): 114 | "Implements a module to calculate the energy (i.e. the loss) for the current prediction." 115 | 116 | def __init__(self, alpha, device): 117 | super(SatLossEvaluator, self).__init__() 118 | self._alpha = alpha 119 | self._device = device 120 | 121 | @staticmethod 122 | def safe_log(x, eps): 123 | return torch.max(x, eps).log() 124 | 125 | @staticmethod 126 | def compute_masks(graph_map, batch_variable_map, batch_function_map, edge_feature, device): 127 | edge_num = graph_map.size(1) 128 | variable_num = batch_variable_map.size(0) 129 | function_num = batch_function_map.size(0) 130 | all_ones = torch.ones(edge_num, device=device) 131 | edge_num_range = torch.arange(edge_num, dtype=torch.int64, device=device) 132 | 133 | variable_sparse_ind = torch.stack([edge_num_range, graph_map[0, :].long()]) 134 | function_sparse_ind = torch.stack([graph_map[1, :].long(), edge_num_range]) 135 | 136 | if device.type == 'cuda': 137 | variable_mask = torch.cuda.sparse.FloatTensor(variable_sparse_ind, edge_feature.squeeze(1), 138 | torch.Size([edge_num, variable_num]), device=device) 139 | function_mask = torch.cuda.sparse.FloatTensor(function_sparse_ind, all_ones, 140 | torch.Size([function_num, edge_num]), device=device) 141 | else: 142 | variable_mask = torch.sparse.FloatTensor(variable_sparse_ind, edge_feature.squeeze(1), 143 | torch.Size([edge_num, variable_num]), device=device) 144 | function_mask = torch.sparse.FloatTensor(function_sparse_ind, all_ones, 145 | torch.Size([function_num, edge_num]), device=device) 146 | 147 | return variable_mask, function_mask 148 | 149 | @staticmethod 150 | def compute_batch_mask(batch_variable_map, batch_function_map, device): 151 | variable_num = batch_variable_map.size()[0] 152 | function_num = batch_function_map.size()[0] 153 | variable_all_ones = torch.ones(variable_num, device=device) 154 | function_all_ones = torch.ones(function_num, device=device) 155 | variable_range = torch.arange(variable_num, dtype=torch.int64, device=device) 156 | function_range = torch.arange(function_num, dtype=torch.int64, device=device) 157 | batch_size = (batch_variable_map.max() + 1).long().item() 158 | 159 | variable_sparse_ind = torch.stack([variable_range, batch_variable_map.long()]) 160 | function_sparse_ind = torch.stack([function_range, batch_function_map.long()]) 161 | 162 | if device.type == 'cuda': 163 | variable_mask = torch.cuda.sparse.FloatTensor(variable_sparse_ind, variable_all_ones, 164 | torch.Size([variable_num, batch_size]), device=device) 165 | function_mask = torch.cuda.sparse.FloatTensor(function_sparse_ind, function_all_ones, 166 | torch.Size([function_num, batch_size]), device=device) 167 | else: 168 | variable_mask = torch.sparse.FloatTensor(variable_sparse_ind, variable_all_ones, 169 | torch.Size([variable_num, batch_size]), device=device) 170 | function_mask = torch.sparse.FloatTensor(function_sparse_ind, function_all_ones, 171 | torch.Size([function_num, batch_size]), device=device) 172 | 173 | variable_mask_transpose = variable_mask.transpose(0, 1) 174 | function_mask_transpose = function_mask.transpose(0, 1) 175 | 176 | return (variable_mask, variable_mask_transpose, function_mask, function_mask_transpose) 177 | 178 | def forward(self, variable_prediction, label, graph_map, batch_variable_map, 179 | batch_function_map, edge_feature, meta_data, global_step, eps, max_coeff, loss_sharpness): 180 | 181 | coeff = torch.min(global_step.pow(self._alpha), torch.tensor([max_coeff], device=self._device)) 182 | 183 | signed_variable_mask_transpose, function_mask = \ 184 | SatLossEvaluator.compute_masks(graph_map, batch_variable_map, batch_function_map, 185 | edge_feature, self._device) 186 | 187 | edge_values = torch.mm(signed_variable_mask_transpose, variable_prediction) 188 | edge_values = edge_values + (1 - edge_feature) / 2 189 | 190 | weights = (coeff * edge_values).exp() 191 | 192 | nominator = torch.mm(function_mask, weights * edge_values) 193 | denominator = torch.mm(function_mask, weights) 194 | 195 | clause_value = denominator / torch.max(nominator, eps) 196 | clause_value = 1 + (clause_value - 1).pow(loss_sharpness) 197 | return torch.mean(SatLossEvaluator.safe_log(clause_value, eps)) 198 | 199 | 200 | ########################################################################################################################## 201 | 202 | 203 | class SatCNFEvaluator(nn.Module): 204 | "Implements a module to evaluate the current prediction." 205 | 206 | def __init__(self, device): 207 | super(SatCNFEvaluator, self).__init__() 208 | self._device = device 209 | 210 | def forward(self, variable_prediction, graph_map, batch_variable_map, 211 | batch_function_map, edge_feature, meta_data): 212 | 213 | variable_num = batch_variable_map.size(0) 214 | function_num = batch_function_map.size(0) 215 | batch_size = (batch_variable_map.max() + 1).item() 216 | all_ones = torch.ones(function_num, 1, device=self._device) 217 | 218 | signed_variable_mask_transpose, function_mask = \ 219 | SatLossEvaluator.compute_masks(graph_map, batch_variable_map, batch_function_map, 220 | edge_feature, self._device) 221 | 222 | b_variable_mask, b_variable_mask_transpose, b_function_mask, b_function_mask_transpose = \ 223 | SatLossEvaluator.compute_batch_mask( 224 | batch_variable_map, batch_function_map, self._device) 225 | 226 | edge_values = torch.mm(signed_variable_mask_transpose, variable_prediction) 227 | edge_values = edge_values + (1 - edge_feature) / 2 228 | edge_values = (edge_values > 0.5).float() 229 | 230 | clause_values = torch.mm(function_mask, edge_values) 231 | clause_values = (clause_values > 0).float() 232 | 233 | max_sat = torch.mm(b_function_mask_transpose, all_ones) 234 | batch_values = torch.mm(b_function_mask_transpose, clause_values) 235 | 236 | return (max_sat == batch_values).float(), max_sat - batch_values 237 | 238 | 239 | ########################################################################################################################## 240 | 241 | 242 | class PerceptronTanh(nn.Module): 243 | "Implements a 1-layer perceptron with Tanh activaton." 244 | 245 | def __init__(self, input_dimension, hidden_dimension, output_dimension): 246 | super(PerceptronTanh, self).__init__() 247 | self._layer1 = nn.Linear(input_dimension, hidden_dimension) 248 | self._layer2 = nn.Linear(hidden_dimension, output_dimension, bias=False) 249 | 250 | def forward(self, inp): 251 | return F.tanh(self._layer2(F.relu(self._layer1(inp)))) 252 | 253 | 254 | ########################################################################################################################## 255 | 256 | 257 | def sparse_argmax(x, mask, device): 258 | "Implements the exact, memory-inefficient argmax operation for a row vector input." 259 | 260 | if device.type == 'cuda': 261 | dense_mat = torch.cuda.sparse.FloatTensor(mask._indices(), x - x.min() + 1, mask.size(), device=device).to_dense() 262 | else: 263 | dense_mat = torch.sparse.FloatTensor(mask._indices(), x - x.min() + 1, mask.size(), device=device).to_dense() 264 | 265 | return torch.argmax(dense_mat, 0) 266 | 267 | def sparse_max(x, mask, device): 268 | "Implements the exact, memory-inefficient max operation for a row vector input." 269 | 270 | if device.type == 'cuda': 271 | dense_mat = torch.cuda.sparse.FloatTensor(mask._indices(), x - x.min() + 1, mask.size(), device=device).to_dense() 272 | else: 273 | dense_mat = torch.sparse.FloatTensor(mask._indices(), x - x.min() + 1, mask.size(), device=device).to_dense() 274 | 275 | return torch.max(dense_mat, 0)[0] + x.min() - 1 276 | 277 | def safe_exp(x, device): 278 | "Implements safe exp operation." 279 | 280 | return torch.min(x, torch.tensor([30.0], device=device)).exp() 281 | 282 | def sparse_smooth_max(x, mask, device, alpha=30): 283 | "Implements the approximate, memory-efficient max operation for a row vector input." 284 | 285 | coeff = safe_exp(alpha * x, device) 286 | return torch.mm(mask, x * coeff) / torch.max(torch.mm(mask, coeff), torch.ones(1, device=device)) 287 | -------------------------------------------------------------------------------- /src/pdp/generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Various types of CNF generators for generating real-time CNF instances. 4 | """ 5 | 6 | # Copyright (c) Microsoft. All rights reserved. 7 | # Licensed under the MIT license. See LICENSE.md file 8 | # in the project root for full license information. 9 | 10 | import numpy as np 11 | import argparse 12 | import os, sys 13 | 14 | 15 | def is_sat(_var_num, iclause_list): 16 | ## Note: Invoke your SAT solver of choice here for generating labeled data. 17 | return False 18 | 19 | ########################################################################## 20 | 21 | 22 | class CNFGeneratorBase(object): 23 | "The base class for all CNF generators." 24 | 25 | def __init__(self, min_n, max_n, min_alpha, max_alpha, alpha_resolution=10): 26 | self._min_n = min_n 27 | self._max_n = max_n 28 | self._min_alpha = min_alpha 29 | self._max_alpha = max_alpha 30 | self._alpha = min_alpha 31 | self._alpha_inc = (max_alpha - min_alpha) / alpha_resolution 32 | self._alpha_resolution = alpha_resolution 33 | 34 | def generate(self): 35 | "Generates unlabeled CNF instances." 36 | pass 37 | 38 | def generate_complete(self): 39 | "Generates labeled CNF instances." 40 | pass 41 | 42 | def _to_json(self, n, m, graph_map, edge_feature, label): 43 | return [[n, m], list(((graph_map[0, :] + 1) * edge_feature).astype(int)), list(graph_map[1, :] + 1), label] 44 | 45 | def _to_dimacs(self, n, m, clause_list): 46 | body = '' 47 | 48 | for clause in clause_list: 49 | body += (str(clause)[1:-1].replace(',', '')) + ' 0\n' 50 | 51 | return 'p cnf ' + str(n) + ' ' + str(m) + '\n' + body 52 | 53 | def generate_dataset(self, size, output_dimacs_path, json_output, name, sat_only=True): 54 | 55 | max_trial = 50 56 | 57 | if not os.path.exists(output_dimacs_path): 58 | os.makedirs(output_dimacs_path) 59 | 60 | if not os.path.exists(json_output): 61 | os.makedirs(json_output) 62 | 63 | output_dimacs_path = os.path.join(output_dimacs_path, name) 64 | json_output = os.path.join(json_output, name) 65 | 66 | for j in range(self._alpha_resolution): 67 | postfix = '_' + str(j) + '_' + str(self._alpha) + '_' + str(self._alpha + self._alpha_inc) 68 | 69 | if not os.path.exists(output_dimacs_path + postfix): 70 | os.makedirs(output_dimacs_path + postfix) 71 | 72 | with open(json_output + postfix + ".json", 'w') as f: 73 | for i in range(size): 74 | flag = False 75 | for _ in range(max_trial): 76 | n, m, graph_map, edge_feature, _, label, clause_list = self.generate_complete() 77 | 78 | if (not sat_only) or (label == 1): 79 | flag = True 80 | break 81 | 82 | if flag: 83 | f.write(str(self._to_json(n, m, graph_map, edge_feature, label)).replace("'", '"') + '\n') 84 | 85 | dimacs_file_name = 'dimacs_' + str(i) + '_sat=' + str(label) + '.DIMACS' 86 | with open(os.path.join(output_dimacs_path + postfix, dimacs_file_name), 'w') as g: 87 | g.write(self._to_dimacs(n, m, clause_list) + '\n') 88 | 89 | sys.stdout.write("Dataset {:2d}/{:2d}: {:.2f} % complete \r".format(j + 1, self._alpha_resolution, 100*float(i+1) / size)) 90 | sys.stdout.flush() 91 | 92 | self._alpha += self._alpha_inc 93 | 94 | 95 | ############################################################################################### 96 | 97 | 98 | class UniformCNFGenerator(CNFGeneratorBase): 99 | "Implements the uniformly random CNF generator." 100 | 101 | def __init__(self, min_n, max_n, min_k, max_k, min_alpha, max_alpha, alpha_resolution=10): 102 | 103 | super(UniformCNFGenerator, self).__init__(min_n, max_n, min_alpha, max_alpha, alpha_resolution) 104 | self._min_k = min_k 105 | self._max_k = max_k 106 | 107 | def generate(self): 108 | n = np.random.randint(self._min_n, self._max_n + 1) 109 | alpha = np.random.uniform(self._min_alpha, self._max_alpha) 110 | m = int(n * alpha) 111 | 112 | clause_length = [np.random.randint(self._min_k, min(self._max_k, n-1) + 1) for _ in range(m)] 113 | edge_num = np.sum(clause_length) 114 | 115 | graph_map = np.zeros((2, edge_num), dtype=np.int32) 116 | 117 | ind = 0 118 | for i in range(m): 119 | graph_map[0, ind:(ind+clause_length[i])] = np.random.choice(n, clause_length[i], replace=False) 120 | graph_map[1, ind:(ind+clause_length[i])] = i 121 | ind += clause_length[i] 122 | 123 | edge_feature = 2.0 * np.random.choice(2, edge_num) - 1 124 | 125 | return n, m, graph_map, edge_feature, None, -1.0, [] 126 | 127 | def generate_complete(self): 128 | n = np.random.randint(self._min_n, self._max_n + 1) 129 | alpha = np.random.uniform(self._alpha, self._alpha + self._alpha_inc) 130 | m = int(n * alpha) 131 | max_trial = 10 132 | 133 | clause_set = set() 134 | clause_list = [] 135 | graph_map = np.zeros((2, 0), dtype=np.int32) 136 | edge_features = np.zeros(0) 137 | 138 | i = -1 139 | for _ in range(m): 140 | for _ in range(max_trial): 141 | clause_length = np.random.randint(self._min_k, min(self._max_k, n-1) + 1) 142 | literals = np.sort(np.random.choice(n, clause_length, replace=False)) 143 | edge_feature = 2.0 * np.random.choice(2, clause_length) - 1 144 | iclause = list(((literals + 1) * edge_feature).astype(int)) 145 | 146 | if str(iclause) not in clause_set: 147 | i += 1 148 | break 149 | 150 | clause_set.add(str(iclause)) 151 | clause_list += [iclause] 152 | 153 | graph_map = np.concatenate((graph_map, np.stack((literals, i * np.ones(clause_length, dtype=np.int32)))), 1) 154 | edge_features = np.concatenate((edge_features, edge_feature)) 155 | 156 | label = is_sat(n, clause_list) 157 | return n, m, graph_map, edge_features, None, label, clause_list 158 | 159 | 160 | ############################################################################################### 161 | 162 | 163 | class ModularCNFGenerator(CNFGeneratorBase): 164 | "Implements the modular random CNF generator according to the Community Attachment model (https://www.iiia.csic.es/sites/default/files/aij16.pdf)" 165 | 166 | def __init__(self, k, min_n, max_n, min_q, max_q, min_c, max_c, min_alpha, max_alpha, alpha_resolution=10): 167 | 168 | super(ModularCNFGenerator, self).__init__(min_n, max_n, min_alpha, max_alpha, alpha_resolution) 169 | self._k = k 170 | self._min_c = min_c 171 | self._max_c = max_c 172 | self._min_q = min_q 173 | self._max_q = max_q 174 | 175 | def generate(self): 176 | n = np.random.randint(self._min_n, self._max_n + 1) 177 | alpha = np.random.uniform(self._min_alpha, self._max_alpha) 178 | m = int(n * alpha) 179 | 180 | q = np.random.uniform(self._min_q, self._max_q) 181 | c = np.random.randint(self._min_c, self._max_c + 1) 182 | c = max(1, min(c, int(n / self._k) - 1)) 183 | size = int(n / c) 184 | community_size = size * np.ones(c, dtype=np.int32) 185 | community_size[c - 1] += (n - np.sum(community_size)) 186 | 187 | p = q + 1.0 / c 188 | edge_num = m * self._k 189 | 190 | graph_map = np.zeros((2, edge_num), dtype=np.int32) 191 | index = np.random.permutation(n) 192 | 193 | ind = 0 194 | for i in range(m): 195 | coin = np.random.uniform() 196 | if coin <= p: # Pick from the same community 197 | community = np.random.randint(0, c) 198 | graph_map[0, ind:(ind + self._k)] = index[np.random.choice(range(size*community, size*community + community_size[community]), self._k, replace=False)] 199 | else: # Pick from different communities 200 | if c >= self._k: 201 | communities = np.random.choice(c, self._k, replace=False) 202 | temp = np.random.uniform(size = self._k) 203 | inner_offset = (temp * community_size[communities]).astype(int) 204 | graph_map[0, ind:(ind + self._k)] = index[size*communities + inner_offset] 205 | else: 206 | graph_map[0, ind:(ind + self._k)] = np.random.choice(n, self._k, replace=False) 207 | 208 | graph_map[1, ind:(ind+self._k)] = i 209 | ind += self._k 210 | 211 | edge_feature = 2.0 * np.random.choice(2, edge_num) - 1 212 | 213 | return n, m, graph_map, edge_feature, None, -1.0, [] 214 | 215 | def generate_complete(self): 216 | n = np.random.randint(self._min_n, self._max_n + 1) 217 | alpha = np.random.uniform(self._alpha, self._alpha + self._alpha_inc) 218 | m = int(n * alpha) 219 | max_trial = 10 220 | 221 | q = np.random.uniform(self._min_q, self._max_q) 222 | c = np.random.randint(self._min_c, self._max_c + 1) 223 | c = max(self._k + 1, min(c, int(n / self._k) - 1)) 224 | size = int(n / c) 225 | community_size = size * np.ones(c, dtype=np.int32) 226 | community_size[c - 1] += (n - np.sum(community_size)) 227 | 228 | p = q + 1.0 / c 229 | edge_num = m * self._k 230 | 231 | index = np.random.permutation(n) 232 | clause_set = set() 233 | clause_list = [] 234 | graph_map = np.zeros((2, 0), dtype=np.int32) 235 | edge_features = np.zeros(0) 236 | 237 | i = -1 238 | for _ in range(m): 239 | for _ in range(max_trial): 240 | coin = np.random.uniform() 241 | if coin <= p: # Pick from the same community 242 | community = np.random.randint(0, c) 243 | literals = np.sort(index[np.random.choice(range(size*community, size*community + community_size[community]), self._k, replace=False)]) 244 | else: # Pick from different communities 245 | communities = np.random.choice(c, self._k, replace=False) 246 | temp = np.random.uniform(size = self._k) 247 | inner_offset = (temp * community_size[communities]).astype(int) 248 | literals = np.sort(index[size*communities + inner_offset]) 249 | 250 | edge_feature = 2.0 * np.random.choice(2, self._k) - 1 251 | iclause = list(((literals + 1) * edge_feature).astype(int)) 252 | 253 | if str(iclause) not in clause_set: 254 | i += 1 255 | break 256 | 257 | clause_set.add(str(iclause)) 258 | clause_list += [iclause] 259 | 260 | graph_map = np.concatenate((graph_map, np.stack((literals, i * np.ones(self._k, dtype=np.int32)))), 1) 261 | edge_features = np.concatenate((edge_features, edge_feature)) 262 | 263 | label = is_sat(n, clause_list) 264 | return n, m, graph_map, edge_features, None, label, clause_list 265 | 266 | 267 | ############################################################################################### 268 | 269 | 270 | class VariableModularCNFGenerator(CNFGeneratorBase): 271 | "Implements a variation of the Community Attachment model with variable sized clauses." 272 | 273 | def __init__(self, min_k, max_k, min_n, max_n, min_q, max_q, min_c, max_c, min_alpha, max_alpha, alpha_resolution=10): 274 | 275 | super(VariableModularCNFGenerator, self).__init__(min_n, max_n, min_alpha, max_alpha, alpha_resolution) 276 | self._min_k = min_k 277 | self._max_k = max_k 278 | self._min_c = min_c 279 | self._max_c = max_c 280 | self._min_q = min_q 281 | self._max_q = max_q 282 | 283 | def generate(self): 284 | n = np.random.randint(self._min_n, self._max_n + 1) 285 | alpha = np.random.uniform(self._min_alpha, self._max_alpha) 286 | m = int(n * alpha) 287 | 288 | q = np.random.uniform(self._min_q, self._max_q) 289 | c = np.random.randint(self._min_c, self._max_c + 1) 290 | c = max(1, min(c, n)) 291 | size = int(n / c) 292 | community_size = size * np.ones(c, dtype=np.int32) 293 | community_size[c - 1] += (n - np.sum(community_size)) 294 | 295 | p = q + 1.0 / c 296 | clause_length = [np.random.randint(min(self._min_k, size), min(self._max_k, n-1, size) + 1) for _ in range(m)] 297 | edge_num = np.sum(clause_length) 298 | 299 | graph_map = np.zeros((2, edge_num), dtype=np.int32) 300 | index = np.random.permutation(n) 301 | 302 | ind = 0 303 | for i in range(m): 304 | coin = np.random.uniform() 305 | if coin <= p: # Pick from the same community 306 | community = np.random.randint(0, c) 307 | graph_map[0, ind:(ind + clause_length[i])] = index[np.random.choice(range(size*community, size*community + community_size[community]), clause_length[i], replace=False)] 308 | else: # Pick from different communities 309 | if c >= clause_length[i]: 310 | communities = np.random.choice(c, clause_length[i], replace=False) 311 | temp = np.random.uniform(size = clause_length[i]) 312 | inner_offset = (temp * community_size[communities]).astype(int) 313 | graph_map[0, ind:(ind + clause_length[i])] = index[size*communities + inner_offset] 314 | else: 315 | graph_map[0, ind:(ind + clause_length[i])] = np.random.choice(n, clause_length[i], replace=False) 316 | 317 | graph_map[1, ind:(ind+clause_length[i])] = i 318 | ind += clause_length[i] 319 | 320 | edge_feature = 2.0 * np.random.choice(2, edge_num) - 1 321 | 322 | return n, m, graph_map, edge_feature, None, -1.0, [] 323 | 324 | def generate_complete(self): 325 | n = np.random.randint(self._min_n, self._max_n + 1) 326 | alpha = np.random.uniform(self._alpha, self._alpha + self._alpha_inc) 327 | m = int(n * alpha) 328 | max_trial = 10 329 | 330 | q = np.random.uniform(self._min_q, self._max_q) 331 | c = np.random.randint(self._min_c, self._max_c + 1) 332 | c = max(self._k + 1, min(c, int(n / self._k) - 1)) 333 | size = int(n / c) 334 | community_size = size * np.ones(c, dtype=np.int32) 335 | community_size[c - 1] += (n - np.sum(community_size)) 336 | 337 | p = q + 1.0 / c 338 | edge_num = m * self._k 339 | 340 | index = np.random.permutation(n) 341 | clause_set = set() 342 | clause_list = [] 343 | graph_map = np.zeros((2, 0), dtype=np.int32) 344 | edge_features = np.zeros(0) 345 | 346 | i = -1 347 | for _ in range(m): 348 | for _ in range(max_trial): 349 | clause_length = np.random.randint(min(self._min_k, size), min(self._max_k, n-1, size) + 1) 350 | coin = np.random.uniform() 351 | if coin <= p: # Pick from the same community 352 | community = np.random.randint(0, c) 353 | literals = np.sort(index[np.random.choice(range(size*community, size*community + community_size[community]), clause_length, replace=False)]) 354 | else: # Pick from different communities 355 | if c >= clause_length: 356 | communities = np.random.choice(c, clause_length, replace=False) 357 | temp = np.random.uniform(size = clause_length) 358 | inner_offset = (temp * community_size[communities]).astype(int) 359 | literals = np.sort(index[size*communities + inner_offset]) 360 | else: 361 | literals = np.random.choice(n, clause_length, replace=False) 362 | 363 | edge_feature = 2.0 * np.random.choice(2, clause_length) - 1 364 | iclause = list(((literals + 1) * edge_feature).astype(int)) 365 | 366 | if str(iclause) not in clause_set: 367 | i += 1 368 | break 369 | 370 | clause_set.add(str(iclause)) 371 | clause_list += [iclause] 372 | 373 | graph_map = np.concatenate((graph_map, np.stack((literals, i * np.ones(clause_length, dtype=np.int32)))), 1) 374 | edge_features = np.concatenate((edge_features, edge_feature)) 375 | 376 | label = is_sat(n, clause_list) 377 | return n, m, graph_map, edge_features, None, label, clause_list 378 | 379 | 380 | ############################################################################################### 381 | 382 | 383 | if __name__ == '__main__': 384 | 385 | parser = argparse.ArgumentParser() 386 | parser.add_argument('out_dir', action='store', type=str) 387 | parser.add_argument('out_json', action='store', type=str) 388 | parser.add_argument('name', action='store', type=str) 389 | parser.add_argument('size', action='store', type=int) 390 | parser.add_argument('method', action='store', type=str) 391 | 392 | parser.add_argument('--min_n', action='store', dest='min_n', type=int, default=40) 393 | parser.add_argument('--max_n', action='store', dest='max_n', type=int, default=40) 394 | 395 | parser.add_argument('--min_c', action='store', dest='min_c', type=int, default=10) 396 | parser.add_argument('--max_c', action='store', dest='max_c', type=int, default=40) 397 | 398 | parser.add_argument('--min_q', action='store', dest='min_q', type=float, default=0.3) 399 | parser.add_argument('--max_q', action='store', dest='max_q', type=float, default=0.9) 400 | 401 | parser.add_argument('--min_k', action='store', dest='min_k', type=int, default=3) 402 | parser.add_argument('--max_k', action='store', dest='max_k', type=int, default=5) 403 | 404 | parser.add_argument('--min_a', action='store', dest='min_a', type=float, default=2) 405 | parser.add_argument('--max_a', action='store', dest='max_a', type=float, default=10) 406 | parser.add_argument('--res', action='store', dest='res', type=int, default=5) 407 | 408 | parser.add_argument('-s', '--sat_only', help='Include SAT examples only', required=False, action='store_true', default=False) 409 | 410 | args = vars(parser.parse_args()) 411 | 412 | if args['method'] == 'modular': 413 | generator = ModularCNFGenerator(k=args['min_k'], min_n=args['min_n'], max_n=args['max_n'], min_q=args['min_q'], 414 | max_q=args['max_q'], min_c=args['min_c'], max_c=args['max_c'], min_alpha=args['min_a'], max_alpha=args['max_a'], alpha_resolution=args['res']) 415 | elif args['method'] == 'v-modular': 416 | generator = VariableModularCNFGenerator(min_k=args['min_k'], max_k=args['max_k'], min_n=args['min_n'], max_n=args['max_n'], min_q=args['min_q'], 417 | max_q=args['max_q'], min_c=args['min_c'], max_c=args['max_c'], min_alpha=args['min_a'], max_alpha=args['max_a'], alpha_resolution=args['res']) 418 | else: 419 | generator = UniformCNFGenerator(min_n=args['min_n'], max_n=args['max_n'], min_k=args['min_k'], 420 | max_k=args['max_k'], min_alpha=args['min_a'], max_alpha=args['max_a'], alpha_resolution=args['res']) 421 | 422 | generator.generate_dataset(args['size'], args['out_dir'], args['out_json'], args['name'], args['sat_only']) -------------------------------------------------------------------------------- /src/pdp/factorgraph/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 3 | 4 | # factor_graph_trainer.py : Defines the trainer base class for the PDP framework. 5 | 6 | import os 7 | import time 8 | import math 9 | import multiprocessing 10 | 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | from pdp.factorgraph.dataset import FactorGraphDataset 17 | 18 | 19 | def _module(model): 20 | return model.module if isinstance(model, nn.DataParallel) else model 21 | 22 | 23 | # pylint: disable=protected-access 24 | class FactorGraphTrainerBase: 25 | "Base class of the Factor Graph trainer pipeline (abstract)." 26 | 27 | # pylint: disable=unused-argument 28 | def __init__(self, config, has_meta_data, error_dim, loss, evaluator, use_cuda, logger): 29 | 30 | self._config = config 31 | self._logger = logger 32 | self._use_cuda = use_cuda and torch.cuda.is_available() 33 | 34 | if config['verbose']: 35 | if self._use_cuda: 36 | self._logger.info('Using GPU...') 37 | else: 38 | self._logger.info('Using CPU...') 39 | 40 | self._device = torch.device("cuda" if self._use_cuda else "cpu") 41 | 42 | self._error_dim = error_dim 43 | self._num_cores = multiprocessing.cpu_count() 44 | self._loss = loss 45 | self._evaluator = evaluator 46 | 47 | if config['verbose']: 48 | self._logger.info("The number of CPU cores is %s." % self._num_cores) 49 | 50 | torch.set_num_threads(self._num_cores) 51 | 52 | # Build the network 53 | self._model_list = [self._set_device(model) for model in self._build_graph(self._config)] 54 | 55 | def _build_graph(self, config): 56 | "Builds the forward computational graph." 57 | 58 | raise NotImplementedError("Subclass must implement abstract method") 59 | 60 | # pylint: disable=unused-argument 61 | def _compute_loss(self, model, loss, prediction, label, graph_map, batch_variable_map, 62 | batch_function_map, edge_feature, meta_data): 63 | "Computes the loss function." 64 | 65 | return loss(prediction, label) 66 | 67 | # pylint: disable=unused-argument 68 | def _compute_evaluation_metrics(self, model, evaluator, prediction, label, graph_map, 69 | batch_variable_map, batch_function_map, edge_feature, meta_data): 70 | "Computes the evaluation function." 71 | 72 | return evaluator(prediction, label) 73 | 74 | def _load(self, import_path_base): 75 | "Loads the model(s) from file." 76 | 77 | for model in self._model_list: 78 | _module(model).load(import_path_base) 79 | 80 | def _save(self, export_path_base): 81 | "Saves the model(s) to file." 82 | 83 | for model in self._model_list: 84 | _module(model).save(export_path_base) 85 | 86 | def _reset_global_step(self): 87 | "Resets the global step counter." 88 | 89 | for model in self._model_list: 90 | _module(model)._global_step.data = torch.tensor( 91 | [0], dtype=torch.float, device=self._device) 92 | 93 | def _set_device(self, model): 94 | "Sets the CPU/GPU device." 95 | 96 | if self._use_cuda: 97 | return nn.DataParallel(model).cuda(self._device) 98 | return model.cpu() 99 | 100 | def _to_cuda(self, data): 101 | if isinstance(data, list): 102 | return data 103 | 104 | if data is not None and self._use_cuda: 105 | return data.cuda(self._device, non_blocking=True) 106 | return data 107 | 108 | def get_parameter_list(self): 109 | "Returns list of dictionaries with models' parameters." 110 | return [{'params': filter(lambda p: p.requires_grad, model.parameters())} 111 | for model in self._model_list] 112 | 113 | def _train_epoch(self, train_loader, optimizer): 114 | 115 | train_batch_num = math.ceil(len(train_loader.dataset) / self._config['batch_size']) 116 | 117 | total_loss = np.zeros(len(self._model_list), dtype=np.float32) 118 | total_example_num = 0 119 | 120 | for (j, data) in enumerate(train_loader, 1): 121 | segment_num = len(data[0]) 122 | 123 | for i in range(segment_num): 124 | 125 | (graph_map, batch_variable_map, batch_function_map, 126 | edge_feature, graph_feat, label, _) = [self._to_cuda(d[i]) for d in data] 127 | total_example_num += (batch_variable_map.max() + 1) 128 | 129 | self._train_batch(total_loss, optimizer, graph_map, batch_variable_map, batch_function_map, 130 | edge_feature, graph_feat, label) 131 | 132 | if self._config['verbose']: 133 | print("Training epoch with batch of size {:4d} ({:4d}/{:4d}): {:3d}% complete...".format( 134 | batch_variable_map.max().item(), total_example_num % self._config['batch_size'], self._config['batch_size'], 135 | int(j * 100.0 / train_batch_num)), end='\r') 136 | 137 | del graph_map 138 | del batch_variable_map 139 | del batch_function_map 140 | del edge_feature 141 | del graph_feat 142 | del label 143 | 144 | for model in self._model_list: 145 | _module(model)._global_step += 1 146 | 147 | return total_loss / total_example_num # max(1, len(train_loader)) 148 | 149 | def _train_batch(self, total_loss, optimizer, graph_map, batch_variable_map, batch_function_map, 150 | edge_feature, graph_feat, label): 151 | 152 | optimizer.zero_grad() 153 | lambda_value = torch.tensor([self._config['lambda']], dtype=torch.float32, device=self._device) 154 | 155 | for (i, model) in enumerate(self._model_list): 156 | 157 | state = _module(model).get_init_state(graph_map, batch_variable_map, batch_function_map, 158 | edge_feature, graph_feat, self._config['randomized']) 159 | 160 | loss = torch.zeros(1, device=self._device) 161 | 162 | for t in torch.arange(self._config['train_outer_recurrence_num'], dtype=torch.int32, device=self._device): 163 | 164 | prediction, state = model( 165 | init_state=state, graph_map=graph_map, batch_variable_map=batch_variable_map, 166 | batch_function_map=batch_function_map, edge_feature=edge_feature, 167 | meta_data=graph_feat, is_training=True, iteration_num=self._config['train_inner_recurrence_num']) 168 | 169 | loss += self._compute_loss( 170 | model=_module(model), loss=self._loss, prediction=prediction, 171 | label=label, graph_map=graph_map, batch_variable_map=batch_variable_map, 172 | batch_function_map=batch_function_map, edge_feature=edge_feature, meta_data=graph_feat) * \ 173 | lambda_value.pow((self._config['train_outer_recurrence_num'] - t - 1).float()) 174 | 175 | loss.backward() 176 | nn.utils.clip_grad_norm_(model.parameters(), self._config['clip_norm']) 177 | total_loss[i] += loss.detach().cpu().numpy() 178 | 179 | for s in state: 180 | del s 181 | 182 | optimizer.step() 183 | 184 | def _test_epoch(self, validation_loader, batch_replication): 185 | 186 | test_batch_num = math.ceil(len(validation_loader.dataset) / self._config['batch_size']) 187 | 188 | with torch.no_grad(): 189 | 190 | error = np.zeros( 191 | (self._error_dim, len(self._model_list)), dtype=np.float32) 192 | total_example_num = 0 193 | 194 | for (j, data) in enumerate(validation_loader, 1): 195 | segment_num = len(data[0]) 196 | 197 | for i in range(segment_num): 198 | 199 | (graph_map, batch_variable_map, batch_function_map, 200 | edge_feature, graph_feat, label, _) = [self._to_cuda(d[i]) for d in data] 201 | total_example_num += (batch_variable_map.max() + 1).detach().cpu().numpy() 202 | 203 | self._test_batch(error, graph_map, batch_variable_map, batch_function_map, 204 | edge_feature, graph_feat, label, batch_replication) 205 | 206 | if self._config['verbose']: 207 | print("Testing epoch with batch of size {:4d} ({:4d}/{:4d}): {:3d}% complete...".format( 208 | batch_variable_map.max().item(), total_example_num % self._config['batch_size'], self._config['batch_size'], 209 | int(j * 100.0 / test_batch_num)), end='\r') 210 | 211 | del graph_map 212 | del batch_variable_map 213 | del batch_function_map 214 | del edge_feature 215 | del graph_feat 216 | del label 217 | 218 | # if self._use_cuda: 219 | # torch.cuda.empty_cache() 220 | 221 | return error / total_example_num 222 | 223 | def _test_batch(self, error, graph_map, batch_variable_map, batch_function_map, 224 | edge_feature, graph_feat, label, batch_replication): 225 | 226 | this_batch_size = batch_variable_map.max() + 1 227 | edge_num = graph_map.size(1) 228 | 229 | for (i, model) in enumerate(self._model_list): 230 | 231 | state = _module(model).get_init_state(graph_map, batch_variable_map, batch_function_map, 232 | edge_feature, graph_feat, randomized=True, batch_replication=batch_replication) 233 | 234 | prediction, _ = model( 235 | init_state=state, graph_map=graph_map, batch_variable_map=batch_variable_map, 236 | batch_function_map=batch_function_map, edge_feature=edge_feature, 237 | meta_data=graph_feat, is_training=False, iteration_num=self._config['test_recurrence_num'], 238 | check_termination=self._check_recurrence_termination, batch_replication=batch_replication) 239 | 240 | error[:, i] += (this_batch_size.float() * self._compute_evaluation_metrics( 241 | model=_module(model), evaluator=self._evaluator, 242 | prediction=prediction, label=label, graph_map=graph_map, 243 | batch_variable_map=batch_variable_map, batch_function_map=batch_function_map, 244 | edge_feature=edge_feature, meta_data=graph_feat)).detach().cpu().numpy() 245 | 246 | for p in prediction: 247 | del p 248 | 249 | for s in state: 250 | del s 251 | 252 | def _predict_epoch(self, validation_loader, post_processor, batch_replication, file): 253 | 254 | test_batch_num = math.ceil(len(validation_loader.dataset) / self._config['batch_size']) 255 | 256 | with torch.no_grad(): 257 | 258 | for (j, data) in enumerate(validation_loader, 1): 259 | segment_num = len(data[0]) 260 | 261 | for i in range(segment_num): 262 | 263 | (graph_map, batch_variable_map, batch_function_map, 264 | edge_feature, graph_feat, label, misc_data) = [self._to_cuda(d[i]) for d in data] 265 | 266 | self._predict_batch(graph_map, batch_variable_map, batch_function_map, 267 | edge_feature, graph_feat, label, misc_data, post_processor, batch_replication, file) 268 | 269 | del graph_map 270 | del batch_variable_map 271 | del batch_function_map 272 | del edge_feature 273 | del graph_feat 274 | del label 275 | 276 | # if self._config['verbose']: 277 | # print("Predicting epoch: %3d%% complete..." 278 | # % (j * 100.0 / test_batch_num), end='\r') 279 | 280 | def _predict_batch(self, graph_map, batch_variable_map, batch_function_map, 281 | edge_feature, graph_feat, label, misc_data, post_processor, batch_replication, file): 282 | 283 | edge_num = graph_map.size(1) 284 | 285 | for (i, model) in enumerate(self._model_list): 286 | 287 | state = _module(model).get_init_state(graph_map, batch_variable_map, batch_function_map, 288 | edge_feature, graph_feat, randomized=False, batch_replication=batch_replication) 289 | 290 | prediction, _ = model( 291 | init_state=state, graph_map=graph_map, batch_variable_map=batch_variable_map, 292 | batch_function_map=batch_function_map, edge_feature=edge_feature, 293 | meta_data=graph_feat, is_training=False, iteration_num=self._config['test_recurrence_num'], 294 | check_termination=self._check_recurrence_termination, batch_replication=batch_replication) 295 | 296 | if post_processor is not None and callable(post_processor): 297 | message = post_processor(_module(model), prediction, graph_map, 298 | batch_variable_map, batch_function_map, edge_feature, graph_feat, label, misc_data) 299 | print(message, file=file) 300 | 301 | for p in prediction: 302 | del p 303 | 304 | for s in state: 305 | del s 306 | 307 | def _check_recurrence_termination(self, active, prediction, sat_problem): 308 | "De-actives the CNF examples which the model has already found a SAT solution for." 309 | pass 310 | 311 | def train(self, train_list, validation_list, optimizer, last_export_path_base=None, 312 | best_export_path_base=None, metric_index=0, load_model=None, reset_step=False, 313 | generator=None, train_epoch_size=0): 314 | "Trains the PDP model." 315 | 316 | # Build the input pipeline 317 | train_loader = FactorGraphDataset.get_loader( 318 | input_file=train_list[0], limit=self._config['train_batch_limit'], 319 | hidden_dim=self._config['hidden_dim'], batch_size=self._config['batch_size'], shuffle=True, 320 | num_workers=self._num_cores, max_cache_size=self._config['max_cache_size'], generator=generator, 321 | epoch_size=train_epoch_size) 322 | 323 | validation_loader = FactorGraphDataset.get_loader( 324 | input_file=validation_list[0], limit=self._config['test_batch_limit'], 325 | hidden_dim=self._config['hidden_dim'], batch_size=self._config['batch_size'], shuffle=False, 326 | num_workers=self._num_cores, max_cache_size=self._config['max_cache_size']) 327 | 328 | model_num = len(self._model_list) 329 | 330 | errors = np.zeros( 331 | (self._error_dim, model_num, self._config['epoch_num'], 332 | self._config['repetition_num']), dtype=np.float32) 333 | 334 | losses = np.zeros( 335 | (model_num, self._config['epoch_num'], self._config['repetition_num']), 336 | dtype=np.float32) 337 | 338 | best_errors = np.repeat(np.inf, model_num) 339 | 340 | if self._use_cuda: 341 | torch.backends.cudnn.benchmark = True 342 | 343 | for rep in range(self._config['repetition_num']): 344 | 345 | if load_model == "best" and best_export_path_base is not None: 346 | self._load(best_export_path_base) 347 | elif load_model == "last" and last_export_path_base is not None: 348 | self._load(last_export_path_base) 349 | 350 | if reset_step: 351 | self._reset_global_step() 352 | 353 | for epoch in range(self._config['epoch_num']): 354 | 355 | # Training 356 | start_time = time.time() 357 | losses[:, epoch, rep] = self._train_epoch(train_loader, optimizer) 358 | 359 | if self._use_cuda: 360 | torch.cuda.empty_cache() 361 | 362 | # Validation 363 | errors[:, :, epoch, rep] = self._test_epoch(validation_loader, 1) 364 | duration = time.time() - start_time 365 | 366 | # Checkpoint the best models so far 367 | if last_export_path_base is not None: 368 | for (i, model) in enumerate(self._model_list): 369 | _module(model).save(last_export_path_base) 370 | 371 | if best_export_path_base is not None: 372 | for (i, model) in enumerate(self._model_list): 373 | if errors[metric_index, i, epoch, rep] < best_errors[i]: 374 | best_errors[i] = errors[metric_index, i, epoch, rep] 375 | _module(model).save(best_export_path_base) 376 | 377 | if self._use_cuda: 378 | torch.cuda.empty_cache() 379 | 380 | if self._config['verbose']: 381 | message = '' 382 | for (i, model) in enumerate(self._model_list): 383 | name = _module(model)._name 384 | message += 'Step {:d}: {:s} error={:s}, {:s} loss={:5.5f} |'.format( 385 | _module(model)._global_step.int()[0], name, 386 | np.array_str(errors[:, i, epoch, rep].flatten()), 387 | name, losses[i, epoch, rep]) 388 | 389 | self._logger.info('Rep {:2d}, Epoch {:2d}: {:s}'.format(rep + 1, epoch + 1, message)) 390 | self._logger.info('Time spent: %s seconds' % duration) 391 | 392 | if self._use_cuda: 393 | torch.backends.cudnn.benchmark = False 394 | 395 | if best_export_path_base is not None: 396 | # Save losses and errors 397 | base = os.path.relpath(best_export_path_base) 398 | np.save(os.path.join(base, "losses"), losses, allow_pickle=False) 399 | np.save(os.path.join(base, "errors"), errors, allow_pickle=False) 400 | 401 | # Save the model 402 | self._save(best_export_path_base) 403 | 404 | return self._model_list, errors, losses 405 | 406 | def test(self, test_list, import_path_base=None, batch_replication=1): 407 | "Tests the PDP model and generates test stats." 408 | 409 | if isinstance(test_list, list): 410 | test_files = test_list 411 | elif os.path.isdir(test_list): 412 | test_files = [os.path.join(test_list, f) for f in os.listdir(test_list) \ 413 | if os.path.isfile(os.path.join(test_list, f)) and f[-5:].lower() == '.json' ] 414 | elif isinstance(test_list, str): 415 | test_files = [test_list] 416 | else: 417 | return None 418 | 419 | result = [] 420 | 421 | for file in test_files: 422 | # Build the input pipeline 423 | test_loader = FactorGraphDataset.get_loader( 424 | input_file=file, limit=self._config['test_batch_limit'], 425 | hidden_dim=self._config['hidden_dim'], batch_size=self._config['batch_size'], shuffle=False, 426 | num_workers=self._num_cores, max_cache_size=self._config['max_cache_size'], batch_replication=batch_replication) 427 | 428 | if import_path_base is not None: 429 | self._load(import_path_base) 430 | 431 | start_time = time.time() 432 | error = self._test_epoch(test_loader, batch_replication) 433 | duration = time.time() - start_time 434 | 435 | if self._use_cuda: 436 | torch.cuda.empty_cache() 437 | 438 | if self._config['verbose']: 439 | message = '' 440 | for (i, model) in enumerate(self._model_list): 441 | message += '{:s}, dataset:{:s} error={:s}|'.format( 442 | _module(model)._name, file, np.array_str(error[:, i].flatten())) 443 | 444 | self._logger.info(message) 445 | self._logger.info('Time spent: %s seconds' % duration) 446 | 447 | result += [[file, error, duration]] 448 | 449 | return result 450 | 451 | def predict(self, test_list, out_file, import_path_base=None, post_processor=None, batch_replication=1): 452 | "Produces predictions for the trained PDP model." 453 | 454 | # Build the input pipeline 455 | test_loader = FactorGraphDataset.get_loader( 456 | input_file=test_list, limit=self._config['test_batch_limit'], 457 | hidden_dim=self._config['hidden_dim'], batch_size=self._config['batch_size'], shuffle=False, 458 | num_workers=self._num_cores, max_cache_size=self._config['max_cache_size'], batch_replication=batch_replication) 459 | 460 | if import_path_base is not None: 461 | self._load(import_path_base) 462 | 463 | start_time = time.time() 464 | self._predict_epoch(test_loader, post_processor, batch_replication, out_file) 465 | 466 | duration = time.time() - start_time 467 | 468 | if self._use_cuda: 469 | torch.cuda.empty_cache() 470 | 471 | if self._config['verbose']: 472 | self._logger.info('Time spent: %s seconds' % duration) 473 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PDP Framework for Neural Constraint Satisfaction Solving 2 | 3 | The PDP framework is a generic framework based on the idea of Propagation, Decimation and Prediction (PDP) for learning and implementing message passing-based solvers for constraint satisfaction problems (CSP). In particular, it provides an elegant unsupervised framework for training neural solvers based on the idea of energy minimization. Our SAT solver adaptation of the PDP framework, referred as **SATYR**, supports a wide spectrum of solvers from fully neural architectures to classical inference-based techniques (such as Survey Propagation) with hybrid methods in between. For further theoretical details of the framework, please refer to our paper: 4 | 5 | **Saeed Amizadeh, Sergiy Matusevych and Markus Weimer, [PDP: A General Neural Framework for Learning Constraint Satisfaction Solvers](https://arxiv.org/abs/1903.01969), arXiv preprint arXiv:1903.01969, 2019.** [**[Video]**](https://www.youtube.com/watch?v=PwuCc7Lylww&ab_channel=CP2020) 6 | 7 | ``` 8 | @article{amizadeh2019pdp, 9 | title={PDP: A General Neural Framework for Learning Constraint Satisfaction Solvers}, 10 | author={Amizadeh, Saeed and Matusevych, Sergiy and Weimer, Markus}, 11 | journal={arXiv preprint arXiv:1903.01969}, 12 | year={2019} 13 | } 14 | ``` 15 | 16 | We also note the present work is still far away from competing with modern industrial solvers; nevertheless, we believe it is a significant step in the right direction for machine learning-based methods. Hence, we are glad to open source our code to the researchers in related fields including the neuro-symbolic community as well as the classical SAT community. 17 | 18 | # SATYR 19 | 20 | SATYR is the adaptation of the PDP framework for training and deploying neural Boolean Satisfiability solvers. In particular, SATYR implements: 21 | 22 | 1. Fully or partially neural SAT solvers that can be trained toward solving SAT for a specific distribution of problem instances. The training is based on unsupervised energy minimization and can be performed on an infinite stream of unlabeled, random instances sampled from the target distribution. 23 | 24 | 2. Non-learnable classical solvers based on message passing in graphical models (e.g. Survey Propagation). Even though, these solvers are non-learnable, they still benefit from the embarrassingly parallel implementation via the PDP framework on GPUs. 25 | 26 | It should be noted that all the SATYR solvers try to find a satisfying assignment for input SAT formulas. However, if the SATYR solvers cannot find a satisfying solution for a given problem within their iteration number budget, it does NOT necessarily mean that the input problem in UNSAT. In other words, none of the SATYR solvers provide the proof of unsatisfiability. 27 | 28 | # Setup 29 | 30 | ## Prerequisites 31 | 32 | * Python 3.5 or higher. 33 | * PyTorch 0.4.0 or higher. 34 | 35 | Run: 36 | 37 | ``` 38 | > python setup.py install 39 | ``` 40 | 41 | # Usage 42 | 43 | The SATYR solvers can be used in two main modes: (1) apply an already-trained model or non-ML algorithm to test data, and (2) train/test new models. 44 | 45 | ## Running a (Trained) SATYR Solver 46 | 47 | The usage for running a SATYR solver againts a set of SAT problems (represented as Conjunctive Normal Form (CNF)) is: 48 | 49 | ``` 50 | > python satyr.py [-h] [-b BATCH_REPLICATION] [-z BATCH_SIZE] 51 | [-m MAX_CACHE_SIZE] [-l TEST_BATCH_LIMIT] 52 | [-w LOCAL_SEARCH_ITERATION] [-e EPSILON] [-v] [-c] [-d] 53 | [-s RANDOM_SEED] [-o OUTPUT] 54 | model_config test_path test_recurrence_num 55 | ``` 56 | 57 | The commandline arguments are: 58 | 59 | + **-h** or **--help**: Shows the commandline options. 60 | 61 | + **-b BATCH_REPLICATION** or **--batch_replication BATCH_REPLICATION**: BATCH_REPLICATION is the replication factor for input problems to further benefit from parallelization (default: 1). 62 | 63 | + **-z BATCH_SIZE** or **--batch_size BATCH_SIZE**: BATCH_SIZE is the batch size (default: 5000). 64 | 65 | + **-m MAX_CACHE_SIZE** or **--max_cache_size MAX_CACHE_SIZE**: MAX_CACHE_SIZE is the maximum size of the cache containing the parsed CNFs loaded from disk (mostly useful for iterative training) (default: 100000). 66 | 67 | + **-l TEST_BATCH_LIMIT** or **--test_batch_limit TEST_BATCH_LIMIT**: TEST_BATCH_LIMIT is the memory limit used for dynamic batching. It must be empirically tuned by the user depending on the available GPU memory (default: 40000000). 68 | 69 | + **-w LOCAL_SEARCH_ITERATION** or **--local_search_iteration LOCAL_SEARCH_ITERATION**: LOCAL_SEARCH_ITERATION is the maximum number of local search (i.e. Walk-SAT) iterations that can be optionally applied as a post-processing step after the main solver terminates (default: 100). 70 | 71 | + **-e EPSILON** or **--epsilon EPSILON**: EPSILON is the probability with which the post-processing local search picks a random variable instead of the best option for flipping (default: 0.5). 72 | 73 | + **-v** or **--verbose**: Prints the log messages to STDOUT. 74 | 75 | + **-c** or **--cpu_mode**: Forces the solver to run on CPU. 76 | 77 | + **-d** or **--dimacs**: Notifies the solver that the input path is a directory of DIMACS files. 78 | 79 | + **-s RANDOM_SEED** or **--random_seed RANDOM_SEED**: RANDOM_SEED is the random seed directly affecting the randomized initial values of the messages in the PDP solvers. 80 | 81 | + **-o OUTPUT** or **--output OUTPUT**: OUTPUT is the path to the output JSON file that would contain the solutions for the input CNFs. If not specified, the output is directed to STDOUT. 82 | 83 | + **model_config**: The path to YAML config file specifying the model used for SAT solving. A few example config files are provided [here](https://github.com/Microsoft/PDP-Solver/tree/master/config/Predict). A model config file specifies the following properties: 84 | 85 | * **model_type**: The type of solver. So far, we have implemented six different types of solvers in SATYR: 86 | 87 | * *'np-nd-np'*: A fully neural PDP solver. 88 | 89 | * *'p-d-p'*: A PDP solver that implements the classical Survey Propagation with greedy sequential decimation.[[1]](#reference-1) 90 | 91 | * *'p-nd-np'*: A PDP solver that implements the classical Survey Propagation except with neural decimation. 92 | 93 | * *'np-d-np'*: A PDP solver that implements neural propagation with a greedy sequential decimation. 94 | 95 | * *'reinforce'*: A PDP solver that implements the classical Survey Propagation with concurrent distributed decimation (The REINFORCE algorithm).[[2]](#reference-2) 96 | 97 | * *'walk-sat'*: A PDP solver that implements the classical local search Walk-SAT algorithm.[[3]](#reference-3) 98 | 99 | * **has_meta_data**: whether the input problem instances contain meta features other than the CNF itself (Note: loading of such features is not supported by the current input pipeline). 100 | 101 | * **model_name**: The name picked for the model by the user. 102 | 103 | * **model_path**: The path to the saved weights for the trained model. 104 | 105 | * **label_dim**: Always set to 1 for SATYR. 106 | 107 | * **edge_feature_dim**: Always set to 1 for SATYR. 108 | 109 | * **meta_feature_dim**: The dimensionality of the meta features (0 for now). 110 | 111 | * **prediction_dim**: Always set to 1 for SATYR. 112 | 113 | If **model_type** is *'np-nd-np'*, *'p-nd-np'* or *'np-d-np'*: 114 | 115 | * **hidden_dim**: The dimensionality of messages between the propagator and the decimator. 116 | 117 | If **model_type** is *'np-nd-np'* or *'np-d-np'*: 118 | 119 | * **mem_hidden_dim**: The dimensionality of the hidden layer for the perceptron that is applied to messages *before* the aggregation step in the propagator. 120 | 121 | * **agg_hidden_dim**: The dimensionality of the hidden layer for the perceptron that is applied to messages *after* the aggregation step in the propagator. 122 | 123 | * **mem_agg_hidden_dim**: The output dimensionality of the perceptron that is applied to messages *before* the aggregation step in the propagator. 124 | 125 | If **model_type** is *'np-nd-np'*, *'p-nd-np'* or *'np-d-np'*: 126 | 127 | * **classifier_dim**: The dimensionality of the hidden layer for the perceptron that is used as the final predictor. 128 | 129 | If **model_type** is *'p-d-p'* or *'np-d-np'*: 130 | 131 | * **tolerance**: The convergence tolerance for the propagator before sequential decimator is invoked. 132 | 133 | * **t_max**: The maximum iteration number for the propagator before sequential decimator is invoked. 134 | 135 | If **model_type** is *'reinforce'*: 136 | 137 | * **pi**: The external force magnitude parameter for the REINFORCE algorithm. 138 | 139 | * **decimation_probability**: The probability with which the distributed decimation is invoked in the REINFORCE algorithm. 140 | 141 | + **test_path**: The path to the input JSON file containing the test CNFs (in the case of '-d' option, the path to the directory containing the test DIMACS files.) 142 | 143 | + **test_recurrence_num**: The maximum number of iterations the solver is allowed to run before termination. 144 | 145 | 146 | ## Training/Testing a SATYR Solver 147 | 148 | The usage for training/testing new SATYR models is: 149 | 150 | ``` 151 | > python satyr-train-test.py [-h] [-t] [-l LOAD_MODEL] [-c] [-r] [-g] 152 | [-b BATCH_REPLICATION] 153 | config 154 | ``` 155 | 156 | The commandline arguments are: 157 | 158 | + **-h** or **--help**: Shows the commandline options. 159 | 160 | + **-t** or **--test**: Skips the training stage directly to the testing stage. 161 | 162 | + **-l LOAD_MODEL** or **-load LOAD_MODEL**: LOAD_MODEL is: 163 | 164 | * *best*: The model is initialized by the best model (according to the validation metric) saved from the previous run. 165 | 166 | * *last*: The model is initialized by the last model saved from the previous run. 167 | 168 | * Otherwise, the model is initialized by random weights. 169 | 170 | + **-c** or **--cpu_mode**: Forces the training/testing to run on CPU. 171 | 172 | + **-r** or **--reset**: Resets the global time parameter to 0 (used for annealing the temperature). 173 | 174 | + **-g** or **--use_generator**: Makes the training process use one of the provided CNF generators to generate unlabeled training CNF instances on the fly. 175 | 176 | + **-b BATCH_REPLICATION** or **--batch_replication BATCH_REPLICATION**: BATCH_REPLICATION is the replication factor for input problems to further benefit from parallelization (default: 1). 177 | 178 | + **config**: The path to YAML config file specifying the model as well as the training parameters. A few example training config files are provided [here](https://github.com/Microsoft/PDP-Solver/tree/master/config/Train). A training config file specifies the following properties: 179 | 180 | * **model_name**: The name picked for the model by the user. 181 | 182 | * **model_type**: The model type explained above. 183 | 184 | * **version**: The model version. 185 | 186 | * **has_meta_data**: whether the input problem instances contain meta features other than the CNF itself (Note: loading of such features is not supported by the current input pipeline). 187 | 188 | * **train_path**: A one-element list containing the path to the training JSON file. Will be ignored in the case of using option -g. 189 | 190 | * **validation_path**: A one-element list containing the path to the validation JSON file. Validation set is used for picking the best model during each training run. 191 | 192 | * **test_path**: A list containing the path(s) to test JSON files. 193 | 194 | * **model_path**: The parent directory of the location where the best and the last models are saved. 195 | 196 | * **repetition_num**: Number of repetitions for the training process (for regular scenarios: 1). 197 | 198 | * **train_epoch_size**: The size of one epoch in the case of using CNF generators via option -g. 199 | 200 | * **epoch_num**: The number of epochs for training. 201 | 202 | * **label_dim**: Always set to 1 for SATYR. 203 | 204 | * **edge_feature_dim**: Always set to 1 for SATYR. 205 | 206 | * **meta_feature_dim**: The dimensionality of the meta features (0 for now). 207 | 208 | * **error_dim**: The number error metrics the model reports on the validation/test sets (3 for now: accuracy, recall and test loss). 209 | 210 | * **metric_index**: The 0-based index of the error metric used to pick the best model. 211 | 212 | * **prediction_dim**: Always set to 1 for SATYR. 213 | 214 | * **batch_size**: The batch size used for training/testing. 215 | 216 | * **learning_rate**: The learning rate for ADAM optimization algorithm used for training. 217 | 218 | * **exploration**: The exploration factor used for annealing the temperature. 219 | 220 | * **verbose**: If TRUE prints log messages to STDOUT. 221 | 222 | * **randomized**: If TRUE initializes the propagator and the decimator messages with random values; otherwise with zeros. 223 | 224 | * **train_inner_recurrence_num**: The number of inner loop iterations before the loss function is computed (typically is set to 1). 225 | 226 | * **train_outer_recurrence_num**: The number of outer loop iterations (T in the paper) used during training. 227 | 228 | * **test_recurrence_num**: The number of outer loop iterations (T in the paper) used during testing. 229 | 230 | * **max_cache_size**: The maximum size of the cache containing the parsed CNFs loaded from disk during training. 231 | 232 | * **dropout**: The dropout factor during training. 233 | 234 | * **clip_norm**: The clip norm ratio used for gradient clipping during training. 235 | 236 | * **weight_decay**: The weight decay coefficient for ADAM optimizer used for training. 237 | 238 | * **loss_sharpness**: The sharpness of the step function used for calculating loss (the kappa parameter in the paper). 239 | 240 | + **train_batch_limit**: The memory limit used for dynamic batching during training. It must be empirically tuned by the user depending on the available GPU memory. 241 | 242 | + **test_batch_limit**: The memory limit used for dynamic batching during testing. It must be empirically tuned by the user depending on the available GPU memory. 243 | 244 | * **generator**: The type of CNF generator incorporated in the case -g option is deployed: 245 | 246 | * *'uniform'*: The uniform random k-SAT generator. 247 | 248 | * *'modular'*: The modular random k-SAT generator with fixed k (specified by **min_k**) according to the Community Attachment model[[4]](#reference-4). 249 | 250 | * *'v-modular'*: The modular random k-SAT generator with variable size k according to the Community Attachment model. 251 | 252 | * **min_n**: The minimum number of variables for a random training CNF instance in the case -g option is deployed. 253 | 254 | * **max_n**: The maximum number of variables for a random training CNF instance in the case -g option is deployed. 255 | 256 | * **min_alpha**: The minimum clause/variable ratio for a random training CNF instance in the case -g option is deployed. 257 | 258 | * **max_alpha**: The maximum clause/variable ratio for a random training CNF instance in the case -g option is deployed. 259 | 260 | * **min_k**: The minimum clause size for a random training CNF instance in the case -g option is deployed. 261 | 262 | * **max_k**: The maximum clause size for a random training CNF instance in the case -g option is deployed (not supported for *'v-modular'* generator). 263 | 264 | * **min_q**: The minimum modularity value for a random training CNF instance generated according to the Community Attachment model in the case -g option is deployed (not supported for *'uniform'* generator). 265 | 266 | * **max_q**: The maximum modularity value for a random training CNF instance generated according to the Community Attachment model in the case -g option is deployed (not supported for *'uniform'* generator). 267 | 268 | * **min_c**: The minimum number of communities for a random training CNF instance generated according to the Community Attachment model in the case -g option is deployed (not supported for *'uniform'* generator). 269 | 270 | * **max_c**: The maximum number of communities for a random training CNF instance generated according to the Community Attachment model in the case -g option is deployed (not supported for *'uniform'* generator). 271 | 272 | * **local_search_iteration**: The maximum number of local search (i.e. Walk-SAT) iterations that can be optionally applied as a post-processing step after the main solver terminates during testing. 273 | 274 | * **epsilon**: The probability with which the optional post-processing local search picks a random variable instead of the best option for flipping. 275 | 276 | * **lambda**: The discounting factor in (0, 1] used for loss calculation (the lambda parameter in the paper). 277 | 278 | If **model_type** is *'np-nd-np'*, *'p-nd-np'* or *'np-d-np'*: 279 | 280 | * **hidden_dim**: The dimensionality of messages between the propagator and the decimator. 281 | 282 | If **model_type** is *'np-nd-np'* or *'np-d-np'*: 283 | 284 | * **mem_hidden_dim**: The dimensionality of the hidden layer for the perceptron that is applied to messages *before* the aggregation step in the propagator. 285 | 286 | * **agg_hidden_dim**: The dimensionality of the hidden layer for the perceptron that is applied to messages *after* the aggregation step in the propagator. 287 | 288 | * **mem_agg_hidden_dim**: The output dimensionality of the perceptron that is applied to messages *before* the aggregation step in the propagator. 289 | 290 | If **model_type** is *'np-nd-np'*, *'p-nd-np'* or *'np-d-np'*: 291 | 292 | * **classifier_dim**: The dimensionality of the hidden layer for the perceptron that is used as the final predictor. 293 | 294 | If **model_type** is *'p-d-p'* or *'np-d-np'*: 295 | 296 | * **tolerance**: The convergence tolerance for the propagator before sequential decimator is invoked. 297 | 298 | * **t_max**: The maximum iteration number for the propagator before sequential decimator is invoked. 299 | 300 | If **model_type** is *'reinforce'*: 301 | 302 | * **pi**: The external force magnitude parameter for the REINFORCE algorithm. 303 | 304 | * **decimation_probability**: The probability with which the distributed decimation is invoked in the REINFORCE algorithm. 305 | 306 | # Input/Output Formats 307 | 308 | ## Input 309 | 310 | SATYR effectively works with the standard DIMACS format for representing CNF formulas. However, in order to increase the ingressing efficiency, the actual solvers work directly with an intermediate JSON format instead of the DIMACS representation for consuming input CNF data. A key feature of the intermediate JSON format is that an entire set of DIMACS files can be represented by a single JSON file where each row in the JSON file associates with one DIMACS file. 311 | 312 | The train/test script assumes the train/validation/test sets are already in the JSON format. In order to convert a set of DIMACS files into a single JSON file, we have provided the following script: 313 | 314 | ``` 315 | > python dimacs2json.py [-h] [-s] [-p] in_dir out_file 316 | ``` 317 | 318 | where the commandline arguments are: 319 | 320 | + **-h** or **--help**: Shows the commandline options. 321 | 322 | + **-s** or **--simplify**: Performs elementary clause propagation simplification on the CNF formulas before converting them to JSON format. This option is not recommended for large formulas as it takes quadratic memory and time in terms of the number of clauses. 323 | 324 | + **-p** or **--positives**: Writes only the satisfiable examples in the output JSON file. This option is specially useful for creating all SAT validation/test sets. Note that this option does NOT invoke any external solver to find out whether an example is SAT or not. Instead, it only works if the SAT/UNSAT labels are already provided in the names of the DIMACS files. In particular, if the name of an input DIMACS file ends in '1', it will be regarded as a SAT (positive) example. 325 | 326 | + **in_dir**: The path to the parent directory of the input DIMACS files. 327 | 328 | + **out_file**: The path of the output JSON file. 329 | 330 | The solver script, however, does not require the input problems to be in the JSON format; they can be in the DIMACS format as long as the -d option is deployed. Nevertheless, for repetitive applications of the solver script on the same input set, we would recommend externally converting the input DIMACS files into the JSON format once and only consume the JSON file afterwards. 331 | 332 | ## Output 333 | 334 | The output of the solver script is a JSON file where each line corresponds to one input CNF instance and is a dictionary with the following key:value pairs: 335 | 336 | + **"ID"**: The DIMACS file name associated with a CNF example. 337 | 338 | + **"label"**: The binary SAT/UNSAT (0/1) label associated with a CNF example (only if it is already provided in the DIMACS filename). 339 | 340 | + **"solved"**: The binary flag showing whether the provided solution satisfies the CNF. 341 | 342 | + **"unsat_clauses"**: The number of clauses in the CNF that are not satisfied by the provided solution (0 if the CNF is satisfied by the solution). 343 | 344 | + **"solution"**: The provided solution by the solver. The variable assignments in the list are ordered based on the increasing variable indices in the original DIMACS file. 345 | 346 | # Main Contributors 347 | 348 | + [Saeed Amizadeh](mailto:saeed.amizadeh@gmail.com), Microsoft Inc. 349 | + [Sergiy Matusevych](mailto:sergiy.matusevych@gmail.com), Microsoft Inc. 350 | 351 | # Contributing 352 | 353 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 354 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 355 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 356 | 357 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 358 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 359 | provided by the bot. You will only need to do this once across all repos using our CLA. 360 | 361 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 362 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 363 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 364 | 365 | ## Extending The PDP Framework 366 | 367 | The PDP framework supports a wide range of solvers from fully neural solvers to hybrid, neuro-symbolic models all the way to classical, non-learnable algorithms. So far, we have only implemented six different types, but there is definitely room for more. Therefore, we highly encourage contributions in the form of other types of PDP-based solvers. Furthermore, we welcome contributions with PDP-based adaptations for other types of constraint satisfaction problems beyond SAT. 368 | 369 | # References 370 | 371 | 1. Mezard, M. and Montanari, A. Information, physics, and computation. Oxford University Press, 2009. 372 | 2. Chavas, J., Furtlehner, C., Mezard, M., and Zecchina, R. Survey-propagation decimation through distributed local computations. Journal of Statistical Mechanics: Theory and Experiment, 2005(11):P11016, 2005. 373 | 3. Hoos, Holger H. On the Run-time Behaviour of Stochastic Local Search Algorithms for SAT. In AAAI/IAAI, pp. 661-666. 1999. 374 | 4. Giraldez-Cru, J. and Levy, J. Generating sat instances with community structure. Artificial Intelligence, 238:119–134, 2016. 375 | -------------------------------------------------------------------------------- /src/pdp/nn/solver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. 3 | 4 | # solver.py : Defines the base class for all PDP Solvers as well as the various inherited solvers. 5 | 6 | import os 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from pdp.nn import pdp_propagate, pdp_decimate, pdp_predict, util 13 | 14 | 15 | ############################################################### 16 | ### The Problem Class 17 | ############################################################### 18 | 19 | class SATProblem(object): 20 | "The class that encapsulates a batch of CNF problem instances." 21 | 22 | def __init__(self, data_batch, device, batch_replication=1): 23 | self._device = device 24 | self._batch_replication = batch_replication 25 | self.setup_problem(data_batch, batch_replication) 26 | self._edge_mask = None 27 | 28 | def setup_problem(self, data_batch, batch_replication): 29 | "Setup the problem properties as well as the relevant sparse matrices." 30 | 31 | if batch_replication > 1: 32 | self._replication_mask_tuple = self._compute_batch_replication_map(data_batch[1], batch_replication) 33 | self._graph_map, self._batch_variable_map, self._batch_function_map, self._edge_feature, self._meta_data, _ = self._replicate_batch(data_batch, batch_replication) 34 | else: 35 | self._graph_map, self._batch_variable_map, self._batch_function_map, self._edge_feature, self._meta_data, _ = data_batch 36 | 37 | self._variable_num = self._batch_variable_map.size()[0] 38 | self._function_num = self._batch_function_map.size()[0] 39 | self._edge_num = self._graph_map.size()[1] 40 | 41 | self._vf_mask_tuple = self._compute_variable_function_map(self._graph_map, self._batch_variable_map, 42 | self._batch_function_map, self._edge_feature) 43 | self._batch_mask_tuple = self._compute_batch_map(self._batch_variable_map, self._batch_function_map) 44 | self._graph_mask_tuple = self._compute_graph_mask(self._graph_map, self._batch_variable_map, self._batch_function_map) 45 | self._pos_mask_tuple = self._compute_graph_mask(self._graph_map, self._batch_variable_map, self._batch_function_map, (self._edge_feature == 1).squeeze(1).float()) 46 | self._neg_mask_tuple = self._compute_graph_mask(self._graph_map, self._batch_variable_map, self._batch_function_map, (self._edge_feature == -1).squeeze(1).float()) 47 | self._signed_mask_tuple = self._compute_graph_mask(self._graph_map, self._batch_variable_map, self._batch_function_map, self._edge_feature.squeeze(1)) 48 | 49 | self._active_variables = torch.ones(self._variable_num, 1, device=self._device) 50 | self._active_functions = torch.ones(self._function_num, 1, device=self._device) 51 | self._solution = 0.5 * torch.ones(self._variable_num, device=self._device) 52 | 53 | self._batch_size = (self._batch_variable_map.max() + 1).long().item() 54 | self._is_sat = 0.5 * torch.ones(self._batch_size, device=self._device) 55 | 56 | def _replicate_batch(self, data_batch, batch_replication): 57 | "Implements the batch replication." 58 | 59 | graph_map, batch_variable_map, batch_function_map, edge_feature, meta_data, label = data_batch 60 | edge_num = graph_map.size()[1] 61 | batch_size = (batch_variable_map.max() + 1).long().item() 62 | variable_num = batch_variable_map.size()[0] 63 | function_num = batch_function_map.size()[0] 64 | 65 | ind = torch.arange(batch_replication, dtype=torch.int32, device=self._device).unsqueeze(1).repeat(1, edge_num).view(1, -1) 66 | graph_map = graph_map.repeat(1, batch_replication) + ind.repeat(2, 1) * torch.tensor([[variable_num], [function_num]], dtype=torch.int32, device=self._device) 67 | 68 | ind = torch.arange(batch_replication, dtype=torch.int32, device=self._device).unsqueeze(1).repeat(1, variable_num).view(1, -1) 69 | batch_variable_map = batch_variable_map.repeat(batch_replication) + ind * batch_size 70 | 71 | ind = torch.arange(batch_replication, dtype=torch.int32, device=self._device).unsqueeze(1).repeat(1, function_num).view(1, -1) 72 | batch_function_map = batch_function_map.repeat(batch_replication) + ind * batch_size 73 | 74 | edge_feature = edge_feature.repeat(batch_replication, 1) 75 | 76 | if meta_data is not None: 77 | meta_data = meta_data.repeat(batch_replication, 1) 78 | 79 | if label is not None: 80 | label = label.repeat(batch_replication, 1) 81 | 82 | return graph_map, batch_variable_map.squeeze(0), batch_function_map.squeeze(0), edge_feature, meta_data, label 83 | 84 | def _compute_batch_replication_map(self, batch_variable_map, batch_replication): 85 | batch_size = (batch_variable_map.max() + 1).long().item() 86 | x_ind = torch.arange(batch_size * batch_replication, dtype=torch.int64, device=self._device) 87 | y_ind = torch.arange(batch_size, dtype=torch.int64, device=self._device).repeat(batch_replication) 88 | ind = torch.stack([x_ind, y_ind]) 89 | all_ones = torch.ones(batch_size * batch_replication, device=self._device) 90 | 91 | if self._device.type == 'cuda': 92 | mask = torch.cuda.sparse.FloatTensor(ind, all_ones, 93 | torch.Size([batch_size * batch_replication, batch_size]), device=self._device) 94 | else: 95 | mask = torch.sparse.FloatTensor(ind, all_ones, 96 | torch.Size([batch_size * batch_replication, batch_size]), device=self._device) 97 | 98 | mask_transpose = mask.transpose(0, 1) 99 | return (mask, mask_transpose) 100 | 101 | def _compute_variable_function_map(self, graph_map, batch_variable_map, batch_function_map, edge_feature): 102 | edge_num = graph_map.size()[1] 103 | variable_num = batch_variable_map.size()[0] 104 | function_num = batch_function_map.size()[0] 105 | all_ones = torch.ones(edge_num, device=self._device) 106 | 107 | if self._device.type == 'cuda': 108 | mask = torch.cuda.sparse.FloatTensor(graph_map.long(), all_ones, 109 | torch.Size([variable_num, function_num]), device=self._device) 110 | signed_mask = torch.cuda.sparse.FloatTensor(graph_map.long(), edge_feature.squeeze(1), 111 | torch.Size([variable_num, function_num]), device=self._device) 112 | else: 113 | mask = torch.sparse.FloatTensor(graph_map.long(), all_ones, 114 | torch.Size([variable_num, function_num]), device=self._device) 115 | signed_mask = torch.sparse.FloatTensor(graph_map.long(), edge_feature.squeeze(1), 116 | torch.Size([variable_num, function_num]), device=self._device) 117 | 118 | mask_transpose = mask.transpose(0, 1) 119 | signed_mask_transpose = signed_mask.transpose(0, 1) 120 | 121 | return (mask, mask_transpose, signed_mask, signed_mask_transpose) 122 | 123 | def _compute_batch_map(self, batch_variable_map, batch_function_map): 124 | variable_num = batch_variable_map.size()[0] 125 | function_num = batch_function_map.size()[0] 126 | variable_all_ones = torch.ones(variable_num, device=self._device) 127 | function_all_ones = torch.ones(function_num, device=self._device) 128 | variable_range = torch.arange(variable_num, dtype=torch.int64, device=self._device) 129 | function_range = torch.arange(function_num, dtype=torch.int64, device=self._device) 130 | batch_size = (batch_variable_map.max() + 1).long().item() 131 | 132 | variable_sparse_ind = torch.stack([variable_range, batch_variable_map.long()]) 133 | function_sparse_ind = torch.stack([function_range, batch_function_map.long()]) 134 | 135 | if self._device.type == 'cuda': 136 | variable_mask = torch.cuda.sparse.FloatTensor(variable_sparse_ind, variable_all_ones, 137 | torch.Size([variable_num, batch_size]), device=self._device) 138 | function_mask = torch.cuda.sparse.FloatTensor(function_sparse_ind, function_all_ones, 139 | torch.Size([function_num, batch_size]), device=self._device) 140 | else: 141 | variable_mask = torch.sparse.FloatTensor(variable_sparse_ind, variable_all_ones, 142 | torch.Size([variable_num, batch_size]), device=self._device) 143 | function_mask = torch.sparse.FloatTensor(function_sparse_ind, function_all_ones, 144 | torch.Size([function_num, batch_size]), device=self._device) 145 | 146 | variable_mask_transpose = variable_mask.transpose(0, 1) 147 | function_mask_transpose = function_mask.transpose(0, 1) 148 | 149 | return (variable_mask, variable_mask_transpose, function_mask, function_mask_transpose) 150 | 151 | def _compute_graph_mask(self, graph_map, batch_variable_map, batch_function_map, edge_values=None): 152 | edge_num = graph_map.size()[1] 153 | variable_num = batch_variable_map.size()[0] 154 | function_num = batch_function_map.size()[0] 155 | 156 | if edge_values is None: 157 | edge_values = torch.ones(edge_num, device=self._device) 158 | 159 | edge_num_range = torch.arange(edge_num, dtype=torch.int64, device=self._device) 160 | 161 | variable_sparse_ind = torch.stack([graph_map[0, :].long(), edge_num_range]) 162 | function_sparse_ind = torch.stack([graph_map[1, :].long(), edge_num_range]) 163 | 164 | if self._device.type == 'cuda': 165 | variable_mask = torch.cuda.sparse.FloatTensor(variable_sparse_ind, edge_values, 166 | torch.Size([variable_num, edge_num]), device=self._device) 167 | function_mask = torch.cuda.sparse.FloatTensor(function_sparse_ind, edge_values, 168 | torch.Size([function_num, edge_num]), device=self._device) 169 | else: 170 | variable_mask = torch.sparse.FloatTensor(variable_sparse_ind, edge_values, 171 | torch.Size([variable_num, edge_num]), device=self._device) 172 | function_mask = torch.sparse.FloatTensor(function_sparse_ind, edge_values, 173 | torch.Size([function_num, edge_num]), device=self._device) 174 | 175 | variable_mask_transpose = variable_mask.transpose(0, 1) 176 | function_mask_transpose = function_mask.transpose(0, 1) 177 | 178 | return (variable_mask, variable_mask_transpose, function_mask, function_mask_transpose) 179 | 180 | def _peel(self): 181 | "Implements the peeling algorithm." 182 | 183 | vf_map, vf_map_transpose, signed_vf_map, _ = self._vf_mask_tuple 184 | 185 | variable_degree = torch.mm(vf_map, self._active_functions) 186 | signed_variable_degree = torch.mm(signed_vf_map, self._active_functions) 187 | 188 | while True: 189 | single_variables = (variable_degree == signed_variable_degree.abs()).float() * self._active_variables 190 | 191 | if torch.sum(single_variables) <= 0: 192 | break 193 | 194 | single_functions = (torch.mm(vf_map_transpose, single_variables) > 0).float() * self._active_functions 195 | degree_delta = torch.mm(vf_map, single_functions) * self._active_variables 196 | signed_degree_delta = torch.mm(signed_vf_map, single_functions) * self._active_variables 197 | self._solution[single_variables[:, 0] == 1] = (signed_variable_degree[single_variables[:, 0] == 1, 0].sign() + 1) / 2.0 198 | 199 | variable_degree -= degree_delta 200 | signed_variable_degree -= signed_degree_delta 201 | 202 | self._active_variables[single_variables[:, 0] == 1, 0] = 0 203 | self._active_functions[single_functions[:, 0] == 1, 0] = 0 204 | 205 | def _set_variable_core(self, assignment): 206 | "Fixes variables to certain binary values." 207 | 208 | _, vf_map_transpose, _, signed_vf_map_transpose = self._vf_mask_tuple 209 | 210 | assignment *= self._active_variables 211 | 212 | # Compute the number of inputs for each function node 213 | input_num = torch.mm(vf_map_transpose, assignment.abs()) 214 | 215 | # Compute the signed evaluation for each function node 216 | function_eval = torch.mm(signed_vf_map_transpose, assignment) 217 | 218 | # Compute the de-activated functions 219 | deactivated_functions = (function_eval > -input_num).float() * self._active_functions 220 | 221 | # De-activate functions and variables 222 | self._active_variables[assignment[:, 0].abs() == 1, 0] = 0 223 | self._active_functions[deactivated_functions[:, 0] == 1, 0] = 0 224 | 225 | # Update the solution 226 | self._solution[assignment[:, 0].abs() == 1] = (assignment[assignment[:, 0].abs() == 1, 0] + 1) / 2.0 227 | 228 | def _propagate_single_clauses(self): 229 | "Implements unit clause propagation algorithm." 230 | 231 | vf_map, vf_map_transpose, signed_vf_map, _ = self._vf_mask_tuple 232 | b_variable_mask, b_variable_mask_transpose, b_function_mask, _ = self._batch_mask_tuple 233 | 234 | while True: 235 | function_degree = torch.mm(vf_map_transpose, self._active_variables) 236 | single_functions = (function_degree == 1).float() * self._active_functions 237 | 238 | if torch.sum(single_functions) <= 0: 239 | break 240 | 241 | # Compute the number of inputs for each variable node 242 | input_num = torch.mm(vf_map, single_functions) 243 | 244 | # Compute the signed evaluation for each variable node 245 | variable_eval = torch.mm(signed_vf_map, single_functions) 246 | 247 | # Detect and de-activate the UNSAT examples 248 | conflict_variables = (variable_eval.abs() != input_num).float() * self._active_variables 249 | if torch.sum(conflict_variables) > 0: 250 | 251 | # Detect the UNSAT examples 252 | unsat_examples = torch.mm(b_variable_mask_transpose, conflict_variables) 253 | self._is_sat[unsat_examples[:, 0] >= 1] = 0 254 | 255 | # De-activate the function nodes related to unsat examples 256 | unsat_functions = torch.mm(b_function_mask, unsat_examples) * self._active_functions 257 | self._active_functions[unsat_functions[:, 0] == 1, 0] = 0 258 | 259 | # De-activate the variable nodes related to unsat examples 260 | unsat_variables = torch.mm(b_variable_mask, unsat_examples) * self._active_variables 261 | self._active_variables[unsat_variables[:, 0] == 1, 0] = 0 262 | 263 | # Compute the assigned variables 264 | assigned_variables = (variable_eval.abs() == input_num).float() * self._active_variables 265 | 266 | # Compute the variable assignment 267 | assignment = torch.sign(variable_eval) * assigned_variables 268 | 269 | # De-activate single functions 270 | self._active_functions[single_functions[:, 0] == 1, 0] = 0 271 | 272 | # Set the corresponding variables 273 | self._set_variable_core(assignment) 274 | 275 | def set_variables(self, assignment): 276 | "Fixes variables to certain binary values and simplifies the CNF accordingly." 277 | 278 | self._set_variable_core(assignment) 279 | self.simplify() 280 | 281 | def simplify(self): 282 | "Simplifies the CNF." 283 | 284 | self._propagate_single_clauses() 285 | self._peel() 286 | 287 | 288 | ############################################################### 289 | ### The Solver Classes 290 | ############################################################### 291 | 292 | 293 | class PropagatorDecimatorSolverBase(nn.Module): 294 | "The base class for all PDP SAT solvers." 295 | 296 | def __init__(self, device, name, propagator, decimator, predictor, local_search_iterations=0, epsilon=0.05): 297 | 298 | super(PropagatorDecimatorSolverBase, self).__init__() 299 | self._device = device 300 | self._module_list = nn.ModuleList() 301 | 302 | self._propagator = propagator 303 | self._decimator = decimator 304 | self._predictor = predictor 305 | 306 | self._module_list.append(self._propagator) 307 | self._module_list.append(self._decimator) 308 | self._module_list.append(self._predictor) 309 | 310 | self._global_step = nn.Parameter(torch.tensor([0], dtype=torch.float, device=self._device), requires_grad=False) 311 | self._name = name 312 | self._local_search_iterations = local_search_iterations 313 | self._epsilon = epsilon 314 | 315 | def parameter_count(self): 316 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 317 | 318 | def save(self, export_path_base): 319 | torch.save(self.state_dict(), os.path.join(export_path_base, self._name)) 320 | 321 | def load(self, import_path_base): 322 | self.load_state_dict(torch.load(os.path.join(import_path_base, self._name))) 323 | 324 | def forward(self, init_state, 325 | graph_map, batch_variable_map, batch_function_map, edge_feature, 326 | meta_data, is_training=True, iteration_num=1, check_termination=None, simplify=True, batch_replication=1): 327 | 328 | init_propagator_state, init_decimator_state = init_state 329 | batch_replication = 1 if is_training else batch_replication 330 | sat_problem = SATProblem((graph_map, batch_variable_map, batch_function_map, edge_feature, meta_data, None), self._device, batch_replication) 331 | 332 | if simplify and not is_training: 333 | sat_problem.simplify() 334 | 335 | if self._propagator is not None and self._decimator is not None: 336 | propagator_state, decimator_state = self._forward_core(init_propagator_state, init_decimator_state, 337 | sat_problem, iteration_num, is_training, check_termination) 338 | else: 339 | decimator_state = None 340 | propagator_state = None 341 | 342 | prediction = self._predictor(decimator_state, sat_problem, True) 343 | 344 | # Post-processing local search 345 | if not is_training: 346 | prediction = self._local_search(prediction, sat_problem, batch_replication) 347 | 348 | prediction = self._update_solution(prediction, sat_problem) 349 | 350 | if batch_replication > 1: 351 | prediction, propagator_state, decimator_state = self._deduplicate(prediction, propagator_state, decimator_state, sat_problem) 352 | 353 | return (prediction, (propagator_state, decimator_state)) 354 | 355 | def _forward_core(self, init_propagator_state, init_decimator_state, sat_problem, iteration_num, is_training, check_termination): 356 | 357 | propagator_state = init_propagator_state 358 | decimator_state = init_decimator_state 359 | 360 | if check_termination is None: 361 | active_mask = None 362 | else: 363 | active_mask = torch.ones(sat_problem._batch_size, 1, dtype=torch.uint8, device=self._device) 364 | 365 | for _ in torch.arange(iteration_num, dtype=torch.int32, device=self._device): 366 | 367 | propagator_state = self._propagator(propagator_state, decimator_state, sat_problem, is_training, active_mask) 368 | decimator_state = self._decimator(decimator_state, propagator_state, sat_problem, is_training, active_mask) 369 | 370 | sat_problem._edge_mask = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables) * \ 371 | torch.mm(sat_problem._graph_mask_tuple[3], sat_problem._active_functions) 372 | 373 | if sat_problem._edge_mask.sum() < sat_problem._edge_num: 374 | decimator_state += (sat_problem._edge_mask,) 375 | 376 | if check_termination is not None: 377 | prediction = self._predictor(decimator_state, sat_problem) 378 | prediction = self._update_solution(prediction, sat_problem) 379 | 380 | check_termination(active_mask, prediction, sat_problem) 381 | num_active = active_mask.sum() 382 | 383 | if num_active <= 0: 384 | break 385 | 386 | return propagator_state, decimator_state 387 | 388 | def _update_solution(self, prediction, sat_problem): 389 | "Updates the the SAT problem object's solution according to the cuerrent prediction." 390 | 391 | if prediction[0] is not None: 392 | variable_solution = sat_problem._active_variables * prediction[0] + \ 393 | (1.0 - sat_problem._active_variables) * sat_problem._solution.unsqueeze(1) 394 | sat_problem._solution[sat_problem._active_variables[:, 0] == 1] = \ 395 | variable_solution[sat_problem._active_variables[:, 0] == 1, 0] 396 | else: 397 | variable_solution = None 398 | 399 | return variable_solution, prediction[1] 400 | 401 | def _deduplicate(self, prediction, propagator_state, decimator_state, sat_problem): 402 | "De-duplicates the current batch (to neutralize the batch replication) by finding the replica with minimum energy for each problem instance. " 403 | 404 | if sat_problem._batch_replication <= 1 or sat_problem._replication_mask_tuple is None: 405 | return None, None, None 406 | 407 | assignment = 2 * prediction[0] - 1.0 408 | energy, _ = self._compute_energy(assignment, sat_problem) 409 | max_ind = util.sparse_argmax(-energy.squeeze(1), sat_problem._replication_mask_tuple[0], device=self._device) 410 | 411 | batch_flag = torch.zeros(sat_problem._batch_size, 1, device=self._device) 412 | batch_flag[max_ind, 0] = 1 413 | 414 | flag = torch.mm(sat_problem._batch_mask_tuple[0], batch_flag) 415 | variable_prediction = (flag * prediction[0]).view(sat_problem._batch_replication, -1).sum(dim=0).unsqueeze(1) 416 | 417 | flag = torch.mm(sat_problem._graph_mask_tuple[1], flag) 418 | new_propagator_state = () 419 | for x in propagator_state: 420 | new_propagator_state += ((flag * x).view(sat_problem._batch_replication, sat_problem._edge_num / sat_problem._batch_replication, -1).sum(dim=0),) 421 | 422 | new_decimator_state = () 423 | for x in decimator_state: 424 | new_decimator_state += ((flag * x).view(sat_problem._batch_replication, sat_problem._edge_num / sat_problem._batch_replication, -1).sum(dim=0),) 425 | 426 | function_prediction = None 427 | if prediction[1] is not None: 428 | flag = torch.mm(sat_problem._batch_mask_tuple[2], batch_flag) 429 | function_prediction = (flag * prediction[1]).view(sat_problem._batch_replication, -1).sum(dim=0).unsqueeze(1) 430 | 431 | return (variable_prediction, function_prediction), new_propagator_state, new_decimator_state 432 | 433 | def _local_search(self, prediction, sat_problem, batch_replication): 434 | "Implements the Walk-SAT algorithm for post-processing." 435 | 436 | assignment = (prediction[0] > 0.5).float() 437 | assignment = sat_problem._active_variables * (2*assignment - 1.0) 438 | 439 | sat_problem._edge_mask = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables) * \ 440 | torch.mm(sat_problem._graph_mask_tuple[3], sat_problem._active_functions) 441 | 442 | for _ in range(self._local_search_iterations): 443 | unsat_examples, unsat_functions = self._compute_energy(assignment, sat_problem) 444 | unsat_examples = (unsat_examples > 0).float() 445 | 446 | if batch_replication > 1: 447 | compact_unsat_examples = 1 - (torch.mm(sat_problem._replication_mask_tuple[1], 1 - unsat_examples) > 0).float() 448 | if compact_unsat_examples.sum() == 0: 449 | break 450 | elif unsat_examples.sum() == 0: 451 | break 452 | 453 | delta_energy = self._compute_energy_diff(assignment, sat_problem) 454 | max_delta_ind = util.sparse_argmax(-delta_energy.squeeze(1), sat_problem._batch_mask_tuple[0], device=self._device) 455 | 456 | unsat_variables = torch.mm(sat_problem._vf_mask_tuple[0], unsat_functions) * sat_problem._active_variables 457 | unsat_variables = (unsat_variables > 0).float() * torch.rand([sat_problem._variable_num, 1], device=self._device) 458 | random_ind = util.sparse_argmax(unsat_variables.squeeze(1), sat_problem._batch_mask_tuple[0], device=self._device) 459 | 460 | coin = (torch.rand(sat_problem._batch_size, device=self._device) > self._epsilon).long() 461 | max_ind = coin * max_delta_ind + (1 - coin) * random_ind 462 | max_ind = max_ind[unsat_examples[:, 0] > 0] 463 | 464 | # Flipping the selected variables 465 | assignment[max_ind, 0] = -assignment[max_ind, 0] 466 | 467 | return (assignment + 1) / 2.0, prediction[1] 468 | 469 | def _compute_energy_diff(self, assignment, sat_problem): 470 | "Computes the delta energy if each variable to be flipped during the local search." 471 | 472 | distributed_assignment = torch.mm(sat_problem._signed_mask_tuple[1], assignment * sat_problem._active_variables) 473 | aggregated_assignment = torch.mm(sat_problem._graph_mask_tuple[2], distributed_assignment) 474 | aggregated_assignment = torch.mm(sat_problem._graph_mask_tuple[3], aggregated_assignment) 475 | aggregated_assignment = aggregated_assignment - distributed_assignment 476 | 477 | function_degree = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables) 478 | function_degree = torch.mm(sat_problem._graph_mask_tuple[2], function_degree) 479 | function_degree = torch.mm(sat_problem._graph_mask_tuple[3], function_degree) 480 | 481 | critical_edges = (aggregated_assignment == (1 - function_degree)).float() * sat_problem._edge_mask 482 | delta = torch.mm(sat_problem._graph_mask_tuple[0], critical_edges * distributed_assignment) 483 | 484 | return delta 485 | 486 | def _compute_energy(self, assignment, sat_problem): 487 | "Computes the energy of each CNF instance present in the batch." 488 | 489 | aggregated_assignment = torch.mm(sat_problem._signed_mask_tuple[1], assignment * sat_problem._active_variables) 490 | aggregated_assignment = torch.mm(sat_problem._graph_mask_tuple[2], aggregated_assignment) 491 | 492 | function_degree = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables) 493 | function_degree = torch.mm(sat_problem._graph_mask_tuple[2], function_degree) 494 | 495 | unsat_functions = (aggregated_assignment == -function_degree).float() * sat_problem._active_functions 496 | return torch.mm(sat_problem._batch_mask_tuple[3], unsat_functions), unsat_functions 497 | 498 | def get_init_state(self, graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication=1): 499 | "Initializes the propgator and the decimator messages in each direction." 500 | 501 | if self._propagator is None: 502 | init_propagator_state = None 503 | else: 504 | init_propagator_state = self._propagator.get_init_state(graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication) 505 | 506 | if self._decimator is None: 507 | init_decimator_state = None 508 | else: 509 | init_decimator_state = self._decimator.get_init_state(graph_map, batch_variable_map, batch_function_map, edge_feature, graph_feat, randomized, batch_replication) 510 | 511 | return init_propagator_state, init_decimator_state 512 | 513 | 514 | ############################################################### 515 | 516 | 517 | class NeuralPropagatorDecimatorSolver(PropagatorDecimatorSolverBase): 518 | "Implements a fully neural PDP SAT solver with both the propagator and the decimator being neural." 519 | 520 | def __init__(self, device, name, edge_dimension, meta_data_dimension, 521 | propagator_dimension, decimator_dimension, 522 | mem_hidden_dimension, agg_hidden_dimension, mem_agg_hidden_dimension, prediction_dimension, 523 | variable_classifier=None, function_classifier=None, dropout=0, 524 | local_search_iterations=0, epsilon=0.05): 525 | 526 | super(NeuralPropagatorDecimatorSolver, self).__init__( 527 | device=device, name=name, 528 | propagator=pdp_propagate.NeuralMessagePasser(device, edge_dimension, decimator_dimension, 529 | meta_data_dimension, propagator_dimension, mem_hidden_dimension, 530 | mem_agg_hidden_dimension, agg_hidden_dimension, dropout), 531 | decimator=pdp_decimate.NeuralDecimator(device, propagator_dimension, meta_data_dimension, 532 | decimator_dimension, mem_hidden_dimension, 533 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, dropout), 534 | predictor=pdp_predict.NeuralPredictor(device, decimator_dimension, prediction_dimension, 535 | edge_dimension, meta_data_dimension, mem_hidden_dimension, agg_hidden_dimension, 536 | mem_agg_hidden_dimension, variable_classifier, function_classifier), 537 | local_search_iterations=local_search_iterations, epsilon=epsilon) 538 | 539 | 540 | ############################################################### 541 | 542 | 543 | class NeuralSurveyPropagatorSolver(PropagatorDecimatorSolverBase): 544 | "Implements a PDP solver with the SP propgator and a neural decimator." 545 | 546 | def __init__(self, device, name, edge_dimension, meta_data_dimension, 547 | decimator_dimension, 548 | mem_hidden_dimension, agg_hidden_dimension, mem_agg_hidden_dimension, prediction_dimension, 549 | variable_classifier=None, function_classifier=None, dropout=0, 550 | local_search_iterations=0, epsilon=0.05): 551 | 552 | super(NeuralSurveyPropagatorSolver, self).__init__( 553 | device=device, name=name, 554 | propagator=pdp_propagate.SurveyPropagator(device, decimator_dimension, include_adaptors=True), 555 | decimator=pdp_decimate.NeuralDecimator(device, (3, 1), meta_data_dimension, 556 | decimator_dimension, mem_hidden_dimension, 557 | mem_agg_hidden_dimension, agg_hidden_dimension, edge_dimension, dropout), 558 | predictor=pdp_predict.NeuralPredictor(device, decimator_dimension, prediction_dimension, 559 | edge_dimension, meta_data_dimension, mem_hidden_dimension, agg_hidden_dimension, 560 | mem_agg_hidden_dimension, variable_classifier, function_classifier), 561 | local_search_iterations=local_search_iterations, epsilon=epsilon) 562 | 563 | 564 | ############################################################### 565 | 566 | 567 | class SurveyPropagatorSolver(PropagatorDecimatorSolverBase): 568 | "Implements the classical SP-guided decimation solver via the PDP framework." 569 | 570 | def __init__(self, device, name, tolerance, t_max, local_search_iterations=0, epsilon=0.05): 571 | 572 | super(SurveyPropagatorSolver, self).__init__( 573 | device=device, name=name, 574 | propagator=pdp_propagate.SurveyPropagator(device, decimator_dimension=1, include_adaptors=False), 575 | decimator=pdp_decimate.SequentialDecimator(device, message_dimension=(3, 1), 576 | scorer=pdp_predict.SurveyScorer(device, message_dimension=1, include_adaptors=False), tolerance=tolerance, t_max=t_max), 577 | predictor=pdp_predict.IdentityPredictor(device=device, random_fill=True), 578 | local_search_iterations=local_search_iterations, epsilon=epsilon) 579 | 580 | 581 | ############################################################### 582 | 583 | 584 | class WalkSATSolver(PropagatorDecimatorSolverBase): 585 | "Implements the classical Walk-SAT solver via the PDP framework." 586 | 587 | def __init__(self, device, name, iteration_num, epsilon=0.05): 588 | 589 | super(WalkSATSolver, self).__init__( 590 | device=device, name=name, propagator=None, decimator=None, 591 | predictor=pdp_predict.IdentityPredictor(device=device, random_fill=True), 592 | local_search_iterations=iteration_num, epsilon=epsilon) 593 | 594 | 595 | ############################################################### 596 | 597 | 598 | class ReinforceSurveyPropagatorSolver(PropagatorDecimatorSolverBase): 599 | "Implements the classical Reinforce solver via the PDP framework." 600 | 601 | def __init__(self, device, name, pi=0.1, decimation_probability=0.5, local_search_iterations=0, epsilon=0.05): 602 | 603 | super(ReinforceSurveyPropagatorSolver, self).__init__( 604 | device=device, name=name, 605 | propagator=pdp_propagate.SurveyPropagator(device, decimator_dimension=1, include_adaptors=False, pi=pi), 606 | decimator=pdp_decimate.ReinforceDecimator(device, 607 | scorer=pdp_predict.SurveyScorer(device, message_dimension=1, include_adaptors=False, pi=pi), 608 | decimation_probability=decimation_probability), 609 | predictor=pdp_predict.ReinforcePredictor(device=device), 610 | local_search_iterations=local_search_iterations, epsilon=epsilon) 611 | 612 | 613 | ############################################################### 614 | 615 | 616 | class NeuralSequentialDecimatorSolver(PropagatorDecimatorSolverBase): 617 | "Implements a PDP solver with a neural propgator and the sequential decimator." 618 | 619 | def __init__(self, device, name, edge_dimension, meta_data_dimension, 620 | propagator_dimension, decimator_dimension, 621 | mem_hidden_dimension, agg_hidden_dimension, mem_agg_hidden_dimension, 622 | classifier_dimension, dropout, tolerance, t_max, 623 | local_search_iterations=0, epsilon=0.05): 624 | 625 | super(NeuralSequentialDecimatorSolver, self).__init__( 626 | device=device, name=name, 627 | propagator=pdp_propagate.NeuralMessagePasser(device, edge_dimension, decimator_dimension, 628 | meta_data_dimension, propagator_dimension, mem_hidden_dimension, 629 | mem_agg_hidden_dimension, agg_hidden_dimension, dropout), 630 | decimator=pdp_decimate.SequentialDecimator(device, message_dimension=(3, 1), 631 | scorer=pdp_predict.NeuralPredictor(device, decimator_dimension, 1, 632 | edge_dimension, meta_data_dimension, mem_hidden_dimension, agg_hidden_dimension, 633 | mem_agg_hidden_dimension, variable_classifier=util.PerceptronTanh(decimator_dimension, 634 | classifier_dimension, 1), function_classifier=None), 635 | tolerance=tolerance, t_max=t_max), 636 | predictor=pdp_predict.IdentityPredictor(device=device, random_fill=True), 637 | local_search_iterations=local_search_iterations, epsilon=epsilon) 638 | --------------------------------------------------------------------------------