├── .gitignore ├── README.md ├── CMakeLists.txt └── cff.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | !build/.gitkeep 3 | .vscode 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # obfus 2 | 3 | This is a proof of concept control flow flattener as an LLVM pass 4 | 5 | ## Building 6 | 7 | ``` 8 | cd make 9 | cmake .. 10 | cmake --build . 11 | ``` 12 | 13 | ## Usage 14 | 15 | ``` 16 | clang main.c -fpass-plugin=../obfus/build/libcff.so -o main 17 | ``` 18 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12) 2 | 3 | project(HelloWorld) 4 | 5 | # Set the C++ standard to C++14 6 | set(CMAKE_CXX_STANDARD 14) 7 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 8 | 9 | # Find LLVM package 10 | find_package(LLVM 14 REQUIRED CONFIG) 11 | 12 | # Add include directories for LLVM headers 13 | include_directories(${LLVM_INCLUDE_DIRS}) 14 | add_definitions(${LLVM_DEFINITIONS}) 15 | 16 | # Define source files for cff library 17 | set(SOURCE_FILES_cff cff.cpp) 18 | 19 | # Build the cff shared library 20 | add_library(cff SHARED ${SOURCE_FILES_cff}) 21 | 22 | # Link cff against LLVM libraries 23 | target_link_libraries(cff PRIVATE LLVM) 24 | 25 | # Set the RPATH for the cff library to find LLVM libraries 26 | set_target_properties(cff PROPERTIES 27 | INSTALL_RPATH "${LLVM_LIBRARY_DIR}" 28 | BUILD_WITH_INSTALL_RPATH TRUE 29 | ) 30 | 31 | # Copy LLVM shared libraries to the output directory for cff 32 | add_custom_command(TARGET cff POST_BUILD 33 | COMMAND ${CMAKE_COMMAND} -E copy_directory 34 | ${LLVM_LIBRARY_DIR} $ 35 | ) 36 | -------------------------------------------------------------------------------- /cff.cpp: -------------------------------------------------------------------------------- 1 | #include "llvm/IR/PassManager.h" 2 | #include "llvm/Passes/PassBuilder.h" 3 | #include "llvm/Passes/PassPlugin.h" 4 | #include "llvm/Support/raw_ostream.h" 5 | #include "llvm/IR/Function.h" 6 | #include "llvm/IR/Instructions.h" 7 | #include "llvm/IR/IRBuilder.h" 8 | #include 9 | #include 10 | #include 11 | 12 | using namespace llvm; 13 | 14 | namespace obfs { 15 | // https://stackoverflow.com/questions/26281823/llvm-how-to-get-the-label-of-basic-blocks 16 | static std::string getSimpleNodeLabel(const BasicBlock *Node) { 17 | if (!Node->getName().empty()) 18 | return Node->getName().str(); 19 | 20 | std::string Str; 21 | raw_string_ostream OS(Str); 22 | 23 | Node->printAsOperand(OS, false); 24 | return OS.str(); 25 | } 26 | 27 | static void demotePhiNodes(Function& F) { 28 | std::vector phiNodes; 29 | do { 30 | phiNodes.clear(); 31 | 32 | for (auto& BB : F) { 33 | for (auto& I : BB.phis()) { 34 | phiNodes.push_back(&I); 35 | } 36 | } 37 | 38 | for (PHINode* phi : phiNodes) { 39 | DemotePHIToStack(phi, F.begin()->getTerminator()); 40 | } 41 | } while (!phiNodes.empty()); 42 | } 43 | 44 | void printFunction(Function& F) { 45 | for (BasicBlock& BB : F) { 46 | outs() << "New basic block " << getSimpleNodeLabel(&BB) << "\n"; 47 | for(Instruction& instruction : BB) { 48 | outs() << " Instruction: " << instruction << "\n"; 49 | } 50 | } 51 | } 52 | 53 | SmallVector getBlocksToFlatten(Function& F) { 54 | SmallVector flattenedBB; 55 | 56 | for (BasicBlock& BB : F) { 57 | if (&BB == &(F.getEntryBlock())) { 58 | outs() << "Not flattening entry block " << getSimpleNodeLabel(&BB) << "\n"; 59 | continue; 60 | } 61 | outs() << "Adding block to flatten: " << getSimpleNodeLabel(&BB) << "\n"; 62 | flattenedBB.push_back(&BB); 63 | } 64 | 65 | return flattenedBB; 66 | } 67 | 68 | AllocaInst* initDispatchVar(Function& F, int initialValue) { 69 | BasicBlock &EntryBlock = F.getEntryBlock(); 70 | IRBuilder<> EntryBuilder(&EntryBlock, EntryBlock.begin()); 71 | AllocaInst *DispatchVar = EntryBuilder.CreateAlloca(EntryBuilder.getInt32Ty(), nullptr, "dispatch_var"); 72 | EntryBuilder.CreateStore(ConstantInt::get(EntryBuilder.getInt32Ty(), initialValue), DispatchVar); 73 | 74 | return DispatchVar; 75 | } 76 | 77 | BasicBlock& splitBranchOffEntryBlock(Function& F) { 78 | BasicBlock &entryBlockTail = F.getEntryBlock(); 79 | BasicBlock* pNewEntryBlock = entryBlockTail.splitBasicBlockBefore(entryBlockTail.getTerminator(), ""); 80 | 81 | return entryBlockTail; 82 | } 83 | 84 | BasicBlock* insertDispatchBlockAfterEntryBlock(Function& F, AllocaInst *DispatchVar) { 85 | BasicBlock &EntryBlock = F.getEntryBlock(); 86 | auto* br = dyn_cast(EntryBlock.getTerminator()); 87 | BasicBlock *Successor = br->getSuccessor(0); 88 | 89 | // EntryBlock -> Successor 90 | // we create DispatchBlock and plug it in at both ends 91 | 92 | // DispatchBlock -> Successor 93 | BasicBlock* DispatchBlock = BasicBlock::Create(F.getContext(), "dispatch_block", &F); 94 | IRBuilder<> DispatchBuilder(DispatchBlock, DispatchBlock->begin()); 95 | DispatchBuilder.CreateBr(Successor); 96 | 97 | // EntryBlock -> DispatchBlock 98 | br->setSuccessor(0, DispatchBlock); 99 | DispatchBlock->moveAfter(&EntryBlock); 100 | 101 | return DispatchBlock; 102 | } 103 | 104 | void flattenBlock(BasicBlock* block, int& dispatchVal, Function& F, AllocaInst *DispatchVar, BasicBlock* DispatchBlock) { 105 | // only handle branches for now 106 | outs() << "Flattening block " << getSimpleNodeLabel(block) << "\n"; 107 | if (auto* br = dyn_cast(block->getTerminator())) { 108 | for (unsigned i = 0; i < br->getNumSuccessors(); ++i) { 109 | // we start with block -> successor 110 | BasicBlock *Successor = br->getSuccessor(i); 111 | 112 | // create detour block 113 | // DispatchVar = X 114 | // jmp DispatchBlock 115 | BasicBlock *DetourBlock = BasicBlock::Create(F.getContext(), "", &F); 116 | IRBuilder<> Builder(DetourBlock); 117 | Builder.CreateStore(ConstantInt::get(Builder.getInt32Ty(), ++dispatchVal), DispatchVar); 118 | Builder.CreateBr(DispatchBlock); 119 | 120 | // insert block after our current one 121 | // block -> DetourBlock 122 | br->setSuccessor(i, DetourBlock); 123 | DetourBlock->moveAfter(block); 124 | 125 | // Add a new branch in the dispatcher to jump to the new block 126 | // DispatchBlock 127 | // if (DispatchVar == dispatchVal) goto successor; 128 | Instruction* FirstInst = DispatchBlock->getFirstNonPHI(); 129 | IRBuilder<> DispatchBuilder(FirstInst); 130 | LoadInst* loadSwitchVar = DispatchBuilder.CreateLoad(DispatchBuilder.getInt32Ty(), DispatchVar, "dispatch_var"); // some PHI weirdness means we need to load this here 131 | auto *Cond = DispatchBuilder.CreateICmpEQ(ConstantInt::get(DispatchBuilder.getInt32Ty(), dispatchVal), loadSwitchVar); 132 | SplitBlockAndInsertIfThen(Cond, FirstInst, false, nullptr, (DomTreeUpdater *)nullptr, nullptr, Successor); 133 | } 134 | } 135 | } 136 | 137 | bool flattenFunction(Function& F) { 138 | outs() << "Flattening " << F.getName() << "\n"; 139 | outs() << F.getInstructionCount() << " instructions\n"; 140 | 141 | if(F.getInstructionCount() < 1) { 142 | outs() << "Skipping\n"; 143 | return false; 144 | } 145 | 146 | int dispatchVal = 0x1001; 147 | 148 | printFunction(F); 149 | auto flattenedBB = getBlocksToFlatten(F); 150 | demotePhiNodes(F); 151 | AllocaInst *DispatchVar = initDispatchVar(F, dispatchVal); 152 | BasicBlock &entryBlockTail = splitBranchOffEntryBlock(F); 153 | flattenedBB.push_back(&entryBlockTail); 154 | BasicBlock* DispatchBlock = insertDispatchBlockAfterEntryBlock(F, DispatchVar); 155 | 156 | for (BasicBlock* pToflatBB : flattenedBB){ 157 | flattenBlock(pToflatBB, dispatchVal, F, DispatchVar, DispatchBlock); 158 | } 159 | 160 | outs() << "After flattening:\n"; 161 | printFunction(F); 162 | 163 | return true; 164 | } 165 | 166 | struct ControlFlowFlattening : public PassInfoMixin { 167 | PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) { 168 | for (Function& F : M) { 169 | flattenFunction(F); 170 | } 171 | return PreservedAnalyses::none(); 172 | } 173 | }; 174 | } 175 | 176 | PassPluginLibraryInfo getPassPluginInfo() { 177 | static std::atomic ONCE_FLAG(false); 178 | return {LLVM_PLUGIN_API_VERSION, "obfs", "0.0.1", 179 | [](PassBuilder &PB) { 180 | 181 | try { 182 | PB.registerPipelineEarlySimplificationEPCallback( 183 | [&] (ModulePassManager &MPM, OptimizationLevel opt) { 184 | if (ONCE_FLAG) { 185 | return true; 186 | } 187 | MPM.addPass(obfs::ControlFlowFlattening()); 188 | ONCE_FLAG = true; 189 | return true; 190 | } 191 | ); 192 | } catch (const std::exception& e) { 193 | outs() << "Error: " << e.what() << "\n"; 194 | } 195 | }}; 196 | }; 197 | 198 | extern "C" __attribute__((visibility("default"))) LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo 199 | llvmGetPassPluginInfo() { 200 | return getPassPluginInfo(); 201 | } 202 | --------------------------------------------------------------------------------