├── src ├── demo.py ├── models │ ├── __init__.py │ ├── gated_sum_conv.py │ ├── mlp_aggr.py │ ├── deepset_conv.py │ ├── gcn_conv.py │ ├── mlp.py │ ├── tfmlp.py │ ├── convgnn.py │ ├── aggnmlp.py │ ├── layernorm_gru.py │ ├── model.py │ ├── layernorm_lstm.py │ ├── gat_conv.py │ ├── mlpgate_merge.py │ ├── vectorgate.py │ ├── mlpgate.py │ ├── losses.py │ ├── dag_convgnn.py │ └── recgnn.py ├── trains │ ├── __init__.py │ ├── recgnn.py │ ├── convgnn.py │ ├── train_factory.py │ ├── base_trainer.py │ └── mlpgnn_trainer.py ├── utils │ ├── __init__.py │ ├── random_seed.py │ ├── logger.py │ ├── data_utils.py │ ├── dag_utils.py │ ├── utils.py │ ├── batch.py │ ├── aiger_utils.py │ └── sat_utils.py ├── datasets │ ├── __init__.py │ ├── dataset_factory.py │ ├── ordered_data.py │ ├── circuit_dataset.py │ ├── mig_dataset.py │ └── mlpgate_dataset.py ├── detectors │ ├── __init__.py │ ├── detector_factory.py │ └── base_detector.py ├── .DS_Store ├── reset_pth.py ├── get_emb_aig.py ├── test_acc_bin.py ├── get_emb_bench.py ├── prepare_dataset.py └── main.py ├── .gitignore ├── dataset └── EPFL │ ├── adder.aig │ ├── bar.aig │ ├── cavlc.aig │ ├── ctrl.aig │ ├── dec.aig │ ├── div.aig │ ├── hyp.aig │ ├── i2c.aig │ ├── log2.aig │ ├── max.aig │ ├── sin.aig │ ├── sqrt.aig │ ├── voter.aig │ ├── arbiter.aig │ ├── router.aig │ ├── square.aig │ ├── int2float.aig │ ├── mem_ctrl.aig │ ├── multiplier.aig │ └── priority.aig ├── run ├── stage2_train.sh ├── stage1_train.sh └── test.sh ├── requirements.txt └── README.md /src/demo.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/trains/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | emb/* 2 | exp/* 3 | data/* 4 | 5 | *__pycache__ 6 | *.pyc -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/src/.DS_Store -------------------------------------------------------------------------------- /dataset/EPFL/adder.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/adder.aig -------------------------------------------------------------------------------- /dataset/EPFL/bar.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/bar.aig -------------------------------------------------------------------------------- /dataset/EPFL/cavlc.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/cavlc.aig -------------------------------------------------------------------------------- /dataset/EPFL/ctrl.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/ctrl.aig -------------------------------------------------------------------------------- /dataset/EPFL/dec.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/dec.aig -------------------------------------------------------------------------------- /dataset/EPFL/div.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/div.aig -------------------------------------------------------------------------------- /dataset/EPFL/hyp.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/hyp.aig -------------------------------------------------------------------------------- /dataset/EPFL/i2c.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/i2c.aig -------------------------------------------------------------------------------- /dataset/EPFL/log2.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/log2.aig -------------------------------------------------------------------------------- /dataset/EPFL/max.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/max.aig -------------------------------------------------------------------------------- /dataset/EPFL/sin.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/sin.aig -------------------------------------------------------------------------------- /dataset/EPFL/sqrt.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/sqrt.aig -------------------------------------------------------------------------------- /dataset/EPFL/voter.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/voter.aig -------------------------------------------------------------------------------- /dataset/EPFL/arbiter.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/arbiter.aig -------------------------------------------------------------------------------- /dataset/EPFL/router.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/router.aig -------------------------------------------------------------------------------- /dataset/EPFL/square.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/square.aig -------------------------------------------------------------------------------- /dataset/EPFL/int2float.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/int2float.aig -------------------------------------------------------------------------------- /dataset/EPFL/mem_ctrl.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/mem_ctrl.aig -------------------------------------------------------------------------------- /dataset/EPFL/multiplier.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/multiplier.aig -------------------------------------------------------------------------------- /dataset/EPFL/priority.aig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cure-lab/DeepGate2/HEAD/dataset/EPFL/priority.aig -------------------------------------------------------------------------------- /src/detectors/detector_factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .base_detector import BaseDetector 6 | 7 | detector_factory = { 8 | 'base': BaseDetector 9 | } -------------------------------------------------------------------------------- /src/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .circuit_dataset import CircuitDataset 6 | 7 | dataset_factory = { 8 | 'benchmarks': CircuitDataset, 9 | 'random': CircuitDataset 10 | } -------------------------------------------------------------------------------- /src/trains/recgnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .base_trainer import BaseTrainer 6 | 7 | class RecGNNTrainer(BaseTrainer): 8 | def __init__(self, opt, model, optimizer=None): 9 | super(RecGNNTrainer, self).__init__(opt, model, optimizer=optimizer) -------------------------------------------------------------------------------- /src/trains/convgnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .base_trainer import BaseTrainer 6 | 7 | class ConvGNNTrainer(BaseTrainer): 8 | def __init__(self, opt, model, optimizer=None): 9 | super(ConvGNNTrainer, self).__init__(opt, model, optimizer=optimizer) -------------------------------------------------------------------------------- /run/stage2_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=2 3 | GPUS=0,1 4 | 5 | cd src 6 | shift 7 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC ./main.py prob \ 8 | --exp_id train \ 9 | --data_dir ../data/train \ 10 | --reg_loss l1 --cls_loss bce \ 11 | --arch mlpgnn \ 12 | --Prob_weight 3 --RC_weight 1 --Func_weight 2 \ 13 | --num_rounds 1 \ 14 | --gpus ${GPUS} --batch_size 16 \ 15 | --resume 16 | 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.1 2 | matplotlib-inline==0.1.3 3 | pytorch-transformers==1.2.0 4 | scikit-learn==1.0.2 5 | scipy==1.8.0 6 | seaborn==0.12.2 7 | torch==1.9.1 8 | torch-cluster==1.5.9 9 | torch-geometric==2.0.2 10 | torch-scatter==2.0.8 11 | torch-sparse==0.6.12 12 | torch-spline-conv==1.2.1 13 | torch-tb-profiler==0.4.0 14 | torchsummary==1.5.1 15 | torchvision==0.10.0 16 | tqdm==4.32.1 17 | xlrd==2.0.1 18 | -------------------------------------------------------------------------------- /run/stage1_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=2 3 | GPUS=0,1 4 | 5 | cd src 6 | shift 7 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC ./main.py prob \ 8 | --exp_id train \ 9 | --data_dir ../data/train \ 10 | --reg_loss l1 --cls_loss bce \ 11 | --arch mlpgnn \ 12 | --Prob_weight 1 --RC_weight 0 --Func_weight 0 \ 13 | --num_rounds 1 \ 14 | --small_train \ 15 | --gpus ${GPUS} --batch_size 16 \ 16 | 17 | -------------------------------------------------------------------------------- /run/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | SPC_EXPID=train 3 | DATASET=test 4 | AGGR=tfmlp 5 | NO_ROUNDS=1 6 | TEST_SCRIPT=test_acc_bin 7 | 8 | GPU=-1 9 | 10 | cd src 11 | python3 ${TEST_SCRIPT}.py prob --exp_id ${SPC_EXPID} --spc_exp_id ${SPC_EXPID} \ 12 | --data_dir ../data/${DATASET} \ 13 | --num_rounds ${NO_ROUNDS} \ 14 | --reg_loss l1 --cls_loss bce \ 15 | --arch mlpgnn \ 16 | --aggr_function ${AGGR} \ 17 | --gpu ${GPU} --batch_size 1 \ 18 | --no_rc \ 19 | --resume 20 | -------------------------------------------------------------------------------- /src/utils/random_seed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | import random 6 | 7 | def set_seed(args): 8 | # fix randomseed for reproducing the results 9 | print('Setting random seed for reproductivity..') 10 | random_seed = args.random_seed 11 | torch.manual_seed(random_seed) 12 | torch.cuda.manual_seed(random_seed) 13 | np.random.seed(random_seed) 14 | random.seed(random_seed) 15 | os.environ['PYTHONHASHSEED'] = str(random_seed) 16 | cudnn.benchmark = not args.not_cuda_benchmark -------------------------------------------------------------------------------- /src/trains/train_factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | # from .recgnn import RecGNNTrainer 7 | # from .convgnn import ConvGNNTrainer 8 | from .base_trainer import BaseTrainer 9 | from .mlpgnn_trainer import MLPGNNTrainer 10 | 11 | train_factory = { 12 | # 'recgnn': RecGNNTrainer, 13 | 'recgnn': BaseTrainer, 14 | # 'convgnn': ConvGNNTrainer, 15 | 'convgnn': BaseTrainer, 16 | 'dagconvgnn': BaseTrainer, 17 | 'base': BaseTrainer, 18 | 'mlpgnn': MLPGNNTrainer, 19 | 'mlpgnn_merge': MLPGNNTrainer 20 | } -------------------------------------------------------------------------------- /src/models/gated_sum_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch_geometric.nn.glob import * 5 | from torch_geometric.nn import MessagePassing 6 | 7 | 8 | class GatedSumConv(MessagePassing): # dvae needs outdim parameter 9 | def __init__(self, in_channels, ouput_channels=None, wea=False, mlp=None, reverse=False, mapper=None, gate=None): 10 | super(GatedSumConv, self).__init__(aggr='add', flow='target_to_source' if reverse else 'source_to_target') 11 | if ouput_channels is None: 12 | ouput_channels = in_channels 13 | assert (in_channels > 0) and (ouput_channels > 0), 'The dimension for the Gated Sum should be larger than 0.' 14 | 15 | self.wea = wea 16 | 17 | self.mapper = nn.Linear(in_channels, ouput_channels) if mapper is None else mapper 18 | self.gate = nn.Sequential(nn.Linear(in_channels, ouput_channels), nn.Sigmoid()) if gate is None else gate 19 | 20 | def forward(self, x, edge_index, edge_attr=None, **kwargs): 21 | 22 | return self.propagate(edge_index, x=x, edge_attr=edge_attr) 23 | 24 | def message(self, x_j, edge_attr=None): 25 | if self.wea: 26 | h_j = torch.cat((x_j, edge_attr), dim=1) 27 | else: 28 | h_j = x_j 29 | return self.gate(h_j) * self.mapper(h_j) 30 | 31 | def update(self, aggr_out): 32 | return aggr_out 33 | -------------------------------------------------------------------------------- /src/datasets/ordered_data.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data 2 | 3 | class OrderedData(Data): 4 | def __init__(self, edge_index=None, x=None, y=None, \ 5 | tt_pair_index=None, tt_dis=None, min_tt_dis=None, \ 6 | rc_pair_index=None, is_rc=None, \ 7 | forward_level=None, forward_index=None, backward_level=None, backward_index=None): 8 | super().__init__() 9 | self.edge_index = edge_index 10 | self.tt_pair_index = tt_pair_index 11 | self.x = x 12 | self.y = y 13 | self.tt_dis = tt_dis 14 | self.min_tt_dis = min_tt_dis 15 | self.forward_level = forward_level 16 | self.forward_index = forward_index 17 | self.backward_level = backward_level 18 | self.backward_index = backward_index 19 | self.rc_pair_index = rc_pair_index 20 | self.is_rc = is_rc 21 | 22 | def __inc__(self, key, value, *args, **kwargs): 23 | if 'index' in key or 'face' in key: 24 | return self.num_nodes 25 | else: 26 | return 0 27 | 28 | def __cat_dim__(self, key, value, *args, **kwargs): 29 | if key == 'forward_index' or key == 'backward_index': 30 | return 0 31 | elif key == "edge_index" or key == 'tt_pair_index' or key == 'rc_pair_index': 32 | return 1 33 | else: 34 | return 0 35 | -------------------------------------------------------------------------------- /src/models/mlp_aggr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | 4 | from .mlp import MLP 5 | 6 | class MlpAggr(MessagePassing): 7 | ''' 8 | The message propagation methods described in NeuroSAT (2 layers without dropout) and CircuitSAT (2 layers, dim = 50, dropout - 20%). 9 | Cite from NeuroSAT: 10 | `we sum the outgoing messages of each of a node’s neighbors to form the incoming message.` 11 | ''' 12 | def __init__(self, in_channels, mlp_channels=512, ouput_channels=64, num_layer=3, p_drop=0.2, act_layer=None, norm_layer=None, reverse=False, mlp_post=None): 13 | super(MlpAggr, self).__init__(aggr='add', flow='target_to_source' if reverse else 'source_to_target') 14 | if ouput_channels is None: 15 | ouput_channels = in_channels 16 | assert (in_channels > 0) and (ouput_channels > 0), 'The dimension for the DeepSetConv should be larger than 0.' 17 | 18 | self.msg = MLP(in_channels, mlp_channels, ouput_channels, 19 | num_layer=num_layer, p_drop=p_drop, act_layer=act_layer, norm_layer=norm_layer) 20 | self.msg_post = None if mlp_post is None else mlp_post 21 | 22 | 23 | def forward(self, x, edge_index, edge_attr=None, **kwargs): 24 | # x has shape [N, in_channels] 25 | # edge_index has shape [2, E] 26 | 27 | return self.propagate(edge_index, x=x, edge_attr=edge_attr) 28 | 29 | def message(self, x_j): 30 | # x_j has shape [E, dim_emb] 31 | return self.msg(x_j) 32 | 33 | def update(self, aggr_out): 34 | if self.msg_post is not None: 35 | return self.msg_post(aggr_out) 36 | else: 37 | return aggr_out 38 | 39 | 40 | -------------------------------------------------------------------------------- /src/reset_pth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import torch 7 | from torch_geometric import data 8 | 9 | from config import get_parse_args 10 | from models.model import create_model, load_model, save_model 11 | from utils.logger import Logger 12 | from utils.random_seed import set_seed 13 | from utils.circuit_utils import check_difference 14 | from trains.train_factory import train_factory 15 | 16 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 17 | 18 | src_pth_filepath = 'exp/prob/aggr_exp_deepset/model_stage1.pth' 19 | dst_pth_filepath = 'exp/prob/aggr_exp_deepset/model_last.pth' 20 | 21 | def main(args): 22 | ################# 23 | # Device 24 | ################# 25 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str 26 | args.device = torch.device('cpu') 27 | args.world_size = 1 28 | args.rank = 0 # global rank 29 | 30 | ################# 31 | # Model 32 | ################# 33 | model = create_model(args) 34 | if args.local_rank == 0: 35 | print('==> Creating model...') 36 | print(model) 37 | 38 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 39 | model, optimizer, start_epoch = load_model( 40 | model, src_pth_filepath, optimizer, args.resume, args.lr, args.lr_step, args.local_rank, args.device) 41 | 42 | for param_group in optimizer.param_groups: 43 | param_group['lr'] = 1e-4 44 | 45 | 46 | save_model(dst_pth_filepath, 0, model, optimizer) 47 | print('Load: ', src_pth_filepath) 48 | print('Save: ', dst_pth_filepath) 49 | 50 | if __name__ == '__main__': 51 | args = get_parse_args() 52 | 53 | main(args) 54 | -------------------------------------------------------------------------------- /src/models/deepset_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | 4 | from .mlp import MLP 5 | 6 | class DeepSetConv(MessagePassing): 7 | ''' 8 | The message propagation methods described in NeuroSAT (2 layers without dropout) and CircuitSAT (2 layers, dim = 50, dropout - 20%). 9 | Cite from NeuroSAT: 10 | `we sum the outgoing messages of each of a node’s neighbors to form the incoming message.` 11 | ''' 12 | def __init__(self, in_channels, ouput_channels=None, wea=False, mlp=None, reverse=False, mlp_post=None): 13 | super(DeepSetConv, self).__init__(aggr='add', flow='target_to_source' if reverse else 'source_to_target') 14 | if ouput_channels is None: 15 | ouput_channels = in_channels 16 | assert (in_channels > 0) and (ouput_channels > 0), 'The dimension for the DeepSetConv should be larger than 0.' 17 | 18 | self.wea = wea 19 | 20 | self.msg = MLP(in_channels, ouput_channels, ouput_channels, num_layer=3, p_drop=0.2) if mlp is None else mlp 21 | self.msg_post = None if mlp_post is None else mlp_post 22 | 23 | 24 | def forward(self, x, edge_index, edge_attr=None, **kwargs): 25 | # x has shape [N, in_channels] 26 | # edge_index has shape [2, E] 27 | 28 | return self.propagate(edge_index, x=x, edge_attr=edge_attr) 29 | 30 | def message(self, x_j, edge_attr=None): 31 | # TODO: add the normalization part like AggConv 32 | # x_j has shape [E, dim_emb] 33 | if self.wea: 34 | return self.msg(torch.cat((x_j, edge_attr), dim=1)) 35 | else: 36 | return self.msg(x_j) 37 | 38 | def update(self, aggr_out): 39 | if self.msg_post is not None: 40 | return self.msg_post(aggr_out) 41 | else: 42 | return aggr_out 43 | 44 | 45 | -------------------------------------------------------------------------------- /src/models/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Optional 5 | from torch import Tensor 6 | import torch_geometric as tg 7 | from torch_geometric.typing import OptTensor 8 | from torch_geometric.utils import softmax 9 | from torch_geometric.utils import add_self_loops, degree 10 | from torch_scatter import scatter_add 11 | from torch_geometric.nn.glob import * 12 | from torch_geometric.nn import MessagePassing 13 | 14 | 15 | class AggConv(MessagePassing): 16 | ''' 17 | Modified based on GCNConv implementation in PyG. 18 | https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv 19 | ''' 20 | def __init__(self, in_channels, ouput_channels=None, wea=False, mlp=None, reverse=False): 21 | super().__init__(aggr='add', flow='target_to_source' if reverse else 'source_to_target') # "Add" aggregation (Step 5). 22 | if ouput_channels is None: 23 | ouput_channels = in_channels 24 | assert (in_channels > 0) and (ouput_channels > 0), 'The dimension for the AggConv should be larger than 0.' 25 | 26 | self.wea = wea 27 | 28 | self.msg = nn.Linear(in_channels, ouput_channels) if mlp is None else mlp 29 | 30 | def forward(self, x, edge_index, edge_attr=None, **kwargs): 31 | # x has shape [N, in_channels] 32 | # edge_index has shape [2, E] 33 | 34 | return self.propagate(edge_index, x=x, edge_attr=edge_attr) 35 | 36 | def message(self, x_j, edge_attr=None): 37 | # TODO: add the normalization part like AggConv 38 | # x_j has shape [E, dim_emb] 39 | if self.wea: 40 | return self.msg(torch.cat((x_j, edge_attr), dim=1)) 41 | else: 42 | return self.msg(x_j) 43 | 44 | def update(self, aggr_out): 45 | return aggr_out -------------------------------------------------------------------------------- /src/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | _norm_layer_factory = { 5 | 'batchnorm': nn.BatchNorm1d, 6 | } 7 | 8 | _act_layer_factory = { 9 | 'relu': nn.ReLU, 10 | 'relu6': nn.ReLU6, 11 | 'sigmoid': nn.Sigmoid, 12 | } 13 | 14 | class MLP(nn.Module): 15 | def __init__(self, dim_in=256, dim_hidden=32, dim_pred=1, num_layer=3, norm_layer=None, act_layer=None, p_drop=0.5, sigmoid=False, tanh=False): 16 | super(MLP, self).__init__() 17 | ''' 18 | The basic structure is refered from 19 | ''' 20 | assert num_layer >= 2, 'The number of layers shoud be larger or equal to 2.' 21 | if norm_layer in _norm_layer_factory.keys(): 22 | self.norm_layer = _norm_layer_factory[norm_layer] 23 | if act_layer in _act_layer_factory.keys(): 24 | self.act_layer = _act_layer_factory[act_layer] 25 | if p_drop > 0: 26 | self.dropout = nn.Dropout 27 | 28 | fc = [] 29 | # 1st layer 30 | fc.append(nn.Linear(dim_in, dim_hidden)) 31 | if norm_layer: 32 | fc.append(self.norm_layer(dim_hidden)) 33 | if act_layer: 34 | fc.append(self.act_layer(inplace=True)) 35 | if p_drop > 0: 36 | fc.append(self.dropout(p_drop)) 37 | for _ in range(num_layer - 2): 38 | fc.append(nn.Linear(dim_hidden, dim_hidden)) 39 | if norm_layer: 40 | fc.append(self.norm_layer(dim_hidden)) 41 | if act_layer: 42 | fc.append(self.act_layer(inplace=True)) 43 | if p_drop > 0: 44 | fc.append(self.dropout(p_drop)) 45 | # last layer 46 | fc.append(nn.Linear(dim_hidden, dim_pred)) 47 | # sigmoid 48 | if sigmoid: 49 | fc.append(nn.Sigmoid()) 50 | if tanh: 51 | fc.append(nn.Tanh()) 52 | self.fc = nn.Sequential(*fc) 53 | 54 | def forward(self, x): 55 | out = self.fc(x) 56 | return out -------------------------------------------------------------------------------- /src/models/tfmlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | import torch.nn as nn 4 | 5 | from typing import Optional 6 | from torch import Tensor 7 | from torch_geometric.utils import softmax 8 | from torch_geometric.typing import Adj, OptTensor 9 | from .mlp import MLP 10 | 11 | class TFMLP(MessagePassing): 12 | ''' 13 | The message propagation methods described in NeuroSAT (2 layers without dropout) and CircuitSAT (2 layers, dim = 50, dropout - 20%). 14 | Cite from NeuroSAT: 15 | `we sum the outgoing messages of each of a node’s neighbors to form the incoming message.` 16 | ''' 17 | def __init__(self, in_channels, ouput_channels=64, reverse=False, mlp_post=None): 18 | super(TFMLP, self).__init__(aggr='add', flow='target_to_source' if reverse else 'source_to_target') 19 | if ouput_channels is None: 20 | ouput_channels = in_channels 21 | assert (in_channels > 0) and (ouput_channels > 0), 'The dimension for the DeepSetConv should be larger than 0.' 22 | 23 | self.msg_post = None if mlp_post is None else mlp_post 24 | self.attn_lin = nn.Linear(ouput_channels + ouput_channels, 1) 25 | 26 | self.msg_q = nn.Linear(in_channels, ouput_channels) 27 | self.msg_k = nn.Linear(in_channels, ouput_channels) 28 | self.msg_v = nn.Linear(in_channels, ouput_channels) 29 | 30 | 31 | def forward(self, x, edge_index, edge_attr=None, **kwargs): 32 | # x has shape [N, in_channels] 33 | # edge_index has shape [2, E] 34 | 35 | return self.propagate(edge_index, x=x, edge_attr=edge_attr) 36 | 37 | def message(self, x_i, x_j, edge_attr, index: Tensor, ptr: OptTensor, size_i: Optional[int]): 38 | # h_i: query, h_j: key 39 | h_attn_q_i = self.msg_q(x_i) 40 | h_attn = self.msg_k(x_j) 41 | # see comment in above self attention why this is done here and not in forward 42 | a_j = self.attn_lin(torch.cat([h_attn_q_i, h_attn], dim=-1)) 43 | a_j = softmax(a_j, index, ptr, size_i) 44 | # x_j -> value 45 | t = self.msg_v(x_j) * a_j 46 | return t 47 | 48 | def update(self, aggr_out): 49 | if self.msg_post is not None: 50 | return self.msg_post(aggr_out) 51 | else: 52 | return aggr_out 53 | 54 | 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepGate2: Functionality-Aware Circuit Representation Learning 2 | 3 | Code repository for the paper: 4 | **DeepGate2: Functionality-Aware Circuit Representation Learning**, 5 | 6 | Zhengyuan Shi, Hongyang Pan, Sadaf Khan, Min Li, Yi Liu, Junhua Huang, Hui-Ling Zhen, Mingxuan Yuan, Zhufei Chu and Qiang Xu 7 | 8 | ## Abstract 9 | Circuit representation learning aims to obtain neural representations of circuit elements and has emerged as a promising research direction that can be applied to various EDA and logic reasoning tasks. Existing solutions, such as DeepGate, have the potential to embed both circuit structural information and functional behavior. However, their capabilities are limited due to weak supervision or flawed model design, resulting in unsatisfactory performance in downstream tasks. In this paper, we introduce **DeepGate2**, a novel functionality-aware learning framework that significantly improves upon the original DeepGate solution in terms of both learning effectiveness and efficiency. Our approach involves using pairwise truth table differences between sampled logic gates as training supervision, along with a well-designed and scalable loss function that explicitly considers circuit functionality. Additionally, we consider inherent circuit characteristics and design an efficient one-round graph neural network (GNN), resulting in an order of magnitude faster learning speed than the original DeepGate solution. Experimental results demonstrate significant improvements in two practical downstream tasks: logic synthesis and Boolean satisfiability solving. 10 | 11 | ## Installation 12 | ```sh 13 | conda create -n deepgate2 python=3.8.9 14 | conda activate deepgate2 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Prepare Dataset 19 | Please refer the script `./src/prepare_dataset.py` for dataset preparation. You should set the folder with circuit netlist in .bench format. 20 | ```sh 21 | cd src 22 | python ./src/prepare_dataset.py # Use the default settings 23 | ``` 24 | 25 | ## Model Training 26 | The model training is separated into two stages. 27 | Model learns to predict the logic probability in Stage.1, which is the same task with previous version of DeepGate 28 | Model learns to predict the pairwise truth table difference of two gates in Stage.2. 29 | ```sh 30 | bash ./run/stage1_train.sh 31 | bash ./run/stage2_trun.sh 32 | ``` 33 | ## Model Inference 34 | You can test the model with the testing dataset. Please remember generate the testing dataset before that. 35 | ```sh 36 | bash ./run/test.sh 37 | ``` 38 | 39 | This model also support to parse raw data (in .aig or .bench format) and generate gate embeddings into file. Therefore, you can apply these embeddings to downstream tasks 40 | ```sh 41 | cd src 42 | python get_emb_bench.py 43 | ``` 44 | 45 | 46 | -------------------------------------------------------------------------------- /src/models/convgnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | from torch import nn 7 | 8 | from .gat_conv import AGNNConv 9 | from .gcn_conv import AggConv 10 | from .deepset_conv import DeepSetConv 11 | from .gated_sum_conv import GatedSumConv 12 | from .mlp import MLP 13 | 14 | 15 | _aggr_function_factory = { 16 | 'aggnconv': AGNNConv, 17 | 'deepset': DeepSetConv, 18 | 'gated_sum': GatedSumConv, 19 | 'conv_sum': AggConv, 20 | } 21 | 22 | 23 | class ConvGNN(nn.Module): 24 | ''' 25 | Convolutional Graph Neural Networks for Circuits. 26 | ''' 27 | 28 | def __init__(self, args): 29 | super(ConvGNN, self).__init__() 30 | 31 | self.args = args 32 | 33 | # configuration 34 | self.device = self.args.device 35 | self.predict_diff = self.args.predict_diff 36 | 37 | # dimensions 38 | self.num_aggr = args.num_aggr 39 | self.dim_node_feature = args.dim_node_feature 40 | self.dim_hidden = args.dim_hidden 41 | self.dim_mlp = args.dim_mlp 42 | self.dim_pred = args.dim_pred 43 | self.num_fc = args.num_fc 44 | self.wx_mlp = args.wx_mlp 45 | 46 | # 1. message/aggr-related 47 | if self.args.aggr_function in _aggr_function_factory.keys(): 48 | # TODO: consider the unconsistent dim of node feature and hidden state 49 | # corner case: gat_conv 50 | self.aggr = self.aggr_forward = nn.ModuleList([_aggr_function_factory[self.args.aggr_function](self.dim_node_feature, self.dim_hidden) if l == 0 else _aggr_function_factory[self.args.aggr_function](self.dim_hidden) for l in range(self.num_aggr)]) 51 | else: 52 | raise KeyError('no support {} aggr function.'.format(self.args.aggr_function)) 53 | 54 | # 2. predictor-related 55 | # TODO: support multiple predictors. Use a nn.ModuleList to handle it. 56 | self.norm_layer = args.norm_layer 57 | self.activation_layer = args.activation_layer 58 | if self.wx_mlp: 59 | self.predictor = MLP(self.dim_hidden+self.dim_node_feature, self.dim_mlp, self.dim_pred, 60 | num_layer=self.num_fc, norm_layer=self.norm_layer, act_layer=self.activation_layer, sigmoid=False, tanh=False) 61 | else: 62 | self.predictor = MLP(self.dim_hidden, self.dim_mlp, self.dim_pred, 63 | num_layer=self.num_fc, norm_layer=self.norm_layer, act_layer=self.activation_layer, sigmoid=False, tanh=False) 64 | 65 | def forward(self, G): 66 | x, edge_index = G.x, G.edge_index 67 | 68 | preds = [] 69 | for i in range(self.num_aggr): 70 | x = self.aggr[i](x, edge_index) 71 | 72 | pred = self.predictor(x) 73 | preds.append(pred) 74 | 75 | return preds 76 | 77 | 78 | def get_conv_gnn(args): 79 | return ConvGNN(args) 80 | -------------------------------------------------------------------------------- /src/models/aggnmlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | import torch.nn as nn 4 | 5 | from typing import Optional 6 | from torch import Tensor 7 | from torch_geometric.utils import softmax 8 | from torch_geometric.typing import Adj, OptTensor 9 | from .mlp import MLP 10 | 11 | class AttnMLP(MessagePassing): 12 | ''' 13 | The message propagation methods described in NeuroSAT (2 layers without dropout) and CircuitSAT (2 layers, dim = 50, dropout - 20%). 14 | Cite from NeuroSAT: 15 | `we sum the outgoing messages of each of a node’s neighbors to form the incoming message.` 16 | ''' 17 | def __init__(self, in_channels, mlp_channels=512, ouput_channels=64, num_layer=3, p_drop=0.2, act_layer=None, norm_layer=None, reverse=False, mlp_post=None): 18 | super(AttnMLP, self).__init__(aggr='add', flow='target_to_source' if reverse else 'source_to_target') 19 | if ouput_channels is None: 20 | ouput_channels = in_channels 21 | assert (in_channels > 0) and (ouput_channels > 0), 'The dimension for the DeepSetConv should be larger than 0.' 22 | 23 | self.msg_pre = MLP(in_channels, mlp_channels, ouput_channels, 24 | num_layer=num_layer, p_drop=p_drop, act_layer=act_layer, norm_layer=norm_layer) 25 | self.msg = nn.Linear(in_channels, ouput_channels) 26 | self.msg_post = None if mlp_post is None else mlp_post 27 | self.attn_lin = nn.Linear(ouput_channels + ouput_channels, 1) 28 | 29 | 30 | self.msg_q = MLP(in_channels, mlp_channels, ouput_channels, 31 | num_layer=num_layer, p_drop=p_drop, act_layer=act_layer, norm_layer=norm_layer) 32 | self.msg_k = MLP(in_channels, mlp_channels, ouput_channels, 33 | num_layer=num_layer, p_drop=p_drop, act_layer=act_layer, norm_layer=norm_layer) 34 | self.msg_v = MLP(in_channels, mlp_channels, ouput_channels, 35 | num_layer=num_layer, p_drop=p_drop, act_layer=act_layer, norm_layer=norm_layer) 36 | 37 | 38 | def forward(self, x, edge_index, edge_attr=None, **kwargs): 39 | # x has shape [N, in_channels] 40 | # edge_index has shape [2, E] 41 | 42 | return self.propagate(edge_index, x=x, edge_attr=edge_attr) 43 | 44 | def message(self, x_i, x_j, edge_attr, index: Tensor, ptr: OptTensor, size_i: Optional[int]): 45 | # h_i: query, h_j: key, value 46 | h_attn_q_i = self.msg_q(x_i) 47 | h_attn = self.msg_k(x_j) 48 | # see comment in above self attention why this is done here and not in forward 49 | a_j = self.attn_lin(torch.cat([h_attn_q_i, h_attn], dim=-1)) 50 | a_j = softmax(a_j, index, ptr, size_i) 51 | t = self.msg_v(x_j) * a_j 52 | return t 53 | 54 | def update(self, aggr_out): 55 | if self.msg_post is not None: 56 | return self.msg_post(aggr_out) 57 | else: 58 | return aggr_out 59 | 60 | 61 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 6 | import os 7 | import time 8 | import sys 9 | import torch 10 | USE_TENSORBOARD = True 11 | try: 12 | import tensorboardX 13 | print('Using tensorboardX') 14 | except: 15 | USE_TENSORBOARD = False 16 | 17 | 18 | class Logger(object): 19 | def __init__(self, opt, gpu): 20 | """Create a summary writer logging to log_dir.""" 21 | if gpu == 0: 22 | if not os.path.exists(opt.save_dir): 23 | os.makedirs(opt.save_dir) 24 | if not os.path.exists(opt.debug_dir): 25 | os.makedirs(opt.debug_dir) 26 | 27 | time_str = time.strftime('%Y-%m-%d-%H-%M') 28 | 29 | args = dict((name, getattr(opt, name)) for name in dir(opt) 30 | if not name.startswith('_')) 31 | file_name = os.path.join(opt.save_dir, 'opt.txt') 32 | with open(file_name, 'wt') as opt_file: 33 | opt_file.write('==> torch version: {}\n'.format(torch.__version__)) 34 | opt_file.write('==> cudnn version: {}\n'.format( 35 | torch.backends.cudnn.version())) 36 | opt_file.write('==> Cmd:\n') 37 | opt_file.write(str(sys.argv)) 38 | opt_file.write('\n==> Opt:\n') 39 | for k, v in sorted(args.items()): 40 | opt_file.write(' %s: %s\n' % (str(k), str(v))) 41 | 42 | if gpu == 0: 43 | log_dir = opt.save_dir + '/logs_{}'.format(time_str) 44 | if USE_TENSORBOARD: 45 | self.writer = tensorboardX.SummaryWriter(log_dir=log_dir) 46 | else: 47 | if not os.path.exists(os.path.dirname(log_dir)): 48 | os.mkdir(os.path.dirname(log_dir)) 49 | if not os.path.exists(log_dir): 50 | os.mkdir(log_dir) 51 | self.log = open(log_dir + '/log.txt', 'w') 52 | try: 53 | os.system('cp {}/opt.txt {}/'.format(opt.save_dir, log_dir)) 54 | except: 55 | pass 56 | self.start_line = True 57 | 58 | def write(self, txt,local_rank): 59 | if local_rank == 0: 60 | if self.start_line: 61 | time_str = time.strftime('%Y-%m-%d-%H-%M') 62 | self.log.write('{}: {}'.format(time_str, txt)) 63 | else: 64 | self.log.write(txt) 65 | self.start_line = False 66 | if '\n' in txt: 67 | self.start_line = True 68 | self.log.flush() 69 | 70 | def close(self): 71 | self.log.close() 72 | 73 | def scalar_summary(self, tag, value, step, local_rank): 74 | if local_rank == 0: 75 | """Log a scalar variable.""" 76 | if USE_TENSORBOARD: 77 | self.writer.add_scalar(tag, value, step) 78 | -------------------------------------------------------------------------------- /src/models/layernorm_gru.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The code from https://gist.github.com/denisyarats/2074e6f302dc6998a9f6f9051334e3bd 3 | ''' 4 | import torch.nn as nn 5 | import torch.nn.init 6 | 7 | 8 | class LayerNormGRUCell(nn.GRUCell): 9 | def __init__(self, input_size, hidden_size, bias=True): 10 | super(LayerNormGRUCell, self).__init__(input_size, hidden_size, bias) 11 | 12 | self.gamma_ih = nn.Parameter(torch.ones(3 * self.hidden_size)) 13 | self.gamma_hh = nn.Parameter(torch.ones(3 * self.hidden_size)) 14 | self.eps = 0 15 | 16 | def _layer_norm_x(self, x, g, b): 17 | mean = x.mean(1).expand_as(x) 18 | std = x.std(1).expand_as(x) 19 | return g.expand_as(x) * ((x - mean) / (std + self.eps)) + b.expand_as(x) 20 | 21 | def _layer_norm_h(self, x, g, b): 22 | mean = x.mean(1).expand_as(x) 23 | return g.expand_as(x) * (x - mean) + b.expand_as(x) 24 | 25 | def forward(self, x, h): 26 | 27 | ih_rz = self._layer_norm_x( 28 | torch.mm(x, self.weight_ih.narrow(0, 0, 2 * self.hidden_size).transpose(0, 1)), 29 | self.gamma_ih.narrow(0, 0, 2 * self.hidden_size), 30 | self.bias_ih.narrow(0, 0, 2 * self.hidden_size)) 31 | 32 | hh_rz = self._layer_norm_h( 33 | torch.mm(h, self.weight_hh.narrow(0, 0, 2 * self.hidden_size).transpose(0, 1)), 34 | self.gamma_hh.narrow(0, 0, 2 * self.hidden_size), 35 | self.bias_hh.narrow(0, 0, 2 * self.hidden_size)) 36 | 37 | rz = torch.sigmoid(ih_rz + hh_rz) 38 | r = rz.narrow(1, 0, self.hidden_size) 39 | z = rz.narrow(1, self.hidden_size, self.hidden_size) 40 | 41 | ih_n = self._layer_norm_x( 42 | torch.mm(x, self.weight_ih.narrow(0, 2 * self.hidden_size, self.hidden_size).transpose(0, 1)), 43 | self.gamma_ih.narrow(0, 2 * self.hidden_size, self.hidden_size), 44 | self.bias_ih.narrow(0, 2 * self.hidden_size, self.hidden_size)) 45 | 46 | hh_n = self._layer_norm_h( 47 | torch.mm(h, self.weight_hh.narrow(0, 2 * self.hidden_size, self.hidden_size).transpose(0, 1)), 48 | self.gamma_hh.narrow(0, 2 * self.hidden_size, self.hidden_size), 49 | self.bias_hh.narrow(0, 2 * self.hidden_size, self.hidden_size)) 50 | 51 | n = torch.tanh(ih_n + r * hh_n) 52 | h = (1 - z) * n + z * h 53 | return h 54 | 55 | class LayerNormGRU(nn.Module): 56 | def __init__(self, input_size, hidden_size, bias=True): 57 | super(LayerNormGRU, self).__init__() 58 | self.cell = LayerNormGRUCell(input_size, hidden_size, bias) 59 | self.weight_ih_l0 = self.cell.weight_ih 60 | self.weight_hh_l0 = self.cell.weight_hh 61 | self.bias_ih_l0 = self.cell.bias_ih 62 | self.bias_hh_l0 = self.cell.bias_hh 63 | 64 | def forward(self, xs, h): 65 | h = h.squeeze(0) 66 | ys = [] 67 | for i in range(xs.size(0)): 68 | x = xs.narrow(0, i, 1).squeeze(0) 69 | h = self.cell(x, h) 70 | ys.append(h.unsqueeze(0)) 71 | y = torch.cat(ys, 0) 72 | h = h.unsqueeze(0) 73 | return y, h -------------------------------------------------------------------------------- /src/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Utility functions for data handling 3 | ''' 4 | 5 | import torch 6 | import math 7 | import os 8 | 9 | import numpy as np 10 | 11 | 12 | def read_file(file_name): 13 | f = open(file_name, "r") 14 | data = f.readlines() 15 | return data 16 | 17 | def write_file(filename, dir, y): 18 | path = os.path.join(dir,filename) 19 | f = open(path, "w") 20 | for val in y: 21 | f.write(str(val[0]) + "\n") 22 | f.close() 23 | 24 | 25 | def read_npz_file(filename, dir): 26 | path = os.path.join(dir, filename) 27 | data = np.load(path, allow_pickle=True) 28 | return data 29 | 30 | 31 | 32 | def write_subcircuits(filename, dir, x_data, edge_data): 33 | # format : "Node name:Gate Type:Logic Level:C1-3Circuits:C0:O:Fanout:Reconvergence" 34 | path = os.path.join(dir,filename) 35 | f = open(path, "w") 36 | 37 | # x_data = x_data.numpy() 38 | 39 | for node in x_data: 40 | for n in node: 41 | f.write(str(n) + ":") 42 | f.write(";") 43 | f.write("\n") 44 | 45 | for edge in edge_data: 46 | f.write("(" + str(edge[0]) + "," + str(edge[1]) + ");") 47 | f.write("\n") 48 | f.close() 49 | 50 | 51 | 52 | def update_labels(x, y): 53 | for idx, val in enumerate(x): 54 | y[idx] = [y[idx][0] - val[3]] 55 | 56 | return y 57 | 58 | 59 | def remove(initial_sources): 60 | final_list = [] 61 | for num in initial_sources: 62 | if num not in final_list: 63 | final_list.append(num) 64 | return final_list 65 | 66 | 67 | def one_hot(idx, length): 68 | if type(idx) is int: 69 | idx = torch.LongTensor([idx]).unsqueeze(0) 70 | else: 71 | idx = torch.LongTensor(idx).unsqueeze(0).t() 72 | x = torch.zeros((len(idx), length)).scatter_(1, idx, 1) 73 | return x 74 | 75 | 76 | 77 | def construct_node_feature(x, no_node_cop, node_reconv, num_gate_types): 78 | # the one-hot embedding for the gate types 79 | gate_list = x[:, 1] 80 | gate_list = np.float32(gate_list) 81 | x_torch = one_hot(gate_list, num_gate_types) 82 | # if node_reconv: 83 | # reconv = torch.tensor(x[:, 7], dtype=torch.float).unsqueeze(1) 84 | # x_torch = torch.cat([x_torch, reconv], dim=1) 85 | return x_torch 86 | 87 | 88 | def add_skip_connection(x, edge_index, edge_attr, ehs): 89 | for (ind, node) in enumerate(x): 90 | if node[7] == 1: 91 | d = ind 92 | s = node[8] 93 | new_edge = torch.tensor([s, d], dtype=torch.long).unsqueeze(0) 94 | edge_index = torch.cat((edge_index, new_edge), dim=0) 95 | ll_diff = node[2] - x[int(node[8])][2] 96 | new_attr = add_edge_attr(1, ehs, ll_diff) 97 | edge_attr = torch.cat([edge_attr, new_attr], dim=0) 98 | return edge_index, edge_attr 99 | 100 | 101 | def add_edge_attr(num_edge, ehs, ll_diff=1): 102 | positional_embeddings = torch.zeros(num_edge, ehs) 103 | for position in range(num_edge): 104 | for i in range(0, ehs, 2): 105 | positional_embeddings[position, i] = ( 106 | math.sin(ll_diff / (10000 ** ((2 * i) / ehs))) 107 | ) 108 | positional_embeddings[position, i + 1] = ( 109 | math.cos(ll_diff / (10000 ** ((2 * (i + 1)) / ehs))) 110 | ) 111 | 112 | return positional_embeddings 113 | -------------------------------------------------------------------------------- /src/get_emb_aig.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import os 7 | from progress.bar import Bar 8 | import random 9 | import time 10 | import torch 11 | import glob 12 | import shutil 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | 16 | import utils.circuit_utils as circuit_utils 17 | from config import get_parse_args 18 | from utils.utils import AverageMeter, pyg_simulation, get_function_acc 19 | from utils.random_seed import set_seed 20 | from utils.sat_utils import solve_sat_iteratively 21 | from utils.aiger_utils import aig_to_xdata 22 | from datasets.dataset_factory import dataset_factory 23 | from detectors.detector_factory import detector_factory 24 | from datasets.mlpgate_dataset import MLPGateDataset 25 | from datasets.load_data import parse_pyg_mlpgate 26 | 27 | AIG_DIR = "../dataset/EPFL/" 28 | TMP_DIR = "./tmp" 29 | EMB_DIR = "./emb" 30 | NEW_AIG_DIR = './tmp/aig' 31 | AIG_NAMELIST = [] 32 | 33 | def save_emb(emb, prob, path): 34 | f = open(path, 'w') 35 | f.write('{} {}\n'.format(len(emb), len(emb[0]))) 36 | for i in range(len(emb)): 37 | for j in range(len(emb[i])): 38 | f.write('{:.6f} '.format(float(emb[i][j]))) 39 | f.write('\n') 40 | for i in range(len(prob)): 41 | f.write('{:.6f}\n'.format(float(prob[i]))) 42 | f.close() 43 | 44 | def save_aig(aig_src, aig_dst): 45 | shutil.copy(aig_src, aig_dst) 46 | 47 | def test(args): 48 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str 49 | detector = detector_factory['base'](args) 50 | if len(AIG_NAMELIST) == 0: 51 | for filename in glob.glob(os.path.join(AIG_DIR, '*.aig')): 52 | aig_name = filename.split('/')[-1].split('.')[0] 53 | AIG_NAMELIST.append(aig_name) 54 | 55 | for aig_name in AIG_NAMELIST: 56 | aig_filepath = os.path.join(AIG_DIR, aig_name + '.aiger') 57 | tmp_aag_filepath = os.path.join(TMP_DIR, aig_name + '.aag') 58 | x_data, edge_index = aig_to_xdata(aig_filepath, tmp_aag_filepath, args.gate_to_index) 59 | os.remove(tmp_aag_filepath) 60 | if len(x_data) == 0: 61 | continue 62 | fanin_list, fanout_list = circuit_utils.get_fanin_fanout(x_data, edge_index) 63 | level_list = circuit_utils.get_level(x_data, fanin_list, fanout_list) 64 | print('Parse AIG: ', aig_filepath) 65 | 66 | # Generate graph 67 | x_data = np.array(x_data) 68 | edge_index = np.array(edge_index) 69 | tt_dis = [] 70 | min_tt_dis = [] 71 | tt_pair_index = [] 72 | prob = [0] * len(x_data) 73 | rc_pair_index = [[0, 1]] 74 | is_rc = [] 75 | g = parse_pyg_mlpgate( 76 | x_data, edge_index, tt_dis, min_tt_dis, tt_pair_index, prob, rc_pair_index, is_rc, 77 | args.use_edge_attr, args.reconv_skip_connection, args.no_node_cop, 78 | args.node_reconv, args.un_directed, args.num_gate_types, 79 | args.dim_edge_feature, args.logic_implication, args.mask 80 | ) 81 | g.to(args.device) 82 | 83 | # Model 84 | start_time = time.time() 85 | res = detector.run(g) 86 | end_time = time.time() 87 | hs, hf, prob, is_rc = res['results'] 88 | print("Circuit: {}, Size: {:}, Time: {:.2f}".format(aig_name, len(x_data), end_time-start_time)) 89 | # acc = get_function_acc(g, hf) 90 | # print("ACC: {:.2f}%".format(acc/100)) 91 | print() 92 | 93 | # Save emb 94 | emb_filepath = os.path.join(EMB_DIR, aig_name + '.txt') 95 | save_emb(hf.detach().cpu().numpy(), prob.detach().cpu().numpy(), emb_filepath) 96 | newaig_filepath = os.path.join(NEW_AIG_DIR, aig_name + '.aig') 97 | save_aig(aig_filepath, newaig_filepath) 98 | 99 | if __name__ == '__main__': 100 | args = get_parse_args() 101 | set_seed(args) 102 | test(args) 103 | -------------------------------------------------------------------------------- /src/test_acc_bin.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import os 7 | from progress.bar import Bar 8 | import random 9 | import torch 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from sklearn.metrics import roc_curve, auc 13 | 14 | from config import get_parse_args 15 | from utils.logger import Logger 16 | from utils.utils import AverageMeter, pyg_simulation 17 | from utils.random_seed import set_seed 18 | from utils.circuit_utils import check_difference 19 | from utils.sat_utils import solve_sat_iteratively 20 | from datasets.dataset_factory import dataset_factory 21 | from detectors.detector_factory import detector_factory 22 | from datasets.mlpgate_dataset import MLPGateDataset 23 | 24 | MIN_DIST_PROB = 0.5 25 | MIN_DIST_TT = 0.5 26 | NAME_LIST = [] 27 | THREHOLD = 0.91 28 | 29 | def test(args): 30 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str 31 | 32 | print(args) 33 | acc_list = [] 34 | 35 | dataset = MLPGateDataset(args.data_dir, args)[:100] 36 | detector = detector_factory['base'](args) 37 | 38 | num_cir = len(dataset) 39 | print('EXP ID: ', args.exp_id) 40 | print('Tot num of circuits: ', num_cir) 41 | 42 | for ind, g in enumerate(dataset): 43 | if g.tt_dis.sum() == 0: 44 | continue 45 | if len(NAME_LIST) > 0 and g.name not in NAME_LIST: 46 | continue 47 | res = detector.run(g) 48 | hs, hf, prob, is_rc = res['results'] 49 | node_emb = hf 50 | tp = 0 51 | tn = 0 52 | fp = 0 53 | fn = 0 54 | tot = 0 55 | pd_list = [] 56 | gt_list = [] 57 | 58 | for pair_index in range(len(g.tt_pair_index[0])): 59 | pair_A = g.tt_pair_index[0][pair_index] 60 | pair_B = g.tt_pair_index[1][pair_index] 61 | pair_gt = g.tt_dis[pair_index] 62 | pair_pd_sim = torch.cosine_similarity(node_emb[pair_A].unsqueeze(0), node_emb[pair_B].unsqueeze(0), eps=1e-8) 63 | # Skip 64 | if pair_gt != 0 and pair_gt < MIN_DIST_TT: 65 | continue 66 | if pair_gt != 0 and abs(g.prob[pair_A] - g.prob[pair_B]) < MIN_DIST_PROB: 67 | continue 68 | 69 | pd_list.append(pair_pd_sim.item()) 70 | gt_list.append(pair_gt.item() == 0) 71 | tot += 1 72 | 73 | if pair_pd_sim > 0.9 and pair_gt > 0: 74 | tmp_a = 0 75 | 76 | pd_list = np.array(pd_list) 77 | gt_list = np.array(gt_list) 78 | fpr, tpr, thresholds = roc_curve(gt_list, pd_list) 79 | roc_auc = auc(fpr, tpr) 80 | opt_thro = thresholds[np.argmax(tpr - fpr)] 81 | # Threshold 82 | # pd_list_bin = pd_list > THREHOLD 83 | pd_list_bin = pd_list > opt_thro 84 | 85 | tp = np.sum(pd_list_bin & gt_list) 86 | tn = np.sum((~pd_list_bin) & (~gt_list)) 87 | fp = np.sum(pd_list_bin & (~gt_list)) 88 | fn = np.sum((~pd_list_bin) & gt_list) 89 | 90 | print('Circuit: {}, Size: {:}'.format( 91 | g.name, len(g.x) 92 | )) 93 | print('Threshold: {:.2f}'.format(opt_thro)) 94 | print('TP: {:.2f}%, TN: {:.2f}%, FP: {:.2f}%, FN: {:.2f}%'.format( 95 | tp / tot * 100, tn / tot * 100, fp / tot * 100, fn / tot * 100 96 | )) 97 | print('Accuracy: {:.2f}%, Precision: {:.2f}%, Recall: {:.2f}%'.format( 98 | (tp + tn) / tot * 100, tp / (tp + fp) * 100, tp / (tp + fn) * 100 99 | )) 100 | print('F1 Score: {:.3f}'.format( 101 | 2 * tp / (2 * tp + fp + fn) 102 | )) 103 | print('AUC: {:.6f}'.format( 104 | roc_auc 105 | )) 106 | print() 107 | 108 | 109 | if __name__ == '__main__': 110 | args = get_parse_args() 111 | set_seed(args) 112 | 113 | test(args) 114 | -------------------------------------------------------------------------------- /src/detectors/base_detector.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from progress.bar import Bar 11 | import time 12 | import torch 13 | 14 | from models.model import create_model, load_model 15 | # from utils.debugger import Debugger 16 | 17 | class BaseDetector(object): 18 | def __init__(self, args): 19 | if args.gpus[0] >= 0: 20 | args.device = torch.device('cuda') 21 | else: 22 | args.device = torch.device('cpu') 23 | self.cop_only = args.cop_only 24 | if not self.cop_only: 25 | print('Creating model...') 26 | self.model = create_model(args) 27 | self.model = load_model(self.model, args.load_model, device = args.device) 28 | self.model = self.model.to(args.device) 29 | self.model.eval() 30 | 31 | self.args = args 32 | self.pause = True 33 | 34 | def run(self, graph): 35 | net_time, post_time = 0, 0 36 | tot_time = 0 37 | # debugger = Debugger(dataset=self.args.dataset, ipynb=(self.args.debug == 3), 38 | # theme=self.args.debugger_theme) 39 | start_time = time.time() 40 | 41 | graph = graph.to(self.args.device) 42 | 43 | output, forward_time = self.process(graph, return_time=True) 44 | net_time += forward_time - start_time 45 | # if self.opt.debug >= 2: 46 | # self.debug(debugger, graph, output) 47 | post_process_time = time.time() 48 | post_time += post_process_time - forward_time 49 | 50 | tot_time += post_process_time - start_time 51 | 52 | if self.args.debug == 1: 53 | # self.show_results(debugger, graph, output, self.args.debug_dir) 54 | self.show_results(graph, output, self.args.debug_dir) 55 | 56 | return {'results': output, 'tot': tot_time, 'net': net_time, 57 | 'post': post_time} 58 | 59 | def process(self, graph, return_time=False): 60 | 61 | with torch.no_grad(): 62 | output = self.model(graph)[0] 63 | 64 | forward_time = time.time() 65 | 66 | 67 | if return_time: 68 | return output, forward_time 69 | else: 70 | return output 71 | 72 | 73 | 74 | def pre_process(self, graph, meta=None): 75 | raise NotImplementedError 76 | 77 | def post_process(self, output, graph): 78 | if self.args.predict_diff: 79 | output = output / self.args.diff_multiplier + graph.c1.to(self.args.device) 80 | output = torch.clamp(output, min=-1., max=1.) 81 | else: 82 | output = torch.clamp(output, min=0., max=1.) 83 | return output 84 | 85 | def merge_outputs(self, detections): 86 | raise NotImplementedError 87 | 88 | def debug(self, debugger, graph, dets, output, scale=1): 89 | raise NotImplementedError 90 | 91 | # def show_results(self, debugger, image, results): 92 | def show_results(self, graph, output, path): 93 | # c1 94 | file_name = os.path.join(path, '{}_c1.png'.format(graph.name)) 95 | x = graph.gt[:, 0].cpu().numpy() 96 | y = graph.c1[:, 0].cpu().numpy() 97 | 98 | plt.scatter(x,y) 99 | 100 | plt.title('{} - gt vs c1'.format(graph.name)) 101 | plt.xlabel("gt") 102 | plt.ylabel("c1") 103 | 104 | plt.savefig(file_name) 105 | plt.close() 106 | 107 | if not self.cop_only: 108 | # pred 109 | file_name = os.path.join(path, '{}_pred.png'.format(graph.name)) 110 | x = graph.gt[:, 0].cpu().numpy() 111 | y = output[:, 0].cpu().numpy() 112 | 113 | plt.scatter(x,y) 114 | 115 | plt.title('{} - gt vs pred'.format(graph.name)) 116 | plt.xlabel("gt") 117 | plt.ylabel("pred") 118 | 119 | plt.savefig(file_name) 120 | plt.close() 121 | 122 | -------------------------------------------------------------------------------- /src/datasets/circuit_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, List 2 | import os.path as osp 3 | 4 | 5 | import torch 6 | from torch_geometric.data import Data 7 | from torch_geometric.data import InMemoryDataset 8 | 9 | from utils.data_utils import read_npz_file 10 | from .load_data import circuit_parse_pyg 11 | 12 | 13 | class CircuitDataset(InMemoryDataset): 14 | r""" 15 | A variety of circuit graph datasets, *e.g.*, open-sourced benchmarks, 16 | random circuits. 17 | 18 | Args: 19 | root (string): Root directory where the dataset should be saved. 20 | args (object): The arguments specified by the main program. 21 | transform (callable, optional): A function/transform that takes in an 22 | :obj:`torch_geometric.data.Data` object and returns a transformed 23 | version. The data object will be transformed before every access. 24 | (default: :obj:`None`) 25 | pre_transform (callable, optional): A function/transform that takes in 26 | an :obj:`torch_geometric.data.Data` object and returns a 27 | transformed version. The data object will be transformed before 28 | being saved to disk. (default: :obj:`None`) 29 | pre_filter (callable, optional): A function that takes in an 30 | :obj:`torch_geometric.data.Data` object and returns a boolean 31 | value, indicating whether the data object should be included in the 32 | final dataset. (default: :obj:`None`) 33 | """ 34 | 35 | def __init__(self, root, args, transform=None, pre_transform=None, pre_filter=None): 36 | self.name = args.dataset 37 | self.args = args 38 | 39 | assert (transform == None) and (pre_transform == None) and (pre_filter == None), "Cannot accept the transform, pre_transfrom and pre_filter args now." 40 | 41 | super().__init__(root, transform, pre_transform, pre_filter) 42 | self.data, self.slices = torch.load(self.processed_paths[0]) 43 | 44 | @property 45 | def raw_dir(self): 46 | return self.root 47 | 48 | @property 49 | def processed_dir(self): 50 | name = "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}".format(int(self.args.use_edge_attr), int(self.args.reconv_skip_connection), int(self.args.predict_diff), self.args.diff_multiplier, int(self.args.no_node_cop), int(self.args.node_reconv), self.args.num_gate_types, self.args.dim_edge_feature, int(self.args.small_train), int(self.args.un_directed), int(self.args.logic_implication), int(self.args.mask)) 51 | return osp.join(self.root, name) 52 | 53 | @property 54 | def raw_file_names(self) -> List[str]: 55 | return [self.args.circuit_file, self.args.label_file] 56 | 57 | @property 58 | def processed_file_names(self) -> str: 59 | return ['data.pt'] 60 | 61 | def download(self): 62 | pass 63 | 64 | def process(self): 65 | data_list = [] 66 | circuits = read_npz_file(self.args.circuit_file, self.args.data_dir)['circuits'].item() 67 | labels = read_npz_file(self.args.label_file, self.args.data_dir)['labels'].item() 68 | 69 | if self.args.small_train: 70 | subset = 100 71 | 72 | for cir_idx, cir_name in enumerate(circuits): 73 | print('Parse circuit: ', cir_name) 74 | x = circuits[cir_name]["x"] 75 | edge_index = circuits[cir_name]["edge_index"] 76 | y = labels[cir_name]["y"] 77 | # check the gate types 78 | # assert (x[:, 1].max() == (len(self.args.gate_to_index)) - 1), 'The gate types are not consistent.' 79 | graph = circuit_parse_pyg(x, edge_index, y, self.args.use_edge_attr, \ 80 | self.args.reconv_skip_connection, self.args.logic_diff_embedding, self.args.predict_diff, \ 81 | self.args.diff_multiplier, self.args.no_node_cop, self.args.node_reconv, self.args.un_directed, self.args.num_gate_types, self.args.dim_edge_feature, self.args.logic_implication, self.args.mask) 82 | graph.name = cir_name 83 | data_list.append(graph) 84 | if self.args.small_train and cir_idx > subset: 85 | break 86 | 87 | 88 | data, slices = self.collate(data_list) 89 | torch.save((data, slices), self.processed_paths[0]) 90 | 91 | def __repr__(self) -> str: 92 | return f'{self.name}({len(self)})' -------------------------------------------------------------------------------- /src/models/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torchvision.models as models 6 | import torch 7 | import torch.nn as nn 8 | import os 9 | 10 | from .recgnn import get_recurrent_gnn 11 | from .convgnn import get_conv_gnn 12 | from .dag_convgnn import get_dag_recurrent_gnn 13 | from .mlpgate import get_mlp_gate 14 | from .mlpgate_merge import get_mlp_gate_merged 15 | 16 | _model_factory = { 17 | 'recgnn': get_recurrent_gnn, 18 | 'convgnn': get_conv_gnn, 19 | 'dagconvgnn': get_dag_recurrent_gnn, 20 | 'mlpgnn': get_mlp_gate, 21 | 'mlpgnn_merge': get_mlp_gate_merged, 22 | 23 | } 24 | 25 | 26 | def create_model(args): 27 | get_model = _model_factory[args.arch] 28 | model = get_model(args) 29 | return model 30 | 31 | def load_model(model, model_path, optimizer=None, resume=False, 32 | lr=None, lr_step=None, local_rank = 0, device='cuda'): 33 | start_epoch = 0 34 | # checkpoint = torch.load( 35 | # model_path, map_location=lambda storage, loc: storage) 36 | 37 | if device == 'cuda': 38 | map = {'cuda:%d' % 0: 'cuda:%d' % local_rank} 39 | checkpoint = torch.load( 40 | model_path, map_location=map) 41 | else: 42 | checkpoint = torch.load( 43 | model_path, map_location=lambda storage, loc: storage) 44 | 45 | if local_rank == 0: 46 | print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch'])) 47 | state_dict_ = checkpoint['state_dict'] 48 | state_dict = {} 49 | 50 | # convert data_parallal to model 51 | for k in state_dict_: 52 | if k.startswith('module') and not k.startswith('module_list'): 53 | state_dict[k[7:]] = state_dict_[k] 54 | else: 55 | state_dict[k] = state_dict_[k] 56 | model_state_dict = model.state_dict() 57 | 58 | # check loaded parameters and created model parameters 59 | msg = 'If you see this, your model does not fully load the ' + \ 60 | 'pre-trained weight. Please make sure ' + \ 61 | 'you have correctly specified --arch xxx ' + \ 62 | 'or set the correct --num_classes for your own dataset.' 63 | for k in state_dict: 64 | if k in model_state_dict: 65 | if state_dict[k].shape != model_state_dict[k].shape: 66 | if local_rank == 0: 67 | print('Skip loading parameter {}, required shape{}, ' 68 | 'loaded shape{}. {}'.format( 69 | k, model_state_dict[k].shape, state_dict[k].shape, msg)) 70 | state_dict[k] = model_state_dict[k] 71 | else: 72 | if local_rank == 0: 73 | print('Drop parameter {}.'.format(k) + msg) 74 | for k in model_state_dict: 75 | if not (k in state_dict): 76 | if local_rank == 0: 77 | print('No param {}.'.format(k) + msg) 78 | state_dict[k] = model_state_dict[k] 79 | model.load_state_dict(state_dict, strict=False) 80 | 81 | # resume optimizer parameters 82 | if optimizer is not None and resume: 83 | if 'optimizer' in checkpoint: 84 | optimizer.load_state_dict(checkpoint['optimizer']) 85 | start_epoch = checkpoint['epoch'] 86 | start_lr = lr 87 | for step in lr_step: 88 | if start_epoch >= step: 89 | start_lr *= 0.1 90 | for param_group in optimizer.param_groups: 91 | param_group['lr'] = start_lr 92 | if local_rank == 0: 93 | print('Resumed optimizer with start lr', start_lr) 94 | else: 95 | if local_rank == 0: 96 | print('No optimizer parameters in checkpoint.') 97 | if optimizer is not None: 98 | return model, optimizer, start_epoch 99 | else: 100 | return model 101 | 102 | 103 | def save_model(path, epoch, model, optimizer=None): 104 | if isinstance(model, torch.nn.DataParallel): 105 | state_dict = model.module.state_dict() 106 | else: 107 | state_dict = model.state_dict() 108 | data = {'epoch': epoch, 109 | 'state_dict': state_dict} 110 | if not (optimizer is None): 111 | data['optimizer'] = optimizer.state_dict() 112 | torch.save(data, path) 113 | -------------------------------------------------------------------------------- /src/models/layernorm_lstm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The code from https://github.com/pytorch/pytorch/issues/11335 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class LayerNormLSTMCell(nn.LSTMCell): 11 | 12 | def __init__(self, input_size, hidden_size, bias=True): 13 | super().__init__(input_size, hidden_size, bias) 14 | 15 | self.ln_ih = nn.LayerNorm(4 * hidden_size) 16 | self.ln_hh = nn.LayerNorm(4 * hidden_size) 17 | self.ln_ho = nn.LayerNorm(hidden_size) 18 | 19 | def forward(self, input, hidden=None): 20 | # self.check_forward_input(input) 21 | if hidden is None: 22 | hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) 23 | cx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) 24 | else: 25 | hx, cx = hidden 26 | # self.check_forward_hidden(input, hx, '[0]') 27 | # self.check_forward_hidden(input, cx, '[1]') 28 | 29 | gates = self.ln_ih(F.linear(input, self.weight_ih, self.bias_ih)) \ 30 | + self.ln_hh(F.linear(hx, self.weight_hh, self.bias_hh)) 31 | i, f, o = gates[:, :(3 * self.hidden_size)].sigmoid().chunk(3, 1) 32 | g = gates[:, (3 * self.hidden_size):].tanh() 33 | 34 | cy = (f * cx) + (i * g) 35 | hy = o * self.ln_ho(cy).tanh() 36 | return hy, cy 37 | 38 | 39 | class LayerNormLSTM(nn.Module): 40 | 41 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, bidirectional=False): 42 | super().__init__() 43 | self.input_size = input_size 44 | self.hidden_size = hidden_size 45 | self.num_layers = num_layers 46 | self.bidirectional = bidirectional 47 | 48 | 49 | num_directions = 2 if bidirectional else 1 50 | self.hidden0 = nn.ModuleList([ 51 | LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions), 52 | hidden_size=hidden_size, bias=bias) 53 | for layer in range(num_layers) 54 | ]) 55 | 56 | if self.bidirectional: 57 | self.hidden1 = nn.ModuleList([ 58 | LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions), 59 | hidden_size=hidden_size, bias=bias) 60 | for layer in range(num_layers) 61 | ]) 62 | 63 | def forward(self, input, hidden=None): 64 | seq_len, batch_size, hidden_size = input.size() # supports TxNxH only 65 | num_directions = 2 if self.bidirectional else 1 66 | if hidden is None: 67 | hx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False) 68 | cx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False) 69 | else: 70 | hx, cx = hidden 71 | 72 | ht = [[None, ] * (self.num_layers * num_directions)] * seq_len 73 | ct = [[None, ] * (self.num_layers * num_directions)] * seq_len 74 | 75 | if self.bidirectional: 76 | xs = input 77 | for l, (layer0, layer1) in enumerate(zip(self.hidden0, self.hidden1)): 78 | l0, l1 = 2 * l, 2 * l + 1 79 | h0, c0, h1, c1 = hx[l0], cx[l0], hx[l1], cx[l1] 80 | for t, (x0, x1) in enumerate(zip(xs, reversed(xs))): 81 | ht[t][l0], ct[t][l0] = layer0(x0, (h0, c0)) 82 | h0, c0 = ht[t][l0], ct[t][l0] 83 | t = seq_len - 1 - t 84 | ht[t][l1], ct[t][l1] = layer1(x1, (h1, c1)) 85 | h1, c1 = ht[t][l1], ct[t][l1] 86 | xs = [torch.cat((h[l0], h[l1]), dim=1) for h in ht] 87 | y = torch.stack(xs) 88 | hy = torch.stack(ht[-1]) 89 | cy = torch.stack(ct[-1]) 90 | else: 91 | h, c = hx, cx 92 | for t, x in enumerate(input): 93 | for l, layer in enumerate(self.hidden0): 94 | ht[t][l], ct[t][l] = layer(x, (h[l], c[l])) 95 | x = ht[t][l] 96 | h, c = ht[t], ct[t] 97 | y = torch.stack([h[-1] for h in ht]) 98 | hy = torch.stack(ht[-1]) 99 | cy = torch.stack(ct[-1]) 100 | 101 | return y, (hy, cy) -------------------------------------------------------------------------------- /src/datasets/mig_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, List 2 | import os.path as osp 3 | 4 | import torch 5 | import shutil 6 | import os 7 | from torch_geometric.data import Data 8 | from torch_geometric.data import InMemoryDataset 9 | 10 | from utils.data_utils import read_npz_file 11 | from .load_data import circuit_parse_pyg 12 | 13 | 14 | class MIGDataset(InMemoryDataset): 15 | r""" 16 | A variety of circuit graph datasets, *e.g.*, open-sourced benchmarks, 17 | random circuits. 18 | 19 | Args: 20 | root (string): Root directory where the dataset should be saved. 21 | args (object): The arguments specified by the main program. 22 | transform (callable, optional): A function/transform that takes in an 23 | :obj:`torch_geometric.data.Data` object and returns a transformed 24 | version. The data object will be transformed before every access. 25 | (default: :obj:`None`) 26 | pre_transform (callable, optional): A function/transform that takes in 27 | an :obj:`torch_geometric.data.Data` object and returns a 28 | transformed version. The data object will be transformed before 29 | being saved to disk. (default: :obj:`None`) 30 | pre_filter (callable, optional): A function that takes in an 31 | :obj:`torch_geometric.data.Data` object and returns a boolean 32 | value, indicating whether the data object should be included in the 33 | final dataset. (default: :obj:`None`) 34 | """ 35 | 36 | def __init__(self, root, args, transform=None, pre_transform=None, pre_filter=None): 37 | self.name = 'MIG' 38 | self.args = args 39 | 40 | assert (transform == None) and (pre_transform == None) and (pre_filter == None), "Cannot accept the transform, pre_transfrom and pre_filter args now." 41 | 42 | # Reload 43 | inmemory_dir = os.path.join(args.data_dir, 'inmemory') 44 | if args.reload_dataset and os.path.exists(inmemory_dir): 45 | shutil.rmtree(inmemory_dir) 46 | 47 | super().__init__(root, transform, pre_transform, pre_filter) 48 | self.data, self.slices = torch.load(self.processed_paths[0]) 49 | 50 | @property 51 | def raw_dir(self): 52 | return self.root 53 | 54 | @property 55 | def processed_dir(self): 56 | if self.args.small_train: 57 | name = 'inmemory_small' 58 | else: 59 | name = 'inmemory' 60 | return osp.join(self.root, name) 61 | 62 | @property 63 | def raw_file_names(self) -> List[str]: 64 | return [self.args.circuit_file, self.args.label_file] 65 | 66 | @property 67 | def processed_file_names(self) -> str: 68 | return ['data.pt'] 69 | 70 | def download(self): 71 | pass 72 | 73 | def process(self): 74 | data_list = [] 75 | circuits = read_npz_file(self.args.circuit_file, self.args.data_dir)['circuits'].item() 76 | labels = read_npz_file(self.args.label_file, self.args.data_dir)['labels'].item() 77 | 78 | if self.args.small_train: 79 | subset = 100 80 | 81 | for cir_idx, cir_name in enumerate(circuits): 82 | print('Parse circuit: ', cir_name) 83 | x = circuits[cir_name]["x"] 84 | edge_index = circuits[cir_name]["edge_index"] 85 | y = labels[cir_name]["y"] 86 | # check the gate types 87 | # assert (x[:, 1].max() == (len(self.args.gate_to_index)) - 1), 'The gate types are not consistent.' 88 | graph = circuit_parse_pyg(x, edge_index, y, self.args.use_edge_attr, \ 89 | self.args.reconv_skip_connection, self.args.logic_diff_embedding, self.args.predict_diff, \ 90 | self.args.diff_multiplier, self.args.no_node_cop, self.args.node_reconv, self.args.un_directed, self.args.num_gate_types, self.args.dim_edge_feature, self.args.logic_implication, self.args.mask) 91 | graph.name = cir_name 92 | data_list.append(graph) 93 | if self.args.small_train and cir_idx > subset: 94 | break 95 | 96 | 97 | data, slices = self.collate(data_list) 98 | torch.save((data, slices), self.processed_paths[0]) 99 | print('[INFO] Inmemory dataset save: ', self.processed_paths[0]) 100 | 101 | def __repr__(self) -> str: 102 | return f'{self.name}({len(self)})' -------------------------------------------------------------------------------- /src/get_emb_bench.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import os 7 | from progress.bar import Bar 8 | import random 9 | import time 10 | import torch 11 | import glob 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | 15 | import utils.circuit_utils as circuit_utils 16 | from config import get_parse_args, update_dir 17 | from utils.random_seed import set_seed 18 | from detectors.detector_factory import detector_factory 19 | from datasets.load_data import parse_pyg_mlpgate 20 | 21 | BENCH_DIR = "../dataset/EPFL/" 22 | TMP_DIR = "../tmp" 23 | EMB_DIR = "../emb" 24 | BENCH_NAMELIST = [] 25 | 26 | def save_emb(emb, prob, path): 27 | f = open(path, 'w') 28 | f.write('{} {}\n'.format(len(emb), len(emb[0]))) 29 | for i in range(len(emb)): 30 | for j in range(len(emb[i])): 31 | f.write('{:.6f} '.format(float(emb[i][j]))) 32 | f.write('\n') 33 | for i in range(len(prob)): 34 | f.write('{:.6f}\n'.format(float(prob[i]))) 35 | f.close() 36 | 37 | def gen_graph(args, x_data, edge_index): 38 | x_data = np.array(x_data) 39 | edge_index = np.array(edge_index) 40 | tt_dis = [] 41 | min_tt_dis = [] 42 | tt_pair_index = [] 43 | prob = [0] * len(x_data) 44 | rc_pair_index = [[0, 1]] 45 | is_rc = [] 46 | g = parse_pyg_mlpgate( 47 | x_data, edge_index, tt_dis, min_tt_dis, tt_pair_index, prob, rc_pair_index, is_rc, 48 | args.use_edge_attr, args.reconv_skip_connection, args.no_node_cop, 49 | args.node_reconv, args.un_directed, args.num_gate_types, 50 | args.dim_edge_feature, args.logic_implication, args.mask 51 | ) 52 | return g 53 | 54 | ################################################################## 55 | # API 56 | ################################################################## 57 | def get_emb(exp_id, bench_filepath, emb_filepath, arch='mlpgnn', aggr='tfmlp', display=1): 58 | args = get_parse_args() 59 | args = update_dir(args, exp_id) 60 | args.arch = arch 61 | args.aggr_function = aggr 62 | args.batch_size = 1 63 | args.resume = True 64 | 65 | detector = detector_factory['base'](args) 66 | bench_name = bench_filepath.split('/')[-1].split('.')[0] 67 | x_data, edge_index, fanin_list, fanout_list, level_list = circuit_utils.parse_bench(bench_filepath, args.gate_to_index) 68 | if len(x_data) == 0: 69 | return 0 70 | if display: 71 | print('Parse AIG: ', bench_filepath) 72 | 73 | # Generate graph 74 | g = gen_graph(args, x_data, edge_index) 75 | g.to(args.device) 76 | 77 | # Model 78 | start_time = time.time() 79 | res = detector.run(g) 80 | end_time = time.time() 81 | hs, hf, prob, is_rc = res['results'] 82 | if display: 83 | print("Circuit: {}, Size: {:}, Time: {:.2f} s".format(bench_name, len(x_data), end_time-start_time)) 84 | print() 85 | 86 | # Save emb 87 | save_emb(hf.detach().cpu().numpy(), prob.detach().cpu().numpy(), emb_filepath) 88 | return 1 89 | ################################################################## 90 | 91 | def test(args): 92 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str 93 | detector = detector_factory['base'](args) 94 | if len(BENCH_NAMELIST) == 0: 95 | for filename in glob.glob(os.path.join(BENCH_DIR, '*.bench')): 96 | bench_name = filename.split('/')[-1].split('.')[0] 97 | BENCH_NAMELIST.append(bench_name) 98 | 99 | for bench_name in BENCH_NAMELIST: 100 | bench_filepath = os.path.join(BENCH_DIR, bench_name + '.bench') 101 | x_data, edge_index, fanin_list, fanout_list, level_list = circuit_utils.parse_bench(bench_filepath, args.gate_to_index) 102 | if len(x_data) == 0: 103 | continue 104 | # fanin_list, fanout_list = circuit_utils.get_fanin_fanout(x_data, edge_index) 105 | # level_list = circuit_utils.get_level(x_data, fanin_list, fanout_list) 106 | print('Parse AIG: ', bench_filepath) 107 | 108 | # Generate graph 109 | g = gen_graph(args, x_data, edge_index) 110 | g.to(args.device) 111 | 112 | # Model 113 | start_time = time.time() 114 | res = detector.run(g) 115 | end_time = time.time() 116 | hs, hf, prob, is_rc = res['results'] 117 | print("Circuit: {}, Size: {:}, Time: {:.2f} s".format(bench_name, len(x_data), end_time-start_time)) 118 | print() 119 | 120 | # Save emb 121 | emb_path = os.path.join(EMB_DIR, bench_name + '.txt') 122 | save_emb(hf.detach().cpu().numpy(), prob.detach().cpu().numpy(), emb_path) 123 | 124 | if __name__ == '__main__': 125 | args = get_parse_args() 126 | set_seed(args) 127 | test(args) 128 | -------------------------------------------------------------------------------- /src/models/gat_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional 4 | from torch import Tensor 5 | from torch_geometric.typing import Adj, OptTensor 6 | from torch_geometric.utils import softmax 7 | from torch_geometric.nn.glob import * 8 | from torch_geometric.nn import MessagePassing 9 | 10 | from .mlp import MLP 11 | 12 | 13 | 14 | class AGNNConv(MessagePassing): 15 | ''' 16 | Additive form of GAT from DAGNN paper. 17 | 18 | In order to do the fair comparison with DeepSet. I add a FC-based layer before doing the attention. 19 | ''' 20 | def __init__(self, in_channels, ouput_channels=None, wea=False, mlp=None, reverse=False): 21 | super(AGNNConv, self).__init__(aggr='add', flow='target_to_source' if reverse else 'source_to_target') 22 | if ouput_channels is None: 23 | ouput_channels = in_channels 24 | assert (in_channels > 0) and (ouput_channels > 0), 'The dimension for the DeepSetConv should be larger than 0.' 25 | 26 | self.wea = wea 27 | if self.wea: 28 | # fix the size of edge_attributes now 29 | self.edge_encoder = nn.Linear(16, ouput_channels) 30 | 31 | # linear transformation 32 | # self.msg = MLP(in_channels, ouput_channels, ouput_channels, num_layer=3, p_drop=0.2) if mlp is None else mlp 33 | self.msg = nn.Linear(in_channels, ouput_channels) 34 | 35 | # attention 36 | attn_dim = ouput_channels 37 | self.attn_lin = nn.Linear(ouput_channels + attn_dim, 1) 38 | 39 | 40 | # h_attn_q is needed; h_attn, edge_attr are optional (we just use kwargs to be able to switch node aggregator above) 41 | def forward(self, x, edge_index, edge_attr=None, **kwargs): 42 | 43 | # Step 2: Linearly transform node feature matrix. 44 | # h = self.msg(x) 45 | 46 | return self.propagate(edge_index, x=x, edge_attr=edge_attr) 47 | 48 | def message(self, x_i, x_j, edge_attr, index: Tensor, ptr: OptTensor, size_i: Optional[int]): 49 | # h_i: query, h_j: key, value 50 | h_attn_q_i = self.msg(x_i) 51 | h_attn = self.msg(x_j) 52 | # see comment in above self attention why this is done here and not in forward 53 | if self.wea: 54 | edge_embedding = self.edge_encoder(edge_attr) 55 | h_attn = h_attn + edge_embedding 56 | a_j = self.attn_lin(torch.cat([h_attn_q_i, h_attn], dim=-1)) 57 | a_j = softmax(a_j, index, ptr, size_i) 58 | t = h_attn * a_j 59 | return t 60 | 61 | def update(self, aggr_out): 62 | return aggr_out 63 | 64 | # ''' 65 | # The attention in dot-product form. Modified version of: 66 | # https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.AGNNConv 67 | # ''' 68 | # class AGNNConv(MessagePassing): 69 | # r"""The graph attentional propagation layer from the 70 | # `"Attention-based Graph Neural Network for Semi-Supervised Learning" 71 | # `_ paper 72 | 73 | # .. math:: 74 | # \mathbf{X}^{\prime} = \mathbf{P} \mathbf{X}, 75 | 76 | # where the propagation matrix :math:`\mathbf{P}` is computed as 77 | 78 | # .. math:: 79 | # P_{i,j} = \frac{\exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_j))} 80 | # {\sum_{k \in \mathcal{N}(i)\cup \{ i \}} \exp( \beta \cdot 81 | # \cos(\mathbf{x}_i, \mathbf{x}_k))} 82 | 83 | # with trainable parameter :math:`\beta`. 84 | 85 | # Args: 86 | # requires_grad (bool, optional): If set to :obj:`False`, :math:`\beta` 87 | # will not be trainable. (default: :obj:`True`) 88 | # add_self_loops (bool, optional): If set to :obj:`False`, will not add 89 | # self-loops to the input graph. (default: :obj:`True`) 90 | # **kwargs (optional): Additional arguments of 91 | # :class:`torch_geometric.nn.conv.MessagePassing`. 92 | # """ 93 | # def __init__(self, dim_emb, reverse=False): 94 | # super(AGNNConv, self).__init__(aggr='add', flow='target_to_source' if reverse else 'source_to_target') 95 | 96 | # self.lin = torch.nn.Linear(dim_emb, dim_emb) 97 | 98 | 99 | # def forward(self, x: Tensor, edge_index: Adj) -> Tensor: 100 | 101 | # x_norm = F.normalize(self.lin(x), p=2., dim=-1) 102 | 103 | # # propagate_type: (x: Tensor, x_norm: Tensor) 104 | # return self.propagate(edge_index, x=x, x_norm=x_norm, size=None) 105 | 106 | 107 | # def message(self, x_j: Tensor, x_norm_i: Tensor, x_norm_j: Tensor, 108 | # index: Tensor, ptr: OptTensor, 109 | # size_i: Optional[int]) -> Tensor: 110 | 111 | # alpha = (x_norm_i * x_norm_j).sum(dim=-1) 112 | # alpha = softmax(alpha, index, ptr, size_i) 113 | # return x_j * alpha.view(-1, 1) 114 | 115 | # def __repr__(self): 116 | # return '{}()'.format(self.__class__.__name__) 117 | -------------------------------------------------------------------------------- /src/datasets/mlpgate_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, List 2 | import os.path as osp 3 | 4 | import torch 5 | import shutil 6 | import os 7 | import copy 8 | from torch_geometric.data import Data 9 | from torch_geometric.data import InMemoryDataset 10 | 11 | from utils.data_utils import read_npz_file 12 | from .load_data import parse_pyg_mlpgate 13 | 14 | 15 | class MLPGateDataset(InMemoryDataset): 16 | r""" 17 | A variety of circuit graph datasets, *e.g.*, open-sourced benchmarks, 18 | random circuits. 19 | 20 | Args: 21 | root (string): Root directory where the dataset should be saved. 22 | args (object): The arguments specified by the main program. 23 | transform (callable, optional): A function/transform that takes in an 24 | :obj:`torch_geometric.data.Data` object and returns a transformed 25 | version. The data object will be transformed before every access. 26 | (default: :obj:`None`) 27 | pre_transform (callable, optional): A function/transform that takes in 28 | an :obj:`torch_geometric.data.Data` object and returns a 29 | transformed version. The data object will be transformed before 30 | being saved to disk. (default: :obj:`None`) 31 | pre_filter (callable, optional): A function that takes in an 32 | :obj:`torch_geometric.data.Data` object and returns a boolean 33 | value, indicating whether the data object should be included in the 34 | final dataset. (default: :obj:`None`) 35 | """ 36 | 37 | def __init__(self, root, args, transform=None, pre_transform=None, pre_filter=None): 38 | self.name = 'MIG' 39 | self.args = args 40 | 41 | assert (transform == None) and (pre_transform == None) and (pre_filter == None), "Cannot accept the transform, pre_transfrom and pre_filter args now." 42 | 43 | # Reload 44 | inmemory_dir = os.path.join(args.data_dir, 'inmemory') 45 | if args.reload_dataset and os.path.exists(inmemory_dir): 46 | shutil.rmtree(inmemory_dir) 47 | 48 | super().__init__(root, transform, pre_transform, pre_filter) 49 | self.data, self.slices = torch.load(self.processed_paths[0]) 50 | 51 | @property 52 | def raw_dir(self): 53 | return self.root 54 | 55 | @property 56 | def processed_dir(self): 57 | if self.args.small_train: 58 | name = 'inmemory_small' 59 | else: 60 | name = 'inmemory' 61 | if self.args.no_rc: 62 | name += '_norc' 63 | return osp.join(self.root, name) 64 | 65 | @property 66 | def raw_file_names(self) -> List[str]: 67 | return [self.args.circuit_file, self.args.label_file] 68 | 69 | @property 70 | def processed_file_names(self) -> str: 71 | return ['data.pt'] 72 | 73 | def download(self): 74 | pass 75 | 76 | def process(self): 77 | data_list = [] 78 | tot_pairs = 0 79 | circuits = read_npz_file(self.args.circuit_file, self.args.data_dir)['circuits'].item() 80 | labels = read_npz_file(self.args.label_file, self.args.data_dir)['labels'].item() 81 | 82 | if self.args.small_train: 83 | subset = 100 84 | 85 | for cir_idx, cir_name in enumerate(circuits): 86 | print('Parse circuit: {}, {:} / {:} = {:.2f}%'.format(cir_name, cir_idx, len(circuits), cir_idx / len(circuits) * 100)) 87 | x = circuits[cir_name]["x"] 88 | edge_index = circuits[cir_name]["edge_index"] 89 | 90 | tt_dis = labels[cir_name]['tt_dis'] 91 | min_tt_dis = labels[cir_name]['min_tt_dis'] 92 | tt_pair_index = labels[cir_name]['tt_pair_index'] 93 | prob = labels[cir_name]['prob'] 94 | 95 | if self.args.no_rc: 96 | rc_pair_index = [[0, 1]] 97 | is_rc = [0] 98 | else: 99 | rc_pair_index = labels[cir_name]['rc_pair_index'] 100 | is_rc = labels[cir_name]['is_rc'] 101 | 102 | if len(tt_pair_index) == 0 or len(rc_pair_index) == 0: 103 | print('No tt or rc pairs: ', cir_name) 104 | continue 105 | 106 | tot_pairs += len(tt_dis) 107 | 108 | # check the gate types 109 | # assert (x[:, 1].max() == (len(self.args.gate_to_index)) - 1), 'The gate types are not consistent.' 110 | graph = parse_pyg_mlpgate( 111 | x, edge_index, tt_dis, min_tt_dis, tt_pair_index, prob, rc_pair_index, is_rc, 112 | self.args.use_edge_attr, self.args.reconv_skip_connection, self.args.no_node_cop, 113 | self.args.node_reconv, self.args.un_directed, self.args.num_gate_types, 114 | self.args.dim_edge_feature, self.args.logic_implication, self.args.mask 115 | ) 116 | graph.name = cir_name 117 | data_list.append(graph) 118 | if self.args.small_train and cir_idx > subset: 119 | break 120 | 121 | data, slices = self.collate(data_list) 122 | torch.save((data, slices), self.processed_paths[0]) 123 | print('[INFO] Inmemory dataset save: ', self.processed_paths[0]) 124 | print('Total Circuits: {:} Total pairs: {:}'.format(len(data_list), tot_pairs)) 125 | 126 | def __repr__(self) -> str: 127 | return f'{self.name}({len(self)})' -------------------------------------------------------------------------------- /src/utils/dag_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy 4 | 5 | 6 | 7 | # see https://github.com/unbounce/pytorch-tree-lstm/blob/66f29a44e98c7332661b57d22501107bcb193f90/treelstm/util.py#L8 8 | # assume nodes consecutively named starting at 0 9 | # 10 | def top_sort(edge_index, graph_size): 11 | 12 | node_ids = numpy.arange(graph_size, dtype=int) 13 | 14 | node_order = numpy.zeros(graph_size, dtype=int) 15 | unevaluated_nodes = numpy.ones(graph_size, dtype=bool) 16 | 17 | parent_nodes = edge_index[0] 18 | child_nodes = edge_index[1] 19 | 20 | n = 0 21 | while unevaluated_nodes.any(): 22 | # Find which parent nodes have not been evaluated 23 | unevaluated_mask = unevaluated_nodes[parent_nodes] 24 | 25 | # Find the child nodes of unevaluated parents 26 | unready_children = child_nodes[unevaluated_mask] 27 | 28 | # Mark nodes that have not yet been evaluated 29 | # and which are not in the list of children with unevaluated parent nodes 30 | nodes_to_evaluate = unevaluated_nodes & ~numpy.isin(node_ids, unready_children) 31 | 32 | node_order[nodes_to_evaluate] = n 33 | unevaluated_nodes[nodes_to_evaluate] = False 34 | 35 | n += 1 36 | 37 | return torch.from_numpy(node_order).long() 38 | 39 | 40 | # to be able to use pyg's batch split everything into 1-dim tensors 41 | def add_order_info_01(graph): 42 | 43 | l0 = top_sort(graph.edge_index, graph.num_nodes) 44 | ei2 = torch.LongTensor([list(graph.edge_index[1]), list(graph.edge_index[0])]) 45 | l1 = top_sort(ei2, graph.num_nodes) 46 | ns = torch.LongTensor([i for i in range(graph.num_nodes)]) 47 | 48 | graph.__setattr__("_bi_layer_idx0", l0) 49 | graph.__setattr__("_bi_layer_index0", ns) 50 | graph.__setattr__("_bi_layer_idx1", l1) 51 | graph.__setattr__("_bi_layer_index1", ns) 52 | 53 | assert_order(graph.edge_index, l0, ns) 54 | assert_order(ei2, l1, ns) 55 | 56 | 57 | def assert_order(edge_index, o, ns): 58 | # already processed 59 | proc = [] 60 | for i in range(max(o)+1): 61 | # nodes in position i in order 62 | l = o == i 63 | l = ns[l].tolist() 64 | for n in l: 65 | # predecessors 66 | ps = edge_index[0][edge_index[1] == n].tolist() 67 | for p in ps: 68 | assert p in proc 69 | proc += l 70 | 71 | 72 | def add_order_info(graph): 73 | ns = torch.LongTensor([i for i in range(graph.num_nodes)]) 74 | layers = torch.stack([top_sort(graph.edge_index, graph.num_nodes), ns], dim=0) 75 | ei2 = torch.LongTensor([list(graph.edge_index[1]), list(graph.edge_index[0])]) 76 | layers2 = torch.stack([top_sort(ei2, graph.num_nodes), ns], dim=0) 77 | 78 | graph.__setattr__("bi_layer_index", torch.stack([layers, layers2], dim=0)) 79 | 80 | def return_order_info(edge_index, num_nodes): 81 | ns = torch.LongTensor([i for i in range(num_nodes)]) 82 | forward_level = top_sort(edge_index, num_nodes) 83 | ei2 = torch.LongTensor([list(edge_index[1]), list(edge_index[0])]) 84 | backward_level = top_sort(ei2, num_nodes) 85 | forward_index = ns 86 | backward_index = torch.LongTensor([i for i in range(num_nodes)]) 87 | 88 | return forward_level, forward_index, backward_level, backward_index 89 | 90 | 91 | def subgraph(target_idx, edge_index, edge_attr=None, dim=0): 92 | ''' 93 | function from DAGNN 94 | ''' 95 | le_idx = [] 96 | for n in target_idx: 97 | ne_idx = edge_index[dim] == n 98 | le_idx += [ne_idx.nonzero().squeeze(-1)] 99 | le_idx = torch.cat(le_idx, dim=-1) 100 | lp_edge_index = edge_index[:, le_idx] 101 | if edge_attr is not None: 102 | lp_edge_attr = edge_attr[le_idx, :] 103 | else: 104 | lp_edge_attr = None 105 | return lp_edge_index, lp_edge_attr 106 | 107 | def custom_backward_subgraph(l_node, edge_index, device, dim=0): 108 | ''' 109 | The custom backward subgraph extraction. 110 | During backwarding, we consider the side inputs of the target nodes as well. 111 | 112 | This function hasn't been checked yet. 113 | ''' 114 | 115 | # Randomly choose one predecessor in edges 116 | lp_edge_index = torch.Tensor().to(device=device) 117 | for n in l_node: 118 | ne_idx = edge_index[dim] == n 119 | 120 | subset_edges = torch.masked_select(edge_index, ne_idx).reshape(edge_index.shape[0], -1) 121 | 122 | pos_count = torch.count_nonzero(ne_idx) 123 | random_predecessor = random.randint(0, pos_count - 1) 124 | 125 | indices = torch.tensor([random_predecessor], device=device) 126 | subset_edges = torch.index_select(subset_edges, 1, indices) 127 | 128 | lp_edge_index = torch.cat((lp_edge_index, subset_edges), dim=1) 129 | 130 | lp_edge_index = lp_edge_index.to(torch.long) 131 | 132 | # collect successors of selected (random) predecessor 133 | 134 | updated_edges = lp_edge_index 135 | for n in l_node: 136 | n_vec = torch.tensor([n], device=device) 137 | 138 | ne = lp_edge_index[0] == n 139 | predecessor = lp_edge_index[1][ne] 140 | 141 | se = edge_index[1] == predecessor 142 | successors = edge_index[0][se] 143 | 144 | for s in successors: 145 | 146 | if s != n: 147 | s_vec = torch.tensor([s], device=device) 148 | new_edge = (torch.stack((n_vec, s_vec), dim=0)) 149 | updated_edges = torch.cat((updated_edges, new_edge), dim=1) 150 | 151 | updated_edges = updated_edges.to(torch.long) 152 | return updated_edges -------------------------------------------------------------------------------- /src/trains/base_trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | from torch_geometric.nn import DataParallel 9 | from progress.bar import Bar 10 | from utils.utils import AverageMeter 11 | 12 | _loss_factory = { 13 | 'l1': nn.L1Loss, 14 | 'sl1': nn.SmoothL1Loss, 15 | 'l2': nn.MSELoss, 16 | } 17 | 18 | class ModelWithLoss(torch.nn.Module): 19 | def __init__(self, model, loss, gpus, device): 20 | super(ModelWithLoss, self).__init__() 21 | self.model = model 22 | self.loss = loss 23 | self.gpus = gpus 24 | self.device = device 25 | 26 | def forward(self, batch): 27 | outputs = self.model(batch) 28 | loss = self.loss(outputs.to(self.device), batch.y.to(self.device)) 29 | loss_stats = {'loss': loss} 30 | return outputs, loss, loss_stats 31 | 32 | 33 | class BaseTrainer(object): 34 | def __init__( 35 | self, args, model, optimizer=None): 36 | self.args = args 37 | self.optimizer = optimizer 38 | self.loss_stats, self.loss = self._get_losses(args.reg_loss) 39 | self.loss = self.loss.to(self.args.device) 40 | self.model_with_loss = ModelWithLoss(model, self.loss, args.gpus, args.device) 41 | 42 | def set_device(self, device, local_rank, gpus): 43 | if len(gpus)> 1: 44 | self.model_with_loss = self.model_with_loss.to(device) 45 | self.model_with_loss = nn.parallel.DistributedDataParallel(self.model_with_loss, 46 | device_ids=[local_rank]) 47 | else: 48 | self.model_with_loss = self.model_with_loss.to(device) 49 | 50 | for state in self.optimizer.state.values(): 51 | for k, v in state.items(): 52 | if isinstance(v, torch.Tensor): 53 | state[k] = v.to(device=device, non_blocking=True) 54 | 55 | def run_epoch(self, phase, epoch, dataset, local_rank): 56 | model_with_loss = self.model_with_loss 57 | if phase == 'train': 58 | model_with_loss.train() 59 | else: 60 | if len(self.args.gpus) > 1: 61 | model_with_loss = self.model_with_loss.module 62 | model_with_loss.eval() 63 | torch.cuda.empty_cache() 64 | 65 | args = self.args 66 | results = {} 67 | data_time, batch_time = AverageMeter(), AverageMeter() 68 | avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} 69 | num_iters = len(dataset) if args.num_iters < 0 else args.num_iters 70 | if local_rank == 0: 71 | bar = Bar('{}/{}'.format(args.task, args.exp_id), max=num_iters) 72 | end = time.time() 73 | for iter_id, batch in enumerate(dataset): 74 | if iter_id >= num_iters: 75 | break 76 | if len(self.args.gpus) == 1: 77 | batch = batch.to(self.args.device) 78 | data_time.update(time.time() - end) 79 | output, loss, loss_stats = model_with_loss(batch) 80 | loss = loss.mean() 81 | if phase == 'train': 82 | self.optimizer.zero_grad() 83 | loss.backward() 84 | if args.grad_clip > 0: 85 | torch.nn.utils.clip_grad_norm_(model_with_loss.parameters(), args.grad_clip) 86 | self.optimizer.step() 87 | batch_time.update(time.time() - end) 88 | end = time.time() 89 | if local_rank == 0: 90 | Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( 91 | epoch, iter_id, num_iters, phase=phase, 92 | total=bar.elapsed_td, eta=bar.eta_td) 93 | for l in avg_loss_stats: 94 | avg_loss_stats[l].update( 95 | loss_stats[l].mean().item(), batch.num_graphs * len(output)) 96 | Bar.suffix = Bar.suffix + \ 97 | '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) 98 | if not args.hide_data_time: 99 | Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ 100 | '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) 101 | if args.print_iter > 0: 102 | if iter_id % args.print_iter == 0: 103 | print('{}/{}| {}'.format(args.task, args.exp_id, Bar.suffix)) 104 | else: 105 | bar.next() 106 | del output, loss, loss_stats 107 | 108 | 109 | ret = {k: v.avg for k, v in avg_loss_stats.items()} 110 | if local_rank == 0: 111 | bar.finish() 112 | ret['time'] = bar.elapsed_td.total_seconds() / 60. 113 | return ret, results 114 | 115 | def debug(self, batch, output, iter_id): 116 | raise NotImplementedError 117 | 118 | def save_result(self, output, batch, results): 119 | raise NotImplementedError 120 | 121 | def _get_losses(self, loss): 122 | if loss in _loss_factory.keys(): 123 | loss = _loss_factory[loss]() 124 | else: 125 | raise KeyError 126 | loss_states = ['loss'] 127 | return loss_states, loss 128 | 129 | def val(self, epoch, data_loader, local_rank): 130 | return self.run_epoch('val', epoch, data_loader, local_rank) 131 | 132 | def train(self, epoch, data_loader, local_rank): 133 | return self.run_epoch('train', epoch, data_loader, local_rank) 134 | -------------------------------------------------------------------------------- /src/models/mlpgate_merge.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch import nn 7 | from utils.dag_utils import subgraph, custom_backward_subgraph 8 | from utils.utils import generate_hs_init 9 | 10 | from .mlp import MLP 11 | from .mlp_aggr import MlpAggr 12 | 13 | from torch.nn import LSTM, GRU 14 | 15 | _update_function_factory = { 16 | 'lstm': LSTM, 17 | 'gru': GRU, 18 | } 19 | 20 | class MLPGate_Merge(nn.Module): 21 | ''' 22 | Recurrent Graph Neural Networks for Circuits. 23 | ''' 24 | def __init__(self, args): 25 | super(MLPGate_Merge, self).__init__() 26 | 27 | self.args = args 28 | 29 | # configuration 30 | self.num_rounds = args.num_rounds 31 | self.device = args.device 32 | self.predict_diff = args.predict_diff 33 | self.intermediate_supervision = args.intermediate_supervision 34 | self.reverse = args.reverse 35 | self.custom_backward = args.custom_backward 36 | self.use_edge_attr = args.use_edge_attr 37 | self.mask = args.mask 38 | 39 | # dimensions 40 | self.num_aggr = args.num_aggr 41 | self.dim_node_feature = args.dim_node_feature 42 | self.dim_hidden = args.dim_hidden 43 | self.dim_mlp = args.dim_mlp 44 | self.dim_pred = args.dim_pred 45 | self.num_fc = args.num_fc 46 | self.wx_update = args.wx_update 47 | self.wx_mlp = args.wx_mlp 48 | self.dim_edge_feature = args.dim_edge_feature 49 | 50 | # Network 51 | self.aggr_and = MlpAggr(self.dim_hidden*1, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 52 | self.aggr_not = MlpAggr(self.dim_hidden*1, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 53 | self.update_and = GRU(self.dim_hidden, self.dim_hidden) 54 | self.update_not = GRU(self.dim_hidden, self.dim_hidden) 55 | 56 | # Readout 57 | self.readout_prob = MLP(self.dim_hidden, args.dim_mlp, 1, num_layer=3, p_drop=0.2, norm_layer='batchnorm', act_layer='relu') 58 | self.readout_rc = MLP(self.dim_hidden * 2, args.dim_mlp, 1, num_layer=3, p_drop=0.2, norm_layer='batchnorm', sigmoid=True) 59 | 60 | # consider the embedding for the LSTM/GRU model initialized by non-zeros 61 | self.one = torch.ones(1).to(self.device) 62 | # self.hs_emd_int = nn.Linear(1, self.dim_hidden) 63 | self.hf_emd_int = nn.Linear(1, self.dim_hidden) 64 | self.one.requires_grad = False 65 | 66 | def forward(self, G): 67 | num_nodes = G.num_nodes 68 | num_layers_f = max(G.forward_level).item() + 1 69 | num_layers_b = max(G.backward_level).item() + 1 70 | 71 | # initialize the hidden state 72 | if self.args.disable_encode: 73 | h_init = torch.zeros(num_nodes, self.dim_hidden) 74 | max_sim = 0 75 | h_init = h_init.to(self.device) 76 | else: 77 | h_init = torch.zeros(num_nodes, self.dim_hidden) 78 | h_init, max_sim = generate_hs_init(G, h_init, self.dim_hidden) 79 | h_init = h_init.to(self.device) 80 | 81 | preds = self._gru_forward(G, h_init, num_layers_f, num_layers_b) 82 | 83 | return preds, max_sim 84 | 85 | def _gru_forward(self, G, h_init, num_layers_f, num_layers_b, h_true=None, h_false=None): 86 | G = G.to(self.device) 87 | x, edge_index = G.x, G.edge_index 88 | edge_attr = G.edge_attr if self.use_edge_attr else None 89 | 90 | node_state = h_init.to(self.device) 91 | and_mask = G.gate.squeeze(1) == 1 92 | not_mask = G.gate.squeeze(1) == 2 93 | 94 | for _ in range(self.num_rounds): 95 | for level in range(1, num_layers_f): 96 | # forward layer 97 | layer_mask = G.forward_level == level 98 | 99 | # AND Gate 100 | l_and_node = G.forward_index[layer_mask & and_mask] 101 | if l_and_node.size(0) > 0: 102 | and_edge_index, and_edge_attr = subgraph(l_and_node, edge_index, edge_attr, dim=1) 103 | # Update structure hidden state 104 | msg = self.aggr_and(node_state, and_edge_index, and_edge_attr) 105 | and_msg = torch.index_select(msg, dim=0, index=l_and_node) 106 | hs_and = torch.index_select(node_state, dim=0, index=l_and_node) 107 | _, hs_and = self.update_and(and_msg.unsqueeze(0), hs_and.unsqueeze(0)) 108 | node_state[l_and_node, :] = hs_and.squeeze(0) 109 | 110 | # NOT Gate 111 | l_not_node = G.forward_index[layer_mask & not_mask] 112 | if l_not_node.size(0) > 0: 113 | not_edge_index, not_edge_attr = subgraph(l_not_node, edge_index, edge_attr, dim=1) 114 | # Update structure hidden state 115 | msg = self.aggr_not(node_state, not_edge_index, not_edge_attr) 116 | not_msg = torch.index_select(msg, dim=0, index=l_not_node) 117 | hs_not = torch.index_select(node_state, dim=0, index=l_not_node) 118 | _, hs_not = self.update_not(not_msg.unsqueeze(0), hs_not.unsqueeze(0)) 119 | node_state[l_not_node, :] = hs_not.squeeze(0) 120 | 121 | node_embedding = node_state.squeeze(0) 122 | 123 | # Readout 124 | prob = self.readout_prob(node_embedding) 125 | rc_emb = torch.cat([node_embedding[G.rc_pair_index[0]], node_embedding[G.rc_pair_index[1]]], dim=1) 126 | is_rc = self.readout_rc(rc_emb) 127 | 128 | return [], node_embedding, prob, is_rc 129 | 130 | 131 | def imply_mask(self, G, h, h_true, h_false): 132 | true_mask = (G.mask == 1.0).unsqueeze(0) 133 | false_mask = (G.mask == 0.0).unsqueeze(0) 134 | normal_mask = (G.mask == -1.0).unsqueeze(0) 135 | h_mask = h * normal_mask + h_true * true_mask + h_false * false_mask 136 | return h_mask 137 | 138 | 139 | 140 | def get_mlp_gate_merged(args): 141 | return MLPGate_Merge(args) -------------------------------------------------------------------------------- /src/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Parse the AIG (in bench format) and truth table for each nodes 3 | 16-11-2022 4 | Note: 5 | gate_to_index = {'PI': 0, 'AND': 1, 'NOT': 2} 6 | x_data: 0 - Name, 1 - gate type, 2 - level, 3 - is RC, 4 - RC source node 7 | ''' 8 | 9 | import argparse 10 | import glob 11 | import os 12 | import sys 13 | import platform 14 | import time 15 | import numpy as np 16 | from collections import Counter 17 | 18 | import utils.circuit_utils as circuit_utils 19 | import utils.utils as utils 20 | 21 | # aig_folder = './rawaig/' 22 | NO_PATTERNS = 15000 23 | 24 | gate_to_index = {'INPUT': 0, 'AND': 1, 'NOT': 2} 25 | MIN_LEVEL = 3 26 | MIN_PI_SIZE = 4 27 | MAX_INCLUDE = 1.5 28 | MAX_PROB_GAP = 0.05 29 | MAX_LEVEL_GAP = 5 30 | 31 | MIDDLE_DIST_IGNORE = [0.2, 0.8] 32 | 33 | def get_parse_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--exp_id', default='train') 36 | parser.add_argument('--start_idx', default=0, type=int) 37 | parser.add_argument('--end_idx', default=100, type=int) 38 | parser.add_argument('--aig_folder', default='~/studio/dataset/rawaig') 39 | 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | def gen_tt_pair(x_data, fanin_list, fanout_list, level_list, tt_prob): 45 | tt_len = len(tt[0]) 46 | pi_cone_list = [] 47 | for idx in range(len(x_data)): 48 | pi_cone_list.append([]) 49 | 50 | # Get pre fanout 51 | for level in range(len(level_list)): 52 | if level == 0: 53 | for idx in level_list[level]: 54 | pi_cone_list[idx].append(idx) 55 | else: 56 | for idx in level_list[level]: 57 | for fanin_idx in fanin_list[idx]: 58 | pi_cone_list[idx] += pi_cone_list[fanin_idx] 59 | pre_dist = Counter(pi_cone_list[idx]) 60 | pi_cone_list[idx] = list(pre_dist.keys()) 61 | 62 | # Pair 63 | tt_pair_index = [] 64 | tt_dis = [] 65 | min_tt_dis = [] 66 | for i in range(len(x_data)): 67 | if x_data[i][2] < MIN_LEVEL or len(pi_cone_list[i]) < MIN_PI_SIZE: 68 | continue 69 | for j in range(i+1, len(x_data), 1): 70 | if x_data[j][2] < MIN_LEVEL or len(pi_cone_list[j]) < MIN_PI_SIZE: 71 | continue 72 | # Cond. 2: probability 73 | if abs(tt_prob[i] - tt_prob[j]) > MAX_PROB_GAP: 74 | continue 75 | # Cond. 1: Level 76 | if abs(x_data[i][2] - x_data[j][2]) > MAX_LEVEL_GAP: 77 | continue 78 | 79 | # Cond. 5: Include 80 | if pi_cone_list[i] != pi_cone_list[j]: 81 | continue 82 | 83 | distance = np.array(tt[i]) - np.array(tt[j]) 84 | distance_value = np.linalg.norm(distance, ord=1) / tt_len 85 | 86 | # Cond. 4: Extreme distance 87 | if distance_value > MIDDLE_DIST_IGNORE[0] and distance_value < MIDDLE_DIST_IGNORE[1]: 88 | continue 89 | 90 | tt_pair_index.append([i, j]) 91 | tt_dis.append(distance_value) 92 | distance_e = (1-np.array(tt[i])) - np.array(tt[j]) 93 | min_distance = min(np.linalg.norm(distance, ord=1), np.linalg.norm(distance_e, ord=1)) 94 | min_tt_dis.append(min_distance / tt_len) 95 | 96 | return tt_pair_index, tt_dis, min_tt_dis 97 | 98 | if __name__ == '__main__': 99 | graphs = {} 100 | labels = {} 101 | args = get_parse_args() 102 | output_folder = '../data/{}'.format(args.exp_id) 103 | if not os.path.exists(output_folder): 104 | os.mkdir(output_folder) 105 | 106 | tot_circuit = 0 107 | cir_idx = 0 108 | tot_nodes = 0 109 | tot_pairs = 0 110 | name_list = [] 111 | for mig_filename in glob.glob(os.path.join(args.aig_folder, '*.bench')): 112 | tot_circuit += 1 113 | name_list.append(mig_filename) 114 | for mig_filename in name_list[args.start_idx: min(args.end_idx, len(name_list))]: 115 | circuit_name = mig_filename.split('/')[-1].split('.')[0] 116 | 117 | x_data, edge_index, fanin_list, fanout_list, level_list = circuit_utils.parse_bench(mig_filename, gate_to_index) 118 | # PI 119 | PI_index = level_list[0] 120 | 121 | # Simulation 122 | start_time = time.time() 123 | if len(PI_index) < 13: 124 | tt = circuit_utils.simulator_truth_table(x_data, PI_index, level_list, fanin_list, gate_to_index) 125 | else: 126 | tt = circuit_utils.simulator_truth_table_random(x_data, PI_index, level_list, fanin_list, gate_to_index, NO_PATTERNS) 127 | y = [0] * len(x_data) 128 | for idx in range(len(x_data)): 129 | y[idx] = np.sum(tt[idx]) / len(tt[idx]) 130 | 131 | # Pair 132 | tt_pair_index, tt_dis, min_tt_dis = gen_tt_pair(x_data, fanin_list, fanout_list, level_list, y) 133 | end_time = time.time() 134 | 135 | # Save 136 | x_data = utils.rename_node(x_data) 137 | graphs[circuit_name] = {'x': np.array(x_data).astype('float32'), "edge_index": np.array(edge_index)} 138 | labels[circuit_name] = { 139 | 'tt_pair_index': np.array(tt_pair_index), 'tt_dis': np.array(tt_dis).astype('float32'), 140 | 'prob': np.array(y).astype('float32'), 141 | 'min_tt_dis': np.array(min_tt_dis).astype('float32'), 142 | } 143 | tot_nodes += len(x_data) 144 | tot_pairs += len(tt_dis) 145 | print('Save: {}, # PI: {:}, Tot Pairs: {:.1f}k, time: {:.2f} s ({:} / {:})'.format( 146 | circuit_name, len(PI_index), tot_pairs/1000, end_time - start_time, cir_idx, args.end_idx - args.start_idx 147 | )) 148 | 149 | if cir_idx != 0 and cir_idx % 1000 == 0: 150 | output_filename_circuit = os.path.join(output_folder, 'tmp_{:}_graphs.npz'.format(cir_idx)) 151 | output_filename_labels = os.path.join(output_folder, 'tmp_{:}_labels.npz'.format(cir_idx)) 152 | np.savez_compressed(output_filename_circuit, circuits=graphs) 153 | np.savez_compressed(output_filename_labels, labels=labels) 154 | cir_idx += 1 155 | 156 | output_filename_circuit = os.path.join(output_folder, 'graphs.npz') 157 | output_filename_labels = os.path.join(output_folder, 'labels.npz') 158 | print('# Graphs: {:}, # Nodes: {:}'.format(len(graphs), tot_nodes)) 159 | print('Total pairs: ', tot_pairs) 160 | np.savez_compressed(output_filename_circuit, circuits=graphs) 161 | np.savez_compressed(output_filename_labels, labels=labels) 162 | print(output_filename_circuit) 163 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch.nn as nn 6 | import math 7 | import copy 8 | import torch 9 | import random 10 | import numpy as np 11 | 12 | from .circuit_utils import random_pattern_generator, logic 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value""" 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | if self.count > 0: 30 | self.avg = self.sum / self.count 31 | 32 | def zero_normalization(x): 33 | mean_x = torch.mean(x) 34 | std_x = torch.std(x) 35 | z_x = (x - mean_x) / std_x 36 | return z_x 37 | 38 | class custom_DataParallel(nn.parallel.DataParallel): 39 | # define a custom DataParallel class to accomodate igraph inputs 40 | def __init__(self, module, device_ids=None, output_device=None, dim=0): 41 | super(custom_DataParallel, self).__init__(module, device_ids, output_device, dim) 42 | 43 | def scatter(self, inputs, kwargs, device_ids): 44 | # to overwride nn.parallel.scatter() to adapt igraph batch inputs 45 | G = inputs[0] 46 | scattered_G = [] 47 | n = math.ceil(len(G) / len(device_ids)) 48 | mini_batch = [] 49 | for i, g in enumerate(G): 50 | mini_batch.append(g) 51 | if len(mini_batch) == n or i == len(G)-1: 52 | scattered_G.append((mini_batch, )) 53 | mini_batch = [] 54 | return tuple(scattered_G), tuple([{}]*len(scattered_G)) 55 | 56 | def collate_fn(G): 57 | return [copy.deepcopy(g) for g in G] 58 | 59 | def pyg_simulation(g, pattern=[]): 60 | # PI, Level list 61 | max_level = 0 62 | PI_indexes = [] 63 | fanin_list = [] 64 | for idx, ele in enumerate(g.forward_level): 65 | level = int(ele) 66 | fanin_list.append([]) 67 | if level > max_level: 68 | max_level = level 69 | if level == 0: 70 | PI_indexes.append(idx) 71 | level_list = [] 72 | for level in range(max_level + 1): 73 | level_list.append([]) 74 | for idx, ele in enumerate(g.forward_level): 75 | level_list[int(ele)].append(idx) 76 | # Fanin list 77 | for k in range(len(g.edge_index[0])): 78 | src = g.edge_index[0][k] 79 | dst = g.edge_index[1][k] 80 | fanin_list[dst].append(src) 81 | 82 | ###################### 83 | # Simulation 84 | ###################### 85 | y = [0] * len(g.x) 86 | if len(pattern) == 0: 87 | pattern = random_pattern_generator(len(PI_indexes)) 88 | j = 0 89 | for i in PI_indexes: 90 | y[i] = pattern[j] 91 | j = j + 1 92 | for level in range(1, len(level_list), 1): 93 | for node_idx in level_list[level]: 94 | source_signals = [] 95 | for pre_idx in fanin_list[node_idx]: 96 | source_signals.append(y[pre_idx]) 97 | if len(source_signals) > 0: 98 | if int(g.x[node_idx][1]) == 1: 99 | gate_type = 1 100 | elif int(g.x[node_idx][2]) == 1: 101 | gate_type = 5 102 | else: 103 | raise("This is PI") 104 | y[node_idx] = logic(gate_type, source_signals) 105 | 106 | # Output 107 | if len(level_list[-1]) > 1: 108 | raise('Too many POs') 109 | return y[level_list[-1][0]], pattern 110 | 111 | def get_function_acc(g, node_emb): 112 | MIN_GAP = 0.05 113 | # Sample 114 | retry = 10000 115 | tri_sample_idx = 0 116 | correct = 0 117 | total = 0 118 | while tri_sample_idx < 100 and retry > 0: 119 | retry -= 1 120 | sample_pair_idx = torch.LongTensor(random.sample(range(len(g.tt_pair_index[0])), 2)) 121 | pair_0 = sample_pair_idx[0] 122 | pair_1 = sample_pair_idx[1] 123 | pair_0_gt = g.tt_dis[pair_0] 124 | pair_1_gt = g.tt_dis[pair_1] 125 | if pair_0_gt == pair_1_gt: 126 | continue 127 | if abs(pair_0_gt - pair_1_gt) < MIN_GAP: 128 | continue 129 | 130 | total += 1 131 | tri_sample_idx += 1 132 | pair_0_sim = torch.cosine_similarity(node_emb[g.tt_pair_index[0][pair_0]].unsqueeze(0), node_emb[g.tt_pair_index[1][pair_0]].unsqueeze(0), eps=1e-8) 133 | pair_1_sim = torch.cosine_similarity(node_emb[g.tt_pair_index[0][pair_1]].unsqueeze(0), node_emb[g.tt_pair_index[1][pair_1]].unsqueeze(0), eps=1e-8) 134 | pair_0_predDis = 1 - pair_0_sim 135 | pair_1_predDis = 1 - pair_1_sim 136 | succ = False 137 | if pair_0_gt > pair_1_gt and pair_0_predDis > pair_1_predDis: 138 | succ = True 139 | elif pair_0_gt < pair_1_gt and pair_0_predDis < pair_1_predDis: 140 | succ = True 141 | if succ: 142 | correct += 1 143 | 144 | if total > 0: 145 | acc = correct * 1.0 / total 146 | return acc 147 | return -1 148 | 149 | def generate_orthogonal_vectors(n, dim): 150 | # Generate an initial random vector 151 | v0 = np.random.randn(dim) 152 | v0 /= np.linalg.norm(v0) 153 | 154 | # Generate n-1 additional vectors 155 | vectors = [v0] 156 | for i in range(n-1): 157 | # Generate a random vector 158 | v = np.random.randn(dim) 159 | 160 | # Project the vector onto the subspace spanned by the previous vectors 161 | for j in range(i+1): 162 | v -= np.dot(v, vectors[j]) * vectors[j] 163 | 164 | # Normalize the vector 165 | v /= np.linalg.norm(v) 166 | 167 | # Append the vector to the list 168 | vectors.append(v) 169 | 170 | # calculate the max cosine similarity between any two vectors 171 | max_cos_sim = 0 172 | for i in range(n): 173 | for j in range(i+1, n): 174 | vi = vectors[i] 175 | vj = vectors[j] 176 | cos_sim = np.dot(vi, vj) / (np.linalg.norm(vi) * np.linalg.norm(vj)) 177 | if cos_sim > max_cos_sim: 178 | max_cos_sim = cos_sim 179 | 180 | return vectors, max_cos_sim 181 | 182 | def generate_hs_init(G, hs, no_dim): 183 | max_sim = 0 184 | if G.batch == None: 185 | batch_size = 1 186 | else: 187 | batch_size = G.batch.max().item() + 1 188 | for batch_idx in range(batch_size): 189 | if G.batch == None: 190 | pi_mask = (G.forward_level == 0) 191 | else: 192 | pi_mask = (G.batch == batch_idx) & (G.forward_level == 0) 193 | pi_node = G.forward_index[pi_mask] 194 | pi_vec, batch_max_sim = generate_orthogonal_vectors(len(pi_node), no_dim) 195 | if batch_max_sim > max_sim: 196 | max_sim = batch_max_sim 197 | hs[pi_node] = torch.tensor(pi_vec, dtype=torch.float) 198 | 199 | return hs, max_sim 200 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import shutil 7 | import torch 8 | from torch_geometric import data 9 | from torch_geometric.loader import DataLoader, DataListLoader 10 | 11 | from config import get_parse_args 12 | from models.model import create_model, load_model, save_model 13 | from utils.logger import Logger 14 | from utils.random_seed import set_seed 15 | from utils.circuit_utils import check_difference 16 | from trains.train_factory import train_factory 17 | from datasets.mig_dataset import MIGDataset 18 | from datasets.mlpgate_dataset import MLPGateDataset 19 | 20 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 21 | 22 | def main(args): 23 | print('==> Using settings {}'.format(args)) 24 | 25 | ################# 26 | # Device 27 | ################# 28 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str 29 | args.device = torch.device('cuda:0' if args.gpus[0] >= 0 else 'cpu') 30 | args.world_size = 1 31 | args.rank = 0 # global rank 32 | if args.device != 'cpu' and len(args.gpus) > 1: 33 | args.distributed = len(args.gpus) 34 | else: 35 | args.distributed = False 36 | if args.distributed: 37 | if 'LOCAL_RANK' in os.environ: 38 | args.local_rank = int(os.getenv('LOCAL_RANK')) 39 | args.device = 'cuda:%d' % args.local_rank 40 | torch.cuda.set_device(args.local_rank) 41 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 42 | args.world_size = torch.distributed.get_world_size() 43 | args.rank = torch.distributed.get_rank() 44 | print('Training in distributed mode. Device {}, Process {:}, total {:}.'.format( 45 | args.device, args.rank, args.world_size 46 | )) 47 | else: 48 | print('Training in single device: ', args.device) 49 | if args.local_rank == 0: 50 | logger = Logger(args, args.local_rank) 51 | load_model_path = os.path.join(args.save_dir, 'model_last.pth') 52 | if args.resume and not os.path.exists(load_model_path): 53 | if args.pretrained_path == '': 54 | raise "No pretrained model (.pth) found" 55 | else: 56 | shutil.copy(args.pretrained_path, load_model_path) 57 | print('Copy pth from: ', args.pretrained_path) 58 | 59 | ################# 60 | # Dataset 61 | ################# 62 | if args.local_rank == 0: 63 | print('==> Loading dataset from: ', args.data_dir) 64 | dataset = MLPGateDataset(args.data_dir, args) 65 | perm = torch.randperm(len(dataset)) 66 | dataset = dataset[perm] 67 | data_len = len(dataset) 68 | if args.local_rank == 0: 69 | print("Size: ", len(dataset)) 70 | print('Splitting the dataset into training and validation sets..') 71 | training_cutoff = int(data_len * args.trainval_split) 72 | if args.local_rank == 0: 73 | print('# training circuits: ', training_cutoff) 74 | print('# validation circuits: ', data_len - training_cutoff) 75 | train_dataset = [] 76 | val_dataset = [] 77 | train_dataset = dataset[:training_cutoff] 78 | val_dataset = dataset[training_cutoff:] 79 | 80 | train_sampler = torch.utils.data.distributed.DistributedSampler( 81 | train_dataset, 82 | num_replicas=args.world_size, 83 | rank=args.rank 84 | ) 85 | val_sampler = torch.utils.data.distributed.DistributedSampler( 86 | val_dataset, 87 | num_replicas=args.world_size, 88 | rank=args.rank 89 | ) 90 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, 91 | num_workers=args.num_workers, sampler=train_sampler) 92 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, 93 | sampler=val_sampler) 94 | 95 | ################# 96 | # Model 97 | ################# 98 | model = create_model(args) 99 | if args.local_rank == 0: 100 | print('==> Creating model...') 101 | print(model) 102 | 103 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 104 | start_epoch = 0 105 | if args.load_model != '': 106 | model, optimizer, start_epoch = load_model( 107 | model, args.load_model, optimizer, args.resume, args.lr, args.lr_step, args.local_rank, args.device) 108 | 109 | Trainer = train_factory[args.arch] 110 | trainer = Trainer(args, model, optimizer) 111 | trainer.set_device(args.device, args.local_rank, args.gpus) 112 | 113 | if args.val_only: 114 | log_dict_val, _ = trainer.val(0, val_loader) 115 | return 116 | 117 | if args.local_rank == 0: 118 | print('==> Starting training...') 119 | best = 1e10 120 | for epoch in range(start_epoch + 1, args.num_epochs + 1): 121 | mark = epoch if args.save_all else 'last' 122 | train_loader.sampler.set_epoch(epoch) 123 | log_dict_train, _ = trainer.train(epoch, train_loader, args.local_rank) 124 | if args.local_rank == 0: 125 | logger.write('epoch: {} |'.format(epoch), args.local_rank) 126 | for k, v in log_dict_train.items(): 127 | logger.scalar_summary('train_{}'.format(k), v, epoch, args.local_rank) 128 | logger.write('{} {:8f} | '.format(k, v), args.local_rank) 129 | if args.save_intervals > 0 and epoch % args.save_intervals == 0: 130 | save_model(os.path.join(args.save_dir, 'model_{}.pth'.format(mark)), 131 | epoch, model, optimizer) 132 | with torch.no_grad(): 133 | val_loader.sampler.set_epoch(0) 134 | log_dict_val, _ = trainer.val(epoch, val_loader, args.local_rank) 135 | 136 | if args.local_rank == 0: 137 | for k, v in log_dict_val.items(): 138 | logger.scalar_summary('val_{}'.format(k), v, epoch, args.local_rank) 139 | logger.write('{} {:8f} | '.format(k, v), args.local_rank) 140 | if log_dict_val[args.metric] < best: 141 | best = log_dict_val[args.metric] 142 | save_model(os.path.join(args.save_dir, 'model_best.pth'), 143 | epoch, model) 144 | else: 145 | save_model(os.path.join(args.save_dir, 'model_last.pth'), 146 | epoch, model, optimizer) 147 | if args.local_rank == 0: 148 | logger.write('\n', args.local_rank) 149 | if epoch in args.lr_step: 150 | if args.local_rank == 0: 151 | save_model(os.path.join(args.save_dir, 'model_{}.pth'.format(epoch)), 152 | epoch, model, optimizer) 153 | lr = args.lr * (0.1 ** (args.lr_step.index(epoch) + 1)) 154 | if args.local_rank == 0: 155 | print('Drop LR to', lr) 156 | for param_group in optimizer.param_groups: 157 | param_group['lr'] = lr 158 | 159 | if args.local_rank == 0: 160 | logger.close() 161 | 162 | 163 | if __name__ == '__main__': 164 | args = get_parse_args() 165 | set_seed(args) 166 | 167 | main(args) 168 | -------------------------------------------------------------------------------- /src/models/vectorgate.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch import nn 7 | from utils.dag_utils import subgraph, custom_backward_subgraph 8 | from utils.utils import generate_hs_init 9 | 10 | from .mlp import MLP 11 | from .mlp_aggr import MlpAggr 12 | 13 | from torch.nn import LSTM, GRU 14 | 15 | _update_function_factory = { 16 | 'lstm': LSTM, 17 | 'gru': GRU, 18 | } 19 | 20 | class VectorGate(nn.Module): 21 | ''' 22 | Recurrent Graph Neural Networks for Circuits. 23 | ''' 24 | def __init__(self, args): 25 | super(VectorGate, self).__init__() 26 | 27 | self.args = args 28 | 29 | # configuration 30 | self.num_rounds = args.num_rounds 31 | self.device = args.device 32 | self.predict_diff = args.predict_diff 33 | self.intermediate_supervision = args.intermediate_supervision 34 | self.reverse = args.reverse 35 | self.custom_backward = args.custom_backward 36 | self.use_edge_attr = args.use_edge_attr 37 | self.mask = args.mask 38 | 39 | # dimensions 40 | self.num_aggr = args.num_aggr 41 | self.dim_node_feature = args.dim_node_feature 42 | self.dim_hidden = args.dim_hidden 43 | self.dim_mlp = args.dim_mlp 44 | self.dim_pred = args.dim_pred 45 | self.num_fc = args.num_fc 46 | self.wx_update = args.wx_update 47 | self.wx_mlp = args.wx_mlp 48 | self.dim_edge_feature = args.dim_edge_feature 49 | 50 | # Network 51 | self.aggr_and_strc = MlpAggr(self.dim_hidden*1, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 52 | self.aggr_and_func = MlpAggr(self.dim_hidden*2, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 53 | self.aggr_not_strc = MlpAggr(self.dim_hidden*1, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 54 | self.aggr_not_func = MlpAggr(self.dim_hidden*1, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 55 | self.update_and_strc = GRU(self.dim_hidden, self.dim_hidden) 56 | self.update_and_func = GRU(self.dim_hidden, self.dim_hidden) 57 | self.update_not_strc = GRU(self.dim_hidden, self.dim_hidden) 58 | self.update_not_func = GRU(self.dim_hidden, self.dim_hidden) 59 | 60 | # Readout 61 | self.readout_prob = MLP(self.dim_hidden, args.dim_mlp, 1, num_layer=3, p_drop=0.2, norm_layer='batchnorm', act_layer='relu') 62 | self.readout_rc = MLP(self.dim_hidden * 2, args.dim_mlp, 1, num_layer=3, p_drop=0.2, norm_layer='batchnorm', sigmoid=True) 63 | 64 | # consider the embedding for the LSTM/GRU model initialized by non-zeros 65 | self.one = torch.ones(1).to(self.device) 66 | # self.hs_emd_int = nn.Linear(1, self.dim_hidden) 67 | self.hf_emd_int = nn.Linear(1, self.dim_hidden) 68 | self.one.requires_grad = False 69 | 70 | def forward(self, G): 71 | num_nodes = G.num_nodes 72 | num_layers_f = max(G.forward_level).item() + 1 73 | num_layers_b = max(G.backward_level).item() + 1 74 | 75 | # initialize the structure hidden state 76 | hs_init = torch.zeros(num_nodes, self.dim_hidden) 77 | hs_init, max_sim = generate_hs_init(G, hs_init, self.dim_hidden) 78 | hs_init = hs_init.to(self.device) 79 | 80 | # initialize the function hidden state 81 | hf_init = self.hf_emd_int(self.one).view(1, -1) # (1 x 1 x dim_hidden) 82 | hf_init = hf_init.repeat(num_nodes, 1) # (1 x num_nodes x dim_hidden) 83 | 84 | preds = self._gru_forward(G, hs_init, hf_init, num_layers_f, num_layers_b) 85 | 86 | return preds, max_sim 87 | 88 | def _gru_forward(self, G, hs_init, hf_init, num_layers_f, num_layers_b, h_true=None, h_false=None): 89 | G = G.to(self.device) 90 | x, edge_index = G.x, G.edge_index 91 | edge_attr = G.edge_attr if self.use_edge_attr else None 92 | 93 | hs = hs_init.to(self.device) 94 | hf = hf_init.to(self.device) 95 | node_state = torch.cat([hs, hf], dim=-1) 96 | and_mask = G.gate.squeeze(1) == 1 97 | not_mask = G.gate.squeeze(1) == 2 98 | 99 | for _ in range(self.num_rounds): 100 | for level in range(1, num_layers_f): 101 | # forward layer 102 | layer_mask = G.forward_level == level 103 | 104 | # AND Gate 105 | l_and_node = G.forward_index[layer_mask & and_mask] 106 | if l_and_node.size(0) > 0: 107 | and_edge_index, and_edge_attr = subgraph(l_and_node, edge_index, edge_attr, dim=1) 108 | # Update structure hidden state 109 | msg = self.aggr_and_strc(hs, and_edge_index, and_edge_attr) 110 | and_msg = torch.index_select(msg, dim=0, index=l_and_node) 111 | hs_and = torch.index_select(hs, dim=0, index=l_and_node) 112 | _, hs_and = self.update_and_strc(and_msg.unsqueeze(0), hs_and.unsqueeze(0)) 113 | hs[l_and_node, :] = hs_and.squeeze(0) 114 | # Update function hidden state 115 | msg = self.aggr_and_func(node_state, and_edge_index, and_edge_attr) 116 | and_msg = torch.index_select(msg, dim=0, index=l_and_node) 117 | hf_and = torch.index_select(hf, dim=0, index=l_and_node) 118 | _, hf_and = self.update_and_func(and_msg.unsqueeze(0), hf_and.unsqueeze(0)) 119 | hf[l_and_node, :] = hf_and.squeeze(0) 120 | 121 | # NOT Gate 122 | l_not_node = G.forward_index[layer_mask & not_mask] 123 | if l_not_node.size(0) > 0: 124 | not_edge_index, not_edge_attr = subgraph(l_not_node, edge_index, edge_attr, dim=1) 125 | # Update structure hidden state 126 | msg = self.aggr_not_strc(hs, not_edge_index, not_edge_attr) 127 | not_msg = torch.index_select(msg, dim=0, index=l_not_node) 128 | hs_not = torch.index_select(hs, dim=0, index=l_not_node) 129 | _, hs_not = self.update_not_strc(not_msg.unsqueeze(0), hs_not.unsqueeze(0)) 130 | hs[l_not_node, :] = hs_not.squeeze(0) 131 | # Update function hidden state 132 | msg = self.aggr_not_func(hf, not_edge_index, not_edge_attr) 133 | not_msg = torch.index_select(msg, dim=0, index=l_not_node) 134 | hf_not = torch.index_select(hf, dim=0, index=l_not_node) 135 | _, hf_not = self.update_not_func(not_msg.unsqueeze(0), hf_not.unsqueeze(0)) 136 | hf[l_not_node, :] = hf_not.squeeze(0) 137 | 138 | # Update node state 139 | node_state = torch.cat([hs, hf], dim=-1) 140 | 141 | node_embedding = node_state.squeeze(0) 142 | hs = node_embedding[:, :self.dim_hidden] 143 | hf = node_embedding[:, self.dim_hidden:] 144 | 145 | # Readout 146 | prob = self.readout_prob(hf) 147 | # rc_emb = torch.cat([hs[G.rc_pair_index[0]], hs[G.rc_pair_index[1]]], dim=1) 148 | # is_rc = self.readout_rc(rc_emb) 149 | is_rc = [] 150 | 151 | return hs, hf, prob, is_rc 152 | 153 | 154 | def imply_mask(self, G, h, h_true, h_false): 155 | true_mask = (G.mask == 1.0).unsqueeze(0) 156 | false_mask = (G.mask == 0.0).unsqueeze(0) 157 | normal_mask = (G.mask == -1.0).unsqueeze(0) 158 | h_mask = h * normal_mask + h_true * true_mask + h_false * false_mask 159 | return h_mask 160 | 161 | 162 | 163 | def get_vector_gate(args): 164 | return VectorGate(args) -------------------------------------------------------------------------------- /src/trains/mlpgnn_trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | from torch_geometric.nn import DataParallel 10 | from progress.bar import Bar 11 | from utils.utils import AverageMeter, zero_normalization, get_function_acc 12 | 13 | _loss_factory = { 14 | # Regression 15 | 'l1': nn.L1Loss, 16 | 'sl1': nn.SmoothL1Loss, 17 | 'l2': nn.MSELoss, 18 | # Classification 19 | 'bce': nn.BCELoss, 20 | } 21 | 22 | class ModelWithLoss(torch.nn.Module): 23 | def __init__(self, model, reg_loss, cls_loss, gpus, device): 24 | super(ModelWithLoss, self).__init__() 25 | self.model = model 26 | self.reg_loss = reg_loss 27 | self.cls_loss = cls_loss 28 | self.gpus = gpus 29 | self.device = device 30 | self.sigmoid = nn.Sigmoid() 31 | 32 | def forward(self, batch): 33 | preds, max_sim = self.model(batch) 34 | hs, hf, prob, is_rc = preds 35 | # Task 1: Probability Prediction 36 | prob_loss = self.reg_loss(prob.to(self.device), batch.prob.to(self.device)) 37 | # Task 2: Structural Prediction 38 | # node_a = hs[batch.rc_pair_index[0]] 39 | # node_b = hs[batch.rc_pair_index[1]] 40 | # emb_rc_sim = torch.cosine_similarity(node_a, node_b, eps=1e-8) 41 | # threshold = (1-max_sim)/2 + max_sim 42 | # # threshold = max_sim 43 | # rc_pred = self.sigmoid(emb_rc_sim - threshold) 44 | # rc_pred = rc_pred.unsqueeze(1).float() 45 | # rc_loss = self.cls_loss(rc_pred.to(self.device), batch.is_rc.to(self.device)) 46 | rc_loss = self.cls_loss(is_rc.to(self.device), batch.is_rc.to(self.device)) 47 | # Task 3: Function Prediction 48 | # emb = torch.cat([hs, hf], dim=-1) 49 | node_a = hf[batch.tt_pair_index[0]] 50 | node_b = hf[batch.tt_pair_index[1]] 51 | emb_dis = 1 - torch.cosine_similarity(node_a, node_b, eps=1e-8) 52 | emb_dis_z = zero_normalization(emb_dis) 53 | tt_dis_z = zero_normalization(batch.tt_dis) 54 | func_loss = self.reg_loss(emb_dis_z.to(self.device), tt_dis_z.to(self.device)) 55 | loss_stats = {'LProb': prob_loss, 'LRC': rc_loss, 'LFunc': func_loss} 56 | 57 | return hs, hf, loss_stats 58 | 59 | class MLPGNNTrainer(object): 60 | def __init__( 61 | self, args, model, optimizer=None): 62 | self.args = args 63 | self.optimizer = optimizer 64 | self.loss_stats, self.reg_loss, self.cls_loss = self._get_losses(args.reg_loss, args.cls_loss) 65 | self.reg_loss = self.reg_loss.to(self.args.device) 66 | self.cls_loss = self.cls_loss.to(self.args.device) 67 | self.model_with_loss = ModelWithLoss(model, self.reg_loss, self.cls_loss, args.gpus, args.device) 68 | 69 | def set_weight(self, w_prob, w_rc, w_func): 70 | self.args.Prob_weight = w_prob 71 | self.args.RC_weight = w_rc 72 | self.args.Func_weight = w_func 73 | 74 | def set_device(self, device, local_rank, gpus): 75 | if len(gpus)> 1: 76 | self.model_with_loss = self.model_with_loss.to(device) 77 | self.model_with_loss = nn.parallel.DistributedDataParallel(self.model_with_loss, 78 | device_ids=[local_rank], 79 | find_unused_parameters=True) 80 | else: 81 | self.model_with_loss = self.model_with_loss.to(device) 82 | 83 | for state in self.optimizer.state.values(): 84 | for k, v in state.items(): 85 | if isinstance(v, torch.Tensor): 86 | state[k] = v.to(device=device, non_blocking=True) 87 | 88 | def run_epoch(self, phase, epoch, dataset, local_rank): 89 | model_with_loss = self.model_with_loss 90 | if phase == 'train': 91 | model_with_loss.train() 92 | else: 93 | if len(self.args.gpus) > 1: 94 | model_with_loss = self.model_with_loss.module 95 | model_with_loss.eval() 96 | torch.cuda.empty_cache() 97 | 98 | args = self.args 99 | results = {} 100 | acc_list = [] 101 | data_time, batch_time = AverageMeter(), AverageMeter() 102 | avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} 103 | num_iters = len(dataset) if args.num_iters < 0 else args.num_iters 104 | if local_rank == 0: 105 | bar = Bar('{}/{}'.format(args.task, args.exp_id), max=num_iters) 106 | end = time.time() 107 | for iter_id, batch in enumerate(dataset): 108 | if iter_id >= num_iters: 109 | break 110 | if len(self.args.gpus) == 1: 111 | batch = batch.to(self.args.device) 112 | data_time.update(time.time() - end) 113 | hs, hf, loss_stats = model_with_loss(batch) 114 | loss = loss_stats['LProb'] * args.Prob_weight + loss_stats['LRC'] * args.RC_weight + loss_stats['LFunc'] * args.Func_weight 115 | loss /= (args.Prob_weight + args.RC_weight + args.Func_weight) 116 | loss = loss.mean() 117 | loss_stats['loss'] = loss 118 | if phase == 'train': 119 | self.optimizer.zero_grad() 120 | loss.backward() 121 | if args.grad_clip > 0: 122 | torch.nn.utils.clip_grad_norm_(model_with_loss.parameters(), args.grad_clip) 123 | self.optimizer.step() 124 | batch_time.update(time.time() - end) 125 | end = time.time() 126 | if local_rank == 0: 127 | Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( 128 | epoch, iter_id, num_iters, phase=phase, 129 | total=bar.elapsed_td, eta=bar.eta_td) 130 | for l in avg_loss_stats: 131 | avg_loss_stats[l].update( 132 | loss_stats[l].mean().item(), batch.num_graphs * len(hf)) 133 | Bar.suffix = Bar.suffix + \ 134 | '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) 135 | 136 | # Get Acc 137 | if phase == 'val': 138 | acc = get_function_acc(batch, hf) 139 | Bar.suffix = Bar.suffix + '|Acc {:}%%'.format(acc*100) 140 | acc_list.append(acc) 141 | 142 | if not args.hide_data_time: 143 | Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ 144 | '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) 145 | if args.print_iter > 0: 146 | if iter_id % args.print_iter == 0: 147 | print('{}/{}| {}'.format(args.task, args.exp_id, Bar.suffix)) 148 | else: 149 | bar.next() 150 | del hs, hf, loss, loss_stats 151 | 152 | 153 | ret = {k: v.avg for k, v in avg_loss_stats.items()} 154 | if local_rank == 0: 155 | bar.finish() 156 | ret['time'] = bar.elapsed_td.total_seconds() / 60. 157 | if phase == 'val': 158 | ret['ACC'] = np.average(acc_list) 159 | return ret, results 160 | 161 | def debug(self, batch, output, iter_id): 162 | raise NotImplementedError 163 | 164 | def save_result(self, output, batch, results): 165 | raise NotImplementedError 166 | 167 | def _get_losses(self, reg_loss, cls_loss): 168 | if reg_loss in _loss_factory.keys(): 169 | reg_loss_func = _loss_factory[reg_loss]() 170 | if cls_loss in _loss_factory.keys(): 171 | cls_loss_func = _loss_factory[cls_loss]() 172 | loss_states = ['loss', 'LProb', 'LRC', 'LFunc'] 173 | return loss_states, reg_loss_func, cls_loss_func 174 | 175 | def val(self, epoch, data_loader, local_rank): 176 | return self.run_epoch('val', epoch, data_loader, local_rank) 177 | 178 | def train(self, epoch, data_loader, local_rank): 179 | return self.run_epoch('train', epoch, data_loader, local_rank) 180 | -------------------------------------------------------------------------------- /src/utils/batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch_sparse import SparseTensor, cat 4 | import torch_geometric 5 | from torch_geometric.data import Data 6 | 7 | 8 | class Batch(Data): 9 | r"""A plain old python object modeling a batch of graphs as one big 10 | (disconnected) graph. With :class:`torch_geometric.data.Data` being the 11 | base class, all its methods can also be used here. 12 | In addition, single graphs can be reconstructed via the assignment vector 13 | :obj:`batch`, which maps each node to its respective graph identifier. 14 | """ 15 | def __init__(self, batch=None, **kwargs): 16 | super(Batch, self).__init__(**kwargs) 17 | 18 | self.batch = batch 19 | self.__data_class__ = Data 20 | self.__slices__ = None 21 | self.__cumsum__ = None 22 | self.__cat_dims__ = None 23 | self.__num_nodes_list__ = None 24 | 25 | @staticmethod 26 | def from_data_list(data_list, follow_batch=[]): 27 | r"""Constructs a batch object from a python list holding 28 | :class:`torch_geometric.data.Data` objects. 29 | The assignment vector :obj:`batch` is created on the fly. 30 | Additionally, creates assignment batch vectors for each key in 31 | :obj:`follow_batch`.""" 32 | 33 | keys = [set(data.keys) for data in data_list] 34 | keys = list(set.union(*keys)) 35 | assert 'batch' not in keys 36 | 37 | batch = Batch() 38 | batch.__data_class__ = data_list[0].__class__ 39 | for key in keys + ['batch']: 40 | batch[key] = [] 41 | 42 | slices = {key: [0] for key in keys} 43 | cumsum = {key: [0] for key in keys} 44 | cat_dims = {} 45 | num_nodes_list = [] 46 | mm = 0 47 | for i, data in enumerate(data_list): 48 | for key in keys: 49 | item = data[key] 50 | 51 | # Increase values by `cumsum` value. 52 | cum = cumsum[key][-1] 53 | # DAGNN 54 | if key in ["bi_layer_index", "bi_layer_parent_index"]: 55 | # if key == "bi_layer_index": 56 | # # print("now") 57 | # mm = torch.max(item).item() 58 | # print(mm, data.x.shape[0]) 59 | item[:, 1] = item[:, 1] + cum 60 | # if key == "bi_layer_index": 61 | # # print("now") 62 | # mm = torch.max(item).item() 63 | # print(mm) #, data.x.shape[0]) 64 | elif key in ["layer_index", "layer_parent_index"]: 65 | # keep layer dimension 0 and only increase ids in dimension 1 (we merge layers of all data objects) 66 | item[1] = item[1] + cum 67 | elif isinstance(item, Tensor) and item.dtype != torch.bool: 68 | item = item + cum if cum != 0 else item 69 | elif isinstance(item, SparseTensor): 70 | value = item.storage.value() 71 | if value is not None and value.dtype != torch.bool: 72 | value = value + cum if cum != 0 else value 73 | item = item.set_value(value, layout='coo') 74 | elif isinstance(item, (int, float)): 75 | item = item + cum 76 | 77 | # Treat 0-dimensional tensors as 1-dimensional. 78 | if isinstance(item, Tensor) and item.dim() == 0: 79 | item = item.unsqueeze(0) 80 | 81 | batch[key].append(item) 82 | 83 | # Gather the size of the `cat` dimension. 84 | size = 1 85 | cat_dim = data.__cat_dim__(key, data[key]) 86 | cat_dims[key] = cat_dim 87 | if isinstance(item, Tensor): 88 | size = item.size(cat_dim) 89 | elif isinstance(item, SparseTensor): 90 | size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] 91 | 92 | slices[key].append(size + slices[key][-1]) 93 | inc = data.__inc__(key, item) 94 | if isinstance(inc, (tuple, list)): 95 | inc = torch.tensor(inc) 96 | cumsum[key].append(inc + cumsum[key][-1]) 97 | 98 | if key in follow_batch: 99 | if isinstance(size, Tensor): 100 | for j, size in enumerate(size.tolist()): 101 | tmp = f'{key}_{j}_batch' 102 | batch[tmp] = [] if i == 0 else batch[tmp] 103 | batch[tmp].append( 104 | torch.full((size, ), i, dtype=torch.long)) 105 | else: 106 | tmp = f'{key}_batch' 107 | batch[tmp] = [] if i == 0 else batch[tmp] 108 | batch[tmp].append( 109 | torch.full((size, ), i, dtype=torch.long)) 110 | 111 | if hasattr(data, '__num_nodes__'): 112 | num_nodes_list.append(data.__num_nodes__) 113 | else: 114 | num_nodes_list.append(None) 115 | 116 | num_nodes = data.num_nodes 117 | # print(mm, num_nodes, sum(num_nodes_list)) 118 | if num_nodes is not None: 119 | item = torch.full((num_nodes, ), i, dtype=torch.long) 120 | batch.batch.append(item) 121 | 122 | # Fix initial slice values: 123 | for key in keys: 124 | slices[key][0] = slices[key][1] - slices[key][1] 125 | 126 | batch.batch = None if len(batch.batch) == 0 else batch.batch 127 | batch.__slices__ = slices 128 | batch.__cumsum__ = cumsum 129 | batch.__cat_dims__ = cat_dims 130 | batch.__num_nodes_list__ = num_nodes_list 131 | 132 | ref_data = data_list[0] 133 | for key in batch.keys: 134 | items = batch[key] 135 | item = items[0] 136 | if isinstance(item, Tensor): 137 | batch[key] = torch.cat(items, ref_data.__cat_dim__(key, item)) 138 | elif isinstance(item, SparseTensor): 139 | batch[key] = cat(items, ref_data.__cat_dim__(key, item)) 140 | elif isinstance(item, (int, float)): 141 | batch[key] = torch.tensor(items) 142 | 143 | if torch_geometric.is_debug_enabled(): 144 | batch.debug() 145 | 146 | return batch.contiguous() 147 | 148 | def to_data_list(self): 149 | r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects 150 | from the batch object. 151 | The batch object must have been created via :meth:`from_data_list` in 152 | order to be able to reconstruct the initial objects.""" 153 | 154 | if self.__slices__ is None: 155 | raise RuntimeError( 156 | ('Cannot reconstruct data list from batch because the batch ' 157 | 'object was not created using `Batch.from_data_list()`.')) 158 | 159 | data_list = [] 160 | for i in range(len(list(self.__slices__.values())[0]) - 1): 161 | data = self.__data_class__() 162 | 163 | for key in self.__slices__.keys(): 164 | item = self[key] 165 | # Narrow the item based on the values in `__slices__`. 166 | if isinstance(item, Tensor): 167 | dim = self.__cat_dims__[key] 168 | start = self.__slices__[key][i] 169 | end = self.__slices__[key][i + 1] 170 | item = item.narrow(dim, start, end - start) 171 | elif isinstance(item, SparseTensor): 172 | for j, dim in enumerate(self.__cat_dims__[key]): 173 | start = self.__slices__[key][i][j].item() 174 | end = self.__slices__[key][i + 1][j].item() 175 | item = item.narrow(dim, start, end - start) 176 | else: 177 | item = item[self.__slices__[key][i]:self. 178 | __slices__[key][i + 1]] 179 | item = item[0] if len(item) == 1 else item 180 | 181 | # Decrease its value by `cumsum` value: 182 | cum = self.__cumsum__[key][i] 183 | if isinstance(item, Tensor): 184 | item = item - cum if cum != 0 else item 185 | elif isinstance(item, SparseTensor): 186 | value = item.storage.value() 187 | if value is not None and value.dtype != torch.bool: 188 | value = value - cum if cum != 0 else value 189 | item = item.set_value(value, layout='coo') 190 | elif isinstance(item, (int, float)): 191 | item = item - cum 192 | 193 | data[key] = item 194 | 195 | if self.__num_nodes_list__[i] is not None: 196 | data.num_nodes = self.__num_nodes_list__[i] 197 | 198 | data_list.append(data) 199 | 200 | return data_list 201 | 202 | @property 203 | def num_graphs(self): 204 | """Returns the number of graphs in the batch.""" 205 | return self.batch[-1].item() + 1 -------------------------------------------------------------------------------- /src/models/mlpgate.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch import nn 7 | from utils.dag_utils import subgraph, custom_backward_subgraph 8 | from utils.utils import generate_hs_init 9 | 10 | from .mlp import MLP 11 | from .mlp_aggr import MlpAggr 12 | from .gat_conv import AGNNConv 13 | from .gcn_conv import AggConv 14 | from .deepset_conv import DeepSetConv 15 | from .gated_sum_conv import GatedSumConv 16 | from .aggnmlp import AttnMLP 17 | from .tfmlp import TFMLP 18 | 19 | from torch.nn import LSTM, GRU 20 | 21 | _aggr_function_factory = { 22 | 'mlp': MlpAggr, # MLP, similar as NeuroSAT 23 | 'attnmlp': AttnMLP, # MLP with attention 24 | 'tfmlp': TFMLP, # MLP with transformer 25 | 'aggnconv': AGNNConv, # DeepGate with attention 26 | 'conv_sum': AggConv, # GCN 27 | 'deepset': DeepSetConv, # DeepSet, similar as NeuroSAT 28 | } 29 | 30 | _update_function_factory = { 31 | 'lstm': LSTM, 32 | 'gru': GRU, 33 | } 34 | 35 | class MLPGate(nn.Module): 36 | ''' 37 | Recurrent Graph Neural Networks for Circuits. 38 | ''' 39 | def __init__(self, args): 40 | super(MLPGate, self).__init__() 41 | 42 | self.args = args 43 | 44 | # configuration 45 | self.num_rounds = args.num_rounds 46 | self.device = args.device 47 | self.predict_diff = args.predict_diff 48 | self.intermediate_supervision = args.intermediate_supervision 49 | self.reverse = args.reverse 50 | self.custom_backward = args.custom_backward 51 | self.use_edge_attr = args.use_edge_attr 52 | self.mask = args.mask 53 | 54 | # dimensions 55 | self.num_aggr = args.num_aggr 56 | self.dim_node_feature = args.dim_node_feature 57 | self.dim_hidden = args.dim_hidden 58 | self.dim_mlp = args.dim_mlp 59 | self.dim_pred = args.dim_pred 60 | self.num_fc = args.num_fc 61 | self.wx_update = args.wx_update 62 | self.wx_mlp = args.wx_mlp 63 | self.dim_edge_feature = args.dim_edge_feature 64 | 65 | # Network 66 | if args.aggr_function == 'mlp' or args.aggr_function == 'attnmlp': 67 | self.aggr_and_strc = _aggr_function_factory[args.aggr_function](self.dim_hidden*1, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 68 | self.aggr_and_func = _aggr_function_factory[args.aggr_function](self.dim_hidden*2, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 69 | self.aggr_not_strc = _aggr_function_factory[args.aggr_function](self.dim_hidden*1, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 70 | self.aggr_not_func = _aggr_function_factory[args.aggr_function](self.dim_hidden*1, args.dim_mlp, self.dim_hidden, num_layer=3, act_layer='relu') 71 | else: 72 | self.aggr_and_strc = _aggr_function_factory[args.aggr_function](self.dim_hidden*1, self.dim_hidden) 73 | self.aggr_and_func = _aggr_function_factory[args.aggr_function](self.dim_hidden*2, self.dim_hidden) 74 | self.aggr_not_strc = _aggr_function_factory[args.aggr_function](self.dim_hidden*1, self.dim_hidden) 75 | self.aggr_not_func = _aggr_function_factory[args.aggr_function](self.dim_hidden*1, self.dim_hidden) 76 | 77 | self.update_and_strc = GRU(self.dim_hidden, self.dim_hidden) 78 | self.update_and_func = GRU(self.dim_hidden, self.dim_hidden) 79 | self.update_not_strc = GRU(self.dim_hidden, self.dim_hidden) 80 | self.update_not_func = GRU(self.dim_hidden, self.dim_hidden) 81 | 82 | # Readout 83 | self.readout_prob = MLP(self.dim_hidden, args.dim_mlp, 1, num_layer=3, p_drop=0.2, norm_layer='batchnorm', act_layer='relu') 84 | self.readout_rc = MLP(self.dim_hidden * 2, args.dim_mlp, 1, num_layer=3, p_drop=0.2, norm_layer='batchnorm', sigmoid=True) 85 | 86 | # consider the embedding for the LSTM/GRU model initialized by non-zeros 87 | self.one = torch.ones(1).to(self.device) 88 | # self.hs_emd_int = nn.Linear(1, self.dim_hidden) 89 | self.hf_emd_int = nn.Linear(1, self.dim_hidden) 90 | self.one.requires_grad = False 91 | 92 | def forward(self, G): 93 | num_nodes = G.num_nodes 94 | num_layers_f = max(G.forward_level).item() + 1 95 | num_layers_b = max(G.backward_level).item() + 1 96 | 97 | # initialize the structure hidden state 98 | if self.args.disable_encode: 99 | hs_init = torch.zeros(num_nodes, self.dim_hidden) 100 | max_sim = 0 101 | hs_init = hs_init.to(self.device) 102 | else: 103 | hs_init = torch.zeros(num_nodes, self.dim_hidden) 104 | hs_init, max_sim = generate_hs_init(G, hs_init, self.dim_hidden) 105 | hs_init = hs_init.to(self.device) 106 | 107 | # initialize the function hidden state 108 | hf_init = self.hf_emd_int(self.one).view(1, -1) # (1 x 1 x dim_hidden) 109 | hf_init = hf_init.repeat(num_nodes, 1) # (1 x num_nodes x dim_hidden) 110 | 111 | preds = self._gru_forward(G, hs_init, hf_init, num_layers_f, num_layers_b) 112 | 113 | return preds, max_sim 114 | 115 | def _gru_forward(self, G, hs_init, hf_init, num_layers_f, num_layers_b, h_true=None, h_false=None): 116 | G = G.to(self.device) 117 | x, edge_index = G.x, G.edge_index 118 | edge_attr = G.edge_attr if self.use_edge_attr else None 119 | 120 | hs = hs_init.to(self.device) 121 | hf = hf_init.to(self.device) 122 | node_state = torch.cat([hs, hf], dim=-1) 123 | and_mask = G.gate.squeeze(1) == 1 124 | not_mask = G.gate.squeeze(1) == 2 125 | 126 | for _ in range(self.num_rounds): 127 | for level in range(1, num_layers_f): 128 | # forward layer 129 | layer_mask = G.forward_level == level 130 | 131 | # AND Gate 132 | l_and_node = G.forward_index[layer_mask & and_mask] 133 | if l_and_node.size(0) > 0: 134 | and_edge_index, and_edge_attr = subgraph(l_and_node, edge_index, edge_attr, dim=1) 135 | # Update structure hidden state 136 | msg = self.aggr_and_strc(hs, and_edge_index, and_edge_attr) 137 | and_msg = torch.index_select(msg, dim=0, index=l_and_node) 138 | hs_and = torch.index_select(hs, dim=0, index=l_and_node) 139 | _, hs_and = self.update_and_strc(and_msg.unsqueeze(0), hs_and.unsqueeze(0)) 140 | hs[l_and_node, :] = hs_and.squeeze(0) 141 | # Update function hidden state 142 | msg = self.aggr_and_func(node_state, and_edge_index, and_edge_attr) 143 | and_msg = torch.index_select(msg, dim=0, index=l_and_node) 144 | hf_and = torch.index_select(hf, dim=0, index=l_and_node) 145 | _, hf_and = self.update_and_func(and_msg.unsqueeze(0), hf_and.unsqueeze(0)) 146 | hf[l_and_node, :] = hf_and.squeeze(0) 147 | 148 | # NOT Gate 149 | l_not_node = G.forward_index[layer_mask & not_mask] 150 | if l_not_node.size(0) > 0: 151 | not_edge_index, not_edge_attr = subgraph(l_not_node, edge_index, edge_attr, dim=1) 152 | # Update structure hidden state 153 | msg = self.aggr_not_strc(hs, not_edge_index, not_edge_attr) 154 | not_msg = torch.index_select(msg, dim=0, index=l_not_node) 155 | hs_not = torch.index_select(hs, dim=0, index=l_not_node) 156 | _, hs_not = self.update_not_strc(not_msg.unsqueeze(0), hs_not.unsqueeze(0)) 157 | hs[l_not_node, :] = hs_not.squeeze(0) 158 | # Update function hidden state 159 | msg = self.aggr_not_func(hf, not_edge_index, not_edge_attr) 160 | not_msg = torch.index_select(msg, dim=0, index=l_not_node) 161 | hf_not = torch.index_select(hf, dim=0, index=l_not_node) 162 | _, hf_not = self.update_not_func(not_msg.unsqueeze(0), hf_not.unsqueeze(0)) 163 | hf[l_not_node, :] = hf_not.squeeze(0) 164 | 165 | # Update node state 166 | node_state = torch.cat([hs, hf], dim=-1) 167 | 168 | node_embedding = node_state.squeeze(0) 169 | hs = node_embedding[:, :self.dim_hidden] 170 | hf = node_embedding[:, self.dim_hidden:] 171 | 172 | # Readout 173 | prob = self.readout_prob(hf) 174 | rc_emb = torch.cat([hs[G.rc_pair_index[0]], hs[G.rc_pair_index[1]]], dim=1) 175 | is_rc = self.readout_rc(rc_emb) 176 | 177 | return hs, hf, prob, is_rc 178 | 179 | 180 | def imply_mask(self, G, h, h_true, h_false): 181 | true_mask = (G.mask == 1.0).unsqueeze(0) 182 | false_mask = (G.mask == 0.0).unsqueeze(0) 183 | normal_mask = (G.mask == -1.0).unsqueeze(0) 184 | h_mask = h * normal_mask + h_true * true_mask + h_false * false_mask 185 | return h_mask 186 | 187 | 188 | 189 | def get_mlp_gate(args): 190 | return MLPGate(args) -------------------------------------------------------------------------------- /src/models/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | 11 | class SmoothStep(nn.Module): 12 | def __init__(self, kstep=10.0): 13 | super(SmoothStep, self).__init__() 14 | self.kstep = kstep 15 | 16 | def forward(self, inputs): 17 | outputs = torch.pow(1-inputs, self.kstep) / (torch.pow(1-inputs, self.kstep) + torch.pow(inputs, self.kstep)) 18 | return outputs 19 | 20 | 21 | def _gather_feat(feat, ind, mask=None): 22 | dim = feat.size(2) 23 | ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) 24 | feat = feat.gather(1, ind) 25 | if mask is not None: 26 | mask = mask.unsqueeze(2).expand_as(feat) 27 | feat = feat[mask] 28 | feat = feat.view(-1, dim) 29 | return feat 30 | 31 | 32 | def _transpose_and_gather_feat(feat, ind): 33 | feat = feat.permute(0, 2, 3, 1).contiguous() 34 | feat = feat.view(feat.size(0), -1, feat.size(3)) 35 | feat = _gather_feat(feat, ind) 36 | return feat 37 | 38 | 39 | def _slow_neg_loss(pred, gt): 40 | '''focal loss from CornerNet''' 41 | pos_inds = gt.eq(1) 42 | neg_inds = gt.lt(1) 43 | 44 | neg_weights = torch.pow(1 - gt[neg_inds], 4) 45 | 46 | loss = 0 47 | pos_pred = pred[pos_inds] 48 | neg_pred = pred[neg_inds] 49 | 50 | pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) 51 | neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights 52 | 53 | num_pos = pos_inds.float().sum() 54 | pos_loss = pos_loss.sum() 55 | neg_loss = neg_loss.sum() 56 | 57 | if pos_pred.nelement() == 0: 58 | loss = loss - neg_loss 59 | else: 60 | loss = loss - (pos_loss + neg_loss) / num_pos 61 | return loss 62 | 63 | 64 | def _neg_loss(pred, gt): 65 | ''' Modified focal loss. Exactly the same as CornerNet. 66 | Runs faster and costs a little bit more memory 67 | Arguments: 68 | pred (batch x c x h x w) 69 | gt_regr (batch x c x h x w) 70 | ''' 71 | pos_inds = gt.eq(1).float() 72 | neg_inds = gt.lt(1).float() 73 | 74 | neg_weights = torch.pow(1 - gt, 4) 75 | 76 | loss = 0 77 | 78 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds 79 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * \ 80 | neg_weights * neg_inds 81 | 82 | num_pos = pos_inds.float().sum() 83 | pos_loss = pos_loss.sum() 84 | neg_loss = neg_loss.sum() 85 | 86 | if num_pos == 0: 87 | loss = loss - neg_loss 88 | else: 89 | loss = loss - (pos_loss + neg_loss) / num_pos 90 | return loss 91 | 92 | 93 | def _not_faster_neg_loss(pred, gt): 94 | pos_inds = gt.eq(1).float() 95 | neg_inds = gt.lt(1).float() 96 | num_pos = pos_inds.float().sum() 97 | neg_weights = torch.pow(1 - gt, 4) 98 | 99 | loss = 0 100 | trans_pred = pred * neg_inds + (1 - pred) * pos_inds 101 | weight = neg_weights * neg_inds + pos_inds 102 | all_loss = torch.log(1 - trans_pred) * torch.pow(trans_pred, 2) * weight 103 | all_loss = all_loss.sum() 104 | 105 | if num_pos > 0: 106 | all_loss /= num_pos 107 | loss -= all_loss 108 | return loss 109 | 110 | 111 | def _slow_reg_loss(regr, gt_regr, mask): 112 | num = mask.float().sum() 113 | mask = mask.unsqueeze(2).expand_as(gt_regr) 114 | 115 | regr = regr[mask] 116 | gt_regr = gt_regr[mask] 117 | 118 | regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False) 119 | regr_loss = regr_loss / (num + 1e-4) 120 | return regr_loss 121 | 122 | 123 | def _reg_loss(regr, gt_regr, mask): 124 | ''' L1 regression loss 125 | Arguments: 126 | regr (batch x max_objects x dim) 127 | gt_regr (batch x max_objects x dim) 128 | mask (batch x max_objects) 129 | ''' 130 | num = mask.float().sum() 131 | mask = mask.unsqueeze(2).expand_as(gt_regr).float() 132 | 133 | regr = regr * mask 134 | gt_regr = gt_regr * mask 135 | 136 | regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False) 137 | regr_loss = regr_loss / (num + 1e-4) 138 | return regr_loss 139 | 140 | 141 | class FocalLoss(nn.Module): 142 | '''nn.Module warpper for focal loss''' 143 | 144 | def __init__(self): 145 | super(FocalLoss, self).__init__() 146 | self.neg_loss = _neg_loss 147 | 148 | def forward(self, out, target): 149 | return self.neg_loss(out, target) 150 | 151 | 152 | class RegLoss(nn.Module): 153 | '''Regression loss for an output tensor 154 | Arguments: 155 | output (batch x dim x h x w) 156 | mask (batch x max_objects) 157 | ind (batch x max_objects) 158 | target (batch x max_objects x dim) 159 | ''' 160 | 161 | def __init__(self): 162 | super(RegLoss, self).__init__() 163 | 164 | def forward(self, output, mask, ind, target): 165 | pred = _transpose_and_gather_feat(output, ind) 166 | loss = _reg_loss(pred, target, mask) 167 | return loss 168 | 169 | 170 | class RegL1Loss(nn.Module): 171 | def __init__(self): 172 | super(RegL1Loss, self).__init__() 173 | 174 | def forward(self, output, mask, ind, target): 175 | pred = _transpose_and_gather_feat(output, ind) 176 | mask = mask.unsqueeze(2).expand_as(pred).float() 177 | # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean') 178 | loss = F.l1_loss(pred * mask, target * mask, size_average=False) 179 | loss = loss / (mask.sum() + 1e-4) 180 | return loss 181 | 182 | 183 | class NormRegL1Loss(nn.Module): 184 | def __init__(self): 185 | super(NormRegL1Loss, self).__init__() 186 | 187 | def forward(self, output, mask, ind, target): 188 | pred = _transpose_and_gather_feat(output, ind) 189 | mask = mask.unsqueeze(2).expand_as(pred).float() 190 | # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean') 191 | pred = pred / (target + 1e-4) 192 | target = target * 0 + 1 193 | loss = F.l1_loss(pred * mask, target * mask, size_average=False) 194 | loss = loss / (mask.sum() + 1e-4) 195 | return loss 196 | 197 | 198 | class RegWeightedL1Loss(nn.Module): 199 | def __init__(self): 200 | super(RegWeightedL1Loss, self).__init__() 201 | 202 | def forward(self, output, mask, ind, target): 203 | pred = _transpose_and_gather_feat(output, ind) 204 | mask = mask.float() 205 | # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean') 206 | loss = F.l1_loss(pred * mask, target * mask, size_average=False) 207 | loss = loss / (mask.sum() + 1e-4) 208 | return loss 209 | 210 | 211 | class L1Loss(nn.Module): 212 | def __init__(self): 213 | super(L1Loss, self).__init__() 214 | 215 | def forward(self, output, mask, ind, target): 216 | pred = _transpose_and_gather_feat(output, ind) 217 | mask = mask.unsqueeze(2).expand_as(pred).float() 218 | loss = F.l1_loss(pred * mask, target * mask, 219 | reduction='elementwise_mean') 220 | return loss 221 | 222 | 223 | class BinRotLoss(nn.Module): 224 | def __init__(self): 225 | super(BinRotLoss, self).__init__() 226 | 227 | def forward(self, output, mask, ind, rotbin, rotres): 228 | pred = _transpose_and_gather_feat(output, ind) 229 | loss = compute_rot_loss(pred, rotbin, rotres, mask) 230 | return loss 231 | 232 | 233 | def compute_res_loss(output, target): 234 | return F.smooth_l1_loss(output, target, reduction='elementwise_mean') 235 | 236 | # TODO: weight 237 | 238 | 239 | def compute_bin_loss(output, target, mask): 240 | mask = mask.expand_as(output) 241 | output = output * mask.float() 242 | return F.cross_entropy(output, target, reduction='elementwise_mean') 243 | 244 | 245 | def compute_rot_loss(output, target_bin, target_res, mask): 246 | # output: (B, 128, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos, 247 | # bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos] 248 | # target_bin: (B, 128, 2) [bin1_cls, bin2_cls] 249 | # target_res: (B, 128, 2) [bin1_res, bin2_res] 250 | # mask: (B, 128, 1) 251 | # import pdb; pdb.set_trace() 252 | output = output.view(-1, 8) 253 | target_bin = target_bin.view(-1, 2) 254 | target_res = target_res.view(-1, 2) 255 | mask = mask.view(-1, 1) 256 | loss_bin1 = compute_bin_loss(output[:, 0:2], target_bin[:, 0], mask) 257 | loss_bin2 = compute_bin_loss(output[:, 4:6], target_bin[:, 1], mask) 258 | loss_res = torch.zeros_like(loss_bin1) 259 | if target_bin[:, 0].nonzero().shape[0] > 0: 260 | idx1 = target_bin[:, 0].nonzero()[:, 0] 261 | valid_output1 = torch.index_select(output, 0, idx1.long()) 262 | valid_target_res1 = torch.index_select(target_res, 0, idx1.long()) 263 | loss_sin1 = compute_res_loss( 264 | valid_output1[:, 2], torch.sin(valid_target_res1[:, 0])) 265 | loss_cos1 = compute_res_loss( 266 | valid_output1[:, 3], torch.cos(valid_target_res1[:, 0])) 267 | loss_res += loss_sin1 + loss_cos1 268 | if target_bin[:, 1].nonzero().shape[0] > 0: 269 | idx2 = target_bin[:, 1].nonzero()[:, 0] 270 | valid_output2 = torch.index_select(output, 0, idx2.long()) 271 | valid_target_res2 = torch.index_select(target_res, 0, idx2.long()) 272 | loss_sin2 = compute_res_loss( 273 | valid_output2[:, 6], torch.sin(valid_target_res2[:, 1])) 274 | loss_cos2 = compute_res_loss( 275 | valid_output2[:, 7], torch.cos(valid_target_res2[:, 1])) 276 | loss_res += loss_sin2 + loss_cos2 277 | return loss_bin1 + loss_bin2 + loss_res 278 | -------------------------------------------------------------------------------- /src/utils/aiger_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import subprocess 4 | 5 | def cnf_to_xdata(cnf_filename, tmp_aig_filename, tmp_aag_filename, gate_to_index): 6 | cnf2aig_cmd = 'cnf2aig {} {}'.format(cnf_filename, tmp_aig_filename) 7 | info = os.popen(cnf2aig_cmd).readlines() 8 | aig2aag_cmd = 'aigtoaig {} {}'.format(tmp_aig_filename, tmp_aag_filename) 9 | info = os.popen(aig2aag_cmd).readlines() 10 | 11 | # read aag 12 | f = open(tmp_aag_filename, 'r') 13 | lines = f.readlines() 14 | f.close() 15 | header = lines[0].strip().split(" ") 16 | assert header[0] == 'aag', 'The header of AIG file is wrong.' 17 | # “M”, “I”, “L”, “O”, “A” separated by spaces. 18 | n_variables = eval(header[1]) 19 | n_inputs = eval(header[2]) 20 | n_outputs = eval(header[4]) 21 | n_and = eval(header[5]) 22 | if n_outputs != 1 or n_variables != (n_inputs + n_and) or n_variables == n_inputs: 23 | return None 24 | assert n_outputs == 1, 'The AIG has multiple outputs.' 25 | assert n_variables == (n_inputs + n_and), 'There are unused AND gates.' 26 | assert n_variables != n_inputs, '# variable equals to # inputs' 27 | # Construct AIG graph 28 | x_data = [] 29 | edge_index = [] 30 | # node_labels = [] 31 | not_dict = {} 32 | 33 | # Add Literal node 34 | for i in range(n_inputs): 35 | x_data.append([len(x_data), gate_to_index['PI']]) 36 | # node_labels += [0] 37 | 38 | # Add AND node 39 | for i in range(n_inputs+1, n_inputs+1+n_and): 40 | x_data.append([len(x_data), gate_to_index['AND']]) 41 | # node_labels += [1] 42 | 43 | # sanity-check 44 | for (i, line) in enumerate(lines[1:1+n_inputs]): 45 | literal = line.strip().split(" ") 46 | assert len(literal) == 1, 'The literal of input should be single.' 47 | assert int(literal[0]) == 2 * (i + 1), 'The value of a input literal should be the index of variables mutiplying by two.' 48 | 49 | literal = lines[1+n_inputs].strip().split(" ")[0] 50 | assert int(literal) == (n_variables * 2) or int(literal) == (n_variables * 2) + 1, 'The value of the output literal shoud be (n_variables * 2)' 51 | sign_final = int(literal) % 2 52 | index_final_and = int(literal) // 2 - 1 53 | 54 | for (i, line) in enumerate(lines[2+n_inputs: 2+n_inputs+n_and]): 55 | literals = line.strip().split(" ") 56 | assert len(literals) == 3, 'invalidate the definition of two-input AND gate.' 57 | assert int(literals[0]) == 2 * (i + 1 + n_inputs) 58 | var_def = lines[2+n_variables].strip().split(" ")[0] 59 | 60 | assert var_def == 'i0', 'The definition of variables is wrong.' 61 | # finish sanity-check 62 | 63 | # Add edge 64 | for (i, line) in enumerate(lines[2+n_inputs: 2+n_inputs+n_and]): 65 | line = line.strip().split(" ") 66 | # assert len(line) == 3, 'The length of AND lines should be 3.' 67 | output_idx = int(line[0]) // 2 - 1 68 | # assert (int(line[0]) % 2) == 0, 'There is inverter sign in output literal.' 69 | 70 | # 1. First edge 71 | input1_idx = int(line[1]) // 2 - 1 72 | sign1_idx = int(line[1]) % 2 73 | # If there's a NOT node 74 | if sign1_idx == 1: 75 | if input1_idx in not_dict.keys(): 76 | not_idx = not_dict[input1_idx] 77 | else: 78 | x_data.append([len(x_data), gate_to_index['NOT']]) 79 | # node_labels += [2] 80 | not_idx = len(x_data) - 1 81 | not_dict[input1_idx] = not_idx 82 | edge_index += [[input1_idx, not_idx]] 83 | edge_index += [[not_idx, output_idx]] 84 | else: 85 | edge_index += [[input1_idx, output_idx]] 86 | 87 | 88 | # 2. Second edge 89 | input2_idx = int(line[2]) // 2 - 1 90 | sign2_idx = int(line[2]) % 2 91 | # If there's a NOT node 92 | if sign2_idx == 1: 93 | if input2_idx in not_dict.keys(): 94 | not_idx = not_dict[input2_idx] 95 | else: 96 | x_data.append([len(x_data), gate_to_index['NOT']]) 97 | # node_labels += [2] 98 | not_idx = len(x_data) - 1 99 | not_dict[input2_idx] = not_idx 100 | edge_index += [[input2_idx, not_idx]] 101 | edge_index += [[not_idx, output_idx]] 102 | else: 103 | edge_index += [[input2_idx, output_idx]] 104 | 105 | 106 | if sign_final == 1: 107 | x_data.append([len(x_data), gate_to_index['NOT']]) 108 | # node_labels += [2] 109 | not_idx = len(x_data) - 1 110 | edge_index += [[index_final_and, not_idx]] 111 | 112 | return x_data, edge_index 113 | 114 | def aig_to_xdata(aig_filename, tmp_aag_filename, gate_to_index): 115 | aig2aag_cmd = 'aigtoaig {} {}'.format(aig_filename, tmp_aag_filename) 116 | info = os.popen(aig2aag_cmd).readlines() 117 | 118 | # read aag 119 | f = open(tmp_aag_filename, 'r') 120 | lines = f.readlines() 121 | f.close() 122 | header = lines[0].strip().split(" ") 123 | assert header[0] == 'aag', 'The header of AIG file is wrong.' 124 | # “M”, “I”, “L”, “O”, “A” separated by spaces. 125 | n_variables = eval(header[1]) 126 | n_inputs = eval(header[2]) 127 | n_outputs = eval(header[4]) 128 | n_and = eval(header[5]) 129 | # if n_outputs != 1 or n_variables != (n_inputs + n_and) or n_variables == n_inputs: 130 | # return [], [] 131 | # assert n_outputs == 1, 'The AIG has multiple outputs.' 132 | # assert n_variables == (n_inputs + n_and), 'There are unused AND gates.' 133 | # assert n_variables != n_inputs, '# variable equals to # inputs' 134 | # Construct AIG graph 135 | x_data = [] 136 | edge_index = [] 137 | # node_labels = [] 138 | 139 | # PI 140 | for i in range(n_inputs): 141 | x_data.append([len(x_data), gate_to_index['PI']]) 142 | # AND 143 | for i in range(n_and): 144 | x_data.append([len(x_data), gate_to_index['AND']]) 145 | 146 | # AND Connections 147 | has_not = [-1] * (len(x_data) + 1) 148 | for (i, line) in enumerate(lines[1+n_inputs+n_outputs: ]): 149 | arr = line.replace('\n', '').split(' ') 150 | if len(arr) != 3: 151 | continue 152 | and_index = int(int(arr[0]) / 2) - 1 153 | fanin_1_index = int(int(arr[1]) / 2) - 1 154 | fanin_2_index = int(int(arr[2]) / 2) - 1 155 | fanin_1_not = int(arr[1]) % 2 156 | fanin_2_not = int(arr[2]) % 2 157 | if fanin_1_not == 1: 158 | if has_not[fanin_1_index] == -1: 159 | x_data.append([len(x_data), gate_to_index['NOT']]) 160 | not_index = len(x_data) - 1 161 | edge_index.append([fanin_1_index, not_index]) 162 | has_not[fanin_1_index] = not_index 163 | fanin_1_index = has_not[fanin_1_index] 164 | if fanin_2_not == 1: 165 | if has_not[fanin_2_index] == -1: 166 | x_data.append([len(x_data), gate_to_index['NOT']]) 167 | not_index = len(x_data) - 1 168 | edge_index.append([fanin_2_index, not_index]) 169 | has_not[fanin_2_index] = not_index 170 | fanin_2_index = has_not[fanin_2_index] 171 | edge_index.append([fanin_1_index, and_index]) 172 | edge_index.append([fanin_2_index, and_index]) 173 | 174 | # PO NOT check 175 | for (i, line) in enumerate(lines[1+n_inputs: 1+n_inputs+n_outputs]): 176 | arr = line.replace('\n', '').split(' ') 177 | if len(arr) != 1: 178 | continue 179 | po_index = int(int(arr[0]) / 2) - 1 180 | po_not = int(arr[0]) % 2 181 | if po_not == 1: 182 | if has_not[po_index] == -1: 183 | x_data.append([len(x_data), gate_to_index['NOT']]) 184 | not_index = len(x_data) - 1 185 | edge_index.append([po_index, not_index]) 186 | has_not[po_index] = not_index 187 | 188 | return x_data, edge_index 189 | 190 | def aig_to_cnf(data, fanin_list, po_idx, gate_to_index, const_0=[], const_1=[]): 191 | cnf = [] 192 | for idx, x_data_info in enumerate(data): 193 | if x_data_info[1] == gate_to_index['PI']: 194 | continue 195 | elif x_data_info[1] == gate_to_index['NOT']: 196 | var_C = idx + 1 197 | var_A = fanin_list[idx][0] + 1 198 | cnf.append([-1 * var_C, -1 * var_A]) 199 | cnf.append([var_C, var_A]) 200 | elif x_data_info[1] == gate_to_index['AND']: 201 | var_C = idx + 1 202 | var_A = fanin_list[idx][0] + 1 203 | var_B = fanin_list[idx][1] + 1 204 | cnf.append([var_C, -1*var_A, -1*var_B]) 205 | cnf.append([-1*var_C, var_A]) 206 | cnf.append([-1*var_C, var_B]) 207 | # Const 208 | cnf.append([po_idx + 1]) 209 | for const_0_idx in const_0: 210 | var = const_0_idx + 1 211 | cnf.append([-1 * var]) 212 | for const_1_idx in const_1: 213 | var = const_1_idx + 1 214 | cnf.append([var]) 215 | return cnf 216 | 217 | def aigcone_to_cnf(data, fanin_list, cone_po, cone_po_val, gate_to_index): 218 | # Mask 219 | mask = [0] * len(data) 220 | bfs_q = [cone_po] 221 | while len(bfs_q) > 0: 222 | idx = bfs_q[-1] 223 | mask[idx] = 1 224 | bfs_q.pop() 225 | for fanin_idx in fanin_list[idx]: 226 | if mask[fanin_idx] == 0: 227 | bfs_q.insert(0, fanin_idx) 228 | 229 | # Build CNF 230 | cnf = [] 231 | for idx, x_data_info in enumerate(data): 232 | if mask[idx] == 0: 233 | continue 234 | if x_data_info[1] == gate_to_index['PI']: 235 | continue 236 | elif x_data_info[1] == gate_to_index['NOT']: 237 | var_C = idx + 1 238 | var_A = fanin_list[idx][0] + 1 239 | cnf.append([-1 * var_C, -1 * var_A]) 240 | cnf.append([var_C, var_A]) 241 | elif x_data_info[1] == gate_to_index['AND']: 242 | var_C = idx + 1 243 | var_A = fanin_list[idx][0] + 1 244 | var_B = fanin_list[idx][1] + 1 245 | cnf.append([var_C, -1*var_A, -1*var_B]) 246 | cnf.append([-1*var_C, var_A]) 247 | cnf.append([-1*var_C, var_B]) 248 | if cone_po_val: 249 | cnf.append([cone_po + 1]) 250 | else: 251 | cnf.append([-1 * (cone_po + 1)]) 252 | 253 | return cnf, np.sum(mask) 254 | -------------------------------------------------------------------------------- /src/models/dag_convgnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch import nn 7 | from utils.dag_utils import subgraph, custom_backward_subgraph 8 | 9 | from .gat_conv import AGNNConv 10 | from .gcn_conv import AggConv 11 | from .deepset_conv import DeepSetConv 12 | from .gated_sum_conv import GatedSumConv 13 | from .mlp import MLP 14 | 15 | from torch.nn import LSTM, GRU 16 | 17 | 18 | _aggr_function_factory = { 19 | 'aggnconv': AGNNConv, 20 | 'deepset': DeepSetConv, 21 | 'gated_sum': GatedSumConv, 22 | 'conv_sum': AggConv, 23 | } 24 | 25 | _update_function_factory = { 26 | 'lstm': LSTM, 27 | 'gru': GRU, 28 | } 29 | 30 | 31 | class DAGConvGNN(nn.Module): 32 | ''' 33 | Recurrent Graph Neural Networks for Circuits. 34 | ''' 35 | def __init__(self, args): 36 | super(DAGConvGNN, self).__init__() 37 | 38 | self.args = args 39 | 40 | # configuration 41 | self.num_rounds = args.num_rounds 42 | self.device = args.device 43 | self.predict_diff = args.predict_diff 44 | self.intermediate_supervision = args.intermediate_supervision 45 | self.reverse = args.reverse 46 | self.custom_backward = args.custom_backward 47 | self.use_edge_attr = args.use_edge_attr 48 | 49 | # dimensions 50 | self.num_aggr = args.num_aggr 51 | self.dim_node_feature = args.dim_node_feature 52 | self.dim_hidden = args.dim_hidden 53 | self.dim_mlp = args.dim_mlp 54 | self.dim_pred = args.dim_pred 55 | self.num_fc = args.num_fc 56 | self.wx_update = args.wx_update 57 | self.wx_mlp = args.wx_mlp 58 | self.dim_edge_feature = args.dim_edge_feature 59 | 60 | # 1. message/aggr-related 61 | dim_aggr = self.dim_hidden# + self.dim_edge_feature if self.use_edge_attr else self.dim_hidden 62 | if self.args.aggr_function in _aggr_function_factory.keys(): 63 | # if self.use_edge_attr: 64 | # aggr_forward_pre = MLP(self.dim_hidden, self.dim_hidden, self.dim_hidden, num_layer=3, p_drop=0.2) 65 | # else: 66 | aggr_forward_pre = nn.Linear(dim_aggr, self.dim_hidden) 67 | if self.args.aggr_function == 'deepset': 68 | aggr_forward_post = nn.Linear(self.dim_hidden, self.dim_hidden) 69 | self.aggr_forward = _aggr_function_factory[self.args.aggr_function](dim_aggr, self.dim_hidden, mlp=aggr_forward_pre, mlp_post=aggr_forward_post, wea=self.use_edge_attr) 70 | else: 71 | self.aggr_forward = _aggr_function_factory[self.args.aggr_function](dim_aggr, self.dim_hidden, mlp=aggr_forward_pre, wea=self.use_edge_attr) 72 | if self.reverse: 73 | # if self.use_edge_attr: 74 | # aggr_backward_pre = MLP(self.dim_hidden, self.dim_hidden, self.dim_hidden, num_layer=3, p_drop=0.2) 75 | # else: 76 | aggr_backward_pre = nn.Linear(dim_aggr, self.dim_hidden) 77 | if self.args.aggr_function == 'deepset': 78 | aggr_backward_post = nn.Linear(self.dim_hidden, self.dim_hidden) 79 | self.aggr_backward = _aggr_function_factory[self.args.aggr_function](dim_aggr, self.dim_hidden, mlp=aggr_backward_pre, mlp_post=aggr_backward_post, wea=self.use_edge_attr) 80 | else: 81 | self.aggr_backward = _aggr_function_factory[self.args.aggr_function](dim_aggr, self.dim_hidden, mlp=aggr_backward_pre, reverse=True, wea=self.use_edge_attr) 82 | else: 83 | raise KeyError('no support {} aggr function.'.format(self.args.aggr_function)) 84 | 85 | 86 | # 2. update-related 87 | # if self.args.update_function in _update_function_factory.keys(): 88 | # # Here only consider the inputs as the concatenated vector from embedding and feature vector. 89 | # if self.wx_update: 90 | # self.update_forward = _update_function_factory[self.args.update_function](self.dim_node_feature+self.dim_hidden, self.dim_hidden) 91 | # if self.reverse: 92 | # self.update_backward = _update_function_factory[self.args.update_function](self.dim_node_feature+self.dim_hidden, self.dim_hidden) 93 | # else: 94 | # self.update_forward = _update_function_factory[self.args.update_function](self.dim_hidden, self.dim_hidden) 95 | # if self.reverse: 96 | # self.update_backward = _update_function_factory[self.args.update_function](self.dim_hidden, self.dim_hidden) 97 | # else: 98 | # raise KeyError('no support {} update function.'.format(self.args.update_function)) 99 | # consider the embedding for the LSTM/GRU model initialized by non-zeros 100 | self.one = torch.ones(1).to(self.device) 101 | self.emd_int = nn.Linear(1, self.dim_hidden) 102 | self.one.requires_grad = False 103 | 104 | 105 | # 3. predictor-related 106 | # TODO: support multiple predictors. Use a nn.ModuleList to handle it. 107 | self.norm_layer = args.norm_layer 108 | self.activation_layer = args.activation_layer 109 | if self.wx_mlp: 110 | self.predictor = MLP(self.dim_hidden+self.dim_node_feature, self.dim_mlp, self.dim_pred, 111 | num_layer=self.num_fc, norm_layer=self.norm_layer, act_layer=self.activation_layer, sigmoid=False, tanh=False) 112 | else: 113 | self.predictor = MLP(self.dim_hidden, self.dim_mlp, self.dim_pred, 114 | num_layer=self.num_fc, norm_layer=self.norm_layer, act_layer=self.activation_layer, sigmoid=False, tanh=False) 115 | 116 | 117 | 118 | def forward(self, G): 119 | num_nodes = G.num_nodes 120 | num_layers_f = max(G.forward_level).item() + 1 121 | num_layers_b = max(G.backward_level).item() + 1 122 | one = self.one 123 | h_init = self.emd_int(one).view(1, 1, -1) # (1 x 1 x dim_hidden) 124 | h_init = h_init.repeat(1, num_nodes, 1) # (1 x num_nodes x dim_hidden) 125 | 126 | if self.args.update_function == 'lstm': 127 | preds = self._lstm_forward(G, h_init, num_layers_f, num_layers_b, num_nodes) 128 | elif self.args.update_function == 'gru': 129 | preds = self._gru_forward(G, h_init, num_layers_f, num_layers_b, num_nodes) 130 | else: 131 | raise NotImplementedError('The update function should be specified as one of lstm and gru.') 132 | 133 | return preds 134 | 135 | 136 | def _lstm_forward(self, G, h_init, num_layers_f, num_layers_b, num_nodes): 137 | x, edge_index = G.x, G.edge_index 138 | edge_attr = G.edge_attr if self.use_edge_attr else None 139 | 140 | node_state = (h_init, torch.zeros(1, num_nodes, self.dim_hidden).to(self.device)) # (h_0, c_0). here we only initialize h_0. TODO: option of not initializing the hidden state of LSTM. 141 | 142 | # TODO: add supports for modified attention and customized backward design. 143 | preds = [] 144 | for _ in range(self.num_rounds): 145 | for l_idx in range(1, num_layers_f): 146 | # forward layer 147 | layer_mask = G.forward_level == l_idx 148 | l_node = G.forward_index[layer_mask] 149 | 150 | l_state = (torch.index_select(node_state[0], dim=1, index=l_node), 151 | torch.index_select(node_state[1], dim=1, index=l_node)) 152 | 153 | l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=1) 154 | msg = self.aggr_forward(node_state[0].squeeze(0), l_edge_index, l_edge_attr) 155 | l_msg = torch.index_select(msg, dim=0, index=l_node) 156 | l_x = torch.index_select(x, dim=0, index=l_node) 157 | 158 | if self.args.wx_update: 159 | _, l_state = self.update_forward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state) 160 | else: 161 | _, l_state = self.update_forward(l_msg.unsqueeze(0), l_state) 162 | 163 | node_state[0][:, l_node, :] = l_state[0] 164 | node_state[1][:, l_node, :] = l_state[1] 165 | if self.reverse: 166 | for l_idx in range(1, num_layers_b): 167 | # backward layer 168 | layer_mask = G.backward_level == l_idx 169 | l_node = G.backward_index[layer_mask] 170 | 171 | l_state = (torch.index_select(node_state[0], dim=1, index=l_node), 172 | torch.index_select(node_state[1], dim=1, index=l_node)) 173 | if self.custom_backward: 174 | l_edge_index = custom_backward_subgraph(l_node, edge_index, device=self.device, dim=0) 175 | else: 176 | l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=0) 177 | msg = self.aggr_backward(node_state[0].squeeze(0), l_edge_index, l_edge_attr) 178 | l_msg = torch.index_select(msg, dim=0, index=l_node) 179 | l_x = torch.index_select(x, dim=0, index=l_node) 180 | 181 | if self.args.wx_update: 182 | _, l_state = self.update_backward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state) 183 | else: 184 | _, l_state = self.update_backward(l_msg.unsqueeze(0), l_state) 185 | 186 | node_state[0][:, l_node, :] = l_state[0] 187 | node_state[1][:, l_node, :] = l_state[1] 188 | 189 | if self.intermediate_supervision: 190 | preds.append(self.predictor(node_state[0].squeeze(0))) 191 | 192 | node_embedding = node_state[0].squeeze(0) 193 | if self.wx_mlp: 194 | pred = self.predictor(torch.cat([node_embedding, x], dim=1)) 195 | else: 196 | pred = self.predictor(node_embedding) 197 | preds.append(pred) 198 | 199 | return preds 200 | 201 | def _gru_forward(self, G, h_init, num_layers_f, num_layers_b, num_nodes): 202 | x, edge_index = G.x, G.edge_index 203 | edge_attr = G.edge_attr if self.use_edge_attr else None 204 | 205 | node_state = h_init # (h_0). here we initialize h_0. TODO: option of not initializing the hidden state of GRU. 206 | 207 | # TODO: add supports for modified attention and customized backward design. 208 | preds = [] 209 | for _ in range(self.num_rounds): 210 | for l_idx in range(1, num_layers_f): 211 | # forward layer 212 | layer_mask = G.forward_level == l_idx 213 | l_node = G.forward_index[layer_mask] 214 | 215 | l_state = torch.index_select(node_state, dim=1, index=l_node) 216 | 217 | l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=1) 218 | msg = self.aggr_forward(node_state.squeeze(0), l_edge_index, l_edge_attr) 219 | l_msg = torch.index_select(msg, dim=0, index=l_node) 220 | l_x = torch.index_select(x, dim=0, index=l_node) 221 | 222 | # if self.args.wx_update: 223 | # _, l_state = self.update_forward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state) 224 | # else: 225 | # _, l_state = self.update_forward(l_msg.unsqueeze(0), l_state) 226 | node_state[:, l_node, :] = l_msg 227 | 228 | if self.reverse: 229 | for l_idx in range(1, num_layers_b): 230 | # backward layer 231 | layer_mask = G.backward_level == l_idx 232 | l_node = G.backward_index[layer_mask] 233 | 234 | l_state = torch.index_select(node_state, dim=1, index=l_node) 235 | 236 | if self.custom_backward: 237 | l_edge_index = custom_backward_subgraph(l_node, edge_index, device=self.device, dim=0) 238 | else: 239 | l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=0) 240 | msg = self.aggr_backward(node_state.squeeze(0), l_edge_index, l_edge_attr) 241 | l_msg = torch.index_select(msg, dim=0, index=l_node) 242 | l_x = torch.index_select(x, dim=0, index=l_node) 243 | 244 | # if self.args.wx_update: 245 | # _, l_state = self.update_backward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state) 246 | # else: 247 | # _, l_state = self.update_backward(l_msg.unsqueeze(0), l_state) 248 | 249 | node_state[:, l_node, :] = l_msg 250 | 251 | if self.intermediate_supervision: 252 | preds.append(self.predictor(node_state.squeeze(0))) 253 | 254 | node_embedding = node_state.squeeze(0) 255 | 256 | if self.wx_mlp: 257 | pred = self.predictor(torch.cat([node_embedding, x], dim=1)) 258 | else: 259 | pred = self.predictor(node_embedding) 260 | preds.append(pred) 261 | 262 | return preds 263 | 264 | 265 | 266 | 267 | def get_dag_recurrent_gnn(args): 268 | return DAGConvGNN(args) -------------------------------------------------------------------------------- /src/models/recgnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch import nn 7 | from utils.dag_utils import subgraph, custom_backward_subgraph 8 | 9 | from .gat_conv import AGNNConv 10 | from .gcn_conv import AggConv 11 | from .deepset_conv import DeepSetConv 12 | from .gated_sum_conv import GatedSumConv 13 | from .mlp import MLP 14 | 15 | from torch.nn import LSTM, GRU 16 | 17 | 18 | _aggr_function_factory = { 19 | 'aggnconv': AGNNConv, 20 | 'deepset': DeepSetConv, 21 | 'gated_sum': GatedSumConv, 22 | 'conv_sum': AggConv, 23 | } 24 | 25 | _update_function_factory = { 26 | 'lstm': LSTM, 27 | 'gru': GRU, 28 | } 29 | 30 | 31 | class RecGNN(nn.Module): 32 | ''' 33 | Recurrent Graph Neural Networks for Circuits. 34 | ''' 35 | def __init__(self, args): 36 | super(RecGNN, self).__init__() 37 | 38 | self.args = args 39 | 40 | # configuration 41 | self.num_rounds = args.num_rounds 42 | self.device = args.device 43 | self.predict_diff = args.predict_diff 44 | self.intermediate_supervision = args.intermediate_supervision 45 | self.reverse = args.reverse 46 | self.custom_backward = args.custom_backward 47 | self.use_edge_attr = args.use_edge_attr 48 | self.mask = args.mask 49 | 50 | # dimensions 51 | self.num_aggr = args.num_aggr 52 | self.dim_node_feature = args.dim_node_feature 53 | self.dim_hidden = args.dim_hidden 54 | self.dim_mlp = args.dim_mlp 55 | self.dim_pred = args.dim_pred 56 | self.num_fc = args.num_fc 57 | self.wx_update = args.wx_update 58 | self.wx_mlp = args.wx_mlp 59 | self.dim_edge_feature = args.dim_edge_feature 60 | 61 | # 1. message/aggr-related 62 | dim_aggr = self.dim_hidden# + self.dim_edge_feature if self.use_edge_attr else self.dim_hidden 63 | if self.args.aggr_function in _aggr_function_factory.keys(): 64 | # if self.use_edge_attr: 65 | # aggr_forward_pre = MLP(self.dim_hidden, self.dim_hidden, self.dim_hidden, num_layer=3, p_drop=0.2) 66 | # else: 67 | aggr_forward_pre = nn.Linear(dim_aggr, self.dim_hidden) 68 | if self.args.aggr_function == 'deepset': 69 | aggr_forward_post = nn.Linear(self.dim_hidden, self.dim_hidden) 70 | self.aggr_forward = _aggr_function_factory[self.args.aggr_function](dim_aggr, self.dim_hidden, mlp=aggr_forward_pre, mlp_post=aggr_forward_post, wea=self.use_edge_attr) 71 | else: 72 | self.aggr_forward = _aggr_function_factory[self.args.aggr_function](dim_aggr, self.dim_hidden, mlp=aggr_forward_pre, wea=self.use_edge_attr) 73 | if self.reverse: 74 | # if self.use_edge_attr: 75 | # aggr_backward_pre = MLP(self.dim_hidden, self.dim_hidden, self.dim_hidden, num_layer=3, p_drop=0.2) 76 | # else: 77 | aggr_backward_pre = nn.Linear(dim_aggr, self.dim_hidden) 78 | if self.args.aggr_function == 'deepset': 79 | aggr_backward_post = nn.Linear(self.dim_hidden, self.dim_hidden) 80 | self.aggr_backward = _aggr_function_factory[self.args.aggr_function](dim_aggr, self.dim_hidden, mlp=aggr_backward_pre, mlp_post=aggr_backward_post, wea=self.use_edge_attr) 81 | else: 82 | self.aggr_backward = _aggr_function_factory[self.args.aggr_function](dim_aggr, self.dim_hidden, mlp=aggr_backward_pre, reverse=True, wea=self.use_edge_attr) 83 | else: 84 | raise KeyError('no support {} aggr function.'.format(self.args.aggr_function)) 85 | 86 | 87 | # 2. update-related 88 | if self.args.update_function in _update_function_factory.keys(): 89 | # Here only consider the inputs as the concatenated vector from embedding and feature vector. 90 | if self.wx_update: 91 | self.update_forward = _update_function_factory[self.args.update_function](self.dim_node_feature+self.dim_hidden, self.dim_hidden) 92 | if self.reverse: 93 | self.update_backward = _update_function_factory[self.args.update_function](self.dim_node_feature+self.dim_hidden, self.dim_hidden) 94 | else: 95 | self.update_forward = _update_function_factory[self.args.update_function](self.dim_hidden, self.dim_hidden) 96 | if self.reverse: 97 | self.update_backward = _update_function_factory[self.args.update_function](self.dim_hidden, self.dim_hidden) 98 | else: 99 | raise KeyError('no support {} update function.'.format(self.args.update_function)) 100 | # consider the embedding for the LSTM/GRU model initialized by non-zeros 101 | self.one = torch.ones(1).to(self.device) 102 | self.emd_int = nn.Linear(1, self.dim_hidden) 103 | self.one.requires_grad = False 104 | 105 | 106 | # 3. predictor-related 107 | # TODO: support multiple predictors. Use a nn.ModuleList to handle it. 108 | self.norm_layer = args.norm_layer 109 | self.activation_layer = args.activation_layer 110 | if self.wx_mlp: 111 | self.predictor = MLP(self.dim_hidden+self.dim_node_feature, self.dim_mlp, self.dim_pred, 112 | num_layer=self.num_fc, norm_layer=self.norm_layer, act_layer=self.activation_layer, sigmoid=False, tanh=False) 113 | else: 114 | self.predictor = MLP(self.dim_hidden, self.dim_mlp, self.dim_pred, 115 | num_layer=self.num_fc, norm_layer=self.norm_layer, act_layer=self.activation_layer, sigmoid=False, tanh=False) 116 | 117 | 118 | 119 | def forward(self, G): 120 | num_nodes = G.num_nodes 121 | num_layers_f = max(G.forward_level).item() + 1 122 | num_layers_b = max(G.backward_level).item() + 1 123 | one = self.one 124 | h_init = self.emd_int(one).view(1, 1, -1) # (1 x 1 x dim_hidden) 125 | h_init = h_init.repeat(1, num_nodes, 1) # (1 x num_nodes x dim_hidden) 126 | # h_init = torch.empty(1, num_nodes, self.dim_hidden).to(self.device) 127 | # nn.init.normal_(h_init) 128 | 129 | if self.mask: 130 | h_true = torch.ones_like(h_init).to(self.device) 131 | h_false = -torch.ones_like(h_init).to(self.device) 132 | h_true.requires_grad = False 133 | h_false.requires_grad = False 134 | h_init = self.imply_mask(G, h_init, h_true, h_false) 135 | else: 136 | h_true = None 137 | h_false = None 138 | 139 | if self.args.update_function == 'lstm': 140 | preds = self._lstm_forward(G, h_init, num_layers_f, num_layers_b, num_nodes) 141 | elif self.args.update_function == 'gru': 142 | preds = self._gru_forward(G, h_init, num_layers_f, num_layers_b, h_true, h_false) 143 | else: 144 | raise NotImplementedError('The update function should be specified as one of lstm and gru.') 145 | 146 | return preds 147 | 148 | 149 | def _lstm_forward(self, G, h_init, num_layers_f, num_layers_b, num_nodes): 150 | x, edge_index = G.x, G.edge_index 151 | edge_attr = G.edge_attr if self.use_edge_attr else None 152 | 153 | node_state = (h_init, torch.zeros(1, num_nodes, self.dim_hidden).to(self.device)) # (h_0, c_0). here we only initialize h_0. TODO: option of not initializing the hidden state of LSTM. 154 | 155 | # TODO: add supports for modified attention 156 | preds = [] 157 | for _ in range(self.num_rounds): 158 | for l_idx in range(1, num_layers_f): 159 | # forward layer 160 | layer_mask = G.forward_level == l_idx 161 | l_node = G.forward_index[layer_mask] 162 | 163 | l_state = (torch.index_select(node_state[0], dim=1, index=l_node), 164 | torch.index_select(node_state[1], dim=1, index=l_node)) 165 | 166 | l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=1) 167 | msg = self.aggr_forward(node_state[0].squeeze(0), l_edge_index, l_edge_attr) 168 | l_msg = torch.index_select(msg, dim=0, index=l_node) 169 | l_x = torch.index_select(x, dim=0, index=l_node) 170 | 171 | if self.args.wx_update: 172 | _, l_state = self.update_forward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state) 173 | else: 174 | _, l_state = self.update_forward(l_msg.unsqueeze(0), l_state) 175 | 176 | node_state[0][:, l_node, :] = l_state[0] 177 | node_state[1][:, l_node, :] = l_state[1] 178 | if self.reverse: 179 | for l_idx in range(1, num_layers_b): 180 | # backward layer 181 | layer_mask = G.backward_level == l_idx 182 | l_node = G.backward_index[layer_mask] 183 | 184 | l_state = (torch.index_select(node_state[0], dim=1, index=l_node), 185 | torch.index_select(node_state[1], dim=1, index=l_node)) 186 | if self.custom_backward: 187 | l_edge_index = custom_backward_subgraph(l_node, edge_index, device=self.device, dim=0) 188 | else: 189 | l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=0) 190 | msg = self.aggr_backward(node_state[0].squeeze(0), l_edge_index, l_edge_attr) 191 | l_msg = torch.index_select(msg, dim=0, index=l_node) 192 | l_x = torch.index_select(x, dim=0, index=l_node) 193 | 194 | if self.args.wx_update: 195 | _, l_state = self.update_backward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state) 196 | else: 197 | _, l_state = self.update_backward(l_msg.unsqueeze(0), l_state) 198 | 199 | node_state[0][:, l_node, :] = l_state[0] 200 | node_state[1][:, l_node, :] = l_state[1] 201 | 202 | if self.intermediate_supervision: 203 | preds.append(self.predictor(node_state[0].squeeze(0))) 204 | 205 | node_embedding = node_state[0].squeeze(0) 206 | if self.wx_mlp: 207 | pred = self.predictor(torch.cat([node_embedding, x], dim=1)) 208 | else: 209 | pred = self.predictor(node_embedding) 210 | preds.append(pred) 211 | 212 | return preds 213 | 214 | def _gru_forward(self, G, h_init, num_layers_f, num_layers_b, h_true=None, h_false=None): 215 | G = G.to(self.device) 216 | x, edge_index = G.x, G.edge_index 217 | edge_attr = G.edge_attr if self.use_edge_attr else None 218 | 219 | node_state = h_init.to(self.device) 220 | 221 | # TODO: add supports for modified attention 222 | preds = [] 223 | for _ in range(self.num_rounds): 224 | for l_idx in range(1, num_layers_f): 225 | # forward layer 226 | layer_mask = G.forward_level == l_idx 227 | l_node = G.forward_index[layer_mask] 228 | 229 | l_state = torch.index_select(node_state, dim=1, index=l_node) 230 | 231 | l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=1) 232 | msg = self.aggr_forward(node_state.squeeze(0), l_edge_index, l_edge_attr) 233 | l_msg = torch.index_select(msg, dim=0, index=l_node) 234 | l_x = torch.index_select(x, dim=0, index=l_node) 235 | 236 | if self.args.wx_update: 237 | _, l_state = self.update_forward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state) 238 | else: 239 | _, l_state = self.update_forward(l_msg.unsqueeze(0), l_state) 240 | node_state[:, l_node, :] = l_state 241 | 242 | # TODO: Add the masking 243 | if self.mask: 244 | node_state = self.imply_mask(G, node_state, h_true, h_false) 245 | 246 | if self.reverse: 247 | for l_idx in range(1, num_layers_b): 248 | # backward layer 249 | layer_mask = G.backward_level == l_idx 250 | l_node = G.backward_index[layer_mask] 251 | 252 | l_state = torch.index_select(node_state, dim=1, index=l_node) 253 | 254 | if self.custom_backward: 255 | l_edge_index = custom_backward_subgraph(l_node, edge_index, device=self.device, dim=0) 256 | else: 257 | l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=0) 258 | msg = self.aggr_backward(node_state.squeeze(0), l_edge_index, l_edge_attr) 259 | l_msg = torch.index_select(msg, dim=0, index=l_node) 260 | l_x = torch.index_select(x, dim=0, index=l_node) 261 | 262 | if self.args.wx_update: 263 | _, l_state = self.update_backward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state) 264 | else: 265 | _, l_state = self.update_backward(l_msg.unsqueeze(0), l_state) 266 | 267 | node_state[:, l_node, :] = l_state 268 | 269 | # TODO: Add the masking 270 | if self.mask: 271 | node_state = self.imply_mask(G, node_state, h_true, h_false) 272 | 273 | if self.intermediate_supervision: 274 | preds.append(self.predictor(node_state.squeeze(0))) 275 | 276 | node_embedding = node_state.squeeze(0) 277 | 278 | if self.wx_mlp: 279 | pred = self.predictor(torch.cat([node_embedding, x], dim=1)) 280 | else: 281 | pred = self.predictor(node_embedding) 282 | preds.append(pred) 283 | 284 | return preds 285 | 286 | 287 | def imply_mask(self, G, h, h_true, h_false): 288 | true_mask = (G.mask == 1.0).unsqueeze(0) 289 | false_mask = (G.mask == 0.0).unsqueeze(0) 290 | normal_mask = (G.mask == -1.0).unsqueeze(0) 291 | h_mask = h * normal_mask + h_true * true_mask + h_false * false_mask 292 | return h_mask 293 | 294 | 295 | 296 | def get_recurrent_gnn(args): 297 | return RecGNN(args) -------------------------------------------------------------------------------- /src/utils/sat_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | import torch 5 | from .utils import pyg_simulation 6 | 7 | 8 | def generate_k_iclause(n, k): 9 | vs = np.random.choice(n, size=min(n, k), replace=False) 10 | return [v + 1 if random.random() < 0.5 else -(v + 1) for v in vs] 11 | 12 | 13 | 14 | # the utility function for Circuit-SAT 15 | def get_sub_cnf(cnf, var, is_inv): 16 | res_cnf = [] 17 | if not is_inv: 18 | for clause in cnf: 19 | if not var in clause: 20 | tmp_clause = clause.copy() 21 | for idx, ele in enumerate(tmp_clause): 22 | if ele == -var: 23 | del tmp_clause[idx] 24 | res_cnf.append(tmp_clause) 25 | else: 26 | for clause in cnf: 27 | if not -var in clause: 28 | tmp_clause = clause.copy() 29 | for idx, ele in enumerate(tmp_clause): 30 | if ele == var: 31 | del tmp_clause[idx] 32 | res_cnf.append(tmp_clause) 33 | return res_cnf 34 | 35 | def two_fanin_gate(po_idx, fan_in_list, x, edge_index, gate_type): 36 | gate_list = fan_in_list.copy() 37 | new_gate_list = [] 38 | 39 | while True: 40 | if len(gate_list) + len(new_gate_list) == 2: 41 | for gate_idx in gate_list: 42 | edge_index.append([gate_idx, po_idx]) 43 | for gate_idx in new_gate_list: 44 | edge_index.append([gate_idx, po_idx]) 45 | break 46 | if len(gate_list) == 0: 47 | gate_list = new_gate_list.copy() 48 | new_gate_list.clear() 49 | elif len(gate_list) == 1: 50 | new_gate_list.append(gate_list[0]) 51 | gate_list = new_gate_list.copy() 52 | new_gate_list.clear() 53 | else: 54 | new_gate_idx = len(x) 55 | x.append(gate_type) 56 | edge_index.append([gate_list[0], new_gate_idx]) 57 | edge_index.append([gate_list[1], new_gate_idx]) 58 | gate_list = gate_list[2:] 59 | new_gate_list.append(new_gate_idx) 60 | 61 | 62 | def save_cnf(cnf, cnf_idx, x, edge_index, inv2idx): 63 | cnf_fan_in_list = [] 64 | for clause in cnf: 65 | if len(clause) == 0: 66 | continue 67 | elif len(clause) == 1: 68 | if clause[0] < 0: 69 | cnf_fan_in_list.append(inv2idx[abs(clause[0])]) 70 | else: 71 | cnf_fan_in_list.append(clause[0]) 72 | else: 73 | clause_idx = len(x) 74 | x.append(one_hot_gate_type('OR')) 75 | cnf_fan_in_list.append(clause_idx) 76 | clause_fan_in_list = [] 77 | for ele in clause: 78 | if ele < 0: 79 | clause_fan_in_list.append(inv2idx[abs(ele)]) 80 | else: 81 | clause_fan_in_list.append(ele) 82 | two_fanin_gate(clause_idx, clause_fan_in_list, x, edge_index, x[clause_idx]) 83 | 84 | x[cnf_idx] = one_hot_gate_type('AND') 85 | two_fanin_gate(cnf_idx, cnf_fan_in_list, x, edge_index, x[cnf_idx]) 86 | 87 | def merge_cnf(cnf): 88 | res = [] 89 | clause2bool = {} 90 | for clause in cnf: 91 | tmp_clause = tuple(clause) 92 | if not tmp_clause in clause2bool: 93 | clause2bool[tmp_clause] = True 94 | res.append(clause) 95 | return res 96 | 97 | def recursion_generation(cnf, cnf_idx, current_depth, max_depth, n_vars, x, edge_index, inv2idx): 98 | ''' 99 | Expand the CNF as binary tree 100 | The expanded CNF can be writen as: 101 | CNF = OR(B_T, B_F) 102 | B_T = AND(exp_T_CNF, var) 103 | B_F = AND(exp_F_CNF, var_inv) 104 | # exp_T_CNF, exp_F_CNF are new CNFs 105 | Input: 106 | cnf: iclauses 107 | cnf_idx: the cnf PO index in x 108 | current_depth: current expand time 109 | max_depth: maximum expand time 110 | n_vars: number of variables 111 | x: nodes 112 | edge_index: edge 113 | inv2idx: PI_inv index 114 | ''' 115 | 116 | #################### 117 | # Store as CNF 118 | #################### 119 | if current_depth == max_depth: 120 | save_cnf(cnf, cnf_idx, x, edge_index, inv2idx) 121 | return 122 | 123 | #################### 124 | # Sort 125 | #################### 126 | var_times = [0] * (n_vars + 1) 127 | for idx in range(1, n_vars + 1, 1): 128 | for clause in cnf: 129 | if idx in clause: 130 | var_times[abs(idx)] += 1 131 | 132 | var_sort = np.argsort(var_times) 133 | most_var = var_sort[-1] 134 | if var_times[most_var] == 0: 135 | save_cnf(cnf, cnf_idx, x, edge_index, inv2idx) 136 | return 137 | 138 | 139 | #################### 140 | # Expansion 141 | #################### 142 | for most_var in var_sort[::-1]: 143 | var_idx = most_var 144 | next_var = False 145 | # Get sub-CNFs 146 | exp_T_cnf = get_sub_cnf(cnf, most_var, 0) 147 | exp_F_cnf = get_sub_cnf(cnf, most_var, 1) 148 | 149 | for clause in exp_T_cnf: 150 | if len(clause) == 0: 151 | next_var = True 152 | break 153 | for clause in exp_F_cnf: 154 | if len(clause) == 0: 155 | next_var = True 156 | break 157 | if not next_var: 158 | break 159 | if most_var == 0: 160 | save_cnf(cnf, cnf_idx, x, edge_index, inv2idx) 161 | return 162 | 163 | if not most_var in inv2idx: 164 | inv2idx[most_var] = len(x) 165 | x.append(one_hot_gate_type('NOT')) 166 | edge_index.append([most_var, inv2idx[most_var]]) 167 | var_inv_idx = inv2idx[most_var] 168 | 169 | exp_T_cnf = merge_cnf(exp_T_cnf) 170 | exp_F_cnf = merge_cnf(exp_F_cnf) 171 | 172 | # ------------------------------------------ 173 | # Construct (exp_T_CNF) and (B_T) 174 | if len(exp_T_cnf) == 0: 175 | edge_index.append([var_idx, cnf_idx]) 176 | elif len(exp_T_cnf) == 1: 177 | # Construct (B_T): B_T = AND(var_idx, exp_T) 178 | B_T_idx = len(x) 179 | x.append(one_hot_gate_type('AND')) 180 | exp_T_cnf = exp_T_cnf[0] 181 | if len(exp_T_cnf) == 1: # The clause only have one var 182 | exp_T_idx = exp_T_cnf[0] 183 | if exp_T_idx < 0: 184 | exp_T_idx = inv2idx[abs(exp_T_idx)] 185 | else: # The clause have many vars 186 | exp_T_idx = len(x) 187 | x.append(one_hot_gate_type('OR')) 188 | for ele in exp_T_cnf: 189 | if ele < 0: 190 | ele_idx = inv2idx[abs(ele)] 191 | else: 192 | ele_idx = ele 193 | edge_index.append([ele_idx, exp_T_idx]) 194 | edge_index.append([exp_T_idx, B_T_idx]) 195 | edge_index.append([var_idx, B_T_idx]) 196 | edge_index.append([B_T_idx, cnf_idx]) 197 | else: 198 | # Construct(exp_T_CNF) 199 | exp_T_cnf_idx = len(x) 200 | x.append(one_hot_gate_type('OR')) 201 | recursion_generation(exp_T_cnf, exp_T_cnf_idx, current_depth + 1, max_depth, 202 | n_vars, x, edge_index, inv2idx) 203 | # Construct (B_T) 204 | B_T_idx = len(x) 205 | x.append(one_hot_gate_type('AND')) 206 | edge_index.append([exp_T_cnf_idx, B_T_idx]) 207 | edge_index.append([var_idx, B_T_idx]) 208 | edge_index.append([B_T_idx, cnf_idx]) 209 | 210 | # ------------------------------------------ 211 | # Construct (exp_F_CNF) and (B_F) 212 | if len(exp_F_cnf) == 0: 213 | edge_index.append([var_inv_idx, cnf_idx]) 214 | elif len(exp_F_cnf) == 1: 215 | # Construct (B_F): B_F = AND(var_idx, exp_F) 216 | B_F_idx = len(x) 217 | x.append(one_hot_gate_type('AND')) 218 | exp_F_cnf = exp_F_cnf[0] 219 | if len(exp_F_cnf) == 1: # The clause only have one var 220 | exp_F_idx = exp_F_cnf[0] 221 | if exp_F_idx < 0: 222 | exp_F_idx = inv2idx[abs(exp_F_idx)] 223 | else: # The clause have many vars 224 | exp_F_idx = len(x) 225 | x.append(one_hot_gate_type('OR')) 226 | for ele in exp_F_cnf: 227 | if ele < 0: 228 | ele_idx = inv2idx[abs(ele)] 229 | else: 230 | ele_idx = ele 231 | edge_index.append([ele_idx, exp_F_idx]) 232 | edge_index.append([exp_F_idx, B_F_idx]) 233 | edge_index.append([var_inv_idx, B_F_idx]) 234 | edge_index.append([B_F_idx, cnf_idx]) 235 | else: 236 | # Construct(exp_F_CNF) 237 | exp_F_cnf_idx = len(x) 238 | x.append(one_hot_gate_type('OR')) 239 | recursion_generation(exp_F_cnf, exp_F_cnf_idx, current_depth + 1, max_depth, 240 | n_vars, x, edge_index, inv2idx) 241 | # Construct (B_F) 242 | B_F_idx = len(x) 243 | x.append(one_hot_gate_type('AND')) 244 | edge_index.append([exp_F_cnf_idx, B_F_idx]) 245 | edge_index.append([var_inv_idx, B_F_idx]) 246 | edge_index.append([B_F_idx, cnf_idx]) 247 | 248 | 249 | def one_hot_gate_type(gate_type): 250 | res = [] 251 | if gate_type == 'PI': 252 | res = [1, 0, 0, 0] 253 | elif gate_type == 'AND': res = [0, 1, 0, 0] 254 | elif gate_type == 'OR': 255 | res = [0, 0, 1, 0] 256 | elif gate_type == 'NOT': 257 | res = [0, 0, 0, 1] 258 | else: 259 | print('[ERROR] Unknown gate type') 260 | return res 261 | 262 | 263 | 264 | def write_dimacs_to(n_vars, iclauses, out_filename): 265 | with open(out_filename, 'w') as f: 266 | f.write("p cnf %d %d\n" % (n_vars, len(iclauses))) 267 | for c in iclauses: 268 | for x in c: 269 | f.write("%d " % x) 270 | f.write("0\n") 271 | 272 | 273 | def solve_sat_iteratively(g, detector): 274 | # here we consider the PO has already been masked 275 | print('Name of circuit: ', g.name) 276 | # print(g) 277 | 278 | # set PO as 1. 279 | layer_mask = g.backward_level == 0 280 | l_node = g.backward_index[layer_mask] 281 | g.mask[l_node] = torch.tensor(1.0) 282 | 283 | # check # PIs 284 | # literal index 285 | layer_mask = g.forward_level == 0 286 | l_node = g.forward_index[layer_mask] 287 | print('# PIs: ', len(l_node)) 288 | 289 | # random solution generation. 290 | # sol = (torch.rand(len(l_node)) > 0.5).int().unsqueeze(1) 291 | # sat = pyg_simulation(g, sol)[0] 292 | # return sol, sat 293 | # only one forward 294 | # ret = detector.run(g) 295 | # output = ret['results'].cpu() 296 | # sol = (output[l_node] > 0.5).int() 297 | # sat = pyg_simulation(g, sol)[0] 298 | # return sol, sat 299 | 300 | # for backtracking 301 | ORDER = [] 302 | change_ind = -1 303 | mask_backup = g.mask.clone().detach() 304 | 305 | 306 | for i in range(len(l_node)): 307 | print('==> # ', i+1, 'solving..') 308 | ret = detector.run(g) 309 | output = ret['results'].cpu() 310 | 311 | # mask 312 | one_mask = torch.zeros(g.y.size(0)) 313 | one_mask = one_mask.scatter(dim=0, index=l_node, src=torch.ones(len(l_node))).unsqueeze(1) 314 | 315 | max_val, max_ind = torch.max(output * one_mask, dim=0) 316 | min_val, min_ind = torch.min(output + (1 - one_mask), dim=0) 317 | 318 | ext_val, ext_ind = (max_val, max_ind) if (max_val > (1 - min_val)) else (min_val, min_ind) 319 | print('Assign No. ', ext_ind.item(), 'with prob: ', ext_val.item(), 'as value: ', 1.0 if ext_val > 0.5 else 0.0) 320 | # g.mask[ext_ind] = torch.tensor(np.random.binomial(n=1, p=ext_val), dtype=torch.float) 321 | g.mask[ext_ind] = torch.tensor(1.0) if ext_val > 0.5 else torch.tensor(0.0) 322 | # push the current index to Q 323 | ORDER.append(ext_ind) 324 | 325 | l_node_new = [] 326 | for i in l_node: 327 | if i != ext_ind: 328 | l_node_new.append(i) 329 | l_node = torch.tensor(l_node_new) 330 | 331 | 332 | 333 | # literal index 334 | layer_mask = g.forward_level == 0 335 | l_node = g.forward_index[layer_mask] 336 | print('Prob: ', output[l_node]) 337 | 338 | sol = g.mask[l_node] 339 | print('Solution: ', sol) 340 | # check the satifiability 341 | sat = pyg_simulation(g, sol)[0] 342 | if sat: 343 | return sol, sat 344 | 345 | print('=====> Step into the backtracking...') 346 | # do the backtracking 347 | while ORDER: 348 | # renew the mask 349 | g.mask = mask_backup.clone().detach() 350 | change_ind = ORDER.pop() 351 | print('Change the values when solving No. ', change_ind.item(), 'PIs') 352 | # literal index 353 | layer_mask = g.forward_level == 0 354 | l_node = g.forward_index[layer_mask] 355 | 356 | for i in range(len(l_node)): 357 | # print('==> # ', i+1, 'solving..') 358 | ret = detector.run(g) 359 | output = ret['results'].cpu() 360 | 361 | # mask 362 | one_mask = torch.zeros(g.y.size(0)) 363 | one_mask = one_mask.scatter(dim=0, index=l_node, src=torch.ones(len(l_node))).unsqueeze(1) 364 | 365 | max_val, max_ind = torch.max(output * one_mask, dim=0) 366 | min_val, min_ind = torch.min(output + (1 - one_mask), dim=0) 367 | 368 | ext_val, ext_ind = (max_val, max_ind) if (max_val > (1 - min_val)) else (min_val, min_ind) 369 | g.mask[ext_ind] = torch.tensor(1.0) if ext_val > 0.5 else torch.tensor(0.0) 370 | # push the current index to Q 371 | if ext_ind == change_ind: 372 | g.mask[ext_ind] = 1 - g.mask[ext_ind] 373 | print('Assign No. ', ext_ind.item(), 'with prob: ', ext_val.item(), 'as value: ', g.mask[ext_ind].item()) 374 | 375 | l_node_new = [] 376 | for i in l_node: 377 | if i != ext_ind: 378 | l_node_new.append(i) 379 | l_node = torch.tensor(l_node_new) 380 | 381 | # literal index 382 | layer_mask = g.forward_level == 0 383 | l_node = g.forward_index[layer_mask] 384 | print('Prob: ', output[l_node]) 385 | 386 | sol = g.mask[l_node] 387 | # check the satifiability 388 | sat = pyg_simulation(g, sol)[0] 389 | print('Solution: ', sol) 390 | if sat: 391 | print('====> Hit the correct solution during the backtracking...') 392 | return sol, sat 393 | else: 394 | print('Wrong..') 395 | 396 | return None, 0 397 | --------------------------------------------------------------------------------