├── TODO.md ├── gengif_tiny.py ├── drive_log.py ├── gengif_compact.py ├── code_diffs.sh ├── start_tgi_servers.sh ├── utmp_reader.py ├── docker └── Dockerfile ├── gengif.py ├── all_gen.sh ├── utmp.py ├── analyze_cov.py ├── getcov.py ├── genvariants_diff.py ├── do_gen.sh ├── gengif_tiny_newconf └── config.yaml ├── genvariants.py ├── genvariants_async.py ├── genvariants_parallel.py ├── driver.py ├── genoutputs.py └── elmconfig.py /TODO.md: -------------------------------------------------------------------------------- 1 | * Make a configuration file for the run, with options: 2 | * Number of seeds per generation 3 | * Number of variants per seed 4 | * Number of outputs per variant 5 | * Selection strategy: elites or best of generation 6 | * Number of generations 7 | * Model(s) to use 8 | * Generation params: temperature, mutation operators, context size 9 | * Region of seed to preserve 10 | * Timeouts 11 | -------------------------------------------------------------------------------- /gengif_tiny.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import struct 4 | from typing import BinaryIO 5 | 6 | # Generates a random GIF file into `out` using the 7 | # random number generator `rng` (/dev/urandom) 8 | def generate_random_gif( 9 | rng: BinaryIO, out: BinaryIO 10 | ): 11 | header = b'GIF89a' 12 | datalen = int.from_bytes( 13 | rng.read(3), byteorder='big') 14 | data = rng.read(datalen) 15 | out.write(header + data) 16 | -------------------------------------------------------------------------------- /drive_log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def setup_custom_logger(name): 4 | formatter = logging.Formatter(fmt='%(asctime)s [%(levelname)s] %(module)s : %(message)s') 5 | 6 | handler = logging.StreamHandler() 7 | handler.setFormatter(formatter) 8 | 9 | logger = logging.getLogger(name) 10 | logger.setLevel(logging.DEBUG) 11 | logger.addHandler(handler) 12 | return logger 13 | 14 | def set_loglevel(logger, args): 15 | if args.quiet: 16 | logger.setLevel(logging.WARNING) 17 | elif args.verbose: 18 | logger.setLevel(logging.DEBUG) 19 | else: 20 | logger.setLevel(logging.INFO) 21 | -------------------------------------------------------------------------------- /gengif_compact.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import struct 4 | from typing import BinaryIO 5 | 6 | def generate_gif(rng: BinaryIO, out: BinaryIO): 7 | header = b'GIF89a' 8 | width, height = struct.unpack('>HH', rng.read(4)) 9 | lsd = struct.pack('>HHBHB', width, height, 0x87, 0x00, 0x00) 10 | color_table = rng.read(3) 11 | gce = b'\x21\xF9\x04\x00\x00\x00\x00\x00' 12 | image_descriptor = struct.pack('>BHHHHB', 0x2C, 0x0000, 0x0000, width, height, 0x00) 13 | lzw_minimum_code_size = 2 14 | data_sub_block = b'\x02\x04\x05' 15 | image_data = bytes([lzw_minimum_code_size]) + data_sub_block 16 | trailer = b'\x3B' 17 | out.write(header + lsd + color_table + gce + image_descriptor + image_data + trailer) 18 | -------------------------------------------------------------------------------- /code_diffs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -lt 3 ]; then 4 | echo "Usage: $0 [ ...]" >&2 5 | exit 1 6 | fi 7 | 8 | output_dir=$1 ; shift 9 | orig=$1 ; shift 10 | 11 | mkdir -p ${output_dir} 12 | 13 | orig_ansi=/tmp/$(basename ${orig}).ansi 14 | pygmentize -P style=vs "${orig}" > ${orig_ansi} 15 | for f in $@; do 16 | echo $f 17 | var_ansi=/tmp/$(basename ${f}).ansi 18 | pygmentize -P style=vs "${f}" > ${var_ansi} 19 | diff -B -a -U 10000 ${orig_ansi} ${var_ansi} | ansi2html -l -s osx-basic > ${output_dir}/$(basename ${f}).diff.html 20 | # Fix background color to white 21 | sed -i 's/background-color: #AAAAAA/background-color: #FFFFFF/g' ${output_dir}/$(basename ${f}).diff.html 22 | rm ${var_ansi} 23 | done 24 | rm ${orig_ansi} 25 | -------------------------------------------------------------------------------- /start_tgi_servers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Share cache directory 4 | volume=/fastdata/hfcache/transformers/ 5 | # HF token 6 | token=$(cat ${HOME}/.config/huggingface/token) 7 | 8 | # StarCoder: 8192, GPUs 0,1 9 | #port=8192 10 | #model=bigcode/starcoder 11 | #docker run -d -e HUGGING_FACE_HUB_TOKEN=$token --gpus '"device=0,1"' --shm-size 1g \ 12 | # -p ${port}:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest \ 13 | # --model-id $model --trust-remote-code --dtype bfloat16 --sharded true --num-shard 2 \ 14 | # --max-total-tokens 8192 --max-input-length 8000 --max-batch-prefill-tokens 8000 15 | 16 | # Code Llama: 8193, GPUs 2,3 17 | port=8192 18 | model=codellama/CodeLlama-13b-hf 19 | docker run -e HUGGING_FACE_HUB_TOKEN=$token --gpus '"device=0,3"' --shm-size 1g \ 20 | -p ${port}:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest \ 21 | --model-id $model --trust-remote-code --dtype bfloat16 --sharded true --num-shard 2 \ 22 | --max-total-tokens 8192 --max-input-length 8000 --max-batch-prefill-tokens 8000 23 | -------------------------------------------------------------------------------- /utmp_reader.py: -------------------------------------------------------------------------------- 1 | 2 | import collections 3 | import datetime 4 | import struct 5 | from enum import Enum 6 | 7 | 8 | class UTmpRecordType(Enum): 9 | empty = 0 10 | run_lvl = 1 11 | boot_time = 2 12 | new_time = 3 13 | old_time = 4 14 | init_process = 5 15 | login_process = 6 16 | user_process = 7 17 | dead_process = 8 18 | accounting = 9 19 | 20 | 21 | def convert_string(val): 22 | if isinstance(val, bytes): 23 | return val.rstrip(b'\0').decode() 24 | return val 25 | 26 | 27 | class UTmpRecord(collections.namedtuple('UTmpRecord', 28 | 'type pid line id user host exit0 exit1 session' + 29 | ' sec usec addr0 addr1 addr2 addr3 unused')): 30 | 31 | @property 32 | def type(self): 33 | return UTmpRecordType(self[0]) 34 | 35 | @property 36 | def time(self): 37 | return datetime.datetime.fromtimestamp(self.sec) + datetime.timedelta(microseconds=self.usec) 38 | 39 | STRUCT = struct.Struct('hi32s4s32s256shhiii4i20s') 40 | 41 | 42 | def read(buf): 43 | offset = 0 44 | while offset < len(buf): 45 | yield UTmpRecord._make(map(convert_string, STRUCT.unpack_from(buf, offset))) 46 | offset += STRUCT.size 47 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use Ubuntu 22.04 as the base image 2 | FROM ubuntu:22.04 3 | 4 | # Avoid warnings by switching to noninteractive 5 | ENV DEBIAN_FRONTEND=noninteractive 6 | 7 | # Update the system and install Python dependencies 8 | RUN apt-get update && apt-get install -y \ 9 | software-properties-common \ 10 | build-essential \ 11 | libssl-dev \ 12 | libffi-dev \ 13 | gpg-agent \ 14 | python3.11-venv \ 15 | python3-pip \ 16 | python3.11-distutils \ 17 | --no-install-recommends && \ 18 | rm -rf /var/lib/apt/lists/* 19 | # add-apt-repository ppa:deadsnakes/ppa && \ 20 | # apt-get update && apt-get install -y python3.11 \ 21 | # --no-install-recommends && \ 22 | 23 | # Create a non-root user 24 | RUN useradd --create-home appuser 25 | USER appuser 26 | WORKDIR /home/appuser 27 | 28 | # Set up a virtual environment for the non-root user 29 | RUN python3.11 -m venv venv 30 | ENV PATH="/home/appuser/venv/bin:$PATH" 31 | 32 | # Upgrade pip and install wheel within the virtual environment 33 | RUN pip install --upgrade pip wheel 34 | 35 | # Make the app directory 36 | RUN mkdir app 37 | 38 | # Your app's setup could go here, for example: 39 | COPY --chown=appuser:appuser ../driver.py ../drive_log.py /home/appuser/app/ 40 | WORKDIR /home/appuser/app 41 | -------------------------------------------------------------------------------- /gengif.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import struct 4 | from typing import BinaryIO 5 | 6 | # Generates a random GIF file into `out` using the 7 | # random number generator `rng` (/dev/urandom) 8 | def generate_random_gif( 9 | rng: BinaryIO, out: BinaryIO 10 | ): 11 | # 1. Header 12 | header = b'GIF89a' 13 | 14 | # 2. Logical Screen Descriptor 15 | # Random width and height (using two bytes for each) 16 | width, height = struct.unpack('>HH', rng.read(4)) 17 | 18 | # Other LSD fields: Packed fields (using a random byte for variety), Background color index, Pixel aspect ratio 19 | lsd = struct.pack('>HHBHB', width, height, rng.read(1)[0], rng.read(1)[0], rng.read(1)[0]) 20 | 21 | # 3. Global Color Table (let's make it have random size, capped at 256 colors, and fully random) 22 | num_colors = (rng.read(1)[0] % 256) or 1 23 | color_table = rng.read(3 * num_colors) # Random colors 24 | 25 | # 4. Graphics Control Extension (optional) 26 | gce = struct.pack('>BBBHB', 0x21, 0xF9, 0x04, rng.read(1)[0], rng.read(1)[0]) 27 | 28 | # 5. Image Descriptor 29 | left, top = struct.unpack('>HH', rng.read(4)) 30 | width, height = struct.unpack('>HH', rng.read(4)) 31 | image_descriptor = struct.pack('>BHHHHB', 0x2C, left, top, width, height, rng.read(1)[0]) 32 | 33 | # 6. Image Data (randomly sized, capped at 255 bytes per sub-block) 34 | lzw_minimum_code_size = (rng.read(1)[0] % 8) or 1 35 | 36 | blocks = [] 37 | total_blocks_size = rng.read(1)[0] # Make it reasonably small 38 | for _ in range(total_blocks_size): 39 | block_size = rng.read(1)[0] # Each block can be up to 255 bytes 40 | blocks.append(bytes([block_size]) + rng.read(block_size)) 41 | image_data = bytes([lzw_minimum_code_size]) + b''.join(blocks) + b'\x00' # Ending with a block size of 0 42 | 43 | # 7. Trailer 44 | trailer = b'\x3B' 45 | 46 | # Write everything to the output file 47 | out.write(header + lsd + color_table + gce + image_descriptor + image_data + trailer) 48 | -------------------------------------------------------------------------------- /all_gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Be strict about failures 4 | set -euo pipefail 5 | 6 | if [ "$#" -ne 1 ]; then 7 | echo "Usage: $0 rundir" 8 | exit 1 9 | fi 10 | 11 | # This needs to be the first thing we do because elmconfig.py uses it 12 | # to find the config file ($ELMFUZZ_RUNDIR/config.yaml) 13 | export ELMFUZZ_RUNDIR="$1" 14 | export ELMFUZZ_RUN_NAME=$(basename "$ELMFUZZ_RUNDIR") 15 | seeds=$(./elmconfig.py get run.seeds) 16 | num_gens=$(./elmconfig.py get run.num_generations) 17 | # Generations are zero-indexed 18 | last_gen=$((num_gens - 1)) 19 | genout_dir=$(./elmconfig.py get run.genoutput_dir -s GEN=. -s MODEL=.) 20 | # normalize the path 21 | genout_dir=$(realpath -m "$genout_dir") 22 | # Check if we should remove the output dirs if they exist 23 | should_clean=$(./elmconfig.py get run.clean) 24 | if [ -d "$genout_dir" ]; then 25 | if [ "$should_clean" == "True" ]; then 26 | echo "Removing generated outputs in $genout_dir" 27 | rm -rf "$genout_dir" 28 | else 29 | echo "Generated output directory $genout_dir already exists; exiting." 30 | echo "Set run.clean to True to remove existing rundirs." 31 | exit 1 32 | fi 33 | fi 34 | # See if we have any gen*, initial, or stamps directories 35 | for pat in "gen*" "initial" "stamps"; do 36 | if compgen -G "$ELMFUZZ_RUNDIR"/$pat > /dev/null; then 37 | if [ "$should_clean" == "True" ]; then 38 | echo "Removing existing rundir(s):" "$ELMFUZZ_RUNDIR"/$pat 39 | rm -rf "$ELMFUZZ_RUNDIR"/$pat 40 | else 41 | echo "Found existing rundir(s):" "$ELMFUZZ_RUNDIR"/$pat 42 | echo "Set run.clean to True to remove existing rundirs." 43 | exit 1 44 | fi 45 | fi 46 | done 47 | 48 | mkdir -p "$ELMFUZZ_RUNDIR"/initial/{variants,seeds,logs} 49 | # Stamp dir tells us when a generation is fully finished 50 | # In the future this will let us resume a run 51 | mkdir -p "$ELMFUZZ_RUNDIR"/stamps 52 | cp -v $seeds "$ELMFUZZ_RUNDIR"/initial/seeds/ 53 | ./do_gen.sh initial gen0 54 | for i in $(seq 0 $last_gen); do 55 | ./do_gen.sh gen$i gen$[i+1] 56 | done 57 | -------------------------------------------------------------------------------- /utmp.py: -------------------------------------------------------------------------------- 1 | # This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild 2 | 3 | import kaitaistruct 4 | from kaitaistruct import KaitaiStruct, KaitaiStream, BytesIO 5 | 6 | 7 | if getattr(kaitaistruct, 'API_VERSION', (0, 9)) < (0, 9): 8 | raise Exception("Incompatible Kaitai Struct Python API: 0.9 or later is required, but you have %s" % (kaitaistruct.__version__)) 9 | 10 | class Utmp(KaitaiStruct): 11 | def __init__(self, _io, _parent=None, _root=None): 12 | self._io = _io 13 | self._parent = _parent 14 | self._root = _root if _root else self 15 | self._read() 16 | 17 | def _read(self): 18 | self.records = [] 19 | i = 0 20 | while not self._io.is_eof(): 21 | self.records.append(Utmp.Record(self._io, self, self._root)) 22 | i += 1 23 | 24 | 25 | class Record(KaitaiStruct): 26 | def __init__(self, _io, _parent=None, _root=None): 27 | self._io = _io 28 | self._parent = _parent 29 | self._root = _root if _root else self 30 | self._read() 31 | 32 | def _read(self): 33 | self.ut_type = self._io.read_s4le() 34 | self.ut_pid = self._io.read_s4le() 35 | self.ut_line = (self._io.read_bytes(32)).decode(u"ASCII") 36 | self.ut_id = (self._io.read_bytes(4)).decode(u"ASCII") 37 | self.ut_user = (self._io.read_bytes(32)).decode(u"ASCII") 38 | self.ut_host = (self._io.read_bytes(256)).decode(u"ASCII") 39 | self.ut_exit = Utmp.Record.ExitStatus(self._io, self, self._root) 40 | self.ut_session = self._io.read_s4le() 41 | self.ut_tv = Utmp.Record.Timeval(self._io, self, self._root) 42 | self.ut_addr_v6 = [] 43 | for i in range(4): 44 | self.ut_addr_v6.append(self._io.read_s4le()) 45 | 46 | self.unused = [] 47 | for i in range(20): 48 | self.unused.append(self._io.read_u1()) 49 | 50 | 51 | class ExitStatus(KaitaiStruct): 52 | def __init__(self, _io, _parent=None, _root=None): 53 | self._io = _io 54 | self._parent = _parent 55 | self._root = _root if _root else self 56 | self._read() 57 | 58 | def _read(self): 59 | self.e_termination = self._io.read_s2le() 60 | self.e_exit = self._io.read_s2le() 61 | 62 | 63 | class Timeval(KaitaiStruct): 64 | def __init__(self, _io, _parent=None, _root=None): 65 | self._io = _io 66 | self._parent = _parent 67 | self._root = _root if _root else self 68 | self._read() 69 | 70 | def _read(self): 71 | self.tv_sec = self._io.read_s4le() 72 | self.tv_usec = self._io.read_s4le() 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /analyze_cov.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import json 5 | from collections import defaultdict 6 | import re 7 | import plotext as plt 8 | 9 | gen_re = re.compile(r'gen(\d+)') 10 | 11 | def print_cov(covfiles): 12 | data = [] 13 | for covfile in covfiles: 14 | gen = int(gen_re.search(covfile).group(1)) 15 | with open(covfile, 'r') as f: 16 | cov = json.load(f) 17 | for model, generators in cov.items(): 18 | for generator, cov in generators.items(): 19 | data.append((gen, model, generator, len(cov))) 20 | return data 21 | 22 | def cumulative_cov(covfiles): 23 | cov_by_gen = defaultdict(set) 24 | for covfile in covfiles: 25 | gen = int(gen_re.search(covfile).group(1)) 26 | with open(covfile, 'r') as f: 27 | cov = json.load(f) 28 | for model, generators in cov.items(): 29 | for generator, cov in generators.items(): 30 | cov_by_gen[gen].update(cov) 31 | cumulative = set() 32 | data = [] 33 | for gen, cov in sorted(cov_by_gen.items()): 34 | cumulative.update(cov) 35 | data.append((gen, len(cumulative))) 36 | return data 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser("Analyze coverage") 40 | parser.add_argument('covfiles', help='Coverage file', nargs='+') 41 | parser.add_argument('-c', '--cumulative', help='Report cumulative coverage', action='store_true') 42 | parser.add_argument('-p', '--plot', help='Plot coverage', action='store_true') 43 | parser.add_argument('-m', '--max-gen', help='Maximum generation for plotting', type=int, default=None) 44 | args = parser.parse_args() 45 | 46 | rundir = args.covfiles[0].split('/')[0] 47 | if args.plot: 48 | # Don't fill the whole terminal 49 | width, height = plt.ts() 50 | plt.plotsize(width // 2, height // 2) 51 | if args.max_gen is not None: 52 | plt.xlim(0, args.max_gen) 53 | 54 | if not args.cumulative: 55 | data = print_cov(args.covfiles) 56 | if args.plot: 57 | plt.scatter([x[0] for x in data], [x[3] for x in data]) 58 | plt.title(f'Variant coverage by generation, {rundir}') 59 | plt.xlabel('Generation') 60 | plt.ylabel('Edges') 61 | plt.show() 62 | else: 63 | for gen, model, generator, cov in data: 64 | gen_str = f'gen{gen}' 65 | print(f'{cov:3} {gen_str:<5} {model:<14} {generator}') 66 | else: 67 | data = cumulative_cov(args.covfiles) 68 | if args.plot: 69 | plt.plot([x[0] for x in data], [x[1] for x in data]) 70 | plt.title(f'Cumulative coverage by generation, {rundir}') 71 | plt.xlabel('Generation') 72 | plt.ylabel('Edges') 73 | plt.show() 74 | else: 75 | for gen, cumulative in data: 76 | gen_str = f'gen{gen}' 77 | print(f'{gen_str:<5} {cumulative:3}') 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /getcov.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from collections import defaultdict 5 | import json 6 | from pathlib import Path 7 | import subprocess 8 | import tempfile 9 | from tqdm import tqdm 10 | from concurrent.futures import ThreadPoolExecutor, as_completed 11 | import glob 12 | import os 13 | 14 | AFL_DIR = '/home/moyix/git/AFLplusplus/' 15 | # ./afl-showmap -q -i {} -o {/}.cov -C -- ./gifread @@ 16 | def afl_cov(showmap_path, prog, input_dir): 17 | with tempfile.NamedTemporaryFile() as f: 18 | cov_file = f.name 19 | cmd = [showmap_path, '-q', '-i', input_dir, '-o', cov_file, '-m', 'none', '-C', '--', prog, '@@'] 20 | subprocess.run( 21 | cmd, 22 | stdout=subprocess.DEVNULL, 23 | stderr=subprocess.DEVNULL, 24 | env={'AFL_QUIET': '1'}, 25 | ) 26 | with open(cov_file, 'r') as f: 27 | return set(l.strip() for l in f) 28 | 29 | def make_parser(): 30 | parser = argparse.ArgumentParser(description="Get coverage for generated inputs") 31 | parser.add_argument('gendir', help='Base directory for generated inputs, structure: gendir/[model]/[generator]/[files]') 32 | parser.add_argument('-O', '--output', type=str, default='output.json', 33 | help='Output file where coverage will be written') 34 | parser.add_argument('-j', '--jobs', type=int, default=64, 35 | help='Number of parallel jobs') 36 | parser.add_argument("--afl_dir", type=Path, 37 | help="Path to AFL++ directory (for afl-showmap)", 38 | default=Path(AFL_DIR)) 39 | return parser 40 | 41 | def init_parser(elm): 42 | pass 43 | 44 | def main(): 45 | from elmconfig import ELMFuzzConfig 46 | parser = make_parser() 47 | config = ELMFuzzConfig(parents={'getcov': parser}) 48 | args = config.parse_args() 49 | showmap = args.afl_dir / 'afl-showmap' 50 | if not showmap.exists(): 51 | config.parser.error(f'afl-showmap not found at {showmap}') 52 | if not args.target.covbin: 53 | config.parser.error('Coverage binary not specified') 54 | covbin = args.target.covbin.expanduser() 55 | if not covbin: 56 | config.parser.error(f'Coverage binary not found at {args.target.covbin}') 57 | combined_cov = {} 58 | with ThreadPoolExecutor(max_workers=64) as executor: 59 | worklist = [] 60 | for model in glob.glob(os.path.join(args.gendir, '*')): 61 | for generator in glob.glob(os.path.join(model, '*')): 62 | worklist.append(( 63 | os.path.basename(model), 64 | os.path.basename(generator), 65 | generator, 66 | )) 67 | futures = {} 68 | progress = tqdm(total=len(worklist), desc='Coverage') 69 | for model, generator, gendir in worklist: 70 | future = executor.submit(afl_cov, showmap, covbin, gendir) 71 | futures[future] = (model, generator, gendir) 72 | future.add_done_callback(lambda _: progress.update()) 73 | for future in as_completed(futures): 74 | model, generator, gendir = futures[future] 75 | cov = future.result() 76 | combined_cov[(model, generator)] = cov 77 | progress.close() 78 | # for (model, generator), cov in combined_cov.items(): 79 | # print(f'{model:>20} {generator} {len(cov)}') 80 | cov_dict = {} 81 | for (model, generator), cov in combined_cov.items(): 82 | if model not in cov_dict: 83 | cov_dict[model] = {} 84 | cov_dict[model][generator] = list(cov) 85 | with open(args.output, 'w') as f: 86 | json.dump(cov_dict, f) 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /genvariants_diff.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import random 4 | import requests 5 | import argparse 6 | import os 7 | 8 | ENDPOINT = 'http://127.0.0.1:8192' 9 | 10 | def model_info(): 11 | """Get information about the model.""" 12 | return requests.get(f'{ENDPOINT}/info').json() 13 | 14 | def generate_completion( 15 | prompt, 16 | temperature=0.2, 17 | max_new_tokens=2048, 18 | repetition_penalty=1.1, 19 | ): 20 | """Generate a completion of the prompt.""" 21 | data = { 22 | 'inputs': prompt, 23 | 'parameters': { 24 | 'temperature': temperature, 25 | 'max_new_tokens': max_new_tokens, 26 | 'do_sample': True, 27 | 'repetition_penalty': repetition_penalty, 28 | "details": True, # So we get the finish_reason 29 | }, 30 | } 31 | return requests.post(f'{ENDPOINT}/generate', json=data).json() 32 | 33 | def random_diff(text: str, msg: str) -> str: 34 | """Generate a prompt for StarCoder's diff format""" 35 | return f'{text}{msg}' 36 | 37 | def new_base(filename: str) -> str: 38 | # filename and extension 39 | base = os.path.basename(filename) 40 | base, ext = os.path.splitext(base) 41 | # Get the first occurrence (if any) of ".base_" 42 | first = base.find('.base_') 43 | if first == -1: 44 | return base, ext 45 | else: 46 | base = base[:first] 47 | return base, ext 48 | 49 | def main(): 50 | global ENDPOINT 51 | parser = argparse.ArgumentParser( 52 | description='Generate variants of a file using an LLM code diff model', 53 | ) 54 | parser.add_argument('file', type=str) 55 | parser.add_argument('-n', '--num', type=int, default=1) 56 | parser.add_argument('-O', '--output', type=str, default='.') 57 | parser.add_argument('--endpoint', type=str, default=ENDPOINT) 58 | parser.add_argument('-c', '--commit-message', type=str, default='Numerous improvements') 59 | # Generation params 60 | parser.add_argument('-t', '--temperature', type=float, default=0.2) 61 | parser.add_argument('-m', '--max-new-tokens', type=int, default=2048) 62 | parser.add_argument('-r', '--repetition-penalty', type=float, default=1.1) 63 | args = parser.parse_args() 64 | ENDPOINT = args.endpoint 65 | 66 | info = model_info() 67 | model = info['model_id'] 68 | if model != 'bigcode/starcoder': 69 | parser.error("Diffs only supported for StarCoder") 70 | 71 | os.makedirs(args.output, exist_ok=True) 72 | 73 | for i in range(args.num): 74 | prompt = random_diff(open(args.file).read(), args.commit_message) 75 | res = generate_completion( 76 | prompt, 77 | temperature=args.temperature, 78 | max_new_tokens=args.max_new_tokens, 79 | repetition_penalty=args.repetition_penalty, 80 | ) 81 | if 'generated_text' not in res: 82 | print(f"WARNING: no generated text in response: {res}") 83 | continue 84 | text = res['generated_text'] 85 | # one of [length, eos_token, stop_sequence] 86 | finish_reason = res['details']['finish_reason'] 87 | finish_reason = { 88 | 'length': 'len', 89 | 'eos_token': 'eos', 90 | 'stop_sequence': 'stp', 91 | }[finish_reason] 92 | # Count lines 93 | gen_lines = text.count('\n') 94 | # filename and extension 95 | base, ext = new_base(args.file) 96 | out_file = f'var_{i:04}.diffmode.gen_{gen_lines:03}-fin_{finish_reason}.base_{base}{ext}' 97 | with open(os.path.join(args.output,out_file), 'w') as f: 98 | f.write(text) 99 | print(f'Wrote {out_file} to {args.output}') 100 | 101 | if __name__ == '__main__': 102 | main() 103 | -------------------------------------------------------------------------------- /do_gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Be strict about failures 4 | set -euo pipefail 5 | 6 | prev_gen="$1" 7 | next_gen="$2" 8 | num_gens=$(./elmconfig.py get run.num_generations) 9 | 10 | # MODELS="codellama starcoder starcoder_diff" 11 | MODELS=$(./elmconfig.py get model.names) 12 | NUM_VARIANTS=$(./elmconfig.py get cli.genvariants_parallel.num_variants) 13 | LOGDIR=$(./elmconfig.py get run.logdir -s GEN=${next_gen}) 14 | 15 | COLOR_RED='\033[0;31m' 16 | COLOR_GREEN='\033[0;32m' 17 | COLOR_RESET='\033[0m' 18 | printf "$COLOR_GREEN"'============> %s: %6s -> %6s of %3d <============'"$COLOR_RESET"'\n' $ELMFUZZ_RUN_NAME $prev_gen $next_gen $num_gens 19 | echo "Running generation $next_gen using $MODELS with $NUM_VARIANTS variants per seed" 20 | 21 | # Create the next generation directory 22 | mkdir -p "$ELMFUZZ_RUNDIR"/${next_gen}/{variants,seeds,logs} 23 | 24 | # Select the seeds for the next generation based on coverage 25 | # If this is the first generation, just use the seed 26 | if [ "$prev_gen" == "initial" ]; then 27 | echo "First generation; using seed(s):" "$ELMFUZZ_RUNDIR"/initial/seeds/*.py 28 | cp "$ELMFUZZ_RUNDIR"/initial/seeds/*.py "$ELMFUZZ_RUNDIR"/${next_gen}/seeds/ 29 | else 30 | # Selection 31 | selection_strategy=$(./elmconfig.py get run.selection_strategy) 32 | # If strategy is elites, select best coverage across all generations 33 | # If it's best_of_generation, select best coverage from the previous generation 34 | # Hopefully eventually we will also have MAP-Elites 35 | if [ "$selection_strategy" == "elites" ]; then 36 | echo "$selection_strategy: Selecting best seeds from all generations" 37 | cov_files=("$ELMFUZZ_RUNDIR"/*/logs/coverage.json) 38 | elif [ "$selection_strategy" == "best_of_generation" ]; then 39 | echo "$selection_strategy: Selecting best seeds from previous generation" 40 | cov_files=("$ELMFUZZ_RUNDIR"/${prev_gen}/logs/coverage.json) 41 | else 42 | echo "Unknown selection strategy $selection_strategy; exiting" 43 | exit 1 44 | fi 45 | python analyze_cov.py "${cov_files[@]}" | sort -n | tail -n 10 | \ 46 | while read cov gen model generator ; do 47 | echo "Selecting $generator from $gen/$model with $cov edges covered" 48 | cp "$ELMFUZZ_RUNDIR"/${gen}/variants/${model}/${generator}.py \ 49 | "$ELMFUZZ_RUNDIR"/${next_gen}/seeds/${gen}_${model}_${generator}.py 50 | done 51 | fi 52 | 53 | # Generate the next generation. If this is the first generation, create 10xNUM_VARIANTS variants 54 | # for each seed with each model. Otherwise, create NUM_VARIANTS variants for each seed with each model. 55 | if [ "$prev_gen" == "initial" ]; then 56 | NUM_VARIANTS=$((NUM_VARIANTS * 10)) 57 | VARIANT_ARGS="-n ${NUM_VARIANTS}" 58 | else 59 | VARIANT_ARGS="" 60 | fi 61 | echo "Generating next generation: ${NUM_VARIANTS} variants for each seed with each model" 62 | for model_name in $MODELS ; do 63 | MODEL=$(basename "$model_name") 64 | GVLOG="${LOGDIR}/meta" 65 | GOLOG="${LOGDIR}/outputgen_${MODEL}.jsonl" 66 | GVOUT=$(./elmconfig.py get run.genvariant_dir -s MODEL=${MODEL} -s GEN=${next_gen}) 67 | GOOUT=$(./elmconfig.py get run.genoutput_dir -s MODEL=${MODEL} -s GEN=${next_gen}) 68 | echo "====================== $model_name ======================" 69 | # TODO: have genvariants_parallel.py do all the models at once 70 | # Will have to add in model-specific args to the config and merge 71 | # in the starcoder_diff script 72 | python genvariants_parallel.py $VARIANT_ARGS \ 73 | -M "${model_name}" -O "$GVOUT" -L "$GVLOG" \ 74 | "$ELMFUZZ_RUNDIR"/${next_gen}/seeds/*.py | \ 75 | python genoutputs.py -L "${GOLOG}" -O "${GOOUT}" 76 | done 77 | 78 | # Collect the coverage of the generators 79 | echo "Collecting coverage of the generators" 80 | all_models_genout_dir=$(realpath -m "$GOOUT"/..) 81 | python getcov.py -O "${LOGDIR}/coverage.json" "$all_models_genout_dir" 82 | 83 | # Plot cumulative coverage so far 84 | python analyze_cov.py -m $num_gens -c -p "$ELMFUZZ_RUNDIR"/*/logs/coverage.json 85 | 86 | # Create a stamp file to indicate that this generation is finished 87 | touch "$ELMFUZZ_RUNDIR"/stamps/${next_gen}.stamp 88 | -------------------------------------------------------------------------------- /gengif_tiny_newconf/config.yaml: -------------------------------------------------------------------------------- 1 | # Automatically generated by genvariants_parallel.dump_config() on 2023-12-12 11:45:29 AM with options: 2 | # {'skip_comments': False, 'skip_defaults': False, 'skip_names': ['config', 'dump_config']} 3 | # Based on existing config file(s): gengif_tiny_newconf/config.yaml 4 | 5 | # -------------------------------- Global options -------------------------------- 6 | 7 | # Options to configure the target program being fuzzed 8 | target: 9 | # Source files in the target (default: None) 10 | srcs: 11 | - ~/git/gifdec/gifread.c 12 | # Path to the target binary with coverage instrumentation (default: None) 13 | covbin: ~/git/gifdec/gifread.cov 14 | 15 | # Options to configure the model(s) used for variant generation 16 | model: 17 | # List of model names (default: None) 18 | names: 19 | - codellama/CodeLlama-13b-hf 20 | # List of model endpoints, formatted as name:endpoint (default: None) 21 | endpoints: 22 | - codellama/CodeLlama-13b-hf:http://127.0.0.1:8192 23 | 24 | # Options to configure the run of the evolutionary algorithm 25 | run: 26 | # Seed files (generator programs that will be mutated) (default: None) 27 | seeds: 28 | - ~/git/elmfuzz/gengif_tiny.py 29 | # Number of generations to run (default: 10) 30 | num_generations: 50 31 | # Selection strategy (one of: elites, best_of_generation) (default: elites) 32 | selection_strategy: elites 33 | # Number of seeds to select each generation (default: 10) 34 | num_selected: 10 35 | # Directory (template) to store generated variants (default: 36 | # {ELMFUZZ_RUNDIR}/{GEN}/variants/{MODEL}) 37 | genvariant_dir: '{ELMFUZZ_RUNDIR}/{GEN}/variants/{MODEL}' 38 | # Directory (template) to store generated outputs (default: 39 | # {ELMFUZZ_RUNDIR}/{GEN}/outputs/{MODEL}) 40 | genoutput_dir: '/fastdata/randomgifs/{ELMFUZZ_RUN_NAME}/{GEN}/{MODEL}' 41 | # Directory (template) to store logs (default: {ELMFUZZ_RUNDIR}/{GEN}/logs) 42 | logdir: '{ELMFUZZ_RUNDIR}/{GEN}/logs' 43 | # Clean the output directories before running (default: False) 44 | clean: true 45 | 46 | # ------------------------- Specific CLI utility options ------------------------- 47 | 48 | cli: 49 | # genvariants_parallel: Use a code model to generate variants of a file. 50 | genvariants_parallel: 51 | # Model to use for generation (default: codellama/CodeLlama-13b-hf) 52 | model_name: codellama/CodeLlama-13b-hf 53 | # Disable the completion mutator (default: False) 54 | no_completion: false 55 | # Disable the FIM (infilling) mutator (default: False) 56 | no_fim: false 57 | # Disable the splice mutator (default: False) 58 | no_splice: false 59 | # Number of variants to generate for each seed (default: 1) 60 | num_variants: 30 61 | # Directory to write variants to (default: .) 62 | output_dir: . 63 | # Directory to write generation metadata to (default: logs) 64 | log_dir: logs 65 | # When making random cuts, always start at this line. Allows specifying an 66 | # immutable region not subject to mutation. (default: 0) 67 | start_line: 9 68 | # Number of inference jobs to run in parallel (default: 16) 69 | jobs: 16 70 | 71 | # Generation parameters 72 | gen: 73 | # Generation temperature (default: 0.2) 74 | temperature: 0.2 75 | # Maximum number of tokens to generate (default: 2048) 76 | max_new_tokens: 2048 77 | # Repetition penalty (default: 1.1) 78 | repetition_penalty: 1.1 79 | genoutputs: 80 | # Don't catch exceptions in the main driver loop (default: False) 81 | raise_errors: true 82 | # Options for the generator; will be passed to each module 83 | driver: 84 | # The function to run in each module (default: None) 85 | function_name: generate_random_gif 86 | # Timeout for each function run (in seconds) (default: 10) 87 | timeout: 10 88 | # Maximum size of the output file (in bytes) (default: 52428800) 89 | size_limit: 52428800 90 | # Maximum memory usage (in bytes) (default: 1073741824) 91 | max_mem: 1073741824 92 | # Suffix for output files (default: .gif) 93 | output_suffix: .gif 94 | # Number of times to run each function in each module (i.e., number of 95 | # outputs to generate) (default: 100) 96 | num_iterations: 100 97 | getcov: 98 | # Output file where coverage will be written (default: output.json) 99 | output: output.json 100 | # Number of parallel jobs (default: 64) 101 | jobs: 64 102 | # Path to AFL++ directory (for afl-showmap) (default: 103 | # /home/moyix/git/AFLplusplus) 104 | afl_dir: /home/moyix/git/AFLplusplus 105 | -------------------------------------------------------------------------------- /genvariants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import random 4 | import requests 5 | import argparse 6 | import os 7 | 8 | ENDPOINT = 'http://127.0.0.1:8192' 9 | 10 | def model_info(): 11 | """Get information about the model.""" 12 | return requests.get(f'{ENDPOINT}/info').json() 13 | 14 | def generate_completion( 15 | prompt, 16 | temperature=0.2, 17 | max_new_tokens=2048, 18 | repetition_penalty=1.1, 19 | stop=None, 20 | ): 21 | """Generate a completion of the prompt.""" 22 | data = { 23 | 'inputs': prompt, 24 | 'parameters': { 25 | 'temperature': temperature, 26 | 'max_new_tokens': max_new_tokens, 27 | 'do_sample': True, 28 | 'repetition_penalty': repetition_penalty, 29 | "details": True, # So we get the finish_reason 30 | }, 31 | } 32 | if stop is not None: 33 | data['parameters']['stop'] = stop 34 | return requests.post(f'{ENDPOINT}/generate', json=data).json() 35 | 36 | def infilling_prompt_llama( 37 | pre: str, 38 | suf: str, 39 | ) -> str: 40 | """ 41 | Format an infilling problem for Code Llama. 42 | If `suffix_first` is set, format in suffix-prefix-middle format. 43 | """ 44 | return f'
 {pre} {suf} '
 45 | 
 46 | def infilling_prompt_starcoder(
 47 |     pre: str,
 48 |     suf: str,
 49 | ) -> str:
 50 |     """
 51 |     Format an infilling problem for StarCoder
 52 |     If `suffix_first` is set, format in suffix-prefix-middle format.
 53 |     """
 54 |     return f'{pre}{suf}'
 55 | 
 56 | def random_completion(text: str) -> [str,str]:
 57 |     """Generate a completion of the text starting from a random line.
 58 |     Always include at least 1 line to avoid an empty prompt."""
 59 |     text_lines = text.split('\n')
 60 |     # Pick a random line number to cut at
 61 |     cut_line = random.randint(1, len(text_lines) - 1)
 62 |     prompt_text = '\n'.join(text_lines[:cut_line])
 63 |     real_completion = '\n'.join(text_lines[cut_line:])
 64 |     return prompt_text, real_completion
 65 | 
 66 | def random_fim(text: str) -> [str,str,str]:
 67 |     """Fill in the middle of the text with a random completion."""
 68 |     text_lines = text.split('\n')
 69 |     # Random start and end lines. Make sure we always have at least
 70 |     # one line in each section.
 71 |     start_line = random.randint(0, len(text_lines) - 2)
 72 |     end_line = random.randint(start_line + 1, len(text_lines) - 1)
 73 |     prefix_text = '\n'.join(text_lines[:start_line]) + '\n'
 74 |     suffix_text = '\n'.join(text_lines[end_line:])
 75 |     real_middle = '\n'.join(text_lines[start_line:end_line])
 76 |     return prefix_text, suffix_text, real_middle
 77 | 
 78 | def new_base(filename: str) -> str:
 79 |     # filename and extension
 80 |     base = os.path.basename(filename)
 81 |     base, ext = os.path.splitext(base)
 82 |     # Get the first occurrence (if any) of ".base_"
 83 |     first = base.find('.base_')
 84 |     if first == -1:
 85 |         return base, ext
 86 |     else:
 87 |         base = base[:first]
 88 |         return base, ext
 89 | 
 90 | def main():
 91 |     global ENDPOINT
 92 |     parser = argparse.ArgumentParser(
 93 |         description='Generate variants of a file using an LLM code model',
 94 |     )
 95 |     parser.add_argument('file', type=str)
 96 |     parser.add_argument('--no-completion', action='store_true')
 97 |     parser.add_argument('--no-fim', action='store_true')
 98 |     parser.add_argument('-n', '--num', type=int, default=1)
 99 |     parser.add_argument('-O', '--output', type=str, default='.')
100 |     parser.add_argument('--endpoint', type=str, default=ENDPOINT)
101 |     # Generation params
102 |     parser.add_argument('-t', '--temperature', type=float, default=0.2)
103 |     parser.add_argument('-m', '--max-new-tokens', type=int, default=2048)
104 |     parser.add_argument('-s', '--start-line', type=int, default=0,
105 |                         help='Minimum start line to use when mutating, to preserve a prefix')
106 |     parser.add_argument('-r', '--repetition-penalty', type=float, default=1.1)
107 |     args = parser.parse_args()
108 |     ENDPOINT = args.endpoint
109 | 
110 |     info = model_info()
111 |     model = info['model_id']
112 |     if model == 'bigcode/starcoder':
113 |         infilling_prompt = infilling_prompt_starcoder
114 |     elif model in ('codellama/CodeLlama-13b-hf',
115 |                    'codellama/CodeLlama-7b-hf'):
116 |         infilling_prompt = infilling_prompt_llama
117 |     else:
118 |         infilling_prompt = None
119 | 
120 |     if infilling_prompt is None and not args.no_fim:
121 |         parser.error(f'Model {model} does not support FIM')
122 |     if args.no_completion and args.no_fim:
123 |         parser.error(f'Nothing to do')
124 | 
125 |     os.makedirs(args.output, exist_ok=True)
126 | 
127 |     generators = []
128 |     if not args.no_completion:
129 |         generators += ['complete']
130 |     if not args.no_fim:
131 |         generators += ['infilled']
132 | 
133 |     for i in range(args.num):
134 |         # Pick a random generator
135 |         generator = random.choice(generators)
136 |         if generator == 'infilled':
137 |             prefix, suffix, orig = random_fim(open(args.file).read())
138 |             prompt = infilling_prompt(prefix, suffix)
139 |             stop = []
140 |         else:
141 |             prefix, orig = random_completion(open(args.file).read())
142 |             suffix = ''
143 |             prompt = prefix
144 |             stop = ['\nif', '\nclass', '\nfor', '\nwhile']
145 |         res = generate_completion(
146 |             prompt,
147 |             temperature=args.temperature,
148 |             max_new_tokens=args.max_new_tokens,
149 |             repetition_penalty=args.repetition_penalty,
150 |             stop=stop,
151 |         )
152 |         if 'generated_text' not in res:
153 |             print(f"WARNING: no generated text in response: {res}")
154 |             continue
155 |         text = res['generated_text']
156 |         if 'codellama' in model:
157 |             # CodeLlama tokenizer decoding seems slightly broken in TGI,
158 |             # so we need to remove the ' ' token manually, and trim the
159 |             # stop sequences.
160 |             text = text.replace(' ', '')
161 |             for stop_seq in stop:
162 |                 if text.endswith(stop_seq):
163 |                     text = text[:-len(stop_seq)]
164 |         # one of [length, eos_token, stop_sequence]
165 |         finish_reason = res['details']['finish_reason']
166 |         finish_reason = {
167 |             'length': 'len',
168 |             'eos_token': 'eos',
169 |             'stop_sequence': 'stp',
170 |         }[finish_reason]
171 |         # Count lines
172 |         plines = prefix.count('\n')
173 |         slines = suffix.count('\n')
174 |         olines = orig.count('\n')
175 |         gen_lines = text.count('\n')
176 |         # filename and extension
177 |         base, ext = new_base(args.file)
178 |         out_file = f'var_{i:04}.{generator}.pre_{plines:03}-orig_{olines:03}-gen_{gen_lines:03}-suf_{slines:03}-fin_{finish_reason}.base_{base}{ext}'
179 |         with open(os.path.join(args.output,out_file), 'w') as f:
180 |             f.write(prefix + text + suffix)
181 |         print(f'Wrote {out_file} to {args.output}')
182 | 
183 | if __name__ == '__main__':
184 |     main()
185 | 


--------------------------------------------------------------------------------
/genvariants_async.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/env python3
  2 | 
  3 | import json
  4 | import random
  5 | import argparse
  6 | import os
  7 | from tqdm.asyncio import tqdm
  8 | import aiohttp
  9 | import asyncio
 10 | import aiofiles
 11 | import requests
 12 | import itertools
 13 | 
 14 | ENDPOINT = 'http://127.0.0.1:8192'
 15 | 
 16 | def model_info():
 17 |     """Get information about the model."""
 18 |     return requests.get(f'{ENDPOINT}/info').json()
 19 | 
 20 | async def generate_completion(
 21 |         prompt,
 22 |         temperature=0.2,
 23 |         max_new_tokens=1200,
 24 |         repetition_penalty=1.1,
 25 |         stop=None,
 26 | ):
 27 |     """Generate a completion of the prompt."""
 28 |     data = {
 29 |         'inputs': prompt,
 30 |         'parameters': {
 31 |             'temperature': temperature,
 32 |             'max_new_tokens': max_new_tokens,
 33 |             'do_sample': True,
 34 |             'repetition_penalty': repetition_penalty,
 35 |              "details": True, # So we get the finish_reason
 36 |         },
 37 |     }
 38 |     if stop is not None:
 39 |         data['parameters']['stop'] = stop
 40 |     async with aiohttp.ClientSession() as session:
 41 |         async with session.post(f'{ENDPOINT}/generate', json=data) as resp:
 42 |             return await resp.json()
 43 | 
 44 | def infilling_prompt_llama(
 45 |     pre: str,
 46 |     suf: str,
 47 | ) -> str:
 48 |     """
 49 |     Format an infilling problem for Code Llama.
 50 |     If `suffix_first` is set, format in suffix-prefix-middle format.
 51 |     """
 52 |     return f'
 {pre} {suf} '
 53 | 
 54 | def infilling_prompt_starcoder(
 55 |     pre: str,
 56 |     suf: str,
 57 | ) -> str:
 58 |     """
 59 |     Format an infilling problem for StarCoder
 60 |     If `suffix_first` is set, format in suffix-prefix-middle format.
 61 |     """
 62 |     return f'{pre}{suf}'
 63 | 
 64 | infilling_prompt = None
 65 | 
 66 | def random_completion(text: str, start_line: int = 1) -> [str,str]:
 67 |     """Generate a completion of the text starting from a random line.
 68 |     Always include at least 1 line to avoid an empty prompt."""
 69 |     text_lines = text.split('\n')
 70 |     # Pick a random line number to cut at
 71 |     cut_line = random.randint(start_line + 1, len(text_lines) - 1)
 72 |     prompt_text = '\n'.join(text_lines[:cut_line])
 73 |     real_completion = '\n'.join(text_lines[cut_line:])
 74 |     return prompt_text, real_completion
 75 | 
 76 | def random_fim(text: str, start_line: int = 1) -> [str,str,str]:
 77 |     """Fill in the middle of the text with a random completion."""
 78 |     text_lines = text.split('\n')
 79 |     # Random start and end lines. Make sure we always have at least
 80 |     # one line in each section.
 81 |     fim_start_line = random.randint(start_line + 1, len(text_lines) - 2)
 82 |     fim_end_line = random.randint(fim_start_line + 1, len(text_lines) - 1)
 83 |     prefix_text = '\n'.join(text_lines[:fim_start_line]) + '\n'
 84 |     suffix_text = '\n'.join(text_lines[fim_end_line:])
 85 |     real_middle = '\n'.join(text_lines[fim_start_line:fim_end_line])
 86 |     return prefix_text, suffix_text, real_middle
 87 | 
 88 | def random_crossover(text1: str, text2: str, start_line: int = 1) -> [str,str]:
 89 |     """Generate a splice of two texts."""
 90 |     text_lines1 = text1.split('\n')
 91 |     text_lines2 = text2.split('\n')
 92 |     cut_line1 = random.randint(start_line + 1, len(text_lines1) - 1)
 93 |     # Cut line in file2.
 94 |     cut_line2 = random.randint(start_line + 1, len(text_lines2) - 1)
 95 |     prefix = '\n'.join(text_lines1[:cut_line1])
 96 |     suffix = '\n'.join(text_lines2[cut_line2:])
 97 |     return prefix, suffix
 98 | 
 99 | def new_base(filename: str) -> str:
100 |     # filename and extension
101 |     base = os.path.basename(filename)
102 |     base, ext = os.path.splitext(base)
103 |     # Get the first occurrence (if any) of ".base_"
104 |     first = base.find('.base_')
105 |     if first == -1:
106 |         return base, ext
107 |     else:
108 |         base = base[:first]
109 |         return base, ext
110 | 
111 | async def generate_variant(i, generators, model, filename, args):
112 |     # Pick a random generator
113 |     generator = random.choice(generators)
114 |     if generator == 'infilled':
115 |         prefix, suffix, orig = random_fim(open(filename).read(), args.start_line)
116 |         prompt = infilling_prompt(prefix, suffix)
117 |         stop = []
118 |     elif generator == 'lmsplice':
119 |         other_files = [f for f in args.files if f != filename]
120 |         if other_files:
121 |             filename2 = random.choice(other_files)
122 |         else:
123 |             filename2 = filename
124 |         prefix, suffix = random_crossover(open(filename).read(), open(filename2).read(), args.start_line)
125 |         orig = ''
126 |         prompt = infilling_prompt(prefix, suffix)
127 |         stop = []
128 |     else:
129 |         prefix, orig = random_completion(open(filename).read(), args.start_line)
130 |         suffix = ''
131 |         prompt = prefix
132 |         stop = ['\nif', '\nclass', '\nfor', '\nwhile']
133 |     res = await generate_completion(
134 |         prompt,
135 |         temperature=args.temperature,
136 |         max_new_tokens=args.max_new_tokens,
137 |         repetition_penalty=args.repetition_penalty,
138 |         stop=stop,
139 |     )
140 |     if 'generated_text' not in res:
141 |         # print(f"WARNING: no generated text in response: {res}")
142 |         return None
143 |     text = res['generated_text']
144 |     if 'codellama' in model:
145 |         # CodeLlama tokenizer decoding seems slightly broken in TGI,
146 |         # so we need to remove the ' ' token manually, and trim the
147 |         # stop sequences.
148 |         text = text.replace(' ', '')
149 |         for stop_seq in stop:
150 |             if text.endswith(stop_seq):
151 |                 text = text[:-len(stop_seq)]
152 |     # one of [length, eos_token, stop_sequence]
153 |     finish_reason = res['details']['finish_reason']
154 |     finish_reason = {
155 |         'length': 'len',
156 |         'eos_token': 'eos',
157 |         'stop_sequence': 'stp',
158 |     }[finish_reason]
159 |     # Count lines
160 |     plines = prefix.count('\n')
161 |     slines = suffix.count('\n')
162 |     olines = orig.count('\n')
163 |     gen_lines = text.count('\n')
164 |     # filename and extension
165 |     base, ext = new_base(filename)
166 |     if generator == 'lmsplice':
167 |         base2, _ = new_base(filename2)
168 |     else:
169 |         base2 = base
170 |     meta = {
171 |         'model': model,
172 |         'generator': generator,
173 |         'prompt_lines': plines,
174 |         'orig_lines': olines,
175 |         'gen_lines': gen_lines,
176 |         'suffix_lines': slines,
177 |         'finish_reason': finish_reason,
178 |         'base': [base] + ([base2] if generator == 'lmsplice' else []),
179 |     }
180 |     out_file = f'var_{i:04}.{generator}{ext}'
181 |     async with aiofiles.open(os.path.join(args.output,out_file), 'w') as f:
182 |         # await f.write(f'# =========== {base} ===========\n')
183 |         await f.write(prefix)
184 |         # await f.write(f'\n# =========== generated ===========\n')
185 |         await f.write(text)
186 |         # await f.write(f'\n# =========== {base2} ===========\n')
187 |         await f.write(suffix)
188 |     # Write metadata to logdir
189 |     async with aiofiles.open(os.path.join(args.logdir, out_file + '.json'), 'w') as f:
190 |         await f.write(json.dumps(meta))
191 |     # tqdm.write(f'Wrote {out_file} to {args.output}')
192 | 
193 | async def main():
194 |     global ENDPOINT
195 |     global infilling_prompt
196 |     parser = argparse.ArgumentParser(
197 |         description='Generate variants of a file using an LLM code model',
198 |     )
199 |     parser.add_argument('files', type=str, nargs='+')
200 |     parser.add_argument('--no-completion', action='store_true')
201 |     parser.add_argument('--no-fim', action='store_true')
202 |     parser.add_argument('--no-splice', action='store_true')
203 |     parser.add_argument('-n', '--num', type=int, default=1)
204 |     parser.add_argument('-O', '--output', type=str, default='.')
205 |     parser.add_argument('-L', '--logdir', type=str, default='logs')
206 |     parser.add_argument('-s', '--start-line', type=int, default=0,
207 |                         help='When making random cuts, always start at this line')
208 |     parser.add_argument('--endpoint', type=str, default=ENDPOINT)
209 |     # Generation params
210 |     parser.add_argument('-t', '--temperature', type=float, default=0.2)
211 |     parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
212 |     parser.add_argument('-r', '--repetition-penalty', type=float, default=1.1)
213 |     args = parser.parse_args()
214 |     ENDPOINT = args.endpoint
215 | 
216 |     info = model_info()
217 |     model = info['model_id']
218 |     if model == 'bigcode/starcoder':
219 |         infilling_prompt = infilling_prompt_starcoder
220 |     elif model in ('codellama/CodeLlama-13b-hf',
221 |                    'codellama/CodeLlama-7b-hf'):
222 |         infilling_prompt = infilling_prompt_llama
223 |     else:
224 |         infilling_prompt = None
225 | 
226 |     if infilling_prompt is None and not args.no_fim:
227 |         parser.error(f'Model {model} does not support FIM')
228 |     if args.no_completion and args.no_fim and args.no_splice:
229 |         parser.error(f'Nothing to do')
230 | 
231 |     os.makedirs(args.output, exist_ok=True)
232 |     os.makedirs(args.logdir, exist_ok=True)
233 | 
234 |     generators = []
235 |     if not args.no_completion:
236 |         generators += ['complete']
237 |     if not args.no_fim:
238 |         generators += ['infilled']
239 |     if not args.no_splice:
240 |         generators += ['lmsplice']
241 | 
242 |     worklist = []
243 |     i = 0
244 |     for _ in range(args.num):
245 |         for filename in args.files:
246 |             worklist.append((i, filename))
247 |             i += 1
248 |     async for i, filename in tqdm(worklist, desc='Generating', unit='variant'):
249 |         await generate_variant(i, generators, model, filename, args)
250 | 
251 | if __name__ == '__main__':
252 |     asyncio.run(main())
253 | 


--------------------------------------------------------------------------------
/genvariants_parallel.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/env python3
  2 | 
  3 | import json
  4 | import random
  5 | import os
  6 | from typing import List, Optional
  7 | from argparse import ArgumentParser
  8 | import requests
  9 | from concurrent.futures import ThreadPoolExecutor, as_completed
 10 | 
 11 | ENDPOINT = 'http://127.0.0.1:8192'
 12 | 
 13 | def model_info():
 14 |     """Get information about the model."""
 15 |     return requests.get(f'{ENDPOINT}/info').json()
 16 | 
 17 | def generate_completion(
 18 |         prompt,
 19 |         temperature=0.2,
 20 |         max_new_tokens=1200,
 21 |         repetition_penalty=1.1,
 22 |         stop=None,
 23 | ):
 24 |     """Generate a completion of the prompt."""
 25 |     data = {
 26 |         'inputs': prompt,
 27 |         'parameters': {
 28 |             'temperature': temperature,
 29 |             'max_new_tokens': max_new_tokens,
 30 |             'do_sample': True,
 31 |             'repetition_penalty': repetition_penalty,
 32 |             'details': True, # So we get the finish_reason
 33 |         },
 34 |     }
 35 |     if stop is not None:
 36 |         data['parameters']['stop'] = stop
 37 |     return requests.post(f'{ENDPOINT}/generate', json=data).json()
 38 | 
 39 | def infilling_prompt_llama(
 40 |     pre: str,
 41 |     suf: str,
 42 | ) -> str:
 43 |     """
 44 |     Format an infilling problem for Code Llama.
 45 |     If `suffix_first` is set, format in suffix-prefix-middle format.
 46 |     """
 47 |     return f'
 {pre} {suf} '
 48 | 
 49 | def infilling_prompt_starcoder(
 50 |     pre: str,
 51 |     suf: str,
 52 | ) -> str:
 53 |     """
 54 |     Format an infilling problem for StarCoder
 55 |     If `suffix_first` is set, format in suffix-prefix-middle format.
 56 |     """
 57 |     return f'{pre}{suf}'
 58 | 
 59 | infilling_prompt = None
 60 | 
 61 | def random_completion(text: str, start_line: int = 1) -> [str,str]:
 62 |     """Generate a completion of the text starting from a random line.
 63 |     Always include at least 1 line to avoid an empty prompt."""
 64 |     text_lines = text.split('\n')
 65 |     # Pick a random line number to cut at
 66 |     cut_line = random.randint(start_line + 1, len(text_lines) - 1)
 67 |     prompt_text = '\n'.join(text_lines[:cut_line])
 68 |     real_completion = '\n'.join(text_lines[cut_line:])
 69 |     return prompt_text, real_completion
 70 | 
 71 | def random_fim(text: str, start_line: int = 1) -> [str,str,str]:
 72 |     """Fill in the middle of the text with a random completion."""
 73 |     text_lines = text.split('\n')
 74 |     # Random start and end lines. Make sure we always have at least
 75 |     # one line in each section.
 76 |     fim_start_line = random.randint(start_line + 1, len(text_lines) - 2)
 77 |     fim_end_line = random.randint(fim_start_line + 1, len(text_lines) - 1)
 78 |     prefix_text = '\n'.join(text_lines[:fim_start_line]) + '\n'
 79 |     suffix_text = '\n'.join(text_lines[fim_end_line:])
 80 |     real_middle = '\n'.join(text_lines[fim_start_line:fim_end_line])
 81 |     return prefix_text, suffix_text, real_middle
 82 | 
 83 | def random_crossover(text1: str, text2: str, start_line: int = 1) -> [str,str]:
 84 |     """Generate a splice of two texts."""
 85 |     text_lines1 = text1.split('\n')
 86 |     text_lines2 = text2.split('\n')
 87 |     cut_line1 = random.randint(start_line + 1, len(text_lines1) - 1)
 88 |     # Cut line in file2.
 89 |     cut_line2 = random.randint(start_line + 1, len(text_lines2) - 1)
 90 |     prefix = '\n'.join(text_lines1[:cut_line1])
 91 |     suffix = '\n'.join(text_lines2[cut_line2:])
 92 |     return prefix, suffix
 93 | 
 94 | # SRCS = [
 95 | #     '/home/moyix/git/gifdec/gifdec.c',
 96 | # ]
 97 | # def random_snippet(text: str, start_line: int = 1) -> [str,str]:
 98 | #     """Include commented out code from the parser code."""
 99 | #     parser_chunks = open(random.choice(SRCS)).read().split('\n\n')
100 | 
101 | #     "# NOTE: the corresponding parser code in C is:\n#\n"
102 | 
103 | def new_base(filename: str) -> str:
104 |     # filename and extension
105 |     base = os.path.basename(filename)
106 |     base, ext = os.path.splitext(base)
107 |     # Get the first occurrence (if any) of ".base_"
108 |     first = base.find('.base_')
109 |     if first == -1:
110 |         return base, ext
111 |     else:
112 |         base = base[:first]
113 |         return base, ext
114 | 
115 | def generate_variant(i, generators, model, filename, args):
116 |     # Pick a random generator
117 |     generator = random.choice(generators)
118 |     if generator == 'infilled':
119 |         prefix, suffix, orig = random_fim(open(filename).read(), args.start_line)
120 |         prompt = infilling_prompt(prefix, suffix)
121 |         stop = []
122 |     elif generator == 'lmsplice':
123 |         other_files = [f for f in args.files if f != filename]
124 |         if other_files:
125 |             filename2 = random.choice(other_files)
126 |         else:
127 |             filename2 = filename
128 |         prefix, suffix = random_crossover(open(filename).read(), open(filename2).read(), args.start_line)
129 |         orig = ''
130 |         prompt = infilling_prompt(prefix, suffix)
131 |         stop = []
132 |     else:
133 |         prefix, orig = random_completion(open(filename).read(), args.start_line)
134 |         suffix = ''
135 |         prompt = prefix
136 |         stop = ['\nif', '\nclass', '\nfor', '\nwhile']
137 | 
138 |     # Prepare metadata up front in case we fail to generate
139 |     # filename and extension
140 |     base, ext = new_base(filename)
141 |     if generator == 'lmsplice':
142 |         base2, _ = new_base(filename2)
143 |     else:
144 |         base2 = base
145 |     # Count lines
146 |     plines = prefix.count('\n')
147 |     slines = suffix.count('\n')
148 |     olines = orig.count('\n')
149 |     # Output filenames
150 |     out_file = f'var_{i:04}.{generator}{ext}'
151 |     out_path = os.path.join(args.output_dir,out_file)
152 |     meta_file = os.path.join(args.log_dir, out_file + '.json')
153 | 
154 |     res = generate_completion(
155 |         prompt,
156 |         stop=stop,
157 |         **vars(args.gen),
158 |     )
159 |     if 'generated_text' not in res:
160 |         meta = {
161 |             'model': model,
162 |             'prompt': prompt,
163 |             'generator': generator,
164 |             'prompt_lines': plines,
165 |             'orig_lines': olines,
166 |             'gen_lines': 0,
167 |             'suffix_lines': slines,
168 |             'finish_reason': 'err',
169 |             'base': [base] + ([base2] if generator == 'lmsplice' else []),
170 |             'response': res,
171 |         }
172 | 
173 |         # Write (error) metadata to logdir
174 |         with open(meta_file, 'w') as f:
175 |             f.write(json.dumps(meta))
176 | 
177 |         return None
178 | 
179 |     # Fix up the generated text
180 |     text = res['generated_text']
181 |     if 'codellama' in model:
182 |         # CodeLlama tokenizer decoding seems slightly broken in TGI,
183 |         # so we need to remove the ' ' token manually, and trim the
184 |         # stop sequences.
185 |         text = text.replace(' ', '')
186 |         for stop_seq in stop:
187 |             if text.endswith(stop_seq):
188 |                 text = text[:-len(stop_seq)]
189 |     gen_lines = text.count('\n')
190 | 
191 |     # one of [length, eos_token, stop_sequence]
192 |     finish_reason = res['details']['finish_reason']
193 |     finish_reason = {
194 |         'length': 'len',
195 |         'eos_token': 'eos',
196 |         'stop_sequence': 'stp',
197 |     }[finish_reason]
198 |     meta = {
199 |         'model': model,
200 |         'prompt': prompt,
201 |         'generator': generator,
202 |         'prompt_lines': plines,
203 |         'orig_lines': olines,
204 |         'gen_lines': gen_lines,
205 |         'suffix_lines': slines,
206 |         'finish_reason': finish_reason,
207 |         'base': [base] + ([base2] if generator == 'lmsplice' else []),
208 |         'response': res,
209 |     }
210 |     # Write output to file
211 |     with open(out_path, 'w') as f:
212 |         f.write(prefix)
213 |         f.write(text)
214 |         f.write(suffix)
215 | 
216 |     # Write metadata to logdir
217 |     with open(meta_file, 'w') as f:
218 |         f.write(json.dumps(meta))
219 | 
220 |     return out_path
221 | 
222 | def make_parser():
223 |     parser = ArgumentParser(
224 |         description='Use a code model to generate variants of a file.'
225 |     )
226 |     parser.add_argument('files', type=str, nargs='+')
227 |     parser.add_argument('-M', '--model_name', type=str, default='codellama/CodeLlama-13b-hf',
228 |                         help='Model to use for generation')
229 |     parser.add_argument('--no-completion', action='store_true',
230 |                         help='Disable the completion mutator')
231 |     parser.add_argument('--no-fim', action='store_true',
232 |                         help='Disable the FIM (infilling) mutator')
233 |     parser.add_argument('--no-splice', action='store_true',
234 |                         help='Disable the splice mutator')
235 |     parser.add_argument('-n', '--num_variants', type=int, default=1,
236 |                         help='Number of variants to generate for each seed')
237 |     parser.add_argument('-O', '--output_dir', type=str, default='.',
238 |                         help='Directory to write variants to')
239 |     parser.add_argument('-L', '--log_dir', type=str, default='logs',
240 |                         help='Directory to write generation metadata to')
241 |     parser.add_argument('-s', '--start_line', type=int, default=0,
242 |                         help='When making random cuts, always start at this line. ' + \
243 |                         'Allows specifying an immutable region not subject to mutation.')
244 |     parser.add_argument('-j', '--jobs', type=int, default=16,
245 |                         help='Number of inference jobs to run in parallel')
246 |     # Generation params
247 |     parser.add_argument('-t', '--gen.temperature', type=float, default=0.2, help='Generation temperature')
248 |     parser.add_argument('-m', '--gen.max-new-tokens', type=int, default=2048, help='Maximum number of tokens to generate')
249 |     parser.add_argument('-r', '--gen.repetition-penalty', type=float, default=1.1, help='Repetition penalty')
250 |     return parser
251 | 
252 | def init_parser(elm):
253 |     # Add a bit of help text to the generation options
254 |     elm.subgroup_help['gen'] = 'Generation parameters'
255 | 
256 | def main():
257 |     global ENDPOINT
258 |     global infilling_prompt
259 |     import sys
260 |     from elmconfig import ELMFuzzConfig
261 |     config = ELMFuzzConfig(parents={'genvariants_parallel': make_parser()})
262 |     init_parser(config)
263 |     args = config.parse_args()
264 | 
265 |     try:
266 |         ENDPOINT = args.model.endpoints[args.model_name]
267 |     except KeyError:
268 |         print(f'WARNING: no endpoint for model {args.model_name}, using default: {ENDPOINT}', file=sys.stderr)
269 | 
270 |     info = model_info()
271 |     model = info['model_id']
272 |     if model != args.model_name:
273 |         config.parser.error(f'Expected model {args.model_name}, but {ENDPOINT} is actually {model}')
274 | 
275 |     if model == 'bigcode/starcoder':
276 |         infilling_prompt = infilling_prompt_starcoder
277 |     elif model in ('codellama/CodeLlama-13b-hf',
278 |                    'codellama/CodeLlama-7b-hf'):
279 |         infilling_prompt = infilling_prompt_llama
280 |     else:
281 |         infilling_prompt = None
282 | 
283 |     if infilling_prompt is None and not args.no_fim:
284 |         config.parser.error(f'Model {model} does not support FIM')
285 |     if args.no_completion and args.no_fim and args.no_splice:
286 |         config.parser.error(f'Nothing to do')
287 | 
288 |     os.makedirs(args.output_dir, exist_ok=True)
289 |     os.makedirs(args.log_dir, exist_ok=True)
290 | 
291 |     generators = []
292 |     if not args.no_completion:
293 |         generators += ['complete']
294 |     if not args.no_fim:
295 |         generators += ['infilled']
296 |     if not args.no_splice:
297 |         generators += ['lmsplice']
298 | 
299 |     # Print the number of variants we'll generate so that the next
300 |     # stage (genoutputs) knows how many to expect.
301 |     print(len(args.files) * args.num_variants, flush=True)
302 | 
303 |     worklist = []
304 |     i = 0
305 |     for _ in range(args.num_variants):
306 |         for filename in args.files:
307 |             worklist.append((i, filename))
308 |             i += 1
309 |     # pbar = tqdm(total=len(worklist), desc='Generating', unit='variant')
310 |     with ThreadPoolExecutor(max_workers=args.jobs) as executor:
311 |         futures = []
312 |         for i, filename in worklist:
313 |             future = executor.submit(generate_variant, i, generators, model, filename, args)
314 |             # future.add_done_callback(lambda _: pbar.update())
315 |             futures.append(future)
316 |         for future in as_completed(futures):
317 |             res = future.result()
318 |             if res is not None:
319 |                 print(res, flush=True)
320 |     # pbar.close()
321 | 
322 | if __name__ == '__main__':
323 |     main()
324 | 


--------------------------------------------------------------------------------
/driver.py:
--------------------------------------------------------------------------------
  1 | from collections import namedtuple
  2 | import io
  3 | import json
  4 | import os
  5 | import shutil
  6 | import sys
  7 | import argparse
  8 | import importlib.util
  9 | from enum import Enum
 10 | import signal
 11 | import time
 12 | import traceback
 13 | import resource
 14 | from typing import BinaryIO, Callable, List, NamedTuple, Tuple, Union
 15 | from tempfile import TemporaryDirectory
 16 | from contextlib import nullcontext, redirect_stdout, redirect_stderr
 17 | from concurrent.futures import ProcessPoolExecutor, as_completed
 18 | import logging
 19 | 
 20 | from drive_log import set_loglevel
 21 | logger = logging.getLogger('root')
 22 | 
 23 | class ExceptionInfo(NamedTuple):
 24 |     exception_class: str
 25 |     exception_message: str
 26 |     module_path: str
 27 |     filename: str
 28 |     line: int
 29 |     traceback: List[str]
 30 | 
 31 |     @classmethod
 32 |     def from_exception(cls, e: Exception, module_path: str):
 33 |         return cls(
 34 |             exception_class = f'{e.__class__.__module__}.{e.__class__.__name__}',
 35 |             exception_message = str(e),
 36 |             module_path = module_path,
 37 |             filename = e.__traceback__.tb_frame.f_code.co_filename if e.__traceback__ else None,
 38 |             line = e.__traceback__.tb_lineno if e.__traceback__ else None,
 39 |             traceback = traceback.format_tb(e.__traceback__) if e.__traceback__ else [],
 40 |         )
 41 | 
 42 | class GenResult(str,Enum):
 43 |     Success     = "Success"
 44 |     Timeout     = "Timeout"
 45 |     TooBig      = "TooBig"
 46 |     ImportError = "ImportError"
 47 |     Error       = "Error"
 48 |     RunError    = "RunError"
 49 |     UnknownErr  = "UnknownErr"
 50 |     NoLogErr    = "NoLogErr"
 51 | 
 52 | class ResultInfo(NamedTuple):
 53 |     time_taken: float
 54 |     memory_used: int
 55 |     stdout: str
 56 |     stderr: str
 57 | 
 58 | class Result(NamedTuple):
 59 |     # Filled in by the callee
 60 |     result_type: GenResult
 61 |     error: Union[ExceptionInfo,None]
 62 |     data: Union[ResultInfo,None]
 63 |     # Filled in by the caller
 64 |     module_path: Union[str,None] = None
 65 |     function_name: Union[str,None] = None
 66 |     output_file: Union[str,None] = None
 67 |     args: Union[argparse.Namespace,None] = None
 68 | 
 69 |     def _convert(self, item):
 70 |         """
 71 |         Recursively converts namedtuples to dictionaries, Enums to their values,
 72 |         and lists and dicts to their converted forms.
 73 |         """
 74 |         if isinstance(item, Enum):
 75 |             return item.value
 76 |         elif isinstance(item, tuple) and hasattr(item, '_asdict'):
 77 |             return {key: self._convert(value) for key, value in item._asdict().items()}
 78 |         elif isinstance(item, list):
 79 |             return [self._convert(sub_item) for sub_item in item]
 80 |         elif isinstance(item, dict):
 81 |             return {key: self._convert(value) for key, value in item.items()}
 82 |         else:
 83 |             return item
 84 | 
 85 |     def json(self):
 86 |         """
 87 |         Converts the namedtuple into JSON.
 88 |         """
 89 |         result_dict = self._convert(self)
 90 |         return json.dumps(
 91 |             result_dict,
 92 |             default=lambda o: o.__dict__ if hasattr(o, '__dict__') else str(o)
 93 |         )
 94 | 
 95 | 
 96 | # Context manager for timing out a function
 97 | class TimedExecution():
 98 |     def __init__(self, timeout):
 99 |         self.timeout = timeout
100 |         self.timed_out = False
101 |         self.start_time = None
102 |         self.old_handler = None
103 | 
104 |     def __enter__(self):
105 |         self.start_time = time.time()
106 |         self.old_handler = signal.signal(signal.SIGALRM, self._handle_timeout)
107 |         signal.alarm(self.timeout)
108 |         return self
109 | 
110 |     def __exit__(self, exc_type, exc_value, traceback):
111 |         signal.alarm(0)
112 |         signal.signal(signal.SIGALRM, self.old_handler)
113 |         self.time_taken = time.time() - self.start_time
114 |         self.timed_out = False
115 |         return self
116 | 
117 |     def _handle_timeout(self, signum, frame):
118 |         self.time_taken = time.time() - self.start_time
119 |         self.timed_out = True
120 |         raise TimeoutError(f"Timed out after {self.timeout} seconds")
121 | 
122 | # Context manager to run in a temporary directory
123 | class TemporaryDirectoryContext():
124 |     def __init__(self, *args, **kwargs):
125 |         self.td = TemporaryDirectory(*args, **kwargs)
126 |         self.old_cwd = None
127 | 
128 |     def __enter__(self):
129 |         self.old_cwd = os.getcwd()
130 |         self.td.__enter__()
131 |         os.chdir(self.td.name)
132 |         return self.td
133 | 
134 |     def __exit__(self, exc_type, exc_value, traceback):
135 |         os.chdir(self.old_cwd)
136 |         self.td.__exit__(exc_type, exc_value, traceback)
137 | 
138 | # Context manager to limit RAM usage
139 | class MemoryLimit():
140 |     def __init__(self, limit):
141 |         self.limit = limit
142 |         self.mem_usage = None
143 | 
144 |     def __enter__(self):
145 |         self.old_limit = resource.getrlimit(resource.RLIMIT_AS)
146 |         # Only change the soft limit so that we can set it back
147 |         resource.setrlimit(resource.RLIMIT_AS, (self.limit, self.old_limit[1]))
148 |         return self
149 | 
150 |     def __exit__(self, exc_type, exc_value, traceback):
151 |         resource.setrlimit(resource.RLIMIT_AS, self.old_limit)
152 |         self.mem_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
153 | 
154 | 
155 | # Context manager that combines the above
156 | class Sandbox():
157 |     def __init__(self, timeout, memory_limit):
158 |         self.timeout = TimedExecution(timeout)
159 |         self.memory_limit = MemoryLimit(memory_limit)
160 |         self.tempdir = TemporaryDirectoryContext()
161 |         self.stdout = io.StringIO()
162 |         self.stderr = io.StringIO()
163 |         self.capture_stdout = redirect_stdout(self.stdout)
164 |         self.capture_stderr = redirect_stderr(self.stderr)
165 | 
166 |     def __enter__(self):
167 |         self.capture_stderr.__enter__()
168 |         self.capture_stdout.__enter__()
169 |         self.tempdir.__enter__()
170 |         self.timeout.__enter__()
171 |         self.memory_limit.__enter__()
172 |         return self
173 | 
174 |     def __exit__(self, exc_type, exc_value, traceback):
175 |         self.memory_limit.__exit__(exc_type, exc_value, traceback)
176 |         self.timeout.__exit__(exc_type, exc_value, traceback)
177 |         self.tempdir.__exit__(exc_type, exc_value, traceback)
178 |         self.capture_stdout.__exit__(exc_type, exc_value, traceback)
179 |         self.capture_stderr.__exit__(exc_type, exc_value, traceback)
180 | 
181 |     def result(self) -> ResultInfo:
182 |         return ResultInfo(
183 |             time_taken = self.timeout.time_taken,
184 |             memory_used = self.memory_limit.mem_usage,
185 |             stdout = self.stdout.getvalue(),
186 |             stderr = self.stderr.getvalue(),
187 |         )
188 | 
189 | class TooBigException(Exception):
190 |     pass
191 | 
192 | class SizeLimitedBinaryFile(io.BufferedWriter):
193 |     def __init__(self, *args, max_size: int, **kwargs):
194 |         super().__init__(*args, **kwargs)
195 |         self.max_size = max_size
196 | 
197 |     def write(self, b: bytes) -> int:
198 |         new_position = self.tell() + len(b)
199 |         if new_position > self.max_size:
200 |             raise TooBigException(f"Writing would exceed the size limit of {self.max_size} bytes")
201 |         return super().write(b)
202 | 
203 | def generate_one(
204 |         output_file: str,
205 |         function: Callable[[BinaryIO,BinaryIO],None],
206 |         args: argparse.Namespace,
207 |     ) -> Result:
208 |     # Function takes a file-like BytesIO object (/dev/urandom)
209 |     # and a writable BytesIO file object (output file)
210 | 
211 |     # Ensure the directory exists
212 |     dirname = os.path.dirname(output_file)
213 |     if dirname:
214 |         os.makedirs(dirname, exist_ok=True)
215 | 
216 |     with open('/dev/urandom', 'rb') as rng, \
217 |         SizeLimitedBinaryFile(open(output_file, 'wb'),
218 |                               max_size=args.size_limit) as out:
219 |         try:
220 |             with Sandbox(args.timeout, args.max_mem) as s:
221 |                 function(rng, out)
222 |             return Result(
223 |                 result_type = GenResult.Success,
224 |                 error = None,
225 |                 data = s.result(),
226 |             )
227 |         except MemoryError as e:
228 |             # Reset the memory limit immediately
229 |             resource.setrlimit(resource.RLIMIT_AS, (-1,-1))
230 |             return Result(
231 |                 result_type = GenResult.Error,
232 |                 error = ExceptionInfo.from_exception(e, args.module_path),
233 |                 data = s.result(),
234 |             )
235 |         except TooBigException:
236 |             return Result(
237 |                 result_type = GenResult.TooBig,
238 |                 error = None,
239 |                 data = s.result(),
240 |             )
241 |         except TimeoutError as e:
242 |             return Result(
243 |                 result_type = GenResult.Timeout,
244 |                 error = None,
245 |                 data = s.result(),
246 |             )
247 |         except Exception as e:
248 |             return Result(
249 |                 result_type = GenResult.Error,
250 |                 error = ExceptionInfo.from_exception(e, args.module_path),
251 |                 data = s.result(),
252 |             )
253 | 
254 | def get_function(module_path, function_name, args):
255 |     try:
256 |         # This needs to wrapped in the sandbox because modules can exec
257 |         # code on load
258 |         with Sandbox(args.timeout, args.max_mem) as s:
259 |             full_module_path = os.path.abspath(module_path)
260 |             shutil.copy(full_module_path, './generator_module.py')
261 |             sys.path.append('.')
262 |             import generator_module
263 |             def capture_exit(rv=None):
264 |                 raise Exception(f"Attempted to exit with code {rv}")
265 |             generator_module.exit = capture_exit
266 |             generator_module.quit = capture_exit
267 |             function = getattr(generator_module, function_name)
268 |             return function
269 |     except Exception as e:
270 |         logger.info(f"Error importing module {module_path}: {e}")
271 |         return Result(
272 |             result_type = GenResult.ImportError,
273 |             error = ExceptionInfo.from_exception(e, module_path),
274 |             data = s.result(),
275 |         )
276 | 
277 | def make_parser(description):
278 |     parser = argparse.ArgumentParser(
279 |         description=description,
280 |         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
281 |     )
282 |     parser.add_argument('module_path', type=str,
283 |         help='Path to the module containing the function to run')
284 |     parser.add_argument(
285 |         'function', type=str,
286 |         help='The function to run',
287 |     )
288 |     parser.add_argument(
289 |         '-n', '--num', type=int, default=1,
290 |         help='Number of times to run the function')
291 |     parser.add_argument(
292 |         '-S', '--size-limit', type=int, default=50*1024*1024,
293 |         help='Maximum size of the output file (in bytes)')
294 |     parser.add_argument(
295 |         '-o', '--output-prefix', type=str, default='./output',
296 |         help='Output prefix')
297 |     parser.add_argument(
298 |         '-s', '--output-suffix', type=str, default='.dat',
299 |         help='Output suffix')
300 |     parser.add_argument(
301 |         '-t', '--timeout', type=int, default=10,
302 |         help='Timeout for the run (in seconds)')
303 |     parser.add_argument(
304 |         '-M', '--max-mem', type=int, default=1024*1024*1024,
305 |         help='Maximum memory usage (in bytes)')
306 |     parser.add_argument(
307 |         '-L', '--logfile', type=str, default=None,
308 |         help='Log file to write to')
309 |     parser.add_argument('-q', '--quiet', action='store_true')
310 |     parser.add_argument('-v', '--verbose', action='store_true')
311 |     return parser
312 | 
313 | def fill_result(result, module_path, function_name, output_file, args):
314 |     return Result(
315 |         result_type = result.result_type,
316 |         error = result.error,
317 |         data = result.data,
318 |         module_path = module_path,
319 |         function_name = function_name,
320 |         output_file = output_file,
321 |         args = args,
322 |     )
323 | 
324 | def main():
325 |     parser = make_parser('Run an input generator function in a loop')
326 |     args = parser.parse_args()
327 |     set_loglevel(logger, args)
328 | 
329 |     with open(args.logfile, 'w') if args.logfile else nullcontext(sys.stdout) as f:
330 |         module_path = os.path.abspath(args.module_path)
331 |         function_name = args.function
332 |         function_or_result = get_function(module_path, function_name, args)
333 |         if isinstance(function_or_result, Result):
334 |             result = function_or_result = get_function(module_path, function_name, args)
335 |             final_result = fill_result(result, module_path, function_name, None, args)
336 |             print(final_result.json(), file=f)
337 |             return
338 | 
339 |         function = function_or_result
340 |         with ProcessPoolExecutor() as executor:
341 |             futures = {}
342 |             for i in range(args.num):
343 |                 output_file = f'{args.output_prefix}_{i:08}{args.output_suffix}'
344 |                 future = executor.submit(generate_one, output_file, function, args)
345 |                 futures[future] = output_file
346 |             for future in as_completed(futures):
347 |                 output_file = futures[future]
348 |                 result = future.result()
349 |                 final_result = fill_result(result, module_path, function_name, output_file, args)
350 |                 print(final_result.json(), file=f)
351 | 
352 | if __name__ == '__main__':
353 |     main()
354 | 


--------------------------------------------------------------------------------
/genoutputs.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/env python3
  2 | 
  3 | import argparse
  4 | from collections import OrderedDict, defaultdict
  5 | from concurrent.futures import ProcessPoolExecutor, as_completed
  6 | import glob
  7 | import json
  8 | import logging
  9 | import os
 10 | import re
 11 | import shutil
 12 | import subprocess
 13 | import sys
 14 | from typing import BinaryIO
 15 | try:
 16 |     # Not until 3.11
 17 |     from hashlib import file_digest
 18 | except ImportError:
 19 |     def file_digest(f: BinaryIO):
 20 |         import hashlib
 21 |         BLOCKSIZE = 65536
 22 |         hasher = hashlib.sha256()
 23 |         buf = f.read(BLOCKSIZE)
 24 |         while len(buf) > 0:
 25 |             hasher.update(buf)
 26 |             buf = f.read(BLOCKSIZE)
 27 |         return hasher
 28 | 
 29 | from drive_log import setup_custom_logger
 30 | logger = setup_custom_logger('root')
 31 | 
 32 | from tqdm import tqdm
 33 | from driver import ExceptionInfo, Result, ResultInfo, GenResult
 34 | 
 35 | # Global color cycle with ANSI colors
 36 | COLOR_GREEN = '\033[92m'
 37 | COLOR_RED = '\033[91m'
 38 | COLOR_YELLOW = '\033[93m'
 39 | COLOR_BLUE = '\033[94m'
 40 | COLOR_MAGENTA = '\033[95m'
 41 | COLOR_CYAN = '\033[96m'
 42 | COLOR_WHITE = '\033[97m'
 43 | COLOR_GREY = '\033[90m'
 44 | COLOR_END = '\033[0m'
 45 | COLOR_CYCLE = [
 46 |     COLOR_GREEN,
 47 |     COLOR_RED,
 48 |     COLOR_YELLOW,
 49 |     COLOR_BLUE,
 50 |     COLOR_MAGENTA,
 51 |     COLOR_CYAN,
 52 |     COLOR_WHITE,
 53 |     COLOR_GREY,
 54 | ]
 55 | 
 56 | def draw_success_rate(stats, preferred_colors=None):
 57 |     BOX = '▓'
 58 |     WIDTH = 80
 59 | 
 60 |     if preferred_colors is None:
 61 |         preferred_colors = {}
 62 | 
 63 |     def bar(color, width):
 64 |         return f"{color}{BOX*width}{COLOR_END}"
 65 | 
 66 |     total = sum(stats.values())
 67 |     outcome_bars = []
 68 |     legends = []
 69 |     color_index = 0
 70 | 
 71 |     # Calculate the width for each key and draw the bars
 72 |     color_cycle = COLOR_CYCLE[:]
 73 |     # Use the preferred colors first
 74 |     for key in preferred_colors:
 75 |         if key not in stats:
 76 |             continue
 77 |         color = preferred_colors[key]
 78 |         width = int(WIDTH * stats[key] / total)
 79 |         outcome_bars.append((key, width, color))
 80 |         legends.append(f"{color}{BOX}{COLOR_END} {key}")
 81 |         color_cycle.remove(color)
 82 |         color_index += 1
 83 |     # Then use the color cycle for the rest
 84 |     for key, count in stats.items():
 85 |         if key in preferred_colors:
 86 |             continue
 87 |         width = int(WIDTH * count / total)
 88 |         color = color_cycle[color_index % len(color_cycle)]
 89 |         color_index += 1
 90 |         outcome_bars.append((key, width, color))
 91 |         legends.append(f"{color}{BOX}{COLOR_END} {key}")
 92 | 
 93 |     # Calculate how much width is left after the outcomes are drawn
 94 |     used_width = sum(width for _, width, _ in outcome_bars)
 95 |     remaining_width = WIDTH - used_width
 96 | 
 97 |     # Add any remaining width to the largest outcome
 98 |     largest_outcome = max(outcome_bars, key=lambda x: x[1])[0]
 99 |     outcome_bars = [(key, width + (remaining_width if key == largest_outcome else 0), color)
100 |                     for key, width, color in outcome_bars]
101 | 
102 |     # Construct the final drawn bar and legend
103 |     drawn_bar = ''.join(bar(color,width) for _, width, color in outcome_bars)
104 |     legend = ' '.join(legends)
105 | 
106 |     return drawn_bar + "  " + legend
107 | 
108 | gentype_re = re.compile(r'var_\d{4}\.(?P[a-z]+)\.')
109 | def get_gentype(module_path):
110 |     basename = os.path.basename(module_path)
111 |     # E.g.: var_0000.diffmode.py
112 |     #  => diffmode
113 |     # E.g.: var_0000.complete.py
114 |     #  => complete
115 |     return gentype_re.search(basename).group('gentype')
116 | 
117 | def generate_stats(logfile):
118 |     color_preferences = {
119 |         'Success': COLOR_GREEN,
120 |         'Error': COLOR_RED,
121 |         'Timeout': COLOR_YELLOW,
122 |     }
123 |     def add_stats(d1, d2):
124 |         return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)}
125 |     # Track stats as we go and print them at the end
126 |     running_stats = defaultdict(lambda: defaultdict(int))
127 |     with open(logfile) as f:
128 |         original_args = json.loads(f.readline())['data']['args']
129 |         for line in f:
130 |             result = json.loads(line)
131 |             try:
132 |                 module_path = result['module_path']
133 |             except KeyError:
134 |                 print(f"Error: {line}", file=sys.stderr)
135 |             if result['result_type'] == 'ImportError':
136 |                 # Mark the batch as an error
137 |                 running_stats[get_gentype(module_path)]['ImportError'] += original_args['driver']['num_iterations']
138 |             else:
139 |                 running_stats[get_gentype(module_path)][result['result_type']] += 1
140 |     running_stats = { k: dict(v) for k, v in running_stats.items() }
141 |     combined = {}
142 |     for k in running_stats:
143 |         combined = add_stats(combined, running_stats[k])
144 |     print(f"Stats:", file=sys.stderr)
145 |     for k in sorted(running_stats.keys()):
146 |         print(f"  {k}: {running_stats[k]}", file=sys.stderr)
147 |     print(f"  combined: {combined}", file=sys.stderr)
148 |     print(f"Stats (visual):", file=sys.stderr)
149 |     for k in sorted(running_stats.keys()):
150 |         print(f"  {k}: {draw_success_rate(running_stats[k],color_preferences)}", file=sys.stderr)
151 |     print(f"  combined: {draw_success_rate(combined,color_preferences)}", file=sys.stderr)
152 |     total = sum(combined.values())
153 |     success = combined.get('Success', 0)
154 |     print(f"     total: {total} files attempted", file=sys.stderr)
155 |     print(f"   success: {success} files generated", file=sys.stderr)
156 |     if total != 0:
157 |         print(f"  success%: {success/total*100:.2f}%", file=sys.stderr)
158 | 
159 | def generate_filestats(logfile):
160 |     def count_unique_files(outdir, ext):
161 |         try:
162 |             files = glob.glob(os.path.join(outdir, f'*{ext}'))
163 |         except FileNotFoundError:
164 |             return 0
165 |         unique_files = set([
166 |             file_digest(open(os.path.join(outdir, f), 'rb')).hexdigest()
167 |             for f in files
168 |         ])
169 |         return len(unique_files)
170 |     def file_sizes(outdir, ext):
171 |         try:
172 |             return [
173 |                 os.path.getsize(os.path.join(outdir, f))
174 |                 for f in glob.glob(os.path.join(outdir, f'*{ext}'))
175 |             ]
176 |         except FileNotFoundError:
177 |             return []
178 |     def new_filestats():
179 |         return {
180 |             'file_sizes': {},
181 |             'unique_hashes': {},
182 |         }
183 |     # Tracks file stats for each module, keyed by generation type
184 |     file_stats = defaultdict(lambda: defaultdict(new_filestats))
185 |     with open(logfile) as f:
186 |         original_args = json.loads(f.readline())['data']['args']
187 |         ext = original_args['driver']['output_suffix']
188 |         output_dir = original_args['output_dir']
189 |         for line in f:
190 |             result = json.loads(line)
191 |             module_path = result['module_path']
192 |             generation_type = get_gentype(module_path)
193 |             file_stats[generation_type][module_path]['file_sizes'] = {}
194 |             file_stats[generation_type][module_path]['unique_hashes'] = {}
195 | 
196 |     # Compute the file stats
197 |     total = sum([
198 |         len(file_stats[generation_type])
199 |         for generation_type in file_stats
200 |     ])
201 |     progress = tqdm(total=total, desc="Computing file stats", unit="mod")
202 |     for generation_type in file_stats:
203 |         for module_path in file_stats[generation_type]:
204 |             worker_dir = os.path.join(output_dir, os.path.splitext(os.path.basename(module_path))[0])
205 |             file_stats[generation_type][module_path]['file_sizes'] = file_sizes(worker_dir, ext)
206 |             file_stats[generation_type][module_path]['unique_hashes'] = count_unique_files(worker_dir, ext)
207 |             progress.update()
208 |     progress.close()
209 |     print(f"File stats:", file=sys.stderr)
210 |     computed_file_stats = defaultdict(dict)
211 |     total_unique = 0
212 |     single_unique = 0
213 |     zero_unique = 0
214 |     # Keep both average and n so we compute the combined average correctly
215 |     average_file_size = []
216 |     average_nonzero_file_size = []
217 |     for generation_type in file_stats:
218 |         print(f"  {generation_type}:", file=sys.stderr)
219 |         gen_total_unique = sum([
220 |             file_stats[generation_type][module_path]['unique_hashes']
221 |             for module_path in file_stats[generation_type]
222 |         ])
223 |         total_unique += gen_total_unique
224 |         computed_file_stats[generation_type]['total_unique'] = gen_total_unique
225 |         print(f"    total unique: {gen_total_unique}", file=sys.stderr)
226 |         # Number of generators with just one unique file
227 |         gen_single_unique = len([
228 |             module_path
229 |             for module_path in file_stats[generation_type]
230 |             if file_stats[generation_type][module_path]['unique_hashes'] == 1
231 |         ])
232 |         single_unique += gen_single_unique
233 |         computed_file_stats[generation_type]['single_unique'] = gen_single_unique
234 |         print(f"    single unique: {gen_single_unique}", file=sys.stderr)
235 |         # Number of generators with zero unique files
236 |         gen_zero_unique = len([
237 |             module_path
238 |             for module_path in file_stats[generation_type]
239 |             if file_stats[generation_type][module_path]['unique_hashes'] == 0
240 |         ])
241 |         zero_unique += gen_zero_unique
242 |         computed_file_stats[generation_type]['zero_unique'] = gen_zero_unique
243 |         print(f"    zero unique: {gen_zero_unique}", file=sys.stderr)
244 |         # Average file size
245 |         gen_file_sizes = [
246 |             file_stats[generation_type][module_path]['file_sizes']
247 |             for module_path in file_stats[generation_type]
248 |         ]
249 |         # concatenate all the lists
250 |         gen_file_sizes = sum(gen_file_sizes, [])
251 |         gen_avg_file_size = sum(gen_file_sizes) / len(gen_file_sizes) if len(gen_file_sizes) > 0 else 0
252 |         nonzero_file_sizes = [s for s in gen_file_sizes if s > 0]
253 |         avg_nonzero_file_size = sum(nonzero_file_sizes) / len(nonzero_file_sizes) if len(nonzero_file_sizes) > 0 else 0
254 |         average_file_size.append((gen_avg_file_size, len(gen_file_sizes)))
255 |         average_nonzero_file_size.append((avg_nonzero_file_size, len(nonzero_file_sizes)))
256 |         computed_file_stats[generation_type]['average_file_size'] = gen_avg_file_size
257 |         computed_file_stats[generation_type]['average_nonzero_file_size'] = avg_nonzero_file_size
258 |         print(f"    average file size: {gen_avg_file_size:.2f} bytes", file=sys.stderr)
259 |         print(f"    average nonzero file size: {avg_nonzero_file_size:.2f} bytes", file=sys.stderr)
260 |     # Combined stats
261 |     print(f"  combined:", file=sys.stderr)
262 |     computed_file_stats['combined']['total_unique'] = total_unique
263 |     print(f"    total unique: {total_unique}", file=sys.stderr)
264 |     computed_file_stats['combined']['single_unique'] = single_unique
265 |     print(f"    single unique: {single_unique}", file=sys.stderr)
266 |     computed_file_stats['combined']['zero_unique'] = zero_unique
267 |     print(f"    zero unique: {zero_unique}", file=sys.stderr)
268 |     # Compute the combined average file size
269 |     total_file_size = sum([s*n for s,n in average_file_size])
270 |     total_nonzero_file_size = sum([s*n for s,n in average_nonzero_file_size])
271 |     total_files = sum([n for s,n in average_file_size])
272 |     total_nonzero_files = sum([n for s,n in average_nonzero_file_size])
273 |     combined_avg_file_size = total_file_size / total_files if total_files > 0 else 0
274 |     combined_avg_nonzero_file_size = total_nonzero_file_size / total_nonzero_files if total_nonzero_files > 0 else 0
275 |     computed_file_stats['combined']['average_file_size'] = combined_avg_file_size
276 |     computed_file_stats['combined']['average_nonzero_file_size'] = combined_avg_nonzero_file_size
277 |     print(f"    average file size: {combined_avg_file_size:.2f} bytes", file=sys.stderr)
278 |     print(f"    average nonzero file size: {combined_avg_nonzero_file_size:.2f} bytes", file=sys.stderr)
279 | 
280 |     # Include the raw file stats in the output
281 |     computed_file_stats['infilled']['raw'] = file_stats['infilled']
282 |     computed_file_stats['complete']['raw'] = file_stats['complete']
283 |     computed_file_stats['diffmode']['raw'] = file_stats['diffmode']
284 | 
285 |     # Save to a new JSON file based on the output log's name
286 |     output_log = os.path.splitext(logfile)[0]
287 |     output_log += '.filestats.json'
288 |     with open(output_log, 'w') as f:
289 |         json.dump(computed_file_stats, f, indent=2)
290 |     print(f"Wrote file stats to {output_log}", file=sys.stderr)
291 | 
292 | class filestats_action(argparse.Action):
293 |     def __init__(self, option_strings, dest, **kwargs):
294 |         return super().__init__(option_strings, dest, nargs=0, **kwargs)
295 | 
296 |     def __call__(self, parser, namespace, values, option_string, **kwargs):
297 |         if namespace.logfile is None:
298 |             parser.error('Must specify --logfile with --stats-only')
299 |         generate_stats(namespace.logfile)
300 |         generate_filestats(namespace.logfile)
301 | 
302 |         parser.exit()
303 | 
304 | def generate_docker(module_path, worker_dir, args):
305 |     module_name = os.path.basename(module_path)
306 |     # Copy the module to the output directory
307 |     copied_module_name = os.path.join(worker_dir, module_name)
308 |     shutil.copyfile(module_path, copied_module_name)
309 |     docker_module_name = os.path.join("/data", module_name)
310 |     docker_outdir = os.path.join("/data", "output")
311 |     logfile_name = f'logfile.json'
312 |     docker_logfile_name = os.path.join("/data", logfile_name)
313 |     # Run the module in the docker container
314 |     container_name = f'elmfuzz_{os.path.basename(worker_dir)}'
315 |     cmd = [
316 |         'docker', 'run', '--rm',
317 |         '--name', container_name,
318 |         '-v', f'{worker_dir}:/data',
319 |         'elmfuzz:latest',
320 |         'python', 'driver.py',
321 |         '-n', str(args.driver.num_iterations),
322 |         '-o', docker_outdir,
323 |         '-L', docker_logfile_name,
324 |         '-t', str(args.driver.timeout),
325 |         '-S', str(args.driver.size_limit),
326 |         '-M', str(args.driver.max_mem),
327 |         '-s', args.driver.output_suffix,
328 |         docker_module_name, args.driver.function_name,
329 |     ]
330 |     logger.debug(f"Running: {' '.join(cmd)}")
331 |     result = None
332 |     try:
333 |         # Timeout: 30 minutes
334 |         timeout = 30 * 60
335 |         subprocess.run(cmd, check=True, text=True, capture_output=True, timeout=timeout)
336 |     except subprocess.TimeoutExpired as e:
337 |         # Stop the container
338 |         subprocess.run(['docker', 'stop', container_name])
339 |     except subprocess.CalledProcessError as e:
340 |         result = Result(
341 |             error = ExceptionInfo.from_exception(e, module_path),
342 |             data = ResultInfo(
343 |                 time_taken=None,
344 |                 memory_used=None,
345 |                 stdout=e.stdout,
346 |                 stderr=e.stderr,
347 |             ),
348 |             module_path = module_path,
349 |             result_type = GenResult.RunError,
350 |             function_name = args.driver.function_name,
351 |             args = args,
352 |         )
353 |     # Remove the module from the output directory
354 |     try:
355 |         os.remove(copied_module_name)
356 |     except FileNotFoundError:
357 |         pass
358 |     gen_results = []
359 |     try:
360 |         # Read the results from the logfile
361 |         with open(os.path.join(worker_dir, logfile_name)) as f:
362 |             for line in f:
363 |                 gen_results.append(json.loads(line))
364 |         # remove the logfile
365 |         os.remove(os.path.join(worker_dir, logfile_name))
366 |     except FileNotFoundError:
367 |         # The logfile wasn't created, so something went wrong
368 |         result = Result(
369 |             error = None,
370 |             data = None,
371 |             module_path = module_path,
372 |             result_type = GenResult.NoLogErr,
373 |             function_name = args.driver.function_name,
374 |             args = args,
375 |         )
376 |     if len(gen_results) == 1 and gen_results[0]['result_type'] == 'ImportError':
377 |         return gen_results
378 | 
379 |     if len(gen_results) != args.driver.num_iterations:
380 |         if result is None:
381 |             result = Result(
382 |                 error = None,
383 |                 data = None,
384 |                 module_path = module_path,
385 |                 result_type = GenResult.UnknownErr,
386 |                 function_name = args.driver.function_name,
387 |                 args = args,
388 |             )
389 |         # Fill in the remaining entries with the error
390 |         for _ in range(args.driver.num_iterations - len(gen_results)):
391 |             gen_results.append(json.loads(result.json()))
392 |     return gen_results
393 | 
394 | def make_parser():
395 |     parser = argparse.ArgumentParser(
396 |         description='Create outputs using generated programs'
397 |     )
398 |     # Global options
399 |     parser.add_argument(
400 |         '-O', '--output-dir', type=str, default='.',
401 |         help='Output directory')
402 |     parser.add_argument(
403 |         '-j', '--jobs', type=int, default=None,
404 |         help='Maximum number of jobs to run in parallel; None means ncpu',
405 |     )
406 |     parser.add_argument('--raise-errors', action='store_true',
407 |                         help="Don't catch exceptions in the main driver loop")
408 |     parser.add_argument('-L', '--logfile', type=str, default=None,
409 |                         help='Log file for JSON results')
410 |     parser.add_argument('--stats-only', action=filestats_action,
411 |                         default=argparse.SUPPRESS,
412 |                         help='Only compute stats for the given log file')
413 |     # These are passed to every module
414 |     parser.add_argument(
415 |         '-f', '--driver.function_name', type=str,
416 |         help='The function to run in each module',
417 |     )
418 |     parser.add_argument(
419 |         '-t', '--driver.timeout', type=int, default=10,
420 |         help='Timeout for each function run (in seconds)',
421 |     )
422 |     parser.add_argument(
423 |         '-S', '--driver.size-limit', type=int, default=50*1024*1024,
424 |         help='Maximum size of the output file (in bytes)')
425 |     parser.add_argument(
426 |         '-M', '--driver.max-mem', type=int, default=1024*1024*1024,
427 |         help='Maximum memory usage (in bytes)',
428 |     )
429 |     parser.add_argument(
430 |         '-s', '--driver.output-suffix', type=str, default='.gif',
431 |         help='Suffix for output files',
432 |     )
433 |     parser.add_argument(
434 |         '-n', '--driver.num_iterations', type=int, default=100,
435 |         help='Number of times to run each function in each module (i.e., number of outputs to generate)',
436 |     )
437 |     return parser
438 | 
439 | def init_parser(elm):
440 |     elm.subgroup_help['driver'] = 'Options for the generator; will be passed to each module'
441 | 
442 | def main():
443 |     from elmconfig import ELMFuzzConfig
444 |     parser = make_parser()
445 |     config = ELMFuzzConfig(parents={'genoutputs': parser})
446 |     init_parser(config)
447 |     args = config.parse_args()
448 |     logger.setLevel(logging.INFO)
449 | 
450 |     if args.logfile is not None:
451 |         output_log = open(args.logfile, 'w')
452 |     else:
453 |         output_log = sys.stdout
454 | 
455 |     # Record the arguments we're using in the log
456 |     print(json.dumps(
457 |         {'error': None, 'data': {'args': args.__dict__}},
458 |         default=lambda x: x.__dict__ if hasattr(x, '__dict__') else str(x),
459 |     ), file=output_log)
460 | 
461 |     # The first line sent by genvariants is the number of modules it will produce
462 |     module_count = int(sys.stdin.readline())
463 | 
464 |     # Call generate_all on each module in args.module_paths in parallel
465 |     with ProcessPoolExecutor(max_workers=args.jobs) as executor:
466 |         progress = tqdm(total=module_count,
467 |                         desc="Generating", unit="mod")
468 |         futures_to_paths = OrderedDict()
469 |         for module_path in sys.stdin:
470 |             module_path = module_path.strip()
471 |             # Make an output directory for this module's outputs
472 |             module_base = os.path.splitext(os.path.basename(module_path))[0]
473 |             worker_dir = os.path.join(args.output_dir, module_base)
474 |             os.makedirs(worker_dir, exist_ok=True)
475 |             future = executor.submit(
476 |                 generate_docker,
477 |                 module_path, worker_dir, args
478 |             )
479 |             future.add_done_callback(lambda _: progress.update())
480 |             futures_to_paths[future] = (module_path, worker_dir)
481 |         for future in as_completed(futures_to_paths):
482 |             module_path, worker_dir = futures_to_paths[future]
483 |             try:
484 |                 result = future.result()
485 |                 for res in result:
486 |                     print(json.dumps(res), file=output_log)
487 |             except Exception as e:
488 |                 if args.raise_errors: raise
489 |                 res = Result(
490 |                     error=ExceptionInfo.from_exception(e, module_path),
491 |                     data = None,
492 |                     module_path = module_path,
493 |                     result_type = GenResult.Error,
494 |                     function_name = args.driver.function_name,
495 |                 )
496 |                 print(json.dumps({
497 |                     'error': ExceptionInfo.from_exception(e, module_path),
498 |                 }), file=output_log)
499 |         progress.close()
500 | 
501 |     if output_log != sys.stdout:
502 |         output_log.close()
503 | 
504 |     # Collect statistics if we have a log
505 |     if args.logfile is None: return
506 | 
507 |     # Print the stats out to stderr now that we're done
508 |     generate_stats(args.logfile)
509 |     # Skip file stats for now, takes too long
510 |     # generate_filestats(args.logfile)
511 | 
512 | if __name__ == '__main__':
513 |     main()
514 | 


--------------------------------------------------------------------------------
/elmconfig.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/env python3
  2 | 
  3 | import argparse
  4 | import copy
  5 | from datetime import datetime
  6 | from enum import Enum
  7 | import sys
  8 | import textwrap
  9 | from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple
 10 | import os
 11 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Action, Namespace
 12 | from pathlib import Path, PosixPath
 13 | from ruamel.yaml import YAML
 14 | from ruamel.yaml.comments import CommentedMap
 15 | from collections import OrderedDict
 16 | from collections.abc import Sequence as SequenceABC
 17 | from io import StringIO
 18 | 
 19 | class SelectionStrategy(Enum):
 20 |     """Selection strategy"""
 21 |     Elites = 'elites'
 22 |     BestOfGeneration = 'best_of_generation'
 23 | 
 24 | class StoreDictKeyPair(Action):
 25 |     """Store a key-value pair in a dict"""
 26 |     def __init__(self, option_strings, dest, nargs=None, **kwargs):
 27 |         super().__init__(option_strings, dest, nargs=nargs, **kwargs)
 28 |     def __call__(self, parser, namespace, values, option_string=None):
 29 |         d = namespace.__dict__.get(self.dest)
 30 |         if d is None:
 31 |             d = {}
 32 |         for kv in values:
 33 |             k, v = kv.split(':', 1)
 34 |             d[k] = v
 35 |         namespace.__dict__[self.dest] = d
 36 |     # Helper to turn a dict into a list of key-value pairs
 37 |     @staticmethod
 38 |     def invert(d: Dict) -> List[str]:
 39 |         if d is None:
 40 |             return None
 41 |         return [ f"{k}:{v}" for k, v in d.items() ]
 42 | 
 43 | def value_is_default(key: str, args: Namespace, parser: ArgumentParser) -> bool:
 44 |     """Check if a value is the default value for the corresponding option"""
 45 |     for opt in parser._get_optional_actions():
 46 |         if key == opt.dest:
 47 |             return args.__dict__[key] == opt.default
 48 |     return False
 49 | 
 50 | def convert_conf_item(key: str, val: Any, args: Namespace, parser: ArgumentParser):
 51 |     store_actions = {
 52 |         act for name, act in parser._registries['action'].items()
 53 |         if isinstance(name, str) and name.startswith('store_')
 54 |     }
 55 |     for opt in parser._get_optional_actions():
 56 |         def type_conv(v):
 57 |             if opt.type is None:
 58 |                 return v
 59 |             elif v is None:
 60 |                 return None
 61 |             else:
 62 |                 return opt.type(v)
 63 |         if key == opt.dest:
 64 |             # If the action is one that stores a value, just set the value
 65 |             if type(opt) in store_actions:
 66 |                 args.__dict__[key] = val
 67 |             # If the action is one that returns a list, make sure val is a list
 68 |             elif opt.nargs in ['+', '*'] or isinstance(opt.nargs, int) and opt.nargs > 1:
 69 |                 # Make sure val is a list
 70 |                 if not isinstance(val, list):
 71 |                     print(f"Warning: {key} should be a list, but is {type(val)}; skipping", file=sys.stderr)
 72 |                     return
 73 |                 # Convert each item in the list
 74 |                 conv = [ type_conv(v) for v in val ]
 75 |                 opt(parser, args, conv, f'--{key}')
 76 |             else:
 77 |                 conv = type_conv(val)
 78 |                 # Set the value by calling the Action directly
 79 |                 opt(parser, args, conv, f'--{key}')
 80 |             return
 81 |     # If we get here, we didn't find a matching option
 82 |     print(f"Warning: unknown option {key}; skipping", file=sys.stderr)
 83 | 
 84 | def nest_namespace(ns: Namespace) -> Namespace:
 85 |     """Recursively create a nested namespace from a flat one"""
 86 |     # Make a copy of the namespace
 87 |     nested = Namespace(**ns.__dict__)
 88 |     for k, v in ns.__dict__.items():
 89 |         if '.' in k:
 90 |             # Split the key
 91 |             nested_name, rest = k.split('.', 1)
 92 |             # Create the nested namespace if it doesn't exist
 93 |             if not hasattr(nested, nested_name):
 94 |                 setattr(nested, nested_name, Namespace())
 95 |             # Set the value
 96 |             setattr(getattr(nested, nested_name), rest, v)
 97 |             # Delete the old value
 98 |             delattr(nested, k)
 99 |     # Recurse
100 |     for k, v in nested.__dict__.items():
101 |         if isinstance(v, Namespace):
102 |             setattr(nested, k, nest_namespace(v))
103 |     return nested
104 | class ELMFuzzConfig:
105 |     """ELMFuzz configuration"""
106 |     def __init__(self,
107 |                  prog: str = None,
108 |                  default_config_file: str = 'config.yaml',
109 |                  parents: Dict[str,ArgumentParser] = None
110 |                  ):
111 |         if prog is None:
112 |             prog = os.path.splitext(os.path.basename(sys.argv[0]))[0]
113 |         self.prog = prog
114 |         self.default_config_file = default_config_file
115 |         self.parent_parsers = parents if parents is not None else {}
116 |         # Used so that we can add help text to subgroups, e.g.
117 |         # the "target" group which has target.covbin, target.srcs, etc.
118 |         self.subgroup_help = {}
119 |         self.init_parser()
120 |         # Defer loading config file until parse_args() is called, because
121 |         # a config file may be specified on the command line
122 |         self.config = None
123 |         self.init_dumper()
124 | 
125 |     def init_dumper(self) -> None:
126 |         self.yaml = YAML(typ='rt')
127 |         self.yaml.preserve_quotes = True
128 |         self.yaml.representer.add_representer(
129 |             PosixPath,
130 |             lambda r, v: r.represent_str(str(v))
131 |         )
132 | 
133 |     def merge_yaml_files(self, files: List[str]) -> CommentedMap:
134 |         from yamlpath.merger import Merger, MergerConfig
135 |         from yamlpath.commands.yaml_merge import get_doc_mergers, merge_docs
136 |         from yamlpath.wrappers import ConsolePrinter
137 |         from yamlpath.common import Parsers
138 |         merge_args = Namespace(
139 |             arrays='left', # Merge arrays by using the left value
140 |             aoh='deep', # Merge arrays of hashes by merging hashes
141 |             preserve_lhs_comments=True,
142 |             debug=False, quiet=True, verbose=False
143 |         )
144 |         log = ConsolePrinter(merge_args)
145 |         merge_config = MergerConfig(log, merge_args)
146 |         yaml_editor: YAML = Parsers.get_yaml_editor()
147 |         mergers: List[Merger] = []
148 |         merge_count: int = 0
149 |         # Process in reverse order because yamlpath only saves the comments
150 |         # from the leftmost file
151 |         for yaml_file in files[::-1]:
152 |             if len(mergers) < 1:
153 |                 (mergers, mergers_loaded) = get_doc_mergers(
154 |                     log, yaml_editor, merge_config, yaml_file)
155 |             else:
156 |                 exit_state = merge_docs(
157 |                     log, yaml_editor, merge_config, mergers, yaml_file)
158 |                 if not exit_state == 0:
159 |                     print(f"Error merging {yaml_file}", file=sys.stderr)
160 |                     break
161 |             merge_count += 1
162 |         dumps = []
163 |         for doc in mergers:
164 |             doc.prepare_for_dump(yaml_editor, '')
165 |             dumps.append(doc.data)
166 |         return dumps[0]
167 | 
168 |     def parse_args_nofail(self, args=None, namespace=None) -> Namespace:
169 |         # Helper to remove an option from the parser
170 |         def _remove_argument(parser, arg):
171 |             for action in parser._actions:
172 |                 opts = action.option_strings
173 |                 if (opts and opts[0] == arg) or action.dest == arg:
174 |                     parser._remove_action(action)
175 |                     break
176 |             for action in parser._action_groups:
177 |                 for group_action in action._group_actions:
178 |                     opts = group_action.option_strings
179 |                     if (opts and opts[0] == arg) or group_action.dest == arg:
180 |                         action._group_actions.remove(group_action)
181 |                         return
182 |         # Work on a copy of the parser, args, and namespace
183 |         dump_parser = copy.deepcopy(self.parser)
184 |         if args is None:
185 |             dump_args = sys.argv[1:][:]
186 |         else:
187 |             dump_args = copy.deepcopy(args)
188 |         dump_namespace = copy.deepcopy(namespace)
189 |         _remove_argument(dump_parser, '--dump-config')
190 | 
191 |         # Remove the option from the args; tricky since nargs is '?'. We'll do it
192 |         # the dumb way and just see if arg+1 contains any of the strings we expect
193 |         dump_config_index = None
194 |         for i, arg in enumerate(dump_args):
195 |             if arg.startswith('--dump-config'):
196 |                 dump_config_index = i
197 |                 break
198 |         assert dump_config_index is not None, "Couldn't find --dump-config in args"
199 |         dump_option = dump_args[dump_config_index]
200 |         if '=' in dump_option or len(dump_args) == dump_config_index + 1:
201 |             # Remove the whole thing
202 |             del dump_args[dump_config_index]
203 |         else:
204 |             next_arg = dump_args[dump_config_index + 1]
205 |             expected_params = { 'skip_comments', 'skip_defaults', 'file' }
206 |             if any(e in next_arg for e in expected_params):
207 |                 # Remove both
208 |                 del dump_args[dump_config_index:dump_config_index+2]
209 |             else:
210 |                 # Remove just the option
211 |                 del dump_args[dump_config_index]
212 | 
213 |         # Make all arguments optional
214 |         for action in dump_parser._actions:
215 |             action.required = False
216 | 
217 |         # Parse again
218 |         args = dump_parser.parse_args(dump_args, dump_namespace)
219 |         # Load config file
220 |         conf = self.load_config(args)
221 |         # Add config args
222 |         self.add_config_args(args)
223 |         return args
224 | 
225 |     def dump_config_action(self,
226 |                 parser: ArgumentParser,
227 |                 namespace: Namespace,
228 |                 values: Any,
229 |                 option_string: str,
230 |             ) -> None:
231 |         # The action is invoked during parsing, so the args aren't fully set up yet.
232 |         # Unfortunately, if we do a full parse, then any missing required arguments
233 |         # will raise errors. So we make a copy of the parser, change its required
234 |         # arguments to optional, and parse the args again. Also, we need to remove the
235 |         # --dump-config option from the parser and the arguments, or we'll either
236 |         # recurse or get an error about an unknown option. Oof!
237 |         args = self.parse_args_nofail(
238 |             self._most_recent_args,
239 |             self._most_recent_namespace,
240 |         )
241 | 
242 |         # Dump config
243 |         kwargs = {}
244 |         if values is True:
245 |             # No arguments
246 |             pass
247 |         else:
248 |             # Split comma-separated arguments
249 |             dump_opts = values.split(',')
250 |             for opt in dump_opts:
251 |                 if opt.startswith('file='):
252 |                     kwargs['file'] = opt.split('=')[1]
253 |                 elif opt in ['skip_comments', 'skip_defaults']:
254 |                     kwargs[opt] = True
255 |                 else:
256 |                     self.parser.error(f"Unknown argument to --dump-config: {opt}")
257 |         # Dump to file if requested
258 |         if 'file' in kwargs:
259 |             kwargs['file'] = open(kwargs['file'], 'w')
260 |         else:
261 |             kwargs['file'] = sys.stdout
262 |         # Dump
263 |         self.dump_config(args, **kwargs)
264 |         sys.exit(0)
265 | 
266 |     def init_parser(self) -> ArgumentParser:
267 |         # First try without a conflict handler so we can
268 |         # print a warning if there are any conflicts
269 |         self.parser = ArgumentParser(
270 |             parents=self.parent_parsers.values(),
271 |             prog=self.prog,
272 |             add_help=not bool(self.parent_parsers),
273 |         )
274 |         config_group = self.parser.add_argument_group('Config options')
275 |         config_group.add_argument("--config", default=None, type=Path,
276 |                                   help="Path to config file (overrides default search)")
277 | 
278 |         def make_dump_config_action(parent):
279 |             class DumpConfigAction(Action):
280 |                 def __init__(self, option_strings, dest, nargs=None, **kwargs):
281 |                     super().__init__(option_strings, dest, nargs=nargs, **kwargs)
282 |                     self.parent = parent
283 |                 def __call__(self, parser, namespace, values, option_string=None):
284 |                     self.parent.dump_config_action(parser, namespace, values, option_string)
285 |             return DumpConfigAction
286 |         config_group.add_argument(
287 |             '--dump-config', type=str, nargs='?',
288 |             help=("Dump config and exit.  "
289 |                   "The optional (comma-separated) argument controls how the config is dumped: "
290 |                   "skip_comments (Don't add help text as comments), "
291 |                   "skip_defaults (Skip options that are set to their default value), and"
292 |                   "file= (Dump to the specified file instead of stdout)."
293 |             ),
294 |             default=False,
295 |             const=True,
296 |             action=make_dump_config_action(self),
297 |         )
298 | 
299 |         group = self.parser.add_argument_group('Global options')
300 |         self.subgroup_help['target'] = 'Options to configure the target program being fuzzed'
301 |         group.add_argument("--target.srcs", type=Path, nargs='+', action='extend',
302 |                            help="Source files in the target")
303 |         group.add_argument("--target.covbin", type=Path,
304 |                            help="Path to the target binary with coverage instrumentation")
305 |         self.subgroup_help['model'] = 'Options to configure the model(s) used for variant generation'
306 |         group.add_argument("--model.names", type=str, nargs='+', action='extend', help="List of model names")
307 |         group.add_argument("--model.endpoints", type=str, nargs='+', action=StoreDictKeyPair,
308 |                            metavar="NAME:ENDPOINT", help="List of model endpoints, formatted as name:endpoint")
309 |         self.subgroup_help['run'] = 'Options to configure the run of the evolutionary algorithm'
310 |         group.add_argument("--run.seeds", type=Path, nargs='+', action='extend',
311 |                            help="Seed files (generator programs that will be mutated)")
312 |         group.add_argument("--run.num_generations", type=int, default=10, help="Number of generations to run")
313 |         selection_choices = [ s.value for s in SelectionStrategy ]
314 |         selection_choices_str = ', '.join(selection_choices)
315 |         group.add_argument("--run.selection_strategy", type=str, default='elites',
316 |                            help=f"Selection strategy (one of: {selection_choices_str})",
317 |                            choices=selection_choices)
318 |         group.add_argument("--run.num_selected", type=int, default=10,
319 |                            help="Number of seeds to select each generation")
320 |         group.add_argument("--run.genvariant_dir", type=str,
321 |                            default='{ELMFUZZ_RUNDIR}/{GEN}/variants/{MODEL}',
322 |                            help="Directory (template) to store generated variants")
323 |         group.add_argument("--run.genoutput_dir", type=str,
324 |                             default='{ELMFUZZ_RUNDIR}/{GEN}/outputs/{MODEL}',
325 |                             help="Directory (template) to store generated outputs")
326 |         group.add_argument("--run.logdir", type=str,
327 |                             default='{ELMFUZZ_RUNDIR}/{GEN}/logs',
328 |                             help="Directory (template) to store logs")
329 |         group.add_argument("--run.clean", action='store_true',
330 |                             help="Clean the output directories before running")
331 | 
332 |         # XXX: For testing only
333 |         # group = self.parser.add_argument_group('Test options', 'Lorem ipsum dolor sit amet')
334 |         # self.subgroup_help['test_a'] = 'Test options for group A'
335 |         # group.add_argument("--test_a.boolopt", action='store_true', help="Test option A (bool)")
336 |         # group.add_argument("--test_a.stropt", type=str, help="Test option A (str)")
337 |         # self.subgroup_help['test_b'] = 'Test options for group B'
338 |         # group.add_argument("--test_b.intopt", type=int, help="Test option B (int)")
339 |         # group.add_argument("--test_b.floatopt", type=float, help="Test option B (float)")
340 | 
341 |     def parse_args(self,
342 |                    args : Optional[Sequence[str]] = None,
343 |                    namespace : Optional[Namespace] = None,
344 |                    nested : bool = True,
345 |                    ) -> Namespace:
346 |         """Parse arguments
347 | 
348 |         :param args: Arguments to parse (default: sys.argv[1:])
349 |         :param namespace: Namespace to store the parsed arguments (default: create a new one)
350 |         :param nested: Whether the returned namespace should be nested to convert key names
351 |             like foo.bar into foo = Namespace(bar=...) (default: True)
352 | 
353 |         Returns a namespace with the parsed arguments.
354 |         """
355 |         # HACK: save the most recently parsed argv/namespace so that dump_config_action
356 |         # can access them
357 |         self._most_recent_args = args
358 |         self._most_recent_namespace = namespace
359 |         self.args = self.parser.parse_args(args, namespace)
360 |         self.config = self.load_config(self.args)
361 |         self.add_config_args(self.args)
362 | 
363 |         # For convenience, nest the namespace we return
364 |         # We keep the original namespace in self.args
365 |         if nested:
366 |             return nest_namespace(self.args)
367 |         else:
368 |             return self.args
369 | 
370 |     def load_config(self, args: Namespace) -> None:
371 |         # Load config file(s)
372 |         if args.config is not None:
373 |             conf = self.yaml.load(args.config)
374 |         else:
375 |             to_merge = [ config_file for config_file in self.config_file_search() if os.path.exists(config_file) ]
376 |             if len(to_merge) == 1:
377 |                 conf = self.yaml.load(open(to_merge[0]).read())
378 |             elif len(to_merge) > 1:
379 |                 conf = self.merge_yaml_files(to_merge)
380 |             else:
381 |                 conf = CommentedMap()
382 |         self.config = conf
383 |         return conf
384 | 
385 |     def config_file_search(self) -> List[str]:
386 |         # NB: The order matters here because later files override earlier files
387 |         config_files = []
388 |         # Check script dir
389 |         script_dir = os.path.dirname(os.path.realpath(__file__))
390 |         config_files.append(
391 |             os.path.join(script_dir, self.default_config_file)
392 |         )
393 |         # Check CWD
394 |         config_files.append(self.default_config_file)
395 |         # Check ELMFUZZ_RUNDIR env var
396 |         if 'ELMFUZZ_RUNDIR' in os.environ:
397 |             config_files.append(
398 |                 os.path.join(os.environ['ELMFUZZ_RUNDIR'], self.default_config_file)
399 |             )
400 |         # Check ELMFUZZ_CONFIG env var
401 |         if 'ELMFUZZ_CONFIG' in os.environ:
402 |             config_files.append(
403 |                 os.environ['ELMFUZZ_CONFIG']
404 |             )
405 |         return config_files
406 | 
407 |     @staticmethod
408 |     def flattened_conf(conf: Dict, prefix='', flatten_lists=False) -> Dict:
409 |         # Flatten config dict, separating nested keys with '.'
410 |         flat_conf = {}
411 |         def _flatten(d, prefix=''):
412 |             for k, v in d.items():
413 |                 if isinstance(v, dict):
414 |                     _flatten(v, prefix + k + '.')
415 |                 elif isinstance(v, list):
416 |                     if flatten_lists:
417 |                         for i, item in enumerate(v):
418 |                             _flatten({str(i): item}, prefix + k + '.')
419 |                     else:
420 |                         flat_conf[prefix + k] = v
421 |                 else:
422 |                     flat_conf[prefix + k] = v
423 |         _flatten(conf, prefix)
424 |         return flat_conf
425 | 
426 |     def add_config_args(self, args: Namespace) -> None:
427 |         """Add config arguments to an existing parser"""
428 |         if self.config is None:
429 |             raise RuntimeError("Config not loaded; call parse_args() first")
430 |         # Flatten config dict
431 |         config_flat = self.flattened_conf(self.config)
432 |         for k, v in config_flat.items():
433 |             # CLI options are in their own section, so will appear as
434 |             #   cli..