├── .gitignore ├── .gitmodules ├── circuits ├── jwt_proof.auth0.circom ├── calculate_total.circom ├── bitify.circom ├── slice.circom ├── claim_proof.circom ├── jwt_proof.circom └── sha256.circom ├── index.js ├── js ├── test.js ├── circuit.js └── utils.js ├── package.json └── test ├── slice.js ├── utils.js ├── sha256.js ├── claim_proof.js └── jwt_proof.js /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | build/ 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "circomlib"] 2 | path = circomlib 3 | url = https://github.com/iden3/circomlib.git 4 | -------------------------------------------------------------------------------- /circuits/jwt_proof.auth0.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "jwt_proof.circom"; 4 | 5 | component main = JwtProof(384, 8, 248, 248); 6 | -------------------------------------------------------------------------------- /index.js: -------------------------------------------------------------------------------- 1 | const utils = require('./js/utils'); 2 | const circuit = require('./js/circuit'); 3 | 4 | module.exports = { 5 | utils: utils, 6 | circuit: circuit 7 | } 8 | -------------------------------------------------------------------------------- /circuits/calculate_total.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | // This circuit returns the sum of the inputs. 4 | // n must be greater than 0. 5 | template CalculateTotal(n) { 6 | signal input nums[n]; 7 | signal output sum; 8 | 9 | signal sums[n]; 10 | sums[0] <== nums[0]; 11 | 12 | for (var i=1; i < n; i++) { 13 | sums[i] <== sums[i - 1] + nums[i]; 14 | } 15 | 16 | sum <== sums[n - 1]; 17 | } 18 | -------------------------------------------------------------------------------- /circuits/bitify.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | template Num2BitsLE(n) { 4 | signal input in; 5 | signal output out[n]; 6 | var lc1 = 0; 7 | 8 | var e2 = 1; 9 | for (var i = 0; i < n; i++) { 10 | var b = (n - 1) - i; 11 | out[b] <-- (in >> i) & 1; 12 | out[b] * (out[b] - 1 ) === 0; 13 | lc1 += out[b] * e2; 14 | e2 = e2 + e2; 15 | } 16 | 17 | lc1 === in; 18 | } 19 | 20 | 21 | template Bits2NumLE(n) { 22 | signal input in[n]; 23 | signal output out; 24 | var lc1=0; 25 | 26 | var e2 = 1; 27 | for (var i = 0; i < n; i++) { 28 | lc1 += in[(n - 1) - i] * e2; 29 | e2 = e2 + e2; 30 | } 31 | 32 | lc1 ==> out; 33 | } 34 | -------------------------------------------------------------------------------- /js/test.js: -------------------------------------------------------------------------------- 1 | const temp = require("temp"); 2 | const path = require("path"); 3 | const fs = require("fs"); 4 | 5 | const circom_wasm = require("circom_tester").wasm; 6 | 7 | async function genMain(template_file, template_name, params = [], tester = circom_wasm) { 8 | temp.track(); 9 | 10 | const temp_circuit = await temp.open({prefix: template_name, suffix: ".circom"}); 11 | const include_path = path.relative(temp_circuit.path, template_file); 12 | const params_string = JSON.stringify(params).slice(1, -1); 13 | 14 | fs.writeSync(temp_circuit.fd, ` 15 | pragma circom 2.0.0; 16 | 17 | include "${include_path}"; 18 | 19 | component main = ${template_name}(${params_string}); 20 | `); 21 | 22 | return circom_wasm(temp_circuit.path); 23 | } 24 | 25 | module.exports = { 26 | genMain: genMain 27 | } 28 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "snark-jwt-verify", 3 | "version": "1.0.0", 4 | "description": "Verify JWTs using SNARK circuits", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "mocha --max-old-space-size=16000 -t 10000s" 8 | }, 9 | "repository": { 10 | "type": "git", 11 | "url": "git+https://github.com/TheFrozenFire/snark-jwt-verify.git" 12 | }, 13 | "keywords": [ 14 | "circuit", 15 | "circom", 16 | "zksnark", 17 | "rsa", 18 | "jwt", 19 | "openid" 20 | ], 21 | "author": "thefrozenfire", 22 | "license": "MIT", 23 | "bugs": { 24 | "url": "https://github.com/TheFrozenFire/snark-jwt-verify/issues" 25 | }, 26 | "homepage": "https://github.com/TheFrozenFire/snark-jwt-verify#readme", 27 | "dependencies": { 28 | "bigint-buffer": "^1.1.5", 29 | "node-rsa": "^1.1.1", 30 | "snarkjs": "^0.4.10", 31 | "temp": "^0.9.4" 32 | }, 33 | "devDependencies": { 34 | "circom_tester": "0.0.7", 35 | "jose": "^4.3.7", 36 | "mocha": "^9.1.3", 37 | "node-jose": "^2.0.0" 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /test/slice.js: -------------------------------------------------------------------------------- 1 | const chai = require("chai"); 2 | const path = require("path"); 3 | const assert = chai.assert; 4 | const crypto = require("crypto"); 5 | 6 | const tester = require("circom_tester").wasm; 7 | 8 | const utils = require("../js/utils"); 9 | const test = require("../js/test"); 10 | 11 | describe("Array Slice", () => { 12 | var cir_fixed; 13 | var cir; 14 | 15 | before(async() => { 16 | cir_fixed = await test.genMain(path.join(__dirname, "..", "circuits", "slice.circom"), "SliceFixed", [6, 2]); 17 | await cir_fixed.loadSymbols(); 18 | 19 | cir = await test.genMain(path.join(__dirname, "..", "circuits", "slice.circom"), "Slice", [10, 5]); 20 | await cir.loadSymbols(); 21 | }); 22 | 23 | it("Fixed circuit extracts correct value", async () => { 24 | input = [1,2,3,4,5,6]; 25 | 26 | const witness = await cir_fixed.calculateWitness({ "in": input, "offset": 1 }); 27 | 28 | assert.sameOrderedMembers(utils.getWitnessArray(witness, cir_fixed.symbols, "main.out"), [2n, 3n]); 29 | }); 30 | 31 | it("Non-fixed circuit extracts correct, masked value", async () => { 32 | input = [1,2,3,4,5,6,7,8,9,10]; 33 | 34 | const witness = await cir.calculateWitness({ "in": input, "offset": 1, "length": 2 }); 35 | 36 | assert.sameOrderedMembers(utils.getWitnessArray(witness, cir.symbols, "main.out"), [2n, 3n, 0n, 0n, 0n]); 37 | }); 38 | }); 39 | -------------------------------------------------------------------------------- /circuits/slice.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "calculate_total.circom"; 4 | include "../circomlib/circuits/comparators.circom"; 5 | 6 | template SliceFixed(inSize, outSize) { 7 | signal input in[inSize]; 8 | signal input offset; 9 | 10 | signal output out[outSize]; 11 | 12 | component selector[outSize]; 13 | component eqs[inSize][outSize]; 14 | for(var i = 0; i < outSize; i++) { 15 | selector[i] = CalculateTotal(inSize); 16 | 17 | for(var j = 0; j < inSize; j++) { 18 | eqs[j][i] = IsEqual(); 19 | eqs[j][i].in[0] <== j; 20 | eqs[j][i].in[1] <== offset + i; 21 | 22 | selector[i].nums[j] <== eqs[j][i].out * in[j]; 23 | } 24 | 25 | out[i] <== selector[i].sum; 26 | } 27 | } 28 | 29 | template Slice(inSize, outSize) { 30 | signal input in[inSize]; 31 | signal input offset; 32 | signal input length; 33 | 34 | signal output out[outSize]; 35 | 36 | component selector[outSize]; 37 | component eqs[inSize][outSize]; 38 | component lt[outSize]; 39 | signal mask[inSize][outSize]; 40 | for(var i = 0; i < outSize; i++) { 41 | selector[i] = CalculateTotal(inSize); 42 | 43 | lt[i] = LessThan(8); 44 | lt[i].in[0] <== i; 45 | lt[i].in[1] <== length; 46 | 47 | for(var j = 0; j < inSize; j++) { 48 | eqs[j][i] = IsEqual(); 49 | eqs[j][i].in[0] <== j; 50 | eqs[j][i].in[1] <== offset + i; 51 | 52 | mask[j][i] <== eqs[j][i].out * lt[i].out; 53 | 54 | selector[i].nums[j] <== mask[j][i] * in[j]; 55 | } 56 | 57 | out[i] <== selector[i].sum; 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /test/utils.js: -------------------------------------------------------------------------------- 1 | const chai = require("chai"); 2 | const assert = chai.assert; 3 | const crypto = require("crypto"); 4 | 5 | const circuit = require("../js/circuit"); 6 | const utils = require("../js/utils"); 7 | 8 | describe("Circuit Utilities", () => { 9 | it("Buffer to/from bit array works as expected", async () => { 10 | const input = crypto.randomBytes(20*32).toString("hex"); 11 | 12 | const bits = utils.buffer2BitArray(Buffer.from(input)); 13 | const buffer = utils.bitArray2Buffer(bits); 14 | 15 | assert.equal(input, buffer.toString()); 16 | }); 17 | 18 | it("rfc4634#4.1 padding conforms: L % 512 = 0", async () => { 19 | const input = crypto.randomBytes(512/8/2).toString("hex"); 20 | 21 | const bits = utils.buffer2BitArray(Buffer.from(input)); 22 | const padded = circuit.padMessage(bits); 23 | 24 | assert.equal(bits.length, 512); 25 | assert.equal(padded.length, 1024); // Padding a 448+-bit message requires an additional block 26 | assert.equal(1, padded.slice(-512, -511)); // Padding begins with 1 27 | assert.equal(bits.length, parseInt(padded.slice(-64).join(''), 2)); // base2(L) 28 | }); 29 | 30 | it("rfc4634#4.1 padding conforms: L % 512 = 65", async () => { 31 | const input = crypto.randomBytes(512/8/2).toString("hex"); 32 | 33 | const bits = utils.buffer2BitArray(Buffer.from(input)).slice(0, 447); 34 | const padded = circuit.padMessage(bits); 35 | 36 | assert.equal(bits.length, 447); 37 | assert.equal(padded.length, 512); 38 | assert.equal(1, padded.slice(-65, -64)); // Padding begins with 1 39 | assert.equal(bits.length, parseInt(padded.slice(-64).join(''), 2)); 40 | }); 41 | 42 | it("rfc4634#4.1 padding conforms: L % 512 = 100", async () => { 43 | const input = crypto.randomBytes(512/8/2).toString("hex"); 44 | 45 | const bits = utils.buffer2BitArray(Buffer.from(input)).slice(0, 412); 46 | const padded = circuit.padMessage(bits); 47 | 48 | assert.equal(bits.length, 412); 49 | assert.equal(padded.length, 512); 50 | assert.equal(1, padded.slice(-100, -99)); // Padding begins with 1 51 | assert.equal(bits.length, parseInt(padded.slice(-64).join(''), 2)); 52 | }); 53 | }); 54 | -------------------------------------------------------------------------------- /test/sha256.js: -------------------------------------------------------------------------------- 1 | const chai = require("chai"); 2 | const path = require("path"); 3 | const assert = chai.assert; 4 | const crypto = require("crypto"); 5 | 6 | const tester = require("circom_tester").wasm; 7 | 8 | const circuit = require("../js/circuit"); 9 | const utils = require("../js/utils"); 10 | const test = require("../js/test"); 11 | 12 | describe("Unsafe SHA256", () => { 13 | const nBlocks = 20; 14 | const hexBytesToBlock = 512/8/2; 15 | var cir; 16 | 17 | before(async() => { 18 | cir = await test.genMain(path.join(__dirname, "..", "circuits", "sha256.circom"), "Sha256_unsafe", [20]); 19 | await cir.loadSymbols(); 20 | }); 21 | 22 | it("Hashing produces expected output for filled blocks", async () => { 23 | const input = crypto.randomBytes((nBlocks * hexBytesToBlock)-32).toString("hex"); 24 | const hash = crypto.createHash("sha256").update(input).digest("hex"); 25 | 26 | const inputs = circuit.genSha256Inputs(input, nBlocks); 27 | 28 | const witness = await cir.calculateWitness(inputs, true); 29 | 30 | const hash2 = utils.getWitnessBuffer(witness, cir.symbols, "main.out").toString("hex"); 31 | 32 | assert.equal(hash2, hash); 33 | }); 34 | 35 | it("Hashing produces expected output for partial last block", async () => { 36 | const input = crypto.randomBytes((nBlocks * hexBytesToBlock)-100).toString("hex"); 37 | const hash = crypto.createHash("sha256").update(input).digest("hex"); 38 | 39 | const inputs = circuit.genSha256Inputs(input, nBlocks); 40 | 41 | const witness = await cir.calculateWitness(inputs, true); 42 | 43 | const hash2 = utils.getWitnessBuffer(witness, cir.symbols, "main.out").toString("hex"); 44 | 45 | assert.equal(hash2, hash); 46 | }); 47 | 48 | it("Hashing produces expected output for less than nBlocks blocks", async () => { 49 | const input = crypto.randomBytes((nBlocks-8) * hexBytesToBlock).toString("hex"); 50 | const hash = crypto.createHash("sha256").update(input).digest("hex"); 51 | 52 | const inputs = circuit.genSha256Inputs(input, nBlocks); 53 | 54 | const witness = await cir.calculateWitness(inputs, true); 55 | 56 | const hash2 = utils.getWitnessBuffer(witness, cir.symbols, "main.out").toString("hex"); 57 | 58 | assert.equal(hash2, hash); 59 | }); 60 | }); 61 | 62 | 63 | -------------------------------------------------------------------------------- /test/claim_proof.js: -------------------------------------------------------------------------------- 1 | const chai = require("chai"); 2 | const path = require("path"); 3 | const assert = chai.assert; 4 | const crypto = require("crypto"); 5 | const {toBigIntBE} = require('bigint-buffer'); 6 | 7 | const tester = require("circom_tester").wasm; 8 | 9 | const circuit = require("../js/circuit"); 10 | const utils = require("../js/utils"); 11 | const test = require("../js/test"); 12 | 13 | describe("Claim Proof", () => { 14 | const nCount = 64; 15 | const nWidth = 16; 16 | const maxClaimLength = 12; 17 | 18 | const hexBytesToSegment = 16/8/2; 19 | const segmentsToBlock = 512/nWidth; 20 | 21 | var cir; 22 | 23 | before(async() => { 24 | cir = await test.genMain(path.join(__dirname, "..", "circuits", "claim_proof.circom"), "ClaimProof", [nCount, nWidth, maxClaimLength]); 25 | await cir.loadSymbols(); 26 | }); 27 | 28 | it("Num2Bits converts inputs to left-hand LSB", async () => { 29 | num2bits = await test.genMain(path.join(__dirname, "..", "circomlib", "circuits", "bitify.circom"), "Num2Bits", [16]); 30 | await num2bits.loadSymbols(); 31 | 32 | input = crypto.randomBytes(2); 33 | 34 | const witness = await num2bits.calculateWitness({"in": "0x" + input.toString("hex") }); 35 | const out = toBigIntBE(utils.bitArray2Buffer(utils.getWitnessArray(witness, num2bits.symbols, "main.out").reverse())); 36 | 37 | assert.equal(out, toBigIntBE(Buffer.from(input))); 38 | }); 39 | 40 | it("Extract from JSON", async () => { 41 | const input = '{ "sub": "1234567890", "name": "John Doe", "iat": 1516239022 }'; 42 | const hash = crypto.createHash("sha256").update(input).digest("hex"); 43 | 44 | const fieldLength = utils.getJSONFieldLength(input, "sub"); 45 | const claimLength = Math.ceil(fieldLength / (nWidth / 8)); 46 | var inputs = circuit.genClaimProofInputs(input, nCount, "sub", claimLength, nWidth); 47 | const expectedClaim = '"sub": "1234567890",'; 48 | 49 | const witness = await cir.calculateWitness(inputs, true); 50 | 51 | const hash2 = utils.getWitnessBuffer(witness, cir.symbols, "main.hash").toString("hex"); 52 | const claim = utils.trimEndByChar(utils.bigIntArray2Buffer(utils.getWitnessArray(witness, cir.symbols, "main.claim")).toString(), '\u0000'); 53 | 54 | assert.equal(hash2, hash); 55 | assert.equal(claim, expectedClaim); 56 | }); 57 | }); 58 | -------------------------------------------------------------------------------- /js/circuit.js: -------------------------------------------------------------------------------- 1 | const utils = require('./utils'); 2 | const {toBigIntBE} = require('bigint-buffer'); 3 | 4 | // https://datatracker.ietf.org/doc/html/rfc4634#section-4.1 5 | function padMessage(bits) { 6 | const L = bits.length; 7 | const K = (512 + 448 - (L % 512 + 1)) % 512; 8 | 9 | bits = bits.concat([1]); 10 | if(K > 0) { 11 | bits = bits.concat(Array(K).fill(0)); 12 | } 13 | bits = bits.concat(utils.buffer2BitArray(Buffer.from(L.toString(16).padStart(16, '0'), 'hex'))) 14 | 15 | return bits; 16 | } 17 | 18 | function genClaimParams(input, claimField, claimLength, nWidth) { 19 | const claimPattern = new RegExp(`"${claimField}"\\:\\s*"`); 20 | const claimOffset = Math.floor(input.search(claimPattern) / (nWidth / 8)); 21 | 22 | var inputs = { "claimOffset": claimOffset }; 23 | 24 | if(claimLength !== undefined) { 25 | inputs = Object.assign({}, 26 | inputs, 27 | { "claimLength": claimLength } 28 | ); 29 | } 30 | 31 | return inputs; 32 | } 33 | 34 | function genJwtMask(input, fields) { 35 | const [header, payload] = input.split('.'); 36 | 37 | var payloadMask = Array(payload.length).fill(0); 38 | for(const field of fields) { 39 | var [start, end] = utils.getBase64JSONSlice(payload, field); 40 | 41 | for(var i = start; i <= end; i++) { 42 | payloadMask[i] = 1; 43 | } 44 | } 45 | 46 | return Array(header.length + 1).fill(0).concat(payloadMask); 47 | } 48 | 49 | function genSha256Inputs(input, nCount, nWidth = 512, inParam = "in") { 50 | var segments = utils.arrayChunk(padMessage(utils.buffer2BitArray(Buffer.from(input))), nWidth); 51 | const tBlock = segments.length / (512 / nWidth); 52 | 53 | if(segments.length < nCount) { 54 | segments = segments.concat(Array(nCount-segments.length).fill(Array(nWidth).fill(0))); 55 | } 56 | 57 | if(segments.length > nCount) { 58 | throw new Error('Padded message exceeds maximum blocks supported by circuit'); 59 | } 60 | 61 | return { [inParam]: segments, "tBlock": tBlock }; 62 | } 63 | 64 | function genClaimProofInputs(input, nCount, claimField, claimLength = undefined, nWidth = 16, inParam = "payload") { 65 | var inputs = genSha256Inputs(input, nCount, nWidth, inParam); 66 | inputs[inParam] = inputs[inParam].map(bits => toBigIntBE(utils.bitArray2Buffer(bits))); 67 | 68 | inputs = Object.assign({}, 69 | inputs, 70 | genClaimParams(input, claimField, claimLength, nWidth) 71 | ); 72 | 73 | return inputs; 74 | } 75 | 76 | function genJwtProofInputs(input, nCount, fields, nWidth = 16, inParam = "payload") { 77 | var inputs = genSha256Inputs(input, nCount, nWidth, inParam); 78 | inputs[inParam] = inputs[inParam].map(bits => toBigIntBE(utils.bitArray2Buffer(bits))); 79 | 80 | inputs = Object.assign({}, 81 | inputs, 82 | { "mask": genJwtMask(input, fields).concat(Array(nCount - input.length).fill(0)) } 83 | ); 84 | 85 | return inputs; 86 | } 87 | 88 | module.exports = { 89 | padMessage: padMessage, 90 | genClaimParams: genClaimParams, 91 | genJwtMask: genJwtMask, 92 | genSha256Inputs: genSha256Inputs, 93 | genClaimProofInputs: genClaimProofInputs, 94 | genJwtProofInputs: genJwtProofInputs, 95 | } 96 | -------------------------------------------------------------------------------- /js/utils.js: -------------------------------------------------------------------------------- 1 | function arrayChunk(array, chunk_size) { 2 | return Array(Math.ceil(array.length / chunk_size)).fill().map((_, index) => index * chunk_size).map(begin => array.slice(begin, begin + chunk_size)); 3 | } 4 | 5 | function trimEndByChar(string, character) { 6 | const arr = Array.from(string); 7 | const last = arr.reverse().findIndex(char => char !== character); 8 | return string.substring(0, string.length - last); 9 | } 10 | 11 | function getJSONFieldLength(input, field) { 12 | const json_input = JSON.parse(input); 13 | const fieldNameLength = input.match(new RegExp(`"${field}"\\:\\s*`))[0].length; 14 | const fieldValueLength = JSON.stringify(json_input[field]).length; 15 | 16 | return fieldNameLength + fieldValueLength; 17 | } 18 | 19 | function getBase64JSONSlice(input, field) { 20 | const decoded = Buffer.from(input, 'base64').toString(); 21 | const fieldStart = decoded.indexOf(`"${field}"`); 22 | const lead = trimEndByChar(Buffer.from(decoded.slice(0, fieldStart)).toString('base64'), '='); 23 | const fieldLength = getJSONFieldLength(decoded, field); 24 | const target = trimEndByChar(Buffer.from(decoded.slice(fieldStart, fieldStart + fieldLength)).toString('base64'), '='); 25 | 26 | const start = Math.floor(lead.length / 4) * 4; 27 | const end = Math.ceil(((lead.length + target.length) - 1) / 4) * 4; 28 | 29 | return [start, end >= input.length ? input.length - 1 : end - 1]; 30 | } 31 | 32 | function buffer2BitArray(b) { 33 | return [].concat(...Array.from(b.entries()).map(([index, byte]) => byte.toString(2).padStart(8, '0').split('').map(bit => bit == '1' ? 1 : 0) )) 34 | } 35 | 36 | function bitArray2Buffer(a) { 37 | return Buffer.from(arrayChunk(a, 8).map(byte => parseInt(byte.join(''), 2))) 38 | } 39 | 40 | function bigIntArray2Bits(arr, intSize=16) { 41 | return [].concat(...arr.map(n => n.toString(2).padStart(intSize, '0').split(''))).map(bit => bit == '1' ? 1 : 0); 42 | } 43 | 44 | function bigIntArray2Buffer(arr, intSize=16) { 45 | return bitArray2Buffer(bigIntArray2Bits(arr, intSize)); 46 | } 47 | 48 | function getWitnessValue(witness, symbols, varName) { 49 | return witness[symbols[varName]['varIdx']]; 50 | } 51 | 52 | function getWitnessMap(witness, symbols, arrName) { 53 | return Object.entries(symbols).filter(([index, symbol]) => index.startsWith(arrName)).map(([index, symbol]) => Object.assign({}, symbol, { "name": index, "value": witness[symbol['varIdx']] }) ); 54 | } 55 | 56 | function getWitnessArray(witness, symbols, arrName) { 57 | return Object.entries(symbols).filter(([index, symbol]) => index.startsWith(`${arrName}[`)).map(([index, symbol]) => witness[symbol['varIdx']] ); 58 | } 59 | 60 | function getWitnessBuffer(witness, symbols, arrName, varSize=1) { 61 | const witnessArray = getWitnessArray(witness, symbols, arrName); 62 | if(varSize == 1) { 63 | return bitArray2Buffer(witnessArray); 64 | } else { 65 | return bigIntArray2Buffer(witnessArray, varSize); 66 | } 67 | } 68 | 69 | module.exports = { 70 | arrayChunk: arrayChunk, 71 | trimEndByChar: trimEndByChar, 72 | getJSONFieldLength: getJSONFieldLength, 73 | getBase64JSONSlice: getBase64JSONSlice, 74 | buffer2BitArray: buffer2BitArray, 75 | bitArray2Buffer: bitArray2Buffer, 76 | bigIntArray2Bits: bigIntArray2Bits, 77 | bigIntArray2Buffer: bigIntArray2Buffer, 78 | getWitnessValue: getWitnessValue, 79 | getWitnessMap: getWitnessMap, 80 | getWitnessArray: getWitnessArray, 81 | getWitnessBuffer: getWitnessBuffer, 82 | } 83 | -------------------------------------------------------------------------------- /circuits/claim_proof.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "sha256.circom"; 4 | include "slice.circom"; 5 | include "../circomlib/circuits/bitify.circom"; 6 | 7 | /* 8 | Claim Proof 9 | Takes a payload segmented into nWidth chunks and calculates a SHA256 hash, for which an RSA signature is known, 10 | as well as extracting a claim from the payload to be publicly output. 11 | 12 | Construction Parameters: 13 | - nCount: Number of payload inputs of nWidth size 14 | - nWidth: Bit width of payload inputs 15 | - maxClaimLength: Maximum length of claim in nWidth segments 16 | 17 | Inputs: 18 | - payload[nCount]: Segments of payload as nWidth bit chunks 19 | - tBlock: At which 512-bit block to select output hash 20 | - claimOffset: Offset in nWidth segments to start extracting claim 21 | - claimLength: Number of nWidth segments to extract 22 | 23 | Outputs: 24 | - hash: 256-bit SHA256 hash output 25 | - claim: nWidth-bit claim segments 26 | */ 27 | template ClaimProof(nCount, nWidth, maxClaimLength) { 28 | signal input payload[nCount]; 29 | signal input tBlock; 30 | 31 | signal input claimOffset; 32 | signal input claimLength; 33 | 34 | signal output hash[256]; 35 | signal output claim[maxClaimLength]; 36 | 37 | // Segments must divide evenly into 512 bit blocks 38 | assert((nCount * nWidth) % 512 == 0); 39 | assert(nWidth <= 512); 40 | assert(512 % nWidth == 0); 41 | 42 | // The number of payload segments, times the bit width of each is the bit length of the payload. 43 | // The payload is decomposed to 512-bit blocks for SHA-256 44 | var nBlocks = (nCount * nWidth) / 512; 45 | 46 | // How many segments are in each block 47 | var nSegments = 512 / nWidth; 48 | 49 | component sha256 = Sha256_unsafe(nBlocks); 50 | component sha256_blocks[nBlocks][nSegments]; 51 | 52 | component claimExtract = Slice(nCount, maxClaimLength); 53 | claimExtract.offset <== claimOffset; 54 | claimExtract.length <== claimLength; 55 | 56 | // For each 512-bit block going into SHA-256 57 | for(var b = 0; b < nBlocks; b++) { 58 | // For each segment going into that block 59 | for(var s = 0; s < nSegments; s++) { 60 | // The index from the payload is offset by the block we're composing times the number of segments per block, 61 | // s is then the segment offset within the block. 62 | var payloadIndex = (b * nSegments) + s; 63 | 64 | // Decompose each segment into an array of individual bits 65 | sha256_blocks[b][s] = Num2Bits(nWidth); 66 | sha256_blocks[b][s].in <== payload[payloadIndex]; 67 | 68 | // The bit index going into the current SHA-256 block is offset by the segment number times the bit width 69 | // of each payload segment. sOffset + i is then the bit offset within the block (0-511). Num2Bits outputs 70 | // in left-hand LSB, so we reverse the ordering of the bits as they go into the SHA-256 circuit. 71 | var sOffset = s * nWidth; 72 | for(var i = 0; i < nWidth; i++) { 73 | sha256.in[b][sOffset + i] <== sha256_blocks[b][s].out[nWidth - i - 1]; 74 | } 75 | } 76 | } 77 | sha256.tBlock <== tBlock; 78 | 79 | for(var p = 0; p < nCount; p++) { 80 | claimExtract.in[p] <== payload[p]; 81 | } 82 | 83 | for(var i = 0; i < 256; i++) { 84 | hash[i] <== sha256.out[i]; 85 | } 86 | 87 | for(var i = 0; i < maxClaimLength; i++) { 88 | claim[i] <== claimExtract.out[i]; 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /test/jwt_proof.js: -------------------------------------------------------------------------------- 1 | const chai = require("chai"); 2 | const path = require("path"); 3 | const assert = chai.assert; 4 | const crypto = require("crypto"); 5 | const jose = require("jose"); 6 | const {toBigIntBE} = require('bigint-buffer'); 7 | 8 | const tester = require("circom_tester").wasm; 9 | 10 | const circuit = require("../js/circuit"); 11 | const utils = require("../js/utils"); 12 | const test = require("../js/test"); 13 | 14 | describe("JWT Proof", () => { 15 | const inCount = 384; 16 | const inWidth = 8; 17 | const outWidth = 248; 18 | const hashWidth = 248; 19 | 20 | const jwt = 'eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6InBldmpiYS1welhGU0ZDcnRTYlg5SyJ9.eyJpc3MiOiJodHRwczovL2Rldi05aDQ3YWpjOS51cy5hdXRoMC5jb20vIiwic3ViIjoidHdpdHRlcnwzMzc4MzQxMiIsImF1ZCI6IlQxNWU2NDZiNHVoQXJ5eW9qNEdOUm9uNnpzNE1ySEZWIiwiaWF0IjoxNjM5MTczMDI4LCJleHAiOjE2MzkyMDkwMjgsIm5vbmNlIjoiNDQwMTdhODkifQ.Vg2Vv-NJXdCqLy_JF4ecEsU_NgaA3DXbjwPfqr-euuXc-WPeyF00yRDP6_PVCx9p8PAU48fCMfNAKEFemPpY5Trn8paeweFk6uWZWGR42vo6BShryLFGRdce0MfTEBdZVsYnx-PDFz5aRFYxNnZL8sv2DUJ4NQM_8Zmz2EI7sSV7_kHCoXz7UHIOAtN8_otxCRwvrR3xAJ9P-Qp43HhUqM0fiC4RC3YkVKHRARcWC4bdVLBpKa1BBs4cd2wQ_tzv15YHPEyy4ODZGSX_M9cic-95TcpvVSuymw3bGj6_a7EPxcs6BzZGWlBwsh2ltB6FcLsDuAxxCPIG39tZ3Arp6Q'; 21 | const input = jwt.split('.').slice(0,2).join('.'); 22 | const signature = jwt.split('.')[2]; 23 | const jwk = { 24 | "alg": "RS256", 25 | "kty": "RSA", 26 | "use": "sig", 27 | "n": "sR4EKsGJBJWzlQZfx-Az5IgyhWOo4deEY3PAadE9kjQcyxj5zcTSae4rB6YiOYtEzqjce-dXhpubxjS_olr0n0puCRN0m7u5Hhim029_f1gN2HQofcCRtJY4c6Vr5xkprmSSxk127tYKJ1X-86vzJLPR2p3VUkznTgskEP5bxvHVyj814NQLhdMmFAJOwu9Uuu2oGE4TB3IiZgSCY8gAdt4YfCXqFeLBWPO93JVwPdU4TN3wRMTwEz_by5ZV29jg8On2WBWEt6RL5BZEg_Mxy6OW_YM_csKvr8irJMTv8s4V-GizO2FUQCdURQyfCHyD95WyW2_u3PpxzC_lizeBZQ", 28 | "e": "AQAB", 29 | "kid": "pevjba-pzXFSFCrtSbX9K" 30 | }; 31 | 32 | var cir; 33 | 34 | before(async() => { 35 | cir = await test.genMain(path.join(__dirname, "..", "circuits", "jwt_proof.circom"), "JwtProof", [inCount, inWidth, outWidth, hashWidth]); 36 | await cir.loadSymbols(); 37 | }); 38 | 39 | it("JWT masking", async() => { 40 | const mask = circuit.genJwtMask(input, ["sub", "nonce"]); 41 | 42 | const claims = input.split('').map((c, i) => mask[i] == 1 ? c : ' ').join('').split(/\s+/).filter(e => e !== '').map(e => Buffer.from(e, 'base64').toString()); 43 | 44 | assert.include(claims[0], '"sub":"twitter|33783412"', "Does not contain sub claim"); 45 | assert.include(claims[1], '"nonce":"44017a89"', "Does not contain nonce claim"); 46 | }); 47 | 48 | it("Extract from Base64 JSON", async () => { 49 | const hash = crypto.createHash("sha256").update(input).digest("hex").slice(0, hashWidth / 4); 50 | const pubkey = await jose.importJWK(jwk); 51 | 52 | var inputs = circuit.genJwtProofInputs(input, inCount, ["sub", "nonce"], inWidth); 53 | 54 | const witness = await cir.calculateWitness(inputs, true); 55 | 56 | const hash2 = utils.getWitnessValue(witness, cir.symbols, "main.hash").toString(16); 57 | const masked = utils.getWitnessBuffer(witness, cir.symbols, "main.out", varSize=outWidth).toString(); 58 | const claims = masked.split(/\x00+/).filter(e => e !== '').map(e => Buffer.from(e, 'base64').toString()); 59 | 60 | assert.equal(hash2, hash); 61 | assert.include(claims[0], '"sub":"twitter|33783412"', "Does not contain sub claim"); 62 | assert.include(claims[1], '"nonce":"44017a89"', "Does not contain nonce claim"); 63 | 64 | assert.isTrue(crypto.createVerify('RSA-SHA256').update(input).verify(pubkey, Buffer.from(signature, 'base64')), "Signature does not correspond to hash"); 65 | }); 66 | }); 67 | -------------------------------------------------------------------------------- /circuits/jwt_proof.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "sha256.circom"; 4 | include "bitify.circom"; 5 | 6 | /* 7 | JWT Proof 8 | Takes a payload segmented into inWidth chunks and calculates a SHA256 hash, for which an RSA signature is known, 9 | as well as masking the payload to obscure private fields. 10 | 11 | Construction Parameters: 12 | - inCount: Number of payload inputs of inWidth size 13 | - inWidth: Bit width of payload inputs 14 | - outWidth: Bit width of masked payload outputs 15 | - hashWidth: Bit width of truncated hash output 16 | 17 | Inputs: 18 | - payload[inCount]: Segments of payload as inWidth bit chunks 19 | - mask[inCount]: Binary mask of payload segments 20 | - tBlock: At which 512-bit block to select output hash 21 | 22 | Outputs: 23 | - hash: SHA256 hash output truncated to hashWidth bits 24 | - out[outCount]: Masked payload 25 | */ 26 | template JwtProof(inCount, inWidth, outWidth, hashWidth) { 27 | // Segments must divide evenly into 512 bit blocks 28 | assert((inCount * inWidth) % 512 == 0); 29 | assert(inWidth <= 512); 30 | assert(512 % inWidth == 0); 31 | 32 | var inBits = inCount * inWidth; 33 | var outExtra = inBits % outWidth; 34 | var outCount = (inBits - outExtra) / outWidth; 35 | if(outExtra > 0) { 36 | outCount += 1; 37 | } 38 | 39 | assert(inWidth <= outWidth); 40 | assert(outCount * outWidth >= inCount * inWidth); 41 | 42 | // The number of payload segments, times the bit width of each is the bit length of the payload. 43 | // The payload is decomposed to 512-bit blocks for SHA-256 44 | var nBlocks = (inCount * inWidth) / 512; 45 | 46 | // How many segments are in each block 47 | var nSegments = 512 / inWidth; 48 | 49 | signal input payload[inCount]; 50 | signal input mask[inCount]; 51 | signal input tBlock; 52 | 53 | signal output hash; 54 | signal output out[outCount]; 55 | 56 | component sha256 = Sha256_unsafe(nBlocks); 57 | component sha256_blocks[nBlocks][nSegments]; 58 | 59 | // For each 512-bit block going into SHA-256 60 | for(var b = 0; b < nBlocks; b++) { 61 | // For each segment going into that block 62 | for(var s = 0; s < nSegments; s++) { 63 | // The index from the payload is offset by the block we're composing times the number of segments per block, 64 | // s is then the segment offset within the block. 65 | var payloadIndex = (b * nSegments) + s; 66 | 67 | // Decompose each segment into an array of individual bits 68 | sha256_blocks[b][s] = Num2BitsLE(inWidth); 69 | sha256_blocks[b][s].in <== payload[payloadIndex]; 70 | 71 | // The bit index going into the current SHA-256 block is offset by the segment number times the bit width 72 | // of each payload segment. sOffset + i is then the bit offset within the block (0-511). 73 | var sOffset = s * inWidth; 74 | for(var i = 0; i < inWidth; i++) { 75 | sha256.in[b][sOffset + i] <== sha256_blocks[b][s].out[i]; 76 | } 77 | } 78 | } 79 | sha256.tBlock <== tBlock; 80 | 81 | component hash_packer = Bits2NumLE(hashWidth); 82 | for(var i = 0; i < hashWidth; i++) { 83 | // Endianness 84 | hash_packer.in[i] <== sha256.out[i]; 85 | } 86 | hash <== hash_packer.out; 87 | 88 | component masked[inCount]; 89 | for(var i = 0; i < inCount; i++) { 90 | masked[i] = Num2BitsLE(inWidth); 91 | masked[i].in <== payload[i] * mask[i]; 92 | } 93 | 94 | component out_packer[outCount]; 95 | for(var i = 0; i < outCount; i++) { 96 | out_packer[i] = Bits2NumLE(outWidth); 97 | } 98 | 99 | for(var i = 0; i < inBits; i++) { 100 | var oB = i % outWidth; 101 | var o = (i - oB) / outWidth; 102 | var m = (i - (i % inWidth)) / inWidth; 103 | var mB = i % inWidth; 104 | 105 | out_packer[o].in[oB] <== masked[m].out[mB]; 106 | } 107 | 108 | for(var i = outExtra; i < outWidth; i++) { 109 | out_packer[outCount - 1].in[i] <== 0; 110 | } 111 | 112 | for(var i = 0; i < outCount; i++) { 113 | out[i] <== out_packer[i].out; 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /circuits/sha256.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "../circomlib/circuits/sha256/constants.circom"; 4 | include "../circomlib/circuits/sha256/sha256compression.circom"; 5 | include "../circomlib/circuits/comparators.circom"; 6 | include "calculate_total.circom"; 7 | 8 | /* 9 | SHA256 Unsafe 10 | Calculates the SHA256 hash of the input, using a signal to select the output round corresponding to the number of 11 | non-empty input blocks. This implementation is referred to as "unsafe", as it relies upon the caller to ensure that 12 | the input is padded correctly, and to ensure that the tBlock input corresponds to the actual terminating data block. 13 | Crafted inputs could result in Length Extension Attacks. 14 | 15 | Construction Parameters: 16 | - nBlocks: Maximum number of 512-bit blocks for payload input 17 | Inputs: 18 | - in: An array of blocks exactly nBlocks in length, each block containing an array of exactly 512 bits. 19 | Padding of the input according to RFC4634 Section 4.1 is left to the caller. 20 | Blocks following tBlock must be supplied, and *should* contain all zeroes 21 | - tBlock: An integer corresponding to the terminating block of the input, which contains the message padding 22 | Outputs: 23 | - out: An array of 256 bits corresponding to the SHA256 output as of the terminating block 24 | */ 25 | template Sha256_unsafe(nBlocks) { 26 | signal input in[nBlocks][512]; 27 | signal input tBlock; 28 | 29 | signal output out[256]; 30 | 31 | component ha0 = H(0); 32 | component hb0 = H(1); 33 | component hc0 = H(2); 34 | component hd0 = H(3); 35 | component he0 = H(4); 36 | component hf0 = H(5); 37 | component hg0 = H(6); 38 | component hh0 = H(7); 39 | 40 | component sha256compression[nBlocks]; 41 | 42 | for(var i=0; i < nBlocks; i++) { 43 | 44 | sha256compression[i] = Sha256compression(); 45 | 46 | if (i==0) { 47 | for(var k = 0; k < 32; k++) { 48 | sha256compression[i].hin[0*32+k] <== ha0.out[k]; 49 | sha256compression[i].hin[1*32+k] <== hb0.out[k]; 50 | sha256compression[i].hin[2*32+k] <== hc0.out[k]; 51 | sha256compression[i].hin[3*32+k] <== hd0.out[k]; 52 | sha256compression[i].hin[4*32+k] <== he0.out[k]; 53 | sha256compression[i].hin[5*32+k] <== hf0.out[k]; 54 | sha256compression[i].hin[6*32+k] <== hg0.out[k]; 55 | sha256compression[i].hin[7*32+k] <== hh0.out[k]; 56 | } 57 | } else { 58 | for(var k = 0; k < 32; k++) { 59 | sha256compression[i].hin[32*0+k] <== sha256compression[i-1].out[32*0+31-k]; 60 | sha256compression[i].hin[32*1+k] <== sha256compression[i-1].out[32*1+31-k]; 61 | sha256compression[i].hin[32*2+k] <== sha256compression[i-1].out[32*2+31-k]; 62 | sha256compression[i].hin[32*3+k] <== sha256compression[i-1].out[32*3+31-k]; 63 | sha256compression[i].hin[32*4+k] <== sha256compression[i-1].out[32*4+31-k]; 64 | sha256compression[i].hin[32*5+k] <== sha256compression[i-1].out[32*5+31-k]; 65 | sha256compression[i].hin[32*6+k] <== sha256compression[i-1].out[32*6+31-k]; 66 | sha256compression[i].hin[32*7+k] <== sha256compression[i-1].out[32*7+31-k]; 67 | } 68 | } 69 | 70 | for (var k = 0; k < 512; k++) { 71 | sha256compression[i].inp[k] <== in[i][k]; 72 | } 73 | } 74 | 75 | // Collapse the hashing result at the terminating data block 76 | // A modified Quin Selector allows us to select the block based on the tBlock signal 77 | component calcTotal[256]; 78 | component eqs[256][nBlocks]; 79 | 80 | // For each bit of the output 81 | for(var k = 0; k < 256; k++) { 82 | calcTotal[k] = CalculateTotal(nBlocks); 83 | 84 | // For each possible block 85 | for (var i = 0; i < nBlocks; i++) { 86 | // Determine if the given block index is equal to the terminating data block index 87 | eqs[k][i] = IsEqual(); 88 | eqs[k][i].in[0] <== i; 89 | eqs[k][i].in[1] <== tBlock - 1; 90 | 91 | // eqs[k][i].out is 1 if the index matches. As such, at most one input to calcTotal is not 0. 92 | // The bit corresponding to the terminating data block will be raised 93 | calcTotal[k].nums[i] <== eqs[k][i].out * sha256compression[i].out[k]; 94 | } 95 | 96 | out[k] <== calcTotal[k].sum; 97 | } 98 | 99 | } 100 | --------------------------------------------------------------------------------