├── .gitignore ├── CMakeLists.txt ├── README.md ├── pass ├── CMakeLists.txt └── CPI.cpp ├── rt.c └── tests ├── .gitignore ├── run.sh ├── sqlite3 ├── README.md └── run.sh └── test.c /.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | /cmake-build-debug 3 | /.idea 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.4.3) 2 | project(CPI) 3 | 4 | find_package(LLVM REQUIRED CONFIG) 5 | list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") 6 | include(AddLLVM) 7 | 8 | set(CMAKE_CXX_STANDARD 14) 9 | add_definitions(${LLVM_DEFINITIONS}) 10 | include_directories(${LLVM_INCLUDE_DIRS}) 11 | 12 | if (APPLE) 13 | list (APPEND CMAKE_CXX_FLAGS "-undefined dynamic_lookup") 14 | endif(APPLE) 15 | 16 | add_subdirectory(pass) 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CPI 2 | 3 | After cloning the source to a folder `CPI`: 4 | 5 | ```sh 6 | # Set environment variable for test script 7 | export LLVM_BIN= 8 | cd CPI 9 | 10 | # Always build in separate directory!! 11 | mkdir build 12 | cd build 13 | 14 | # Generate makefiles 15 | # You might need to set LLVM_DIR to LLVM sources, or add 16 | # LLVM_BIN to PATH for cmake to work properly 17 | cmake ../ 18 | 19 | # Run test scripts 20 | cd ../test 21 | ./run.sh -b test.c 22 | ``` 23 | 24 | - `test.llvm.ll` is unpatched assembly 25 | - `test.llvm.p.ll` is patched assembly 26 | - `test.llvm.out` is unpatched program 27 | - `test.llvm.p.out` is patched program 28 | -------------------------------------------------------------------------------- /pass/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_llvm_loadable_module( 2 | LLVMCPI 3 | CPI.cpp 4 | ) -------------------------------------------------------------------------------- /pass/CPI.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "llvm/ADT/Statistic.h" 5 | #include "llvm/Pass.h" 6 | #include "llvm/Support/Debug.h" 7 | #include "llvm/Support/raw_ostream.h" 8 | #include "llvm/IR/Type.h" 9 | #include "llvm/IR/Module.h" 10 | #include "llvm/IR/Instructions.h" 11 | #include "llvm/IR/IRBuilder.h" 12 | #include "llvm/IR/GlobalVariable.h" 13 | #include "llvm/IR/InstrTypes.h" 14 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" 15 | 16 | #define DEBUG_TYPE "cpi" 17 | 18 | STATISTIC(NumSMAlloca, "Number of __sm_alloca"); 19 | STATISTIC(NumSMMalloc, "Number of __sm_malloc"); 20 | STATISTIC(NumRMStore, "Number of replaced stack stores"); 21 | STATISTIC(NumRMLoad, "Number of replaced stack loads"); 22 | STATISTIC(NumRMPStore, "Number of replaced stores"); 23 | STATISTIC(NumRMPLoad, "Number of replaced loads"); 24 | STATISTIC(NumCommit, "Number of commits"); 25 | STATISTIC(NumRestore, "Number of restores"); 26 | 27 | using namespace llvm; 28 | 29 | struct CPI : public ModulePass { 30 | static char ID; 31 | CPI() : ModulePass(ID) {} 32 | 33 | bool runOnModule(Module &M) override { 34 | LLVMContext &ctx = M.getContext(); 35 | 36 | // Commonly used types 37 | intT = Type::getInt32Ty(ctx); 38 | voidT = Type::getVoidTy(ctx); 39 | voidPT = Type::getInt8PtrTy(ctx); 40 | voidPPT = PointerType::get(voidPT, 0); 41 | 42 | // Create NOP used for instruction insertion 43 | auto zero = ConstantInt::get(intT, 0); 44 | nop = BinaryOperator::Create(Instruction::Add, zero, zero, "NOP", 45 | M.getFunctionList().front().getEntryBlock().getFirstNonPHIOrDbg()); 46 | 47 | // Add global variables rt.c 48 | smSp = new GlobalVariable(M, intT, false, GlobalValue::ExternalLinkage, nullptr, "__sm_sp"); 49 | 50 | // Add function references in rt.c 51 | smAlloca = cast(M.getOrInsertFunction("__sm_alloca", voidPPT)); 52 | smMalloc = cast(M.getOrInsertFunction("__sm_malloc", voidPPT, voidPPT)); 53 | smLoad = cast(M.getOrInsertFunction("__sm_load", voidPT, voidPPT, voidPPT)); 54 | 55 | // Find all sensitive structs 56 | for (auto s : M.getIdentifiedStructTypes()) { 57 | for (unsigned i = 0; i < s->getNumElements(); ++i) { 58 | if (isFunctionPtr(s->getElementType(i))) { 59 | ssMap[s].push_back(i); 60 | } 61 | } 62 | } 63 | 64 | // Loop through all functions 65 | for (auto &F: M.getFunctionList()) { 66 | // Only care locally implemented functions 67 | if (!F.isDeclaration()) 68 | runOnFunction(F); 69 | } 70 | 71 | // Remove NOP 72 | nop->eraseFromParent(); 73 | 74 | return true; 75 | } 76 | 77 | private: 78 | Function *smAlloca; 79 | Function *smMalloc; 80 | Function *smLoad; 81 | Value *smSp; 82 | 83 | IntegerType *intT; 84 | Type *voidT; 85 | PointerType *voidPT; 86 | PointerType *voidPPT; 87 | 88 | Instruction *nop; 89 | 90 | // A map of StructType to the list of entries numbers that are function pointers 91 | std::map > ssMap; 92 | 93 | void runOnFunction(Function &F) { 94 | bool hasInject = false; 95 | 96 | BasicBlock &entryBlock = F.getEntryBlock(); 97 | 98 | // Alloca only happens in the first basic block 99 | hasInject |= swapFunctionPtrAlloca(entryBlock); 100 | hasInject |= handleStructAlloca(entryBlock); 101 | 102 | // Handle struct ptrs from unknown sources 103 | hasInject |= handleStructPtrs(F); 104 | 105 | if (hasInject) { 106 | // Create checkpoint 107 | auto spLoad = new LoadInst(smSp, "smStackCheckpoint", entryBlock.getFirstNonPHI()); 108 | for (auto &bb : F) { 109 | auto ti = bb.getTerminator(); 110 | if (isa(ti)) { 111 | // Restore checkpoint 112 | new StoreInst(spLoad, smSp, ti); 113 | } 114 | } 115 | } 116 | } 117 | 118 | bool swapFunctionPtrAlloca(BasicBlock &bb) { 119 | auto v = getFunctionPtrAlloca(bb); 120 | for (auto alloc : v) { 121 | IRBuilder<> b(alloc); 122 | std::string name(alloc->getName()); 123 | auto addr = b.CreateCall(smAlloca); 124 | DEBUG(dbgs() << "ADD:" << *addr << "\n"); 125 | NumSMAlloca++; 126 | swapAllocaPtr(alloc, addr); 127 | addr->setName(name); 128 | } 129 | return !v.empty(); 130 | } 131 | 132 | bool handleStructAlloca(BasicBlock &bb) { 133 | bool hasInject = false; 134 | for (auto alloc: getSSAlloca(bb)) { 135 | nop->moveAfter(alloc); 136 | hasInject |= replaceSSAllocaFPEntries(alloc); 137 | } 138 | return hasInject; 139 | } 140 | 141 | bool handleStructPtrs(Function &F) { 142 | bool hasInject = false; 143 | nop->moveBefore(F.getEntryBlock().getFirstNonPHI()); 144 | for (auto &arg : F.args()) { 145 | if (isSSPtr(arg.getType())) { 146 | hasInject |= replaceUnknownSrcSSFPEntries(&arg); 147 | } 148 | } 149 | for (auto &bb : F) { 150 | for (auto &I : bb) { 151 | if (!isa(I) && isSSPtr(I.getType())) { 152 | nop->moveAfter(&I); 153 | hasInject |= replaceUnknownSrcSSFPEntries(&I); 154 | } 155 | } 156 | } 157 | return hasInject; 158 | } 159 | 160 | bool replaceSSAllocaFPEntries(Value *ssp) { 161 | std::map, std::vector > rmMap; 162 | for (int sentry : ssMap[cast(ssp->getType())->getElementType()]) { 163 | for (auto user : ssp->users()) { 164 | int idx; 165 | auto *gep = dyn_cast(user); 166 | if ((idx = isSensitiveGEP(gep, sentry)) >= 0) { 167 | rmMap[{idx, sentry}].push_back(gep); 168 | } 169 | } 170 | } 171 | if (rmMap.empty()) 172 | return false; 173 | 174 | for (const auto &geps : rmMap) { 175 | IRBuilder<> b(nop); 176 | auto addr = b.CreateCall(smAlloca, None, 177 | ssp->getName() + "." + std::to_string(geps.first.first) + "." + std::to_string(geps.first.second)); 178 | DEBUG(dbgs() << "ADD:" << *addr << "\n"); 179 | NumSMAlloca++; 180 | auto tmp = b.CreateGEP(ssp, {ConstantInt::get(intT, geps.first.first), ConstantInt::get(intT, geps.first.second)}); 181 | auto orig = b.CreatePointerCast(tmp, voidPPT, addr->getName() + ".orig"); 182 | 183 | for (auto u: geps.second) { 184 | swapAllocaPtr(u, addr); 185 | } 186 | 187 | // Check for external calls 188 | for (auto user : ssp->users()) { 189 | CallInst *ci; 190 | if ((ci = dyn_cast(user))) { 191 | commitAndRestore(addr, orig, ci); 192 | } 193 | } 194 | 195 | // Remove if not needed 196 | if (orig->getNumUses() == 0) { 197 | cast(orig)->eraseFromParent(); 198 | cast(tmp)->eraseFromParent(); 199 | } 200 | } 201 | 202 | return true; 203 | } 204 | 205 | bool replaceUnknownSrcSSFPEntries(Value *ssp) { 206 | std::map, std::vector > rmMap; 207 | for (int sentry : ssMap[cast(ssp->getType())->getElementType()]) { 208 | for (auto user : ssp->users()) { 209 | int idx; 210 | auto *gep = dyn_cast(user); 211 | if ((idx = isSensitiveGEP(gep, sentry)) >= 0) { 212 | rmMap[{idx, sentry}].push_back(gep); 213 | } 214 | } 215 | } 216 | if (rmMap.empty()) 217 | return false; 218 | 219 | for (const auto &geps : rmMap) { 220 | std::string name(ssp->getName()); 221 | name += "." + std::to_string(geps.first.first) + "." + std::to_string(geps.first.second); 222 | 223 | IRBuilder<> b(nop); 224 | auto tmp = b.CreateGEP(ssp, 225 | {ConstantInt::get(intT, geps.first.first), ConstantInt::get(intT, geps.first.second)}); 226 | auto orig = b.CreatePointerCast(tmp, voidPPT, name + ".orig"); 227 | auto addr = b.CreateCall(smMalloc, orig, name); 228 | DEBUG(dbgs() << "ADD:" << *addr << "\n"); 229 | NumSMMalloc++; 230 | 231 | for (auto u: geps.second) { 232 | swapUnknownSrcPtr(u, addr, orig); 233 | } 234 | 235 | // Check for external calls 236 | for (auto user : ssp->users()) { 237 | CallInst *ci; 238 | if ((ci = dyn_cast(user))) { 239 | restore(addr, orig, ci->getNextNode()); 240 | } 241 | } 242 | } 243 | 244 | return true; 245 | } 246 | 247 | void swapAllocaPtr(Instruction *from, Instruction *to) { 248 | // Swap out all uses (store and load) 249 | for (auto a : from->users()) { 250 | StoreInst *s; 251 | LoadInst *l; 252 | if ((s = dyn_cast(a))) { 253 | IRBuilder<> b(s); 254 | auto cast = b.CreatePointerCast(s->getValueOperand(), voidPT); 255 | b.CreateStore(cast, to); 256 | DEBUG(dbgs() << "SWAP:" << *s << "\n"); 257 | NumRMStore++; 258 | s->eraseFromParent(); 259 | } else if ((l = dyn_cast(a))) { 260 | IRBuilder<> b(l); 261 | auto raw = b.CreateLoad(to); 262 | auto cast = b.CreatePointerCast(raw, l->getType()); 263 | DEBUG(dbgs() << "SWAP:" << *l << "\n"); 264 | NumRMLoad++; 265 | BasicBlock::iterator ii(l); 266 | ReplaceInstWithValue(l->getParent()->getInstList(), ii, cast); 267 | } else { 268 | DEBUG(dbgs() << "OTHER:" << *from << "\n"); 269 | } 270 | } 271 | if (from->getNumUses() == 0) { 272 | DEBUG(dbgs() << "RM:" << *from << "\n"); 273 | from->eraseFromParent(); 274 | } 275 | } 276 | 277 | void swapUnknownSrcPtr(Instruction *from, Instruction *to, Value *orig) { 278 | // Swap out all uses (store and load) 279 | for (auto a : from->users()) { 280 | StoreInst *s; 281 | LoadInst *l; 282 | if ((s = dyn_cast(a))) { 283 | IRBuilder<> b(s); 284 | auto cast = b.CreatePointerCast(s->getValueOperand(), voidPT); 285 | b.CreateStore(cast, to); 286 | b.CreateStore(cast, orig); 287 | DEBUG(dbgs() << "SWAP:" << *s << "\n"); 288 | NumRMPStore++; 289 | s->eraseFromParent(); 290 | } else if ((l = dyn_cast(a))) { 291 | IRBuilder<> b(l); 292 | auto raw = b.CreateCall(smLoad, {to, orig}); 293 | auto cast = b.CreatePointerCast(raw, l->getType()); 294 | DEBUG(dbgs() << "SWAP:" << *l << "\n"); 295 | NumRMPLoad++; 296 | BasicBlock::iterator ii(l); 297 | ReplaceInstWithValue(l->getParent()->getInstList(), ii, cast); 298 | } else { 299 | DEBUG(dbgs() << "OTHER:" << *from << "\n"); 300 | } 301 | } 302 | if (from->getNumUses() == 0) { 303 | DEBUG(dbgs() << "RM:" << *from << "\n"); 304 | from->eraseFromParent(); 305 | } 306 | } 307 | 308 | // Commit sm memory to actual memory 309 | void commit(Value *a, Value *b, Instruction *i) { 310 | IRBuilder<> builder(i); 311 | auto v = builder.CreateLoad(a); 312 | builder.CreateStore(v, b); 313 | NumCommit++; 314 | } 315 | 316 | void restore(Value *a, Value *b, Instruction *i) { 317 | IRBuilder<> builder(i); 318 | auto v = builder.CreateLoad(b); 319 | builder.CreateStore(v, a); 320 | NumRestore++; 321 | } 322 | 323 | void commitAndRestore(Value *a, Value *b, Instruction *i) { 324 | commit(a, b, i); 325 | restore(a, b, i->getNextNode()); 326 | } 327 | 328 | std::vector getSensitiveAlloca(BasicBlock &bb, const std::function &filter) { 329 | std::vector v; 330 | AllocaInst *ai; 331 | for (auto &I : bb) { 332 | if ((ai = dyn_cast(&I)) && filter(ai)) { 333 | DEBUG(dbgs() << "SENS:" << I << "\n"); 334 | v.push_back(ai); 335 | } 336 | } 337 | return v; 338 | } 339 | 340 | std::vector getFunctionPtrAlloca(BasicBlock &bb) { 341 | return getSensitiveAlloca(bb, [this](auto i) -> bool { 342 | if (!isFunctionPtr(i->getAllocatedType())) 343 | return false; 344 | for (auto user : i->users()) { 345 | // If the pointer is passed to a function call, skip it 346 | if (isa(user)) 347 | return false; 348 | } 349 | return true; 350 | }); 351 | } 352 | 353 | std::vector getSSAlloca(BasicBlock &bb) { 354 | return getSensitiveAlloca(bb, [this](auto i) -> bool { 355 | return ssMap.count(i->getAllocatedType()); 356 | }); 357 | } 358 | 359 | /* Check struct entry to function pointer (2nd GEP index, or 3rd operand) */ 360 | int isSensitiveGEP(GetElementPtrInst *gep, int fpentry) { 361 | ConstantInt *ci; 362 | if(gep && gep->getNumOperands() >= 3 && 363 | (ci = dyn_cast(gep->getOperand(2))) && 364 | ci->getSExtValue() == fpentry && 365 | (ci = dyn_cast(gep->getOperand(1)))) { 366 | return ci->getSExtValue(); 367 | } 368 | return -1; 369 | } 370 | 371 | bool isFunctionPtr(Type *T) { 372 | PointerType *t; 373 | return (t = dyn_cast(T)) && t->getElementType()->isFunctionTy(); 374 | } 375 | 376 | bool isSSPtr(Type *T) { 377 | PointerType *t; 378 | if (!(t = dyn_cast(T))) 379 | return false; 380 | for (const auto &p : ssMap) { 381 | if (t->getElementType() == p.first) 382 | return true; 383 | } 384 | return false; 385 | } 386 | }; 387 | 388 | char CPI::ID = 0; 389 | RegisterPass X("cpi", "Code Pointer Integrity"); 390 | -------------------------------------------------------------------------------- /rt.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define BLOCK_SZ 16 7 | #define PTR_MEM_SZ(i) ((i) * sizeof(void *)) 8 | 9 | int __sm_sp = 0; 10 | static void ***block_table; 11 | static int table_sz = 0; 12 | 13 | __attribute__((always_inline)) 14 | static inline void **sm_alloca() { 15 | int block_num = __sm_sp >> 4; 16 | if (block_num == table_sz) { 17 | int new_sz = table_sz ? table_sz * 2 : 1; 18 | block_table = realloc(block_table, PTR_MEM_SZ(table_sz)); 19 | // Zero out the new space 20 | memset(&block_table[table_sz], 0, PTR_MEM_SZ(new_sz - table_sz)); 21 | table_sz = new_sz; 22 | } 23 | if (!block_table[block_num]) 24 | block_table[block_num] = malloc(PTR_MEM_SZ(BLOCK_SZ)); 25 | fprintf(stderr, "__sm_alloca %d, %d\n", __sm_sp, table_sz << 4); 26 | return &block_table[block_num][__sm_sp++ & 0xF]; 27 | } 28 | 29 | void **__sm_alloca() { 30 | return sm_alloca(); 31 | } 32 | 33 | void **__sm_malloc(void **ua) { 34 | void **sa = sm_alloca(); 35 | *sa = *ua; 36 | return sa; 37 | } 38 | 39 | __attribute__((always_inline)) 40 | void *__sm_load(void **sa, void **ua) { 41 | assert(*sa == *ua); 42 | return *sa; 43 | } 44 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | *.ll 2 | *.bc 3 | *.out 4 | -------------------------------------------------------------------------------- /tests/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Usage: ./run.sh [-b] 4 | # With -b, it will rebuild libsafe_rt and LLVMCPI 5 | 6 | # Note: llvm built binaries can be in LLVM_BIN env variable 7 | 8 | [[ -z $LLVM_BIN ]] || export PATH="$LLVM_BIN:$PATH" 9 | 10 | if [[ "$1" = "-b" ]]; then 11 | # Rebuild libraries 12 | cwd=`pwd` 13 | cd ../build 14 | make -j4 15 | cd $cwd 16 | shift 17 | fi 18 | 19 | case `uname -s` in 20 | Linux) 21 | dll=so 22 | CPI_FLAG="-cpi" 23 | ;; 24 | Darwin) 25 | dll=dylib 26 | CPI_FLAG="-cpi -debug-only=cpi" 27 | ;; 28 | *) 29 | exit 1 30 | ;; 31 | esac 32 | 33 | src=$1 34 | name=${src%.*}.llvm 35 | 36 | OPT=-O0 37 | 38 | # Compile test program 39 | # clang -S -emit-llvm -c $src -o ${name}.ll 40 | 41 | # Compile test program with alloca-hoisting and mem2reg 42 | clang -emit-llvm -O1 -mllvm -disable-llvm-optzns -c $src -o - | opt -S -alloca-hoisting -mem2reg -o ${name}.ll || exit 1 43 | clang $OPT ${name}.ll -o ${name}.out 44 | 45 | # Run CPI pass 46 | opt -S -o ${name}.p.ll -load ../build/pass/LLVMCPI.${dll} -stats $CPI_FLAG ${name}.ll || exit 1 47 | 48 | # Build patched code 49 | clang $OPT ${name}.p.ll ../rt.c -o ${name}.p.out 50 | exit 51 | 52 | # Generate combined bitcode 53 | clang -S -emit-llvm -c ../rt.c -o rt.llvm.ll 54 | llvm-link ${name}.p.ll rt.llvm.ll | opt -S $OPT -o ${name}.o.ll 55 | -------------------------------------------------------------------------------- /tests/sqlite3/README.md: -------------------------------------------------------------------------------- 1 | # SQLite3 Benchmarks 2 | 3 | Download 2 zips from the [official website](https://www.sqlite.org/download.html) 4 | 5 | - sqlite-amalgamation-[ver].zip 6 | - sqlite-src-[ver].zip 7 | 8 | From the amalgamation zip, copy `sqlite3.c` and `sqlite3.h` to the current folder. 9 | 10 | From the src zip, copy `speedtest1.c` from the `test` folder. 11 | 12 | Run `run.sh` to compile. 13 | -------------------------------------------------------------------------------- /tests/sqlite3/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | [[ -z $LLVM_BIN ]] || export PATH="$LLVM_BIN:$PATH" 4 | 5 | if [[ "$1" = "-b" ]]; then 6 | # Rebuild libraries 7 | cwd=`pwd` 8 | cd ../../build 9 | make -j4 10 | cd $cwd 11 | shift 12 | fi 13 | 14 | case `uname -s` in 15 | Linux) 16 | dll=so 17 | CPI_FLAG="-cpi" 18 | ;; 19 | Darwin) 20 | dll=dylib 21 | CPI_FLAG="-cpi -debug-only=cpi" 22 | ;; 23 | *) 24 | exit 1 25 | ;; 26 | esac 27 | 28 | clang -emit-llvm -O1 -mllvm -disable-llvm-optzns -c sqlite3.c speedtest1.c 29 | 30 | # Original 31 | clang sqlite3.bc speedtest1.bc -o sqlite3.speed.out 32 | 33 | # Pass 34 | llvm-link sqlite3.bc speedtest1.bc | opt -alloca-hoisting -mem2reg | opt -o sqlite3.p.bc -load ../../build/pass/LLVMCPI.${dll} -stats $CPI_FLAG 35 | clang sqlite3.p.bc ../../rt.c -o sqlite3.speed.p.out 36 | 37 | -------------------------------------------------------------------------------- /tests/test.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | struct foo { 6 | int i; 7 | void (*func)(); 8 | }; 9 | 10 | struct bar { 11 | void (*f1)(); 12 | int i; 13 | void (*f2)(); 14 | }; 15 | 16 | struct baz { 17 | size_t i; 18 | void (*f)(size_t s); 19 | }; 20 | 21 | static void T() { 22 | printf("T\n"); 23 | } 24 | 25 | static void F() { 26 | printf("F\n"); 27 | } 28 | 29 | static void good(size_t s) { 30 | printf("Proper control flow!\n"); 31 | } 32 | 33 | static void bad(size_t s) { 34 | printf("Hijacked control flow!\n"); 35 | } 36 | 37 | static void test_1(int i); 38 | static void test_2(struct bar *b); 39 | static void test_3(struct foo *f); 40 | 41 | static void vuln(int off, size_t val) { 42 | struct baz b; 43 | size_t buf[4]; 44 | printf("Vuln offset: %ld\n", ((void * )&b.f - (void *) buf) / (ssize_t) sizeof(size_t)); 45 | b.f = good; 46 | buf[off] = val; 47 | b.f(buf[off]); 48 | } 49 | 50 | static void test_1(int i) { 51 | void (*fptr)(); 52 | struct foo f; 53 | struct bar b; 54 | b.i = i; 55 | if (i) { 56 | fptr = T; 57 | f.func = T; 58 | b.f1 = T; 59 | b.f2 = T; 60 | } else { 61 | fptr = F; 62 | f.func = F; 63 | b.f1 = F; 64 | b.f2 = F; 65 | } 66 | printf("* test_2\n"); 67 | test_2(&b); 68 | printf("* test_1\n"); 69 | fptr(); 70 | f.func(); 71 | b.f2(); 72 | } 73 | 74 | static void test_2(struct bar *b) { 75 | b->f1(); 76 | b->f2 = b->i ? F : T; 77 | struct foo *f = malloc(sizeof(*f)); 78 | f->i = b->i; 79 | f->func = b->f1; 80 | printf("* test_3\n"); 81 | test_3(f); 82 | printf("* test_2\n"); 83 | f->func(); 84 | free(f); 85 | } 86 | 87 | static void test_3(struct foo *f) { 88 | f->func(); 89 | f->func = f->i ? F : T; 90 | } 91 | 92 | int main(int argc, char const *argv[]) { 93 | /* Prevent segfault */ 94 | if (argc < 2) 95 | return 1; 96 | 97 | int val = atoi(argv[1]); 98 | 99 | /* Test control flow hijack */ 100 | printf("------- Control Flow Test -------\n"); 101 | vuln(val, (size_t) bad); 102 | 103 | printf("------- Correctness Test -------\n"); 104 | printf("* test_1\n"); 105 | test_1(val); 106 | 107 | return 0; 108 | } 109 | --------------------------------------------------------------------------------