├── .gitattributes
├── LICENSE
├── README.md
├── docker
├── Dockerfile
├── build.sh
└── run.sh
├── exps
├── parity.py
└── sudoku.py
├── images
├── forward_pass.png
├── mnist_sudoku.png
└── poster_forward.png
├── notebooks
└── Learning and Solving Sudoku via SATNet.ipynb
├── requirements.txt
├── satnet
├── __init__.py
└── models.py
├── setup.py
└── src
├── satnet.cpp
├── satnet.h
├── satnet_cpu.cpp
└── satnet_cuda.cu
/.gitattributes:
--------------------------------------------------------------------------------
1 | notebooks/*.ipynb linguist-documentation
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Po-Wei Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SATNet • [![PyPi][pypi-image]][pypi] [![colab][colab-image]][colab] [![License][license-image]][license]
2 |
3 | [license-image]: https://img.shields.io/badge/License-MIT-yellow.svg
4 | [license]: LICENSE
5 |
6 | [pypi-image]: https://img.shields.io/pypi/v/satnet.svg
7 | [pypi]: https://pypi.python.org/pypi/satnet
8 |
9 | [colab-image]: https://colab.research.google.com/assets/colab-badge.svg
10 | [colab]: https://colab.research.google.com/drive/1dRfepPLEE8N6BBZhXz8bbLDcPnRKaOcJ#forceEdit=true&offline=true&sandboxMode=true
11 |
12 | *Bridging deep learning and logical reasoning using a differentiable satisfiability solver.*
13 |
14 | This repository contains the source code to reproduce the experiments in the ICML 2019 paper [SATNet: Bridging deep learning and logical reasoning using a differentiable satisfiability solver](https://arxiv.org/abs/1905.12149) by [Po-Wei Wang](https://powei.tw/), [Priya L. Donti](https://priyadonti.com/), [Bryan Wilder](http://teamcore.usc.edu/people/bryanwilder/default.htm), and [J. Zico Kolter](http://zicokolter.com/).
15 |
16 |
17 | ## What is SATNet
18 |
19 | SATNet is a differentiable (smoothed) maximum satisfiability (MAXSAT) solver that can be integrated into the loop of larger deep learning systems. This (approximate) solver is based upon a fast coordinate descent approach to solving the semidefinite program (SDP) associated with the MAXSAT problem.
20 |
21 | #### How SATNet works
22 |
23 | A SATNet layer takes as input the discrete or probabilistic assignments of known MAXSAT variables, and outputs guesses for the assignments of unknown variables via a MAXSAT SDP relaxation with weights *S*. A schematic depicting the forward pass of this layer is shown below. To obtain the backward pass, we analytically differentiate through the SDP relaxation (see the paper for more details).
24 |
25 | 
26 |
27 | #### Overview of experiments
28 |
29 | We show that by integrating SATNet into end-to-end learning systems, we can learn the logical structure of challenging problems in a minimally supervised fashion. In particular, we show that we can:
30 | * Learn the **parity function** using single-bit supervision (a traditionally hard task for deep networks)
31 | * Learn how to play **9×9 Sudoku (original and permuted)** solely from examples.
32 | * Solve a **"visual Sudoku"** problem that maps images of Sudoku puzzles to their associated logical solutions. (A sample "visual Sudoku" input is shown below.)
33 |
34 |
35 |
36 |
37 |
38 | ## Installation
39 |
40 | ### Via pip
41 | ```bash
42 | pip install satnet
43 | ```
44 |
45 |
46 | ### From source
47 | ```bash
48 | git clone https://github.com/locuslab/SATNet
49 | cd SATNet && python setup.py install
50 | ```
51 |
52 | #### Package Dependencies
53 | ```
54 | conda install -c pytorch tqdm
55 | ```
56 | The package also depends on the nvcc compiler. If it doesn't exist (try nvcc from commandline), you can install it via
57 | ```
58 | conda install -c conda-forge cudatoolkit-dev
59 | ```
60 |
61 |
62 |
63 | ### Via Docker image
64 | ```bash
65 | cd docker
66 | sh ./build.sh
67 | sh ./run.sh
68 | ```
69 |
70 | ## Running experiments
71 | ### Jupyter Notebook and Google Colab
72 | [Jupyter notebook](https://github.com/locuslab/SATNet/blob/master/notebooks/Learning%20and%20Solving%20Sudoku%20via%20SATNet.ipynb)
73 | and [Google Colab](https://colab.research.google.com/drive/1dRfepPLEE8N6BBZhXz8bbLDcPnRKaOcJ#forceEdit=true&offline=true&sandboxMode=true)
74 |
75 | ### Run them manually
76 |
77 | #### Getting the datasets
78 | The [Sudoku dataset](https://powei.tw/sudoku.zip) and [Parity dataset](https://powei.tw/parity.zip) can be downloaded via
79 |
80 | ```bash
81 | wget -cq powei.tw/sudoku.zip && unzip -qq sudoku.zip
82 | wget -cq powei.tw/parity.zip && unzip -qq parity.zip
83 | ```
84 | #### Sudoku experiments (original, permuted, and visual)
85 | ```bash
86 | python exps/sudoku.py
87 | python exps/sudoku.py --perm
88 | python exps/sudoku.py --mnist --batchSz=50
89 | ```
90 |
91 | #### Parity experiments
92 | ```bash
93 | python exps/parity.py --seq=20
94 | python exps/parity.py --seq=40
95 | ```
96 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel
2 | RUN pip install setproctitle
3 |
4 | ARG USER_ID
5 | ARG GROUP_ID
6 | ARG USER_NAME
7 | ARG HOME_DIR
8 |
9 | RUN addgroup --gid ${GROUP_ID} ${USER_NAME} || groupmod -n ${USER_NAME} $(getent group ${GROUP_ID})
10 | RUN apt-get -q update; apt-get -q -y install sudo vim
11 | RUN conda install -y -q jupyter matplotlib
12 | RUN adduser --quiet --disabled-password --system --no-create-home --uid ${USER_ID} --gid ${GROUP_ID} --gecos '' --shell /bin/bash ${USER_NAME}
13 | RUN usermod -d ${HOME_DIR} ${USER_NAME}
14 | RUN adduser --quiet ${USER_NAME} sudo ; echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers
15 | RUN apt-get install -y git
16 |
17 | RUN mkdir -p /data
18 | WORKDIR /data
19 | USER ${USER_NAME}
20 |
--------------------------------------------------------------------------------
/docker/build.sh:
--------------------------------------------------------------------------------
1 | docker image build \
2 | --build-arg USER_ID=$(id -u ${USER}) \
3 | --build-arg GROUP_ID=$(id -g ${USER}) \
4 | --build-arg USER_NAME=$(whoami) \
5 | --build-arg HOME_DIR=$HOME \
6 | -t satnet .
7 |
--------------------------------------------------------------------------------
/docker/run.sh:
--------------------------------------------------------------------------------
1 | DATA_VOLUME="-v $(pwd)/..:/data"
2 | HOME_VOLUME="-v $HOME:$HOME"
3 | docker run --rm --runtime=nvidia -it --net=host --ipc=host ${DATA_VOLUME} ${HOME_VOLUME} --name=satnet satnet
4 |
--------------------------------------------------------------------------------
/exps/parity.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 |
5 | import os
6 | import sys
7 | import csv
8 | import shutil
9 |
10 | import numpy.random as npr
11 |
12 | import torch
13 | import torch.optim as optim
14 | import torch.nn.functional as F
15 | from torch.utils.data import TensorDataset, DataLoader
16 |
17 | import satnet
18 | from tqdm.auto import tqdm
19 |
20 | class CSVLogger(object):
21 | def __init__(self, fname):
22 | self.f = open(fname, 'w')
23 | self.logger = csv.writer(self.f)
24 |
25 | def log(self, fields):
26 | self.logger.writerow(fields)
27 | self.f.flush()
28 |
29 | def main():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--data_dir', type=str, default='parity')
32 | parser.add_argument('--testPct', type=float, default=0.1)
33 | parser.add_argument('--batchSz', type=int, default=100)
34 | parser.add_argument('--testBatchSz', type=int, default=500)
35 | parser.add_argument('--nEpoch', type=int, default=100)
36 | parser.add_argument('--lr', type=float, default=1e-1)
37 | parser.add_argument('--seq', type=int, default=20)
38 | parser.add_argument('--save', type=str)
39 | parser.add_argument('--m', type=int, default=4)
40 | parser.add_argument('--aux', type=int, default=4)
41 | parser.add_argument('--no_cuda', action='store_true')
42 | parser.add_argument('--adam', action='store_true')
43 |
44 | args = parser.parse_args()
45 |
46 | # For debugging: fix the random seed
47 | npr.seed(1)
48 | torch.manual_seed(7)
49 |
50 | args.cuda = not args.no_cuda and torch.cuda.is_available()
51 | if args.cuda:
52 | print('Using', torch.cuda.get_device_name(0))
53 | torch.backends.cudnn.deterministic = True
54 | torch.backends.cudnn.benchmark = False
55 | torch.cuda.init()
56 |
57 | save = 'parity.aux{}-m{}-lr{}-bsz{}'.format(
58 | args.aux, args.m, args.lr, args.batchSz)
59 |
60 | if args.save: save = '{}-{}'.format(args.save, save)
61 | save = os.path.join('logs', save)
62 | if os.path.isdir(save): shutil.rmtree(save)
63 | os.makedirs(save)
64 |
65 | L = args.seq
66 |
67 | with open(os.path.join(args.data_dir, str(L), 'features.pt'), 'rb') as f:
68 | X = torch.load(f).float()
69 | with open(os.path.join(args.data_dir, str(L), 'labels.pt'), 'rb') as f:
70 | Y = torch.load(f).float()
71 |
72 | if args.cuda: X, Y = X.cuda(), Y.cuda()
73 |
74 | N = X.size(0)
75 |
76 | nTrain = int(N*(1-args.testPct))
77 | nTest = N-nTrain
78 |
79 | assert(nTrain % args.batchSz == 0)
80 | assert(nTest % args.testBatchSz == 0)
81 |
82 | train_is_input = torch.IntTensor([1,1,0]).repeat(nTrain,1)
83 | test_is_input = torch.IntTensor([1,1,0]).repeat(nTest,1)
84 | if args.cuda: train_is_input, test_is_input = train_is_input.cuda(), test_is_input.cuda()
85 |
86 | train_set = TensorDataset(X[:nTrain], train_is_input, Y[:nTrain])
87 | test_set = TensorDataset(X[nTrain:], test_is_input, Y[nTrain:])
88 |
89 | model = satnet.SATNet(3, args.m, args.aux, prox_lam=1e-1)
90 | if args.cuda: model = model.cuda()
91 |
92 | if args.adam:
93 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
94 | else:
95 | optimizer = optim.SGD(model.parameters(), lr=args.lr)
96 |
97 | train_logger = CSVLogger(os.path.join(save, 'train.csv'))
98 | test_logger = CSVLogger(os.path.join(save, 'test.csv'))
99 | fields = ['epoch', 'loss', 'err']
100 | train_logger.log(fields)
101 | test_logger.log(fields)
102 |
103 | test(0, model, optimizer, test_logger, test_set, args.testBatchSz)
104 | for epoch in range(1, args.nEpoch+1):
105 | train(epoch, model, optimizer, train_logger, train_set, args.batchSz)
106 | test(epoch, model, optimizer, test_logger, test_set, args.testBatchSz)
107 |
108 | def apply_seq(net, zeros, batch_data, batch_is_inputs, batch_targets):
109 | y = torch.cat([batch_data[:,:2], zeros], dim=1)
110 | y = net(y, batch_is_inputs)
111 | L = batch_data.size(1)
112 | for i in range(L-2):
113 | y = torch.cat([y[:,-1].unsqueeze(1), batch_data[:,i+2].unsqueeze(1), zeros], dim=1)
114 | y = net(((y-0.5).sign()+1)/2, batch_is_inputs)
115 | loss = F.binary_cross_entropy(y[:,-1], batch_targets[:,-1])
116 | return loss, y
117 |
118 | def run(epoch, model, optimizer, logger, dataset, batchSz, to_train):
119 | loss_final, err_final = 0, 0
120 |
121 | loader = DataLoader(dataset, batch_size=batchSz)
122 | tloader = tqdm(enumerate(loader), total=len(loader))
123 |
124 | start = torch.zeros(batchSz, 1)
125 | if next(model.parameters()).is_cuda: start = start.cuda()
126 |
127 | for i,(data,is_input, label) in tloader:
128 | if to_train: optimizer.zero_grad()
129 |
130 | loss, pred = apply_seq(model, start, data, is_input, label)
131 |
132 | if to_train:
133 | loss.backward()
134 | optimizer.step()
135 |
136 | err = computeErr(pred, label)
137 | tloader.set_description('Epoch {} {} Loss {:.4f} Err: {:.4f}'.format(
138 | epoch, ('Train' if to_train else 'Test '), loss.item(), err))
139 | loss_final += loss.item()
140 | err_final += err
141 |
142 | loss_final, err_final = loss_final/len(loader), err_final/len(loader)
143 | logger.log((epoch, loss_final, err_final))
144 |
145 | if not to_train:
146 | print('TESTING SET RESULTS: Average loss: {:.4f} Err: {:.4f}'.format(loss_final, err_final))
147 |
148 | def train(epoch, model, optimizer, logger, dataset, batchSz):
149 | run(epoch, model, optimizer, logger, dataset, batchSz, True)
150 |
151 | @torch.no_grad()
152 | def test(epoch, model, optimizer, logger, dataset, batchSz):
153 | run(epoch, model, optimizer, logger, dataset, batchSz, False)
154 |
155 | @torch.no_grad()
156 | def computeErr(pred, target):
157 | y = (pred[:,-1]-0.5)
158 | t = (target[:,-1]-0.5)
159 | correct = ((y * t).sign()+1.)/2
160 | acc = correct.sum().float()/target.size(0)
161 |
162 | return 1-float(acc)
163 |
164 | if __name__ == '__main__':
165 | main()
166 |
--------------------------------------------------------------------------------
/exps/sudoku.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Partly derived from:
4 | # https://github.com/locuslab/optnet/blob/master/sudoku/train.py
5 |
6 | import argparse
7 |
8 | import os
9 | import shutil
10 | import csv
11 |
12 | import numpy as np
13 | import numpy.random as npr
14 | #import setproctitle
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.optim as optim
19 | import torch.nn.functional as F
20 | from torch.utils.data import TensorDataset, DataLoader
21 | from tqdm.auto import tqdm
22 |
23 | import satnet
24 |
25 | class SudokuSolver(nn.Module):
26 | def __init__(self, boardSz, aux, m):
27 | super(SudokuSolver, self).__init__()
28 | n = boardSz**6
29 | self.sat = satnet.SATNet(n, m, aux)
30 |
31 | def forward(self, y_in, mask):
32 | out = self.sat(y_in, mask)
33 | return out
34 |
35 | class DigitConv(nn.Module):
36 | '''
37 | Convolutional neural network for MNIST digit recognition. From:
38 | https://github.com/pytorch/examples/blob/master/mnist/main.py
39 | '''
40 | def __init__(self):
41 | super(DigitConv, self).__init__()
42 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
43 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
44 | self.fc1 = nn.Linear(4*4*50, 500)
45 | self.fc2 = nn.Linear(500, 10)
46 |
47 | def forward(self, x):
48 | x = F.relu(self.conv1(x))
49 | x = F.max_pool2d(x, 2, 2)
50 | x = F.relu(self.conv2(x))
51 | x = F.max_pool2d(x, 2, 2)
52 | x = x.view(-1, 4*4*50)
53 | x = F.relu(self.fc1(x))
54 | x = self.fc2(x)
55 | return F.softmax(x, dim=1)[:,:9].contiguous()
56 |
57 | class MNISTSudokuSolver(nn.Module):
58 | def __init__(self, boardSz, aux, m):
59 | super(MNISTSudokuSolver, self).__init__()
60 | self.digit_convnet = DigitConv()
61 | self.sudoku_solver = SudokuSolver(boardSz, aux, m)
62 | self.boardSz = boardSz
63 | self.nSq = boardSz**2
64 |
65 | def forward(self, x, is_inputs):
66 | nBatch = x.shape[0]
67 | x = x.flatten(start_dim = 0, end_dim = 1)
68 | digit_guess = self.digit_convnet(x)
69 | puzzles = digit_guess.view(nBatch, self.nSq * self.nSq * self.nSq)
70 |
71 | solution = self.sudoku_solver(puzzles, is_inputs)
72 | return solution
73 |
74 | class CSVLogger(object):
75 | def __init__(self, fname):
76 | self.f = open(fname, 'w')
77 | self.logger = csv.writer(self.f)
78 |
79 | def log(self, fields):
80 | self.logger.writerow(fields)
81 | self.f.flush()
82 |
83 | class FigLogger(object):
84 | def __init__(self, fig, base_ax, title):
85 | self.colors = ['tab:red', 'tab:blue']
86 | self.labels = ['Loss (entropy)', 'Error']
87 | self.markers = ['d', '.']
88 | self.axes = [base_ax, base_ax.twinx()]
89 | base_ax.set_xlabel('Epochs')
90 | base_ax.set_title(title)
91 |
92 | for i, ax in enumerate(self.axes):
93 | ax.set_ylabel(self.labels[i], color=self.colors[i])
94 | ax.tick_params(axis='y', labelcolor=self.colors[i])
95 |
96 | self.reset()
97 | self.fig = fig
98 |
99 | def log(self, args):
100 | for i, arg in enumerate(args[-2:]):
101 | self.curves[i].append(arg)
102 | x = list(range(len(self.curves[i])))
103 | self.axes[i].plot(x, self.curves[i], self.colors[i], marker=self.markers[i])
104 | self.axes[i].set_ylim(0, 1.05)
105 |
106 | self.fig.canvas.draw()
107 |
108 | def reset(self):
109 | for ax in self.axes:
110 | for line in ax.lines:
111 | line.remove()
112 | self.curves = [[], []]
113 |
114 | def print_header(msg):
115 | print('===>', msg)
116 |
117 | def find_unperm(perm):
118 | unperm = torch.zeros_like(perm)
119 | for i in range(perm.size(0)):
120 | unperm[perm[i]] = i
121 | return unperm
122 |
123 | def main():
124 | parser = argparse.ArgumentParser()
125 | parser.add_argument('--data_dir', type=str, default='sudoku')
126 | parser.add_argument('--boardSz', type=int, default=3)
127 | parser.add_argument('--batchSz', type=int, default=40)
128 | parser.add_argument('--testBatchSz', type=int, default=40)
129 | parser.add_argument('--aux', type=int, default=300)
130 | parser.add_argument('--m', type=int, default=600)
131 | parser.add_argument('--nEpoch', type=int, default=100)
132 | parser.add_argument('--testPct', type=float, default=0.1)
133 | parser.add_argument('--lr', type=float, default=2e-3)
134 | parser.add_argument('--save', type=str)
135 | parser.add_argument('--model', type=str)
136 | parser.add_argument('--no_cuda', action='store_true')
137 | parser.add_argument('--mnist', action='store_true')
138 | parser.add_argument('--perm', action='store_true')
139 |
140 | args = parser.parse_args()
141 |
142 | # For debugging: fix the random seed
143 | npr.seed(1)
144 | torch.manual_seed(7)
145 |
146 | args.cuda = not args.no_cuda and torch.cuda.is_available()
147 | if args.cuda:
148 | print('Using', torch.cuda.get_device_name(0))
149 | torch.backends.cudnn.deterministic = True
150 | torch.backends.cudnn.benchmark = False
151 | torch.cuda.init()
152 |
153 | save = 'sudoku{}{}.boardSz{}-aux{}-m{}-lr{}-bsz{}'.format(
154 | '.perm' if args.perm else '', '.mnist' if args.mnist else '',
155 | args.boardSz, args.aux, args.m, args.lr, args.batchSz)
156 | if args.save: save = '{}-{}'.format(args.save, save)
157 | save = os.path.join('logs', save)
158 | if os.path.isdir(save): shutil.rmtree(save)
159 | os.makedirs(save)
160 |
161 | #setproctitle.setproctitle('sudoku.{}'.format(save))
162 |
163 | print_header('Loading data')
164 |
165 | with open(os.path.join(args.data_dir, 'features.pt'), 'rb') as f:
166 | X_in = torch.load(f)
167 | with open(os.path.join(args.data_dir, 'features_img.pt'), 'rb') as f:
168 | Ximg_in = torch.load(f)
169 | with open(os.path.join(args.data_dir, 'labels.pt'), 'rb') as f:
170 | Y_in = torch.load(f)
171 | with open(os.path.join(args.data_dir, 'perm.pt'), 'rb') as f:
172 | perm = torch.load(f)
173 |
174 | N = X_in.size(0)
175 | nTrain = int(N*(1.-args.testPct))
176 | nTest = N-nTrain
177 | assert(nTrain % args.batchSz == 0)
178 | assert(nTest % args.testBatchSz == 0)
179 |
180 | print_header('Forming inputs')
181 | X, Ximg, Y, is_input = process_inputs(X_in, Ximg_in, Y_in, args.boardSz)
182 | data = Ximg if args.mnist else X
183 | if args.cuda: data, is_input, Y = data.cuda(), is_input.cuda(), Y.cuda()
184 |
185 | unperm = None
186 | if args.perm and not args.mnist:
187 | print('Applying permutation')
188 | data[:,:], Y[:,:], is_input[:,:] = data[:,perm], Y[:,perm], is_input[:,perm]
189 | unperm = find_unperm(perm)
190 |
191 | train_set = TensorDataset(data[:nTrain], is_input[:nTrain], Y[:nTrain])
192 | test_set = TensorDataset(data[nTrain:], is_input[nTrain:], Y[nTrain:])
193 |
194 | print_header('Building model')
195 | if args.mnist:
196 | model = MNISTSudokuSolver(args.boardSz, args.aux, args.m)
197 | else:
198 | model = SudokuSolver(args.boardSz, args.aux, args.m)
199 |
200 | if args.cuda: model = model.cuda()
201 |
202 | if args.mnist:
203 | optimizer = optim.Adam([
204 | {'params': model.sudoku_solver.parameters(), 'lr': args.lr},
205 | {'params': model.digit_convnet.parameters(), 'lr': 1e-5},
206 | ])
207 | else:
208 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
209 |
210 | if args.model:
211 | model.load_state_dict(torch.load(args.model))
212 |
213 | train_logger = CSVLogger(os.path.join(save, 'train.csv'))
214 | test_logger = CSVLogger(os.path.join(save, 'test.csv'))
215 | fields = ['epoch', 'loss', 'err']
216 | train_logger.log(fields)
217 | test_logger.log(fields)
218 |
219 | test(args.boardSz, 0, model, optimizer, test_logger, test_set, args.testBatchSz, unperm)
220 | for epoch in range(1, args.nEpoch+1):
221 | train(args.boardSz, epoch, model, optimizer, train_logger, train_set, args.batchSz, unperm)
222 | test(args.boardSz, epoch, model, optimizer, test_logger, test_set, args.testBatchSz, unperm)
223 | #torch.save(model.state_dict(), os.path.join(save, 'it'+str(epoch)+'.pth'))
224 |
225 | def process_inputs(X, Ximg, Y, boardSz):
226 | is_input = X.sum(dim=3, keepdim=True).expand_as(X).int().sign()
227 |
228 | Ximg = Ximg.flatten(start_dim=1, end_dim=2)
229 | Ximg = Ximg.unsqueeze(2).float()
230 |
231 | X = X.view(X.size(0), -1)
232 | Y = Y.view(Y.size(0), -1)
233 | is_input = is_input.view(is_input.size(0), -1)
234 |
235 | return X, Ximg, Y, is_input
236 |
237 | def run(boardSz, epoch, model, optimizer, logger, dataset, batchSz, to_train=False, unperm=None):
238 |
239 | loss_final, err_final = 0, 0
240 |
241 | loader = DataLoader(dataset, batch_size=batchSz)
242 | tloader = tqdm(enumerate(loader), total=len(loader))
243 |
244 | for i,(data,is_input,label) in tloader:
245 | if to_train: optimizer.zero_grad()
246 | preds = model(data.contiguous(), is_input.contiguous())
247 | loss = nn.functional.binary_cross_entropy(preds, label)
248 |
249 | if to_train:
250 | loss.backward()
251 | optimizer.step()
252 |
253 | err = computeErr(preds.data, boardSz, unperm)/batchSz
254 | tloader.set_description('Epoch {} {} Loss {:.4f} Err: {:.4f}'.format(epoch, ('Train' if to_train else 'Test '), loss.item(), err))
255 | loss_final += loss.item()
256 | err_final += err
257 |
258 | loss_final, err_final = loss_final/len(loader), err_final/len(loader)
259 | logger.log((epoch, loss_final, err_final))
260 |
261 | if not to_train:
262 | print('TESTING SET RESULTS: Average loss: {:.4f} Err: {:.4f}'.format(loss_final, err_final))
263 |
264 | #print('memory: {:.2f} MB, cached: {:.2f} MB'.format(torch.cuda.memory_allocated()/2.**20, torch.cuda.memory_cached()/2.**20))
265 | torch.cuda.empty_cache()
266 |
267 | def train(args, epoch, model, optimizer, logger, dataset, batchSz, unperm=None):
268 | run(args, epoch, model, optimizer, logger, dataset, batchSz, True, unperm)
269 |
270 | @torch.no_grad()
271 | def test(args, epoch, model, optimizer, logger, dataset, batchSz, unperm=None):
272 | run(args, epoch, model, optimizer, logger, dataset, batchSz, False, unperm)
273 |
274 | @torch.no_grad()
275 | def computeErr(pred_flat, n, unperm):
276 | if unperm is not None: pred_flat[:,:] = pred_flat[:,unperm]
277 |
278 | nsq = n ** 2
279 | pred = pred_flat.view(-1, nsq, nsq, nsq)
280 |
281 | batchSz = pred.size(0)
282 | s = (nsq-1)*nsq//2 # 0 + 1 + ... + n^2-1
283 | I = torch.max(pred, 3)[1].squeeze().view(batchSz, nsq, nsq)
284 |
285 | def invalidGroups(x):
286 | valid = (x.min(1)[0] == 0)
287 | valid *= (x.max(1)[0] == nsq-1)
288 | valid *= (x.sum(1) == s)
289 | return valid.bitwise_not()
290 |
291 | boardCorrect = torch.ones(batchSz).type_as(pred)
292 | for j in range(nsq):
293 | # Check the jth row and column.
294 | boardCorrect[invalidGroups(I[:,j,:])] = 0
295 | boardCorrect[invalidGroups(I[:,:,j])] = 0
296 |
297 | # Check the jth block.
298 | row, col = n*(j // n), n*(j % n)
299 | M = invalidGroups(I[:,row:row+n,col:col+n].contiguous().view(batchSz,-1))
300 | boardCorrect[M] = 0
301 |
302 | if boardCorrect.sum() == 0:
303 | return batchSz
304 |
305 | return float(batchSz-boardCorrect.sum())
306 |
307 | if __name__=='__main__':
308 | main()
309 |
--------------------------------------------------------------------------------
/images/forward_pass.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/SATNet/50897688eae47bf765c1d9ed9a7c6f5419d62a9a/images/forward_pass.png
--------------------------------------------------------------------------------
/images/mnist_sudoku.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/SATNet/50897688eae47bf765c1d9ed9a7c6f5419d62a9a/images/mnist_sudoku.png
--------------------------------------------------------------------------------
/images/poster_forward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/SATNet/50897688eae47bf765c1d9ed9a7c6f5419d62a9a/images/poster_forward.png
--------------------------------------------------------------------------------
/notebooks/Learning and Solving Sudoku via SATNet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import shutil\n",
11 | "import argparse\n",
12 | "from collections import namedtuple\n",
13 | "\n",
14 | "import numpy as np\n",
15 | "import numpy.random as npr\n",
16 | "\n",
17 | "import torch\n",
18 | "import torch.nn as nn\n",
19 | "import torch.optim as optim\n",
20 | "import torch.nn.functional as F\n",
21 | "from torch.utils.data import TensorDataset, DataLoader\n",
22 | "\n",
23 | "%matplotlib inline\n",
24 | "import matplotlib.pyplot as plt\n",
25 | "from IPython.display import display, Markdown, Latex, clear_output"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "# Introduction to SATNet"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "SATNet is a differentiable (smoothed) maximum satisfiability (MAXSAT) solver that can be integrated into the loop of larger deep learning systems. Our (approximate) solver is based upon a fast coordinate descent approach to solving the semidefinite program (SDP) associated with the MAXSAT problem.\n",
40 | "\n",
41 | "The code below reproduces the Sudoku experiments from our paper \"SATNet: Bridging deep learning and logical reasoning using a differentiable satisfiability solver.\" These experiments show that by integrating the SATNet solver into end-to-end learning systems, we can learn the logical structure of challenging problems in a minimally supervised fashion. In particular, this notebook shows how we can learn to:\n",
42 | "* Play **9×9 Sudoku (original and permuted)** solely from examples.\n",
43 | "* Solve a **\"visual Sudoku\"** problem that maps images of Sudoku puzzles to their associated logical solutions. \n",
44 | "\n",
45 | "For more details and discussion about these experiments, please see the [SATNet paper](https://icml.cc/Conferences/2019/Schedule?showEvent=3947)."
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 2,
51 | "metadata": {},
52 | "outputs": [
53 | {
54 | "name": "stdout",
55 | "output_type": "stream",
56 | "text": [
57 | "SATNet document\n",
58 | " Apply a SATNet layer to complete the input probabilities.\n",
59 | "\n",
60 | " Args:\n",
61 | " n: Number of input variables.\n",
62 | " m: Rank of the clause matrix.\n",
63 | " aux: Number of auxiliary variables.\n",
64 | "\n",
65 | " max_iter: Maximum number of iterations for solving\n",
66 | " the inner optimization problem.\n",
67 | " Default: 40\n",
68 | " eps: The stopping threshold for the inner optimizaiton problem.\n",
69 | " The inner Mixing method will stop when the function decrease\n",
70 | " is less then eps times the initial function decrease.\n",
71 | " Default: 1e-4\n",
72 | " prox_lam: The diagonal increment in the backward linear system\n",
73 | " to make the backward pass more stable.\n",
74 | " Default: 1e-2\n",
75 | " weight_normalize: Set true to perform normlization for init weights.\n",
76 | " Default: True\n",
77 | "\n",
78 | " Inputs: (z, is_input)\n",
79 | " **z** of shape `(batch, n)`: \n",
80 | " Float tensor containing the probabilities (must be in [0,1]).\n",
81 | " **is_input** of shape `(batch, n)`: \n",
82 | " Int tensor indicating which **z** is a input.\n",
83 | "\n",
84 | " Outputs: z\n",
85 | " **z** of shape `(batch, n)`:\n",
86 | " The prediction probabiolities.\n",
87 | "\n",
88 | " Attributes: S\n",
89 | " **S** of shape `(n, m)`:\n",
90 | " The learnable clauses matrix containing `m` clauses \n",
91 | " for the `n` variables.\n",
92 | "\n",
93 | " Examples:\n",
94 | " >>> sat = satnet.SATNet(3, 4, aux=5)\n",
95 | " >>> z = torch.randn(2, 3)\n",
96 | " >>> is_input = torch.IntTensor([[1, 1, 0], [1,0,1]])\n",
97 | " >>> pred = sat(z, is_input)\n",
98 | " \n"
99 | ]
100 | }
101 | ],
102 | "source": [
103 | "import satnet\n",
104 | "print('SATNet document\\n', satnet.SATNet.__doc__)"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {},
110 | "source": [
111 | "# Building SATNet-based Models"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | "To solve **Sudoku** and a **permuted version of Sudoku**: We construct a SATNet-based SudokuSolver layer that takes as input a logical (bit) representation of the initial Sudoku board along with a mask representing which bits must be learned (i.e. all bits in empty Sudoku cells). This input is vectorized. Given this input, the SudokuSolver layer then outputs a bit representation of the Sudoku board with guesses for the unknown bits."
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 3,
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "class SudokuSolver(nn.Module):\n",
128 | " def __init__(self, boardSz, aux, m):\n",
129 | " super(SudokuSolver, self).__init__()\n",
130 | " n = boardSz**6\n",
131 | " self.sat = satnet.SATNet(n, m, aux)\n",
132 | "\n",
133 | " def forward(self, y_in, mask):\n",
134 | " out = self.sat(y_in, mask)\n",
135 | " del y_in, mask\n",
136 | " return out"
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "To solve **\"visual Sudoku\"**: We construct a (standard) convolutional neural network for MNIST digit recognition and train it end-to-end with our SudokuSolver layer. This architecture takes in an image representation of a Sudoku board constructed with MNIST digits. Each MNIST digit is classified by the convolutional network, and the resulting (estimated) logical representation of the initial Sudoku board is then fed as input to the SudokuSolver layer. (As described earlier, the SudokuSolver layer then outputs a bit representation of the Sudoku board with guesses for the unknown bits.)"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 4,
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "class DigitConv(nn.Module):\n",
153 | " '''\n",
154 | " Convolutional neural network for MNIST digit recognition. From:\n",
155 | " https://github.com/pytorch/examples/blob/master/mnist/main.py\n",
156 | " '''\n",
157 | " def __init__(self):\n",
158 | " super(DigitConv, self).__init__()\n",
159 | " self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
160 | " self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
161 | " self.fc1 = nn.Linear(4*4*50, 500)\n",
162 | " self.fc2 = nn.Linear(500, 10)\n",
163 | "\n",
164 | " def forward(self, x):\n",
165 | " x = F.relu(self.conv1(x))\n",
166 | " x = F.max_pool2d(x, 2, 2)\n",
167 | " x = F.relu(self.conv2(x))\n",
168 | " x = F.max_pool2d(x, 2, 2)\n",
169 | " x = x.view(-1, 4*4*50)\n",
170 | " x = F.relu(self.fc1(x))\n",
171 | " x = self.fc2(x)\n",
172 | " return F.softmax(x, dim=1)[:,:9].contiguous()\n",
173 | "\n",
174 | "class MNISTSudokuSolver(nn.Module):\n",
175 | " def __init__(self, boardSz, aux, m):\n",
176 | " super(MNISTSudokuSolver, self).__init__()\n",
177 | " self.digit_convnet = DigitConv()\n",
178 | " self.sudoku_solver = SudokuSolver(boardSz, aux, m)\n",
179 | " self.boardSz = boardSz\n",
180 | " self.nSq = boardSz**2\n",
181 | " \n",
182 | " def forward(self, x, is_inputs):\n",
183 | " nBatch = x.shape[0]\n",
184 | " x = x.flatten(start_dim = 0, end_dim = 1)\n",
185 | " digit_guess = self.digit_convnet(x)\n",
186 | " puzzles = digit_guess.view(nBatch, self.nSq * self.nSq * self.nSq)\n",
187 | "\n",
188 | " solution = self.sudoku_solver(puzzles, is_inputs)\n",
189 | " return solution"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "metadata": {},
195 | "source": [
196 | "The experimental parameters we use in the paper are below."
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 5,
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "from exps.sudoku import train, test, FigLogger, find_unperm\n",
206 | "args_dict = {'lr': 2e-3, \n",
207 | " 'cuda': torch.cuda.is_available(), \n",
208 | " 'batchSz': 40,\n",
209 | " 'mnistBatchSz': 50,\n",
210 | " 'boardSz': 3, # for 9x9 Sudoku\n",
211 | " 'm': 600,\n",
212 | " 'aux': 300,\n",
213 | " 'nEpoch': 100\n",
214 | " }\n",
215 | "args = namedtuple('Args', args_dict.keys())(*args_dict.values())"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {},
221 | "source": [
222 | "# The Sudoku Datasets"
223 | ]
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "metadata": {},
228 | "source": [
229 | "We use and/or create the following datasets:\n",
230 | "* **Sudoku:** We generate 10K 9x9 Sudoku boards (9K test/1K train) using code available [here](https://github.com/Kyubyong/sudoku) and represent them via bit (one-hot) representations.\n",
231 | "* **Permuted Sudoku:** We apply a fixed permutation to the 10K Sudoku board bit representations generated for the Sudoku experiment.\n",
232 | "* **Visual Sudoku:** We construct versions of the 10K Sudoku boards generated for the Sudoku experiment in which each board cell is represented by a (randomly-selected) MNIST digit. (MNIST digits are also split into train/test sets, with train and test MNIST digits applied only to train and test Sudoku boards, respectively.)\n",
233 | "\n",
234 | "The code below reads and processes these datasets for use with the architectures constructed above. A sample Sudoku board, its associated bit representation, and its associated MNIST representation are displayed below."
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": 6,
240 | "metadata": {},
241 | "outputs": [],
242 | "source": [
243 | "def process_inputs(X, Ximg, Y, boardSz):\n",
244 | " is_input = X.sum(dim=3, keepdim=True).expand_as(X).int().sign()\n",
245 | "\n",
246 | " Ximg = Ximg.flatten(start_dim=1, end_dim=2)\n",
247 | " Ximg = Ximg.unsqueeze(2).float()\n",
248 | "\n",
249 | " X = X.view(X.size(0), -1)\n",
250 | " Y = Y.view(Y.size(0), -1)\n",
251 | " is_input = is_input.view(is_input.size(0), -1)\n",
252 | "\n",
253 | " return X, Ximg, Y, is_input\n",
254 | "\n",
255 | "with open('sudoku/features.pt', 'rb') as f:\n",
256 | " X_in = torch.load(f)\n",
257 | "with open('sudoku/features_img.pt', 'rb') as f:\n",
258 | " Ximg_in = torch.load(f)\n",
259 | "with open('sudoku/labels.pt', 'rb') as f:\n",
260 | " Y_in = torch.load(f)\n",
261 | "with open('sudoku/perm.pt', 'rb') as f:\n",
262 | " perm = torch.load(f)\n",
263 | "\n",
264 | "X, Ximg, Y, is_input = process_inputs(X_in, Ximg_in, Y_in, args.boardSz)\n",
265 | "if args.cuda: X, Ximg, is_input, Y = X.cuda(), Ximg.cuda(), is_input.cuda(), Y.cuda()\n",
266 | "\n",
267 | "N = X_in.size(0)\n",
268 | "nTrain = int(N*0.9)\n",
269 | "\n",
270 | "sudoku_train = TensorDataset(X[:nTrain], is_input[:nTrain], Y[:nTrain])\n",
271 | "sudoku_test = TensorDataset(X[nTrain:], is_input[nTrain:], Y[nTrain:])\n",
272 | "perm_train = TensorDataset(X[:nTrain,perm], is_input[:nTrain,perm], Y[:nTrain,perm])\n",
273 | "perm_test = TensorDataset(X[nTrain:,perm], is_input[nTrain:,perm], Y[nTrain:,perm])\n",
274 | "mnist_train = TensorDataset(Ximg[:nTrain], is_input[:nTrain], Y[:nTrain])\n",
275 | "mnist_test = TensorDataset(Ximg[nTrain:], is_input[nTrain:], Y[nTrain:])"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 7,
281 | "metadata": {},
282 | "outputs": [
283 | {
284 | "data": {
285 | "text/markdown": [
286 | "## Sudoku"
287 | ],
288 | "text/plain": [
289 | ""
290 | ]
291 | },
292 | "metadata": {},
293 | "output_type": "display_data"
294 | },
295 | {
296 | "name": "stdout",
297 | "output_type": "stream",
298 | "text": [
299 | "tensor([[6, 7, 0, 0, 0, 0, 0, 0, 5],\n",
300 | " [0, 0, 3, 0, 4, 0, 0, 8, 2],\n",
301 | " [0, 4, 0, 0, 0, 5, 1, 3, 6],\n",
302 | " [0, 0, 0, 7, 3, 0, 0, 9, 0],\n",
303 | " [3, 0, 4, 2, 0, 6, 0, 7, 0],\n",
304 | " [0, 0, 1, 0, 9, 0, 6, 0, 0],\n",
305 | " [5, 0, 9, 0, 0, 8, 0, 0, 0],\n",
306 | " [0, 0, 0, 9, 5, 0, 2, 0, 8],\n",
307 | " [0, 0, 0, 1, 2, 7, 4, 0, 0]])\n",
308 | "\n"
309 | ]
310 | },
311 | {
312 | "data": {
313 | "text/markdown": [
314 | "## One-hot encoded Boolean Sudoku"
315 | ],
316 | "text/plain": [
317 | ""
318 | ]
319 | },
320 | "metadata": {},
321 | "output_type": "display_data"
322 | },
323 | {
324 | "name": "stdout",
325 | "output_type": "stream",
326 | "text": [
327 | "tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
328 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
329 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
330 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
331 | " 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
332 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n",
333 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
334 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
335 | " 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,\n",
336 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
337 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
338 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
339 | " 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n",
340 | " 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
341 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
342 | " 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n",
343 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
344 | " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
345 | " 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
346 | " 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,\n",
347 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n",
348 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
349 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
350 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
351 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,\n",
352 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n",
353 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
354 | " 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
355 | " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
356 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
357 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
358 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
359 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
360 | " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
361 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,\n",
362 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
363 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
364 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
365 | " 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
366 | " 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
367 | " 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')\n"
368 | ]
369 | },
370 | {
371 | "data": {
372 | "text/markdown": [
373 | "## MNIST Sudoku"
374 | ],
375 | "text/plain": [
376 | ""
377 | ]
378 | },
379 | "metadata": {},
380 | "output_type": "display_data"
381 | },
382 | {
383 | "data": {
384 | "image/png": "\n",
385 | "text/plain": [
386 | ""
387 | ]
388 | },
389 | "metadata": {
390 | "needs_background": "light"
391 | },
392 | "output_type": "display_data"
393 | }
394 | ],
395 | "source": [
396 | "def show_sudoku(raw):\n",
397 | " return (torch.argmax(raw,2)+1)*(raw.sum(2).long())\n",
398 | "\n",
399 | "def show_mnist_sudoku(raw):\n",
400 | " A = raw.numpy()\n",
401 | " digits = np.concatenate(np.concatenate(A,axis=1), axis=1).astype(np.uint8)\n",
402 | " linewidth = 2\n",
403 | " board = np.zeros((digits.shape[0]+linewidth*4, digits.shape[1]+linewidth*4), dtype=np.uint8)\n",
404 | " gridwidth = digits.shape[0]//3\n",
405 | "\n",
406 | " board[:] = 255\n",
407 | " for i in range(3):\n",
408 | " for j in range(3):\n",
409 | " xoff = linewidth+(linewidth+gridwidth)*i\n",
410 | " yoff = linewidth+(linewidth+gridwidth)*j\n",
411 | " xst = gridwidth*i\n",
412 | " yst = gridwidth*j\n",
413 | " board[xoff:xoff+gridwidth, yoff:yoff+gridwidth] = digits[xst:xst+gridwidth, yst:yst+gridwidth]\n",
414 | "\n",
415 | " #img = Image.fromarray(255-board)\n",
416 | " plt.imshow(255-board, cmap='gray')\n",
417 | "\n",
418 | "display(Markdown('## Sudoku'))\n",
419 | "print(show_sudoku(X_in[0]))\n",
420 | "print()\n",
421 | "display(Markdown('## One-hot encoded Boolean Sudoku'))\n",
422 | "print(X[0])\n",
423 | " \n",
424 | "display(Markdown('## MNIST Sudoku'))\n",
425 | "show_mnist_sudoku(Ximg_in[0])"
426 | ]
427 | },
428 | {
429 | "cell_type": "markdown",
430 | "metadata": {},
431 | "source": [
432 | "# The 9x9 Sudoku Experiment"
433 | ]
434 | },
435 | {
436 | "cell_type": "markdown",
437 | "metadata": {},
438 | "source": [
439 | "The results for our 9x9 Sudoku experiment are below. In this experiment, we:\n",
440 | "* **Input** a logical (bit) representation of the initial (unsolved) Sudoku board along with a mask representing which bits must be learned (i.e. all bits in empty Sudoku cells). This input is vectorized, which means that our SATNet model cannot exploit the locality structure of the input Sudoku grid when learning to solve puzzles.\n",
441 | "* **Output** a bit representation of the Sudoku board with guesses for the unknown bits."
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": 8,
447 | "metadata": {},
448 | "outputs": [],
449 | "source": [
450 | "%%capture\n",
451 | "sudoku_model = SudokuSolver(args.boardSz, args.aux, args.m)\n",
452 | "if args.cuda: sudoku_model = sudoku_model.cuda()\n",
453 | " \n",
454 | "optimizer = optim.Adam(sudoku_model.parameters(), lr=args.lr)\n",
455 | "\n",
456 | "fig, axes = plt.subplots(1,2, figsize=(10,4))\n",
457 | "plt.subplots_adjust(wspace=0.4)\n",
458 | "train_logger = FigLogger(fig, axes[0], 'Traininig')\n",
459 | "test_logger = FigLogger(fig, axes[1], 'Testing')"
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": 9,
465 | "metadata": {},
466 | "outputs": [
467 | {
468 | "data": {
469 | "image/png": "\n",
470 | "text/plain": [
471 | ""
472 | ]
473 | },
474 | "metadata": {},
475 | "output_type": "display_data"
476 | }
477 | ],
478 | "source": [
479 | "test(args.boardSz, 0, sudoku_model, optimizer, test_logger, sudoku_test, args.batchSz)\n",
480 | "for epoch in range(1, args.nEpoch+1):\n",
481 | " train(args.boardSz, epoch, sudoku_model, optimizer, train_logger, sudoku_train, args.batchSz)\n",
482 | " test(args.boardSz, epoch, sudoku_model, optimizer, test_logger, sudoku_test, args.batchSz)\n",
483 | " clear_output()\n",
484 | " display(fig)"
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "metadata": {},
490 | "source": [
491 | "# The Permuted 9x9 Sudoku Experiment"
492 | ]
493 | },
494 | {
495 | "cell_type": "markdown",
496 | "metadata": {},
497 | "source": [
498 | "The results for our permuted 9x9 Sudoku experiment are below. In this experiment, we:\n",
499 | "* **Input** the same inputs as in the original 9x9 Sudoku experiment, but with a fixed permutation applied.\n",
500 | "* **Output** a bit representation of the permuted Sudoku board with guesses for the unknown bits."
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": 10,
506 | "metadata": {},
507 | "outputs": [],
508 | "source": [
509 | "%%capture\n",
510 | "perm_model = SudokuSolver(args.boardSz, args.aux, args.m)\n",
511 | "if args.cuda: perm_model = perm_model.cuda()\n",
512 | " \n",
513 | "optimizer = optim.Adam(perm_model.parameters(), lr=args.lr)\n",
514 | "\n",
515 | "fig, axes = plt.subplots(1,2, figsize=(10,4))\n",
516 | "plt.subplots_adjust(wspace=0.4)\n",
517 | "train_logger = FigLogger(fig, axes[0], 'Traininig')\n",
518 | "test_logger = FigLogger(fig, axes[1], 'Testing')"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": 11,
524 | "metadata": {},
525 | "outputs": [
526 | {
527 | "data": {
528 | "image/png": "\n",
529 | "text/plain": [
530 | ""
531 | ]
532 | },
533 | "metadata": {},
534 | "output_type": "display_data"
535 | }
536 | ],
537 | "source": [
538 | "unperm = find_unperm(perm)\n",
539 | "test(args.boardSz, 0, perm_model, optimizer, test_logger, perm_test, args.batchSz, unperm)\n",
540 | "for epoch in range(1, args.nEpoch+1):\n",
541 | " train(args.boardSz, epoch, perm_model, optimizer, train_logger, perm_train, args.batchSz, unperm)\n",
542 | " test(args.boardSz, epoch, perm_model, optimizer, test_logger, perm_test, args.batchSz, unperm)\n",
543 | " clear_output()\n",
544 | " display(fig)"
545 | ]
546 | },
547 | {
548 | "cell_type": "markdown",
549 | "metadata": {},
550 | "source": [
551 | "# The End-to-End MNIST Sudoku (\"Visual Sudoku\") Experiment"
552 | ]
553 | },
554 | {
555 | "cell_type": "markdown",
556 | "metadata": {},
557 | "source": [
558 | "The results for our permuted 9x9 Sudoku experiment are below. In this experiment, we:\n",
559 | "* **Input** an image representation of the initial (unsolved) Sudoku board.\n",
560 | "* **Output** a bit representation of the Sudoku board with guesses for the unknown bits."
561 | ]
562 | },
563 | {
564 | "cell_type": "code",
565 | "execution_count": 12,
566 | "metadata": {},
567 | "outputs": [],
568 | "source": [
569 | "%%capture\n",
570 | "mnist_sudoku = MNISTSudokuSolver(args.boardSz, args.aux, args.m)\n",
571 | "if args.cuda: mnist_sudoku = mnist_sudoku.cuda()\n",
572 | " \n",
573 | "optimizer = optim.Adam([\n",
574 | " {'params': mnist_sudoku.sudoku_solver.parameters(), 'lr': args.lr},\n",
575 | " {'params': mnist_sudoku.digit_convnet.parameters(), 'lr': 1e-5},\n",
576 | " ])\n",
577 | "\n",
578 | "fig, axes = plt.subplots(1,2, figsize=(10,4))\n",
579 | "plt.subplots_adjust(wspace=0.4)\n",
580 | "train_logger = FigLogger(fig, axes[0], 'Traininig')\n",
581 | "test_logger = FigLogger(fig, axes[1], 'Testing')"
582 | ]
583 | },
584 | {
585 | "cell_type": "code",
586 | "execution_count": 13,
587 | "metadata": {},
588 | "outputs": [
589 | {
590 | "data": {
591 | "image/png": "\n",
592 | "text/plain": [
593 | ""
594 | ]
595 | },
596 | "metadata": {},
597 | "output_type": "display_data"
598 | }
599 | ],
600 | "source": [
601 | "test(args.boardSz, 0, mnist_sudoku, optimizer, test_logger, mnist_test, args.mnistBatchSz)\n",
602 | "for epoch in range(1, args.nEpoch+1):\n",
603 | " train(args.boardSz, epoch, mnist_sudoku, optimizer, train_logger, mnist_train, args.mnistBatchSz)\n",
604 | " test(args.boardSz, epoch, mnist_sudoku, optimizer, test_logger, mnist_test, args.mnistBatchSz)\n",
605 | " clear_output()\n",
606 | " display(fig)"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": null,
612 | "metadata": {},
613 | "outputs": [],
614 | "source": []
615 | }
616 | ],
617 | "metadata": {
618 | "kernelspec": {
619 | "display_name": "Python 3",
620 | "language": "python",
621 | "name": "python3"
622 | },
623 | "language_info": {
624 | "codemirror_mode": {
625 | "name": "ipython",
626 | "version": 3
627 | },
628 | "file_extension": ".py",
629 | "mimetype": "text/x-python",
630 | "name": "python",
631 | "nbconvert_exporter": "python",
632 | "pygments_lexer": "ipython3",
633 | "version": "3.6.8"
634 | }
635 | },
636 | "nbformat": 4,
637 | "nbformat_minor": 2
638 | }
639 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.0.0
2 | tqdm
3 | requests
4 |
--------------------------------------------------------------------------------
/satnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import SATNet
2 |
3 | __all__ = ['SATNet']
4 |
--------------------------------------------------------------------------------
/satnet/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Function
4 | import torch.optim as optim
5 |
6 | import satnet._cpp
7 | if torch.cuda.is_available(): import satnet._cuda
8 |
9 |
10 | def get_k(n):
11 | return int((2*n)**0.5+3)//4*4
12 |
13 | class MixingFunc(Function):
14 | '''Apply the Mixing method to the input probabilities.
15 |
16 | Args: see SATNet.
17 |
18 | Impl Note:
19 | The SATNet is a wrapper for the MixingFunc,
20 | handling the initialization and the wrapping of auxiliary variables.
21 | '''
22 | @staticmethod
23 | def forward(ctx, S, z, is_input, max_iter, eps, prox_lam):
24 | B, n, m, k = z.size(0), S.size(0), S.size(1), 32 #get_k(S.size(0))
25 | ctx.prox_lam = prox_lam
26 |
27 | device = 'cuda' if S.is_cuda else 'cpu'
28 | ctx.g, ctx.gnrm = torch.zeros(B,k, device=device), torch.zeros(B,n, device=device)
29 | ctx.index = torch.zeros(B,n, dtype=torch.int, device=device)
30 | ctx.is_input = torch.zeros(B,n, dtype=torch.int, device=device)
31 | ctx.V, ctx.W = torch.zeros(B,n,k, device=device).normal_(), torch.zeros(B,k,m, device=device)
32 | ctx.z = torch.zeros(B,n, device=device)
33 | ctx.niter = torch.zeros(B, dtype=torch.int, device=device)
34 |
35 | ctx.S = torch.zeros(n,m, device=device)
36 | ctx.Snrms = torch.zeros(n, device=device)
37 |
38 | ctx.z[:] = z.data
39 | ctx.S[:] = S.data
40 | ctx.is_input[:] = is_input.data
41 |
42 | perm = torch.randperm(n-1, dtype=torch.int, device=device)
43 |
44 | satnet_impl = satnet._cuda if S.is_cuda else satnet._cpp
45 | satnet_impl.init(perm, is_input, ctx.index, ctx.z, ctx.V)
46 |
47 | for b in range(B):
48 | ctx.W[b] = ctx.V[b].t().mm(ctx.S)
49 | ctx.Snrms[:] = S.norm(dim=1)**2
50 |
51 | satnet_impl.forward(max_iter, eps,
52 | ctx.index, ctx.niter, ctx.S, ctx.z,
53 | ctx.V, ctx.W, ctx.gnrm, ctx.Snrms, ctx.g)
54 |
55 | return ctx.z.clone()
56 |
57 | @staticmethod
58 | def backward(ctx, dz):
59 | B, n, m, k = dz.size(0), ctx.S.size(0), ctx.S.size(1), 32 #get_k(ctx.S.size(0))
60 |
61 | device = 'cuda' if ctx.S.is_cuda else 'cpu'
62 | ctx.dS = torch.zeros(B,n,m, device=device)
63 | ctx.U, ctx.Phi = torch.zeros(B,n,k, device=device), torch.zeros(B,k,m, device=device)
64 | ctx.dz = torch.zeros(B,n, device=device)
65 |
66 | ctx.dz[:] = dz.data
67 |
68 | satnet_impl = satnet._cuda if ctx.S.is_cuda else satnet._cpp
69 | satnet_impl.backward(ctx.prox_lam,
70 | ctx.is_input, ctx.index, ctx.niter, ctx.S, ctx.dS, ctx.z, ctx.dz,
71 | ctx.V, ctx.U, ctx.W, ctx.Phi, ctx.gnrm, ctx.Snrms, ctx.g)
72 |
73 | ctx.dS = ctx.dS.sum(dim=0)
74 |
75 | return ctx.dS, ctx.dz, None, None, None, None
76 |
77 | def insert_constants(x, pre, n_pre, app, n_app):
78 | ''' prepend and append torch tensors
79 | '''
80 | one = x.new(x.size()[0],1).fill_(1)
81 | seq = []
82 | if n_pre != 0:
83 | seq.append((pre*one).expand(-1, n_pre))
84 | seq.append(x)
85 | if n_app != 0:
86 | seq.append((app*one).expand(-1, n_app))
87 | r = torch.cat(seq, dim=1)
88 | r.requires_grad = False
89 | return r
90 |
91 | class SATNet(nn.Module):
92 | '''Apply a SATNet layer to complete the input probabilities.
93 |
94 | Args:
95 | n: Number of input variables.
96 | m: Rank of the clause matrix.
97 | aux: Number of auxiliary variables.
98 |
99 | max_iter: Maximum number of iterations for solving
100 | the inner optimization problem.
101 | Default: 40
102 | eps: The stopping threshold for the inner optimizaiton problem.
103 | The inner Mixing method will stop when the function decrease
104 | is less then eps times the initial function decrease.
105 | Default: 1e-4
106 | prox_lam: The diagonal increment in the backward linear system
107 | to make the backward pass more stable.
108 | Default: 1e-2
109 | weight_normalize: Set true to perform normlization for init weights.
110 | Default: True
111 |
112 | Inputs: (z, is_input)
113 | **z** of shape `(batch, n)`:
114 | Float tensor containing the probabilities (must be in [0,1]).
115 | **is_input** of shape `(batch, n)`:
116 | Int tensor indicating which **z** is a input.
117 |
118 | Outputs: z
119 | **z** of shape `(batch, n)`:
120 | The prediction probabiolities.
121 |
122 | Attributes: S
123 | **S** of shape `(n, m)`:
124 | The learnable clauses matrix containing `m` clauses
125 | for the `n` variables.
126 |
127 | Examples:
128 | >>> sat = satnet.SATNet(3, 4, aux=5)
129 | >>> z = torch.randn(2, 3)
130 | >>> is_input = torch.IntTensor([[1, 1, 0], [1,0,1]])
131 | >>> pred = sat(z, is_input)
132 | '''
133 |
134 | def __init__(self, n, m, aux=0, max_iter=40, eps=1e-4, prox_lam=1e-2, weight_normalize=True):
135 | super(SATNet, self).__init__()
136 |
137 | S_t = torch.FloatTensor(n+1+aux, m) # n+1 for truth vector
138 | S_t = S_t.normal_()
139 | if weight_normalize: S_t = S_t * ((.5/(n+1+aux+m))**0.5)
140 |
141 | self.S = nn.Parameter(S_t)
142 | self.aux = aux
143 | self.max_iter, self.eps, self.prox_lam = max_iter, eps, prox_lam
144 |
145 | def forward(self, z, is_input):
146 | B = z.size(0)
147 | device = 'cuda' if self.S.is_cuda else 'cpu'
148 | m = self.S.shape[1]
149 | if device == 'cpu' and m%4 != 0:
150 | raise ValueError('m is required to be a multiple of 4 on CPU for SSE acceleration. Now '+str(m))
151 | is_input = insert_constants(is_input.data, 1, 1, 0, self.aux)
152 | z = torch.cat([torch.ones(z.size(0),1,device=device), z, torch.zeros(z.size(0),self.aux,device=device)],dim=1)
153 |
154 | z = MixingFunc.apply(self.S, z, is_input, self.max_iter, self.eps, self.prox_lam)
155 |
156 | return z[:,1:self.S.size(0)-self.aux]
157 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import torch.cuda
2 |
3 | from setuptools import setup
4 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
5 | from torch.utils.cpp_extension import CUDA_HOME
6 |
7 | gencode = [
8 | '-gencode=arch=compute_50,code=sm_50',
9 | '-gencode=arch=compute_52,code=sm_52',
10 | '-gencode=arch=compute_60,code=sm_60',
11 | '-gencode=arch=compute_61,code=sm_61',
12 | '-gencode=arch=compute_75,code=sm_75',
13 | '-gencode=arch=compute_80,code=sm_80',
14 | '-gencode=arch=compute_86,code=sm_86',
15 | ]
16 |
17 | ext_modules = [
18 | CppExtension(
19 | name = 'satnet._cpp',
20 | include_dirs = ['./src'],
21 | sources = [
22 | 'src/satnet.cpp',
23 | 'src/satnet_cpu.cpp',
24 | ],
25 | extra_compile_args = ['-fopenmp', '-msse4.1', '-Wall', '-g']
26 | )
27 | ]
28 |
29 | if torch.cuda.is_available() and CUDA_HOME is not None:
30 | extension = CUDAExtension(
31 | name = 'satnet._cuda',
32 | include_dirs = ['./src'],
33 | sources = [
34 | 'src/satnet.cpp',
35 | 'src/satnet_cuda.cu',
36 | ],
37 | extra_compile_args = {
38 | 'cxx': ['-DMIX_USE_GPU', '-g'],
39 | 'nvcc': ['-g', '-restrict', '-maxrregcount', '32', '-lineinfo', '-Xptxas=-v']
40 | }
41 | )
42 | ext_modules.append(extension)
43 |
44 | with open("README.md", "r", encoding="utf-8") as fh:
45 | long_description = fh.read()
46 |
47 | # Python interface
48 | setup(
49 | name='satnet',
50 | version='0.1.4',
51 | install_requires=['torch>=1.3'],
52 | packages=['satnet'],
53 | ext_modules=ext_modules,
54 | cmdclass={'build_ext': BuildExtension},
55 | author='Po-Wei Wang',
56 | author_email='poweiw@cs.cmu.edu',
57 | url='https://github.com/locuslab/SATNet',
58 | zip_safe=False,
59 | description='Bridging deep learning and logical reasoning using a differentiable satisfiability solver',
60 | long_description=long_description,
61 | long_description_content_type="text/markdown",
62 | classifiers=[
63 | "License :: OSI Approved :: MIT License",
64 | ],
65 | )
66 |
--------------------------------------------------------------------------------
/src/satnet.cpp:
--------------------------------------------------------------------------------
1 | #ifdef MIX_USE_GPU
2 | #include
3 | #endif
4 | #include
5 |
6 | #ifdef MIX_USE_GPU
7 | #define DEVICE_NAME cuda
8 | #define _MIX_DEV_STR "cuda"
9 | #define _MIX_CUDA_DECL , cudaStream_t stream
10 | #define _MIX_CUDA_ARG , stream
11 | #define _MIX_CUDA_HEAD cudaStream_t stream = at::cuda::getCurrentCUDAStream();
12 | #define _MIX_CUDA_TAIL AT_CUDA_CHECK(cudaGetLastError());
13 | //AT_CUDA_CHECK(cudaStreamSynchronize(stream));
14 | #else
15 | #define DEVICE_NAME cpu
16 | #define _MIX_DEV_STR "cpu"
17 | #define _MIX_CUDA_DECL
18 | #define _MIX_CUDA_ARG
19 | #define _MIX_CUDA_HEAD
20 | #define _MIX_CUDA_TAIL
21 | #endif
22 |
23 | // name mangling for CPU and CUDA
24 | #define _MIX_CAT(x,y) x ## _ ## y
25 | #define _MIX_EVAL(x,y) _MIX_CAT(x,y)
26 | #define _MIX_FUNC(name) _MIX_EVAL(name, DEVICE_NAME)
27 |
28 | #include "satnet.h"
29 |
30 | using Tensor=torch::Tensor;
31 | float *fptr(Tensor& a) { return a.data_ptr(); }
32 | int *iptr(Tensor& a) { return a.data_ptr(); }
33 |
34 | void _MIX_FUNC(mix_init_launcher) (mix_t mix, int32_t *perm _MIX_CUDA_DECL);
35 | void _MIX_FUNC(mix_forward_launcher) (mix_t mix, int max_iter, float eps _MIX_CUDA_DECL);
36 | void _MIX_FUNC(mix_backward_launcher)(mix_t mix, float prox_lam _MIX_CUDA_DECL);
37 |
38 | void mix_init(Tensor perm,
39 | Tensor is_input, Tensor index, Tensor z, Tensor V)
40 | {
41 | _MIX_CUDA_HEAD;
42 |
43 | mix_t mix;
44 | mix.b = V.size(0); mix.n = V.size(1); mix.k = V.size(2);
45 | mix.is_input = iptr(is_input);
46 | mix.index = iptr(index);
47 | mix.z = fptr(z);
48 | mix.V = fptr(V);
49 |
50 | _MIX_FUNC(mix_init_launcher)(mix, iptr(perm) _MIX_CUDA_ARG);
51 |
52 | _MIX_CUDA_TAIL;
53 | }
54 |
55 | void mix_forward(int max_iter, float eps,
56 | Tensor index, Tensor niter, Tensor S, Tensor z, Tensor V, Tensor W, Tensor gnrm, Tensor Snrms, Tensor cache)
57 | {
58 | _MIX_CUDA_HEAD;
59 |
60 | mix_t mix;
61 | mix.b = V.size(0); mix.n = V.size(1); mix.m = S.size(1); mix.k = V.size(2);
62 | mix.index = iptr(index);
63 | mix.niter = iptr(niter);
64 | mix.S = fptr(S);
65 | mix.z = fptr(z);
66 | mix.V = fptr(V);
67 | mix.W = fptr(W);
68 | mix.gnrm = fptr(gnrm); mix.Snrms = fptr(Snrms);
69 | mix.cache = fptr(cache);
70 |
71 | _MIX_FUNC(mix_forward_launcher)(mix, max_iter, eps _MIX_CUDA_ARG);
72 |
73 | _MIX_CUDA_TAIL;
74 | }
75 |
76 | void mix_backward(float prox_lam,
77 | Tensor is_input, Tensor index, Tensor niter, Tensor S, Tensor dS, Tensor z, Tensor dz,
78 | Tensor V, Tensor U, Tensor W, Tensor Phi, Tensor gnrm, Tensor Snrms, Tensor cache)
79 | {
80 | _MIX_CUDA_HEAD;
81 |
82 | mix_t mix;
83 | mix.b = V.size(0); mix.n = V.size(1); mix.m = S.size(1); mix.k = V.size(2);
84 | mix.is_input = iptr(is_input);
85 | mix.index = iptr(index);
86 | mix.niter = iptr(niter);
87 | mix.S = fptr(S); mix.dS = fptr(dS);
88 | mix.z = fptr(z); mix.dz = fptr(dz);
89 | mix.V = fptr(V); mix.U = fptr(U);
90 | mix.W = fptr(W); mix.Phi = fptr(Phi);
91 | mix.gnrm = fptr(gnrm); mix.Snrms = fptr(Snrms);
92 | mix.cache = fptr(cache);
93 |
94 | _MIX_FUNC(mix_backward_launcher)(mix, prox_lam _MIX_CUDA_ARG);
95 |
96 | _MIX_CUDA_TAIL;
97 | }
98 |
99 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
100 | m.def("init" , &mix_init, "SATNet init (" _MIX_DEV_STR ")");
101 | m.def("forward" , &mix_forward, "SATNet forward (" _MIX_DEV_STR ")");
102 | m.def("backward" , &mix_backward, "SATNet backward (" _MIX_DEV_STR ")");
103 | }
104 |
--------------------------------------------------------------------------------
/src/satnet.h:
--------------------------------------------------------------------------------
1 | typedef struct mix_t {
2 | int b, n, m, k;
3 | int32_t *is_input; // b*n
4 | int32_t *index; // b*n
5 | int32_t *niter; // b
6 | float *S, *dS; // n*m
7 | float *z, *dz; // b*n
8 | float *V, *U; // b*n*k
9 | float *W, *Phi; // b*m*k
10 | float *gnrm, *Snrms;// b*n
11 | float *cache;
12 | } mix_t ;
13 |
--------------------------------------------------------------------------------
/src/satnet_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | #include
12 |
13 | #include "satnet.h"
14 |
15 | #define saxpy mysaxpy
16 | #define scopy myscopy
17 | #define sscal mysscal
18 | #define sdot mysdot
19 | #define snrm2 mysnrm2
20 | #define szero myszero
21 | #define saturate mysaturate
22 |
23 | const double MEPS = 1e-24;
24 |
25 | /*
26 | * Helper functions
27 | */
28 | void saxpy(float *__restrict__ y, float a, const float *__restrict__ x, int l)
29 | {
30 | y = (float*)__builtin_assume_aligned(y, 4*sizeof(float));
31 | x = (float*)__builtin_assume_aligned(x, 4*sizeof(float));
32 | __m128 const a_ = _mm_set1_ps(a);
33 | for(int i=0; i1)*(1-x);
162 | }
163 |
164 | // consider the \min unsat problem,
165 | void mix_forward(int max_iter, float eps, int n, int m, int k, const int32_t *index, int32_t *niter, const float *S, float *z, float *V, float *W, float *gnrm, float *Snrms, float *cache)
166 | {
167 | float delta;
168 | int iter = 0;
169 | for (; iter < max_iter; iter++) {
170 | delta = mix_kernel(1, 0, m, k, index, S, NULL, V, NULL, W, gnrm, Snrms, cache);
171 | if (iter && delta < eps) break;
172 | if (iter == 0) eps = delta*eps;
173 | }
174 |
175 | *niter = iter;
176 |
177 | for (int i,i_=0; (i=index[i_]); i_++) {
178 | float zi = V[i*k];
179 | zi = saturate((zi+1)/2)*2-1;
180 | zi = saturate(1-acosf(zi)/M_PI);
181 | z[i] = zi;
182 | }
183 | }
184 |
185 | void mix_backward(float prox_lam, int n, int m, int k, int32_t *is_input, int32_t *index, int32_t *niter, const float *S, float *dS, float *z, float *dz, const float *V, float *U, float *W, float *Phi, float *gnrm, float *Snrms, float *cache)
186 | {
187 | int invalid_flag=0;
188 | for (int i,i_=0; (i=index[i_]); i_++) {
189 | float zi = z[i];
190 | float dzi = dz[i]/M_PI/sin(zi*M_PI);
191 | if (isnan(dzi) || isinf(dzi) || gnrm[i] < MEPS) invalid_flag = 1;
192 | dz[i] = dzi;
193 | }
194 | if (invalid_flag) { szero(dz, n); return; }
195 |
196 | // solve P (S'S+D_z-D_sii)xI_k P U = -dz P v0
197 | for (int iter=0; iter<*niter; iter++) {
198 | mix_kernel(0, prox_lam, m, k, index, S, dz, U, V, Phi, gnrm, Snrms, cache);
199 | }
200 |
201 | // sanity check
202 | for (int ik=0; ik
2 | #include
3 | //#include
4 | #include
5 | #include
6 |
7 | #include
8 | #include "satnet.h"
9 |
10 | const double MEPS = 1e-24;
11 | const int WARP_SIZE = 32;
12 | const int WARP_NUM = 32;
13 | const int MBUF_SIZE = 320;
14 |
15 | // Warp level dot product
16 | __device__
17 | float warpdot(const float * x, const float * z, int k)
18 | {
19 | if (k==0) return 0;
20 | int lane = threadIdx.x % WARP_SIZE;
21 |
22 | float val = 0;
23 | #pragma unroll 2
24 | for (int i=lane; iMBUF_SIZE ? MBUF_SIZE : m; // mbuf = # of m inside buffer (in smem)
90 | int mrem = m>MBUF_SIZE ? m-MBUF_SIZE : 0; // mrem = # of m outside buffer (in global mem)
91 | for (int j=lane; j>>(perm,
259 | mix.n, mix.k, mix.is_input, mix.index, mix.z,
260 | mix.V);
261 | }
262 |
263 | void mix_forward_launcher_cuda(mix_t mix, int max_iter, float eps, cudaStream_t stream)
264 | {
265 | int smem_size = (mix.m+mix.k*(1+MBUF_SIZE))*sizeof(float);
266 | mix_forward<<>>(max_iter, eps,
267 | mix.n, mix.m, mix.k, mix.index, mix.niter,
268 | mix.S, mix.z, mix.V, mix.W, mix.gnrm, mix.Snrms, mix.cache);
269 | }
270 |
271 | void mix_backward_launcher_cuda(mix_t mix, float prox_lam, cudaStream_t stream)
272 | {
273 | int smem_size = (mix.m+mix.k*(1+MBUF_SIZE))*sizeof(float);
274 | mix_backward<<>>(prox_lam,
275 | mix.n, mix.m, mix.k, mix.is_input, mix.index, mix.niter,
276 | mix.S, mix.dS, mix.z, mix.dz, mix.V, mix.U, mix.W, mix.Phi, mix.gnrm, mix.Snrms, mix.cache);
277 | }
278 |
--------------------------------------------------------------------------------