├── rustfmt.toml ├── .gitignore ├── checker ├── verified │ ├── .gitignore │ ├── strict_and_parallel.patch │ ├── ROOT │ ├── Checker_Codegen.thy │ ├── document │ │ └── root.tex │ └── Sorting_Network_Bound.thy ├── snocheck │ ├── .gitignore │ ├── Setup.hs │ ├── src │ │ ├── Parallel.hs │ │ ├── ProofStep.hs │ │ ├── Main.hs │ │ ├── Translate.hs │ │ ├── Decode.hs │ │ ├── VectSet.hs │ │ ├── Check.hs │ │ └── Verified │ │ │ └── Checker.hs │ ├── stack.yaml.lock │ ├── package.yaml │ └── stack.yaml └── update_extracted_code.sh ├── src ├── lib.rs ├── huffman.rs ├── logging.rs ├── fix.rs ├── prune.rs ├── bin │ └── sortnetopt.rs ├── thread_pool.rs ├── proof.rs ├── output_set │ ├── subsume.rs │ ├── canon.rs │ ├── index.rs │ └── index │ │ └── tree.rs ├── search │ └── states.rs ├── search.rs └── output_set.rs ├── verify_proof_cert.sh ├── search_and_verify.sh ├── Cargo.toml └── README.md /rustfmt.toml: -------------------------------------------------------------------------------- 1 | merge_imports = true 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | -------------------------------------------------------------------------------- /checker/verified/.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | /export 3 | /output 4 | -------------------------------------------------------------------------------- /checker/snocheck/.gitignore: -------------------------------------------------------------------------------- 1 | .stack-work/ 2 | *.cabal 3 | *.prof 4 | -------------------------------------------------------------------------------- /checker/snocheck/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /checker/snocheck/src/Parallel.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE Trustworthy #-} 2 | module Parallel 3 | ( par 4 | ) 5 | where 6 | 7 | import Control.Parallel 8 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod fix; 2 | pub mod huffman; 3 | pub mod logging; 4 | pub mod output_set; 5 | pub mod proof; 6 | pub mod prune; 7 | pub mod search; 8 | pub mod thread_pool; 9 | -------------------------------------------------------------------------------- /checker/verified/strict_and_parallel.patch: -------------------------------------------------------------------------------- 1 | @@ -14,2 +14,3 @@ 2 | import qualified Prelude; 3 | +import qualified Parallel; 4 | 5 | @@ -72,3 +73,3 @@ 6 | 7 | -data Vect_trie = VtEmpty | VtNode Bool Vect_trie Vect_trie; 8 | +data Vect_trie = VtEmpty | VtNode !Bool !Vect_trie !Vect_trie; 9 | 10 | @@ -121,3 +122,3 @@ 11 | par :: forall a b. a -> b -> b; 12 | -par a b = b; 13 | +par = Parallel.par; 14 | 15 | -------------------------------------------------------------------------------- /checker/verified/ROOT: -------------------------------------------------------------------------------- 1 | session Sorting_Networks = HOL + 2 | options [document = pdf, document_output = "output", 3 | document_variants = "document:outline=/proof,/ML"] 4 | sessions 5 | "HOL-Library" 6 | theories 7 | Sorting_Network_Bound 8 | Huffman 9 | Sorting_Network 10 | Checker 11 | Checker_Codegen 12 | document_files 13 | "root.tex" 14 | export_files 15 | "Sorting_Networks.Checker_Codegen:**" 16 | -------------------------------------------------------------------------------- /src/huffman.rs: -------------------------------------------------------------------------------- 1 | use std::{cmp::Reverse, collections::BinaryHeap}; 2 | 3 | pub fn max_plus_1_huffman(values: &[u8]) -> u8 { 4 | let mut heap: BinaryHeap> = values.iter().map(|&v| Reverse(v)).collect(); 5 | 6 | while let Some(Reverse(first)) = heap.pop() { 7 | if let Some(Reverse(second)) = heap.pop() { 8 | heap.push(Reverse(first.max(second) + 1)); 9 | } else { 10 | return first; 11 | } 12 | } 13 | 0 14 | } 15 | -------------------------------------------------------------------------------- /checker/snocheck/stack.yaml.lock: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by Stack. 2 | # You should not edit this file by hand. 3 | # For more information, please see the documentation at: 4 | # https://docs.haskellstack.org/en/stable/lock_files 5 | 6 | packages: [] 7 | snapshots: 8 | - completed: 9 | size: 524130 10 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/14/10.yaml 11 | sha256: 04841e0397d7ce6364f19359fce89aa45ca8684609d665845db80d398002baa9 12 | original: lts-14.10 13 | -------------------------------------------------------------------------------- /checker/snocheck/package.yaml: -------------------------------------------------------------------------------- 1 | name: snocheck 2 | version: 0.1.0.0 3 | #synopsis: 4 | #description: 5 | #homepage: 6 | #license: 7 | author: Jannis Harder 8 | maintainer: me@jix.one 9 | copyright: 2019 Jannis Harder 10 | #category: Web 11 | 12 | dependencies: 13 | - base >= 4.7 && < 5 14 | - bytestring 15 | - containers 16 | - parallel 17 | 18 | executables: 19 | snocheck: 20 | source-dirs: src 21 | ghc-options: -threaded -rtsopts -with-rtsopts=-N 22 | main: Main.hs 23 | -------------------------------------------------------------------------------- /verify_proof_cert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -z "$1" ]]; then 4 | echo "usage: bash verify_proof_cert.sh CERTBIN" 5 | exit 1 6 | fi 7 | 8 | CERTBIN="$1" 9 | 10 | set -euo pipefail 11 | 12 | SWD=$(pwd) 13 | 14 | cd "$(dirname "$0")" 15 | 16 | cd checker/snocheck 17 | 18 | echo 'Running verified proof certificate checker...' 19 | echo 20 | echo 'An output of `None` indicates verification failure.' 21 | echo 'An output of `Some (w,b)` indicates successful verification of the' 22 | echo 'lower bound b for w channels.' 23 | echo 24 | 25 | stack run --cwd "$SWD" -- -v "$CERTBIN" 26 | -------------------------------------------------------------------------------- /checker/update_extracted_code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | cd "$(dirname "$0")" 6 | 7 | cd verified 8 | 9 | echo "Running isabelle build ..." 10 | isabelle build -D . -e 11 | 12 | echo "-- Generated using update_extracted_code.sh, do not edit --" \ 13 | > ../snocheck/src/Verified/Checker.hs 14 | 15 | echo "Patch extracted source..." 16 | patch -o- \ 17 | export/Sorting_Networks.Checker_Codegen/code/checker/Verified/Checker.hs \ 18 | strict_and_parallel.patch >> ../snocheck/src/Verified/Checker.hs 19 | 20 | echo <<'EOM' 21 | The patch adds strictness annotations to a datatype and replaces 22 | `par a b = b` with `par = Control.Parallel.par`. 23 | 24 | Both changes only affect the evaluation order. 25 | EOM 26 | -------------------------------------------------------------------------------- /checker/snocheck/src/ProofStep.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE Safe #-} 2 | module ProofStep 3 | ( ProofStep(..) 4 | , Witnesses(..) 5 | , Witness(..) 6 | , witnessList 7 | ) 8 | where 9 | 10 | import VectSet ( VectSet ) 11 | import qualified VectSet as VS 12 | data ProofStep = ProofStep 13 | { vectSet :: !VectSet 14 | , bound :: !Int 15 | , witnesses :: !Witnesses 16 | } deriving (Show) 17 | 18 | data Witnesses = Huffman Bool [Maybe Witness] 19 | | Successors [Maybe Witness] 20 | deriving (Show) 21 | 22 | witnessList :: Witnesses -> [Maybe Witness] 23 | witnessList (Huffman _ ws ) = ws 24 | witnessList (Successors ws) = ws 25 | 26 | data Witness = Witness 27 | { stepId :: !Int 28 | , invert :: !Bool 29 | , perm :: ![Int] 30 | } deriving (Show) 31 | 32 | -------------------------------------------------------------------------------- /search_and_verify.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -z "$2" ]]; then 4 | echo "usage: bash search_and_verify.sh CHANNELS DATADIR" 5 | exit 1 6 | fi 7 | 8 | set -euo pipefail 9 | 10 | CHANNELS="$1" 11 | DATADIR=$(realpath "$2") 12 | 13 | cd "$(dirname "$0")" 14 | 15 | if [[ ! -d "$DATADIR" ]]; then 16 | mkdir "$DATADIR" 17 | fi 18 | 19 | echo "Building sortnetopt binary..." 20 | 21 | cargo build --release 22 | 23 | echo "Searching lower bound..." 24 | 25 | cargo run --release -- -m search "$CHANNELS" "$DATADIR/_search_$CHANNELS" 26 | 27 | echo "Pruning search log..." 28 | 29 | cargo run --release -- -m prune-all "$DATADIR/_search_$CHANNELS" 30 | 31 | echo "Generating proof certificate..." 32 | 33 | cargo run --release -- -m gen-proof "$DATADIR/_search_$CHANNELS" 34 | 35 | echo "Verifying proof certificate..." 36 | 37 | bash verify_proof_cert.sh "$DATADIR/_search_$CHANNELS/proof.bin" 38 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "sortnetopt" 3 | version = "0.1.0" 4 | authors = ["Jannis Harder "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [profile.release] 10 | debug = true # Enables profiling 11 | 12 | [profile.bench] 13 | debug = true # Enables profiling 14 | 15 | [dependencies] 16 | log = "0.4.11" 17 | env_logger = "0.7.1" 18 | better-panic = "0.2.0" 19 | jemallocator = "0.3.2" 20 | structopt = "0.3.16" 21 | arrayvec = "0.5.1" 22 | rustc-hash = "1.1.0" 23 | arrayref = "0.3.6" 24 | byteorder = "1.3.4" 25 | num_cpus = "1.13.0" 26 | async-std = "1.6.3" 27 | async-task = "3.0.0" 28 | futures = "0.3.5" 29 | crossbeam-channel = "0.4.3" 30 | abort_on_panic = "2.0.0" 31 | scopeguard = "1.1.0" 32 | pin-utils = "0.1.0" 33 | memmap = "0.7.0" 34 | rayon = "1.3.1" 35 | indicatif = "0.15.0" 36 | 37 | [dev-dependencies] 38 | proptest = "0.10.1" 39 | -------------------------------------------------------------------------------- /checker/verified/Checker_Codegen.thy: -------------------------------------------------------------------------------- 1 | theory Checker_Codegen 2 | imports Main Sorting_Network Checker "HOL-Library.Code_Target_Numeral" 3 | begin 4 | 5 | lemma check_proof_get_bound_spec: 6 | assumes \check_proof_get_bound cert = Some (width, bound)\ 7 | shows \lower_size_bound (nat width) (nat bound)\ 8 | using assms by (rule Checker.check_proof_get_bound_spec) 9 | 10 | definition nat_pred_code :: \nat \ nat\ where 11 | \nat_pred_code n = (case n of 0 \ nat.pred 0 | Suc n' \ n')\ 12 | 13 | lemma nat_pred_code[code]: \nat.pred = nat_pred_code\ 14 | by (rule; metis nat_pred_code_def old.nat.simps(4) pred_def) 15 | 16 | export_code 17 | check_proof_get_bound integer_of_int int_of_integer 18 | ProofCert ProofStep HuffmanWitnesses SuccessorWitnesses ProofWitness 19 | in Haskell module_name Verified.Checker file_prefix "checker" 20 | 21 | end -------------------------------------------------------------------------------- /checker/snocheck/src/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE Safe #-} 2 | module Main where 3 | 4 | import System.Environment 5 | import qualified Data.ByteString as B 6 | import Decode 7 | import VectSet 8 | import Check 9 | import Translate 10 | import qualified Verified.Checker as VC 11 | 12 | main :: IO () 13 | main = do 14 | args <- getArgs 15 | case args of 16 | ["-r", proofFileName] -> 17 | print . checkProof . proofSteps =<< B.readFile proofFileName 18 | ["-v", proofFileName] -> do 19 | proofData <- B.readFile proofFileName 20 | let steps = proofSteps proofData 21 | vcSteps = translateProofSteps steps 22 | result = do 23 | (width, bound) <- VC.check_proof_get_bound vcSteps 24 | return (VC.integer_of_int width, VC.integer_of_int bound) 25 | print result 26 | _ -> do 27 | putStrLn "usage (reference checker): snocheck -r PROOF" 28 | putStrLn " (verified checker): snocheck -v PROOF" 29 | -------------------------------------------------------------------------------- /checker/verified/document/root.tex: -------------------------------------------------------------------------------- 1 | \documentclass[11pt,a4paper]{article} 2 | \usepackage{isabelle,isabellesym} 3 | 4 | % further packages required for unusual symbols (see also 5 | % isabellesym.sty), use only when needed 6 | 7 | %\usepackage{amssymb} 8 | %for \, \, \, \, \, \, 9 | %\, \, \, \, \, 10 | %\, \, \ 11 | 12 | %\usepackage{eurosym} 13 | %for \ 14 | 15 | %\usepackage[only,bigsqcap]{stmaryrd} 16 | %for \ 17 | 18 | %\usepackage{eufrak} 19 | %for \ ... \, \ ... \ (also included in amssymb) 20 | 21 | %\usepackage{textcomp} 22 | %for \, \, \, \, \, 23 | %\ 24 | 25 | % this should be the last package used 26 | \usepackage{pdfsetup} 27 | 28 | % urls in roman style, theory text in math-similar italics 29 | \urlstyle{rm} 30 | \isabellestyle{it} 31 | 32 | % for uniform font size 33 | %\renewcommand{\isastyle}{\isastyleminor} 34 | 35 | \begin{document} 36 | 37 | \title{Verified Checker for Sorting Network Size Bounds} 38 | \author{Jannis Harder} 39 | \maketitle 40 | 41 | \tableofcontents 42 | 43 | % sane default for proof documents 44 | \parindent 0pt\parskip 0.5ex 45 | 46 | % generated text of all theories 47 | \input{session} 48 | 49 | % optional bibliography 50 | %\bibliographystyle{abbrv} 51 | %\bibliography{root} 52 | 53 | \end{document} 54 | 55 | %%% Local Variables: 56 | %%% mode: latex 57 | %%% TeX-master: t 58 | %%% End: 59 | -------------------------------------------------------------------------------- /checker/snocheck/src/Translate.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE Safe #-} 2 | module Translate 3 | ( translateProofSteps 4 | ) 5 | where 6 | 7 | import qualified Verified.Checker as VC 8 | import ProofStep 9 | import VectSet ( VectSet ) 10 | import qualified VectSet as VS 11 | 12 | translateProofSteps :: (Int, Int -> ProofStep) -> VC.Proof_cert 13 | translateProofSteps (stepCount, steps) = 14 | VC.ProofCert (VC.Int_of_integer $ toInteger stepCount) (translateStepFn steps) 15 | 16 | translateStepFn :: (Int -> ProofStep) -> VC.Int -> VC.Proof_step 17 | translateStepFn steps (VC.Int_of_integer stepId) = 18 | translateStep (steps $ fromInteger stepId) 19 | 20 | translateStep :: ProofStep -> VC.Proof_step 21 | translateStep step = VC.ProofStep width vectList stepBound stepWitnesses 22 | where 23 | width = VC.Int_of_integer . toInteger $ VS.channels (vectSet step) 24 | vectList = VS.asBoolVectList (vectSet step) 25 | stepBound = VC.Int_of_integer . toInteger $ bound step 26 | stepWitnesses = translateWitnesses (witnesses step) 27 | 28 | translateWitnesses :: Witnesses -> VC.Proof_step_witnesses 29 | translateWitnesses (Huffman pol witnesses) = 30 | VC.HuffmanWitnesses pol (translateWitnessList witnesses) 31 | translateWitnesses (Successors witnesses) = 32 | VC.SuccessorWitnesses (translateWitnessList witnesses) 33 | 34 | translateWitnessList :: [Maybe Witness] -> [Maybe VC.Proof_witness] 35 | translateWitnessList = map (fmap translateWitness) 36 | 37 | translateWitness :: Witness -> VC.Proof_witness 38 | translateWitness witness = VC.ProofWitness 39 | (VC.Int_of_integer . toInteger $ stepId witness) 40 | (invert witness) 41 | (map (VC.Int_of_integer . toInteger) $ perm witness) 42 | -------------------------------------------------------------------------------- /src/logging.rs: -------------------------------------------------------------------------------- 1 | use std::io::Write; 2 | 3 | pub fn setup(print_mem_usage: bool) { 4 | better_panic::install(); 5 | let startup = std::time::Instant::now(); 6 | 7 | let _ = env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) 8 | .format(move |buf, record| { 9 | let elapsed = startup.elapsed().as_millis(); 10 | let minutes = elapsed / 60000; 11 | let seconds = (elapsed % 60000) / 1000; 12 | let millis = elapsed % 1000; 13 | 14 | if print_mem_usage { 15 | let (mem, unit) = get_mem_usage(); 16 | writeln!( 17 | buf, 18 | "{:3}:{:02}.{:03} [{:4} {}]: {}", 19 | minutes, 20 | seconds, 21 | millis, 22 | mem, 23 | unit, 24 | record.args() 25 | ) 26 | } else { 27 | writeln!( 28 | buf, 29 | "{:3}:{:02}.{:03}: {}", 30 | minutes, 31 | seconds, 32 | millis, 33 | record.args() 34 | ) 35 | } 36 | }) 37 | .target(env_logger::Target::Stdout) 38 | .is_test(cfg!(test)) 39 | .try_init(); 40 | } 41 | 42 | fn get_mem_usage() -> (usize, &'static str) { 43 | let prefix = "VmHWM:"; 44 | if let Some(mem_line) = std::fs::read_to_string("/proc/self/status") 45 | .unwrap() 46 | .lines() 47 | .filter(|line| line.starts_with(prefix)) 48 | .next() 49 | { 50 | let usage_kb_str = mem_line[prefix.len()..].trim(); 51 | let usage_kb = str::parse::(&usage_kb_str[..usage_kb_str.len() - 3]).unwrap(); 52 | 53 | if usage_kb > 10 * 1024 * 1024 { 54 | return (usage_kb / (1024 * 1024), "G"); 55 | } else if usage_kb > 10 * 1024 { 56 | return (usage_kb / 1024, "M"); 57 | } else { 58 | return (usage_kb, "k"); 59 | } 60 | } 61 | 62 | (0, "?") 63 | } 64 | -------------------------------------------------------------------------------- /checker/verified/Sorting_Network_Bound.thy: -------------------------------------------------------------------------------- 1 | theory Sorting_Network_Bound 2 | imports Main 3 | begin 4 | 5 | text \Due to the 0-1-principle we're only concerned with Boolean sequences. While we're interested 6 | in sorting vectors of a given fixed width, it is advantageous to represent them as a function from 7 | the naturals to Booleans.\ 8 | 9 | type_synonym bseq = \nat \ bool\ 10 | 11 | text \To represent Boolean sequences of a fixed length, we extend them with True to infinity. This 12 | way monotonicity of a fixed length sequence corresponds to monotonicity of our representation.\ 13 | 14 | definition fixed_len_bseq :: \nat \ bseq \ bool\ where 15 | \fixed_len_bseq n x = (\i \ n. x i = True)\ 16 | 17 | text \A comparator is represented as an ordered pair of channel indices. Applying a comparator to a 18 | sequence will order the values of the two channels so that the channel corresponding to the first 19 | index receives the smaller value.\ 20 | 21 | type_synonym cmp = \nat \ nat\ 22 | 23 | definition apply_cmp :: \cmp \ bseq \ bseq\ where 24 | \apply_cmp c x = ( 25 | let (i, j) = c 26 | in x( 27 | i := min (x i) (x j), 28 | j := max (x i) (x j) 29 | ) 30 | )\ 31 | 32 | text \A lower size bound for a partial sorting network on a given set of input sequences is the 33 | number of comparators required for any comparator network that is able to sort every sequence of the 34 | given set.\ 35 | 36 | definition partial_lower_size_bound :: \bseq set \ nat \ bool\ where 37 | \partial_lower_size_bound X k = (\cn. (\x \ X. mono (fold apply_cmp cn x)) \ length cn \ k)\ 38 | 39 | text \A lower size bound for a sorting network on n channels is the same as a lower size bound for a 40 | partial sorting network on all length n sequences.\ 41 | 42 | definition lower_size_bound :: \nat \ nat \ bool\ where 43 | \lower_size_bound n k = partial_lower_size_bound {x. fixed_len_bseq n x} k\ 44 | 45 | end -------------------------------------------------------------------------------- /checker/snocheck/stack.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # 15 | # The location of a snapshot can be provided as a file or url. Stack assumes 16 | # a snapshot provided as a file might change, whereas a url resource does not. 17 | # 18 | # resolver: ./custom-snapshot.yaml 19 | # resolver: https://example.com/snapshots/2018-01-01.yaml 20 | resolver: lts-14.10 21 | 22 | # User packages to be built. 23 | # Various formats can be used as shown in the example below. 24 | # 25 | # packages: 26 | # - some-directory 27 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 28 | # subdirs: 29 | # - auto-update 30 | # - wai 31 | packages: 32 | - . 33 | # Dependency packages to be pulled from upstream that are not in the resolver. 34 | # These entries can reference officially published versions as well as 35 | # forks / in-progress versions pinned to a git hash. For example: 36 | # 37 | # extra-deps: 38 | # - acme-missiles-0.3 39 | # - git: https://github.com/commercialhaskell/stack.git 40 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 41 | # 42 | # extra-deps: [] 43 | 44 | # Override default flag values for local packages and extra-deps 45 | # flags: {} 46 | 47 | # Extra package databases containing global packages 48 | # extra-package-dbs: [] 49 | 50 | # Control whether we use the GHC we find on the path 51 | # system-ghc: true 52 | # 53 | # Require a specific version of stack, using version ranges 54 | # require-stack-version: -any # Default 55 | # require-stack-version: ">=2.1" 56 | # 57 | # Override the architecture used by stack, especially useful on Windows 58 | # arch: i386 59 | # arch: x86_64 60 | # 61 | # Extra directories used by stack for building 62 | # extra-include-dirs: [/path/to/dir] 63 | # extra-lib-dirs: [/path/to/dir] 64 | # 65 | # Allow a newer minor version of GHC than the snapshot specifies 66 | # compiler-check: newer-minor 67 | -------------------------------------------------------------------------------- /src/fix.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::BTreeMap, 3 | fs::File, 4 | io::{BufRead, BufReader, BufWriter, Read, Write}, 5 | path::PathBuf, 6 | }; 7 | 8 | use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; 9 | use memmap::MmapOptions; 10 | 11 | use crate::output_set::OutputSet; 12 | 13 | pub fn fix_proof(input: PathBuf) { 14 | let mut pruned_sets = BTreeMap::default(); 15 | 16 | let index = input.join("index.txt"); 17 | let index_file = BufReader::new(File::open(index).unwrap()); 18 | for line in index_file.lines() { 19 | let line = line.unwrap(); 20 | let channels = str::parse::(line.split('_').nth(1).unwrap()).unwrap(); 21 | let bound = 22 | str::parse::(line.split('_').nth(2).unwrap().split('.').next().unwrap()).unwrap(); 23 | 24 | log::info!("loading {}_{}", channels, bound); 25 | 26 | let pruned_path = input.join(line).with_extension("pbin"); 27 | 28 | let pruned_file = File::open(pruned_path).unwrap(); 29 | let pruned_data = unsafe { MmapOptions::new().map(&pruned_file).unwrap() }; 30 | 31 | let stride = OutputSet::packed_len_for_channels(channels); 32 | for packed in pruned_data.chunks(stride) { 33 | pruned_sets.insert((channels, packed.to_owned()), bound); 34 | } 35 | } 36 | 37 | let proof_path = input.join("proof.old.bin"); 38 | let mut proof_file = BufReader::new(File::open(proof_path).unwrap()); 39 | 40 | let proof_fix_path = input.join("proof.bin"); 41 | if proof_fix_path.exists() { 42 | panic!("proof.bin already exists"); 43 | } 44 | let mut proof_fix_file = BufWriter::new(File::create(proof_fix_path).unwrap()); 45 | 46 | let step_count = proof_file.read_u32::().unwrap(); 47 | 48 | proof_fix_file 49 | .write_u32::(step_count) 50 | .unwrap(); 51 | 52 | let mut lengths = vec![]; 53 | 54 | for step in 0..step_count { 55 | let offset = proof_file.read_u64::().unwrap(); 56 | let len = proof_file.read_u32::().unwrap(); 57 | 58 | lengths.push(len); 59 | 60 | proof_fix_file 61 | .write_u64::(offset + step as u64) 62 | .unwrap(); 63 | 64 | proof_fix_file.write_u32::(len + 1).unwrap(); 65 | } 66 | 67 | for len in lengths { 68 | let mut buffer = vec![]; 69 | (&mut proof_file) 70 | .take(len as u64) 71 | .read_to_end(&mut buffer) 72 | .unwrap(); 73 | 74 | let channels = buffer[0] as usize; 75 | let packed = buffer[1..][..OutputSet::packed_len_for_channels(channels)].to_owned(); 76 | let bound = *pruned_sets.get(&(channels, packed)).expect("set not found"); 77 | 78 | proof_fix_file.write_u8(buffer[0]).unwrap(); 79 | proof_fix_file.write_u8(bound).unwrap(); 80 | proof_fix_file.write_all(&buffer[1..]).unwrap(); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /checker/snocheck/src/Decode.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE Safe #-} 2 | module Decode 3 | ( proofSteps 4 | ) 5 | where 6 | 7 | import Data.Bits 8 | import Data.Word 9 | import Data.ByteString ( ByteString ) 10 | import qualified Data.ByteString as B 11 | import VectSet ( VectSet ) 12 | import qualified VectSet as VS 13 | import ProofStep 14 | 15 | readLE :: Int -> ByteString -> Int 16 | readLE len bytes = sum 17 | [ (`shift` (8 * i)) . fromIntegral $ B.index bytes i | i <- [0 .. len - 1] ] 18 | 19 | read32LE :: ByteString -> Int 20 | read32LE = readLE 4 21 | 22 | read64LE :: ByteString -> Int 23 | read64LE = readLE 8 24 | 25 | bits :: Word8 -> [Bool] 26 | bits x = [ testBit x i | i <- [0 .. 7] ] 27 | 28 | proofSteps :: ByteString -> (Int, Int -> ProofStep) 29 | proofSteps proofData = (stepCount, decodeProofStep . encodedStep) 30 | where (stepCount, encodedStep) = encodedProofSteps proofData 31 | 32 | encodedProofSteps :: ByteString -> (Int, Int -> ByteString) 33 | encodedProofSteps proofData = (stepCount, stepData) 34 | where 35 | stepCount = read32LE proofData 36 | stepData step = 37 | let stepHeader = B.drop (4 + 12 * step) proofData 38 | offset = read64LE stepHeader 39 | length = read32LE $ B.drop 8 stepHeader 40 | in B.take length $ B.drop offset proofData 41 | 42 | decodeProofStep :: ByteString -> ProofStep 43 | decodeProofStep stepData = ProofStep { vectSet = vectSet 44 | , bound = bound 45 | , witnesses = witnesses 46 | } 47 | where 48 | channels = fromIntegral (B.index stepData 0) 49 | bound = fromIntegral (B.index stepData 1) 50 | encodedSetLength = bit (0 `max` (channels - 3)) 51 | encodedSetBytes = B.take encodedSetLength (B.drop 2 stepData) 52 | vectSet = VS.fromBytes channels encodedSetBytes 53 | witnessData = B.drop (encodedSetLength + 2) stepData 54 | (witnessType, witnessChannels) = case B.index witnessData 0 of 55 | 0 -> (Huffman False, channels - 1) 56 | 1 -> (Huffman True, channels - 1) 57 | 2 -> (Successors, channels) 58 | witnesses = 59 | witnessType . decodeWitnesses witnessChannels $ B.drop 1 witnessData 60 | 61 | decodeWitnesses :: Int -> ByteString -> [Maybe Witness] 62 | decodeWitnesses channels bs = case B.uncons bs of 63 | Nothing -> [] 64 | Just (0, tail) -> decodeWitness False tail 65 | Just (1, tail) -> decodeWitness True tail 66 | Just (2, tail) -> (Nothing :) $! decodeWitnesses channels tail 67 | where 68 | decodeWitness invert bs = 69 | let (perm, tail) = B.splitAt channels bs 70 | in (( Just 71 | $! (Witness { invert = invert 72 | , perm = map fromIntegral . B.unpack $ perm 73 | , stepId = read32LE tail 74 | } 75 | ) 76 | ) : 77 | ) 78 | $! decodeWitnesses channels (B.drop 4 tail) 79 | -------------------------------------------------------------------------------- /checker/snocheck/src/VectSet.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE Safe, BangPatterns #-} 2 | module VectSet 3 | ( VectSet(channels) 4 | , allVects 5 | , fromBytes 6 | , applyComp 7 | , size 8 | , permute 9 | , invert 10 | , subsumes 11 | , extremalChannels 12 | , pruneExtremal 13 | , asBoolVectList 14 | ) 15 | where 16 | 17 | import Data.Bits 18 | import Control.Monad 19 | import Data.IntSet ( IntSet ) 20 | import qualified Data.IntSet as IS 21 | import Data.ByteString ( ByteString ) 22 | import qualified Data.ByteString as B 23 | 24 | data VectSet = VectSet 25 | { channels :: !Int 26 | , vects :: !IntSet 27 | } deriving (Show, Eq) 28 | 29 | allVects :: Int -> VectSet 30 | allVects channels = 31 | VectSet { channels = channels, vects = IS.fromList [0 .. (bit channels) - 1] } 32 | 33 | fromBytes :: Int -> ByteString -> VectSet 34 | fromBytes channels bytes = VectSet { channels = channels, vects = vs } 35 | where 36 | vs = fst $ B.foldl' addByte (IS.empty, 0) bytes 37 | addByte (!set, !pos) 0 = (set, pos + 8) 38 | addByte (!set, !pos) n = 39 | addByte (IS.insert (pos + countTrailingZeros n) set, pos) (n .&. (n - 1)) 40 | 41 | applyComp :: Int -> Int -> VectSet -> Maybe VectSet 42 | applyComp a b vs = guard irredundant 43 | >> return VectSet { channels = channels vs, vects = vs' } 44 | where 45 | aMask = bit a 46 | bMask = bit b 47 | abMask = aMask .|. bMask 48 | have xMask = any ((== xMask) . (.&. abMask)) (IS.toList $ vects vs) 49 | irredundant = have aMask && have bMask 50 | vs' = IS.map (\v -> if v .&. abMask == bMask then v `xor` abMask else v) 51 | $ vects vs 52 | 53 | size :: VectSet -> Int 54 | size = IS.size . vects 55 | 56 | permute :: [Int] -> VectSet -> VectSet 57 | permute perm vs = VectSet { channels = channels vs 58 | , vects = IS.map permVect $ vects vs 59 | } 60 | where 61 | permVect v = 62 | foldr (\ !p !a -> (a `shiftL` 1) .|. fromEnum (testBit v p)) 0 perm 63 | 64 | invert :: VectSet -> VectSet 65 | invert vs = VectSet { channels = channels vs, vects = vs' } 66 | where 67 | mask = (bit $ channels vs) - 1 68 | vs' = IS.map (xor mask) $ vects vs 69 | 70 | subsumes :: VectSet -> VectSet -> Bool 71 | subsumes a b | channels a == channels b = IS.isSubsetOf (vects a) (vects b) 72 | 73 | extremalChannels :: Bool -> VectSet -> [Int] 74 | extremalChannels pol vs = 75 | [ i | i <- [0 .. channels vs - 1], IS.member (bit i `xor` flip) (vects vs) ] 76 | where flip = if pol then (bit $ channels vs) - 1 else 0 77 | 78 | pruneExtremal :: Bool -> Int -> VectSet -> VectSet 79 | pruneExtremal pol a vs | channels vs > 0 = VectSet { channels = channels vs - 1 80 | , vects = vs' 81 | } 82 | where 83 | mask = bit a 84 | target = if pol then 0 else mask 85 | maskLow = (bit a) - 1 86 | maskHigh = complement maskLow 87 | vs' = 88 | IS.map (\v -> v .&. maskLow .|. (v `shiftR` 1) .&. maskHigh) 89 | . IS.filter (\v -> v .&. mask == target) 90 | $ vects vs 91 | 92 | asBoolVectList :: VectSet -> [[Bool]] 93 | asBoolVectList vs = map toBoolVect (IS.toList (vects vs)) 94 | where toBoolVect v = [testBit v i | i <- [0..channels vs - 1]] 95 | -------------------------------------------------------------------------------- /src/prune.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::BTreeMap, 3 | fs::{self, File}, 4 | io::{BufRead, BufReader, BufWriter, Write}, 5 | path::PathBuf, 6 | }; 7 | 8 | use indicatif::{ProgressBar, ProgressStyle}; 9 | use memmap::{Mmap, MmapOptions}; 10 | use rayon::prelude::*; 11 | 12 | use crate::output_set::{ 13 | index::{LowerInvert, OutputSetIndex}, 14 | OutputSet, 15 | }; 16 | 17 | struct PackedOutputSets { 18 | stride: usize, 19 | packed_buf: Mmap, 20 | } 21 | 22 | impl PackedOutputSets { 23 | fn new(channels: usize, packed_buf: Mmap) -> Self { 24 | let stride = OutputSet::packed_len_for_channels(channels); 25 | Self { stride, packed_buf } 26 | } 27 | 28 | fn iter_packed<'a>(&'a self) -> impl Iterator { 29 | self.packed_buf.chunks(self.stride) 30 | } 31 | 32 | fn get_packed(&self, index: usize) -> &[u8] { 33 | &self.packed_buf[self.stride * index..][..self.stride] 34 | } 35 | } 36 | 37 | pub fn prune(channels: usize, input: PathBuf) { 38 | let input_file = File::open(&input).unwrap(); 39 | let input_data = unsafe { MmapOptions::new().map(&input_file).unwrap() }; 40 | 41 | let output_path = input.with_extension("pbin"); 42 | if output_path.exists() { 43 | log::info!("{:?} already exists", output_path); 44 | return; 45 | } 46 | let output_tmp_path = input.with_extension("tmp"); 47 | let mut output_file = BufWriter::new(File::create(&output_tmp_path).unwrap()); 48 | 49 | let packed_output_sets = PackedOutputSets::new(channels, input_data); 50 | 51 | let mut by_state_len = BTreeMap::>::default(); 52 | 53 | let mut total_input_len = 0; 54 | 55 | for (i, packed) in packed_output_sets.iter_packed().enumerate() { 56 | let state_len = packed 57 | .iter() 58 | .map(|byte| byte.count_ones() as usize) 59 | .sum::(); 60 | 61 | by_state_len 62 | .entry(state_len) 63 | .or_insert_with(|| vec![]) 64 | .push(i); 65 | 66 | total_input_len += 1; 67 | } 68 | 69 | let mut index = OutputSetIndex::::new(channels); 70 | 71 | let template = "{elapsed_precise} [{wide_bar:.green/blue}] {percent}% {pos}/{len} {eta}"; 72 | let bar = ProgressBar::new(total_input_len); 73 | bar.set_style( 74 | ProgressStyle::default_bar() 75 | .template(template) 76 | .progress_chars("#>-"), 77 | ); 78 | 79 | for (state_len, ids) in by_state_len { 80 | log::info!("len = {}, count = {}", state_len, ids.len()); 81 | log::info!(""); 82 | 83 | let ids = ids 84 | .into_par_iter() 85 | .filter(|&id| { 86 | bar.inc(1); 87 | let packed = packed_output_sets.get_packed(id); 88 | let output_set = OutputSet::from_packed(channels, packed); 89 | 90 | index 91 | .lookup_with_abstraction(output_set.as_ref(), &output_set.abstraction()) 92 | .is_none() 93 | }) 94 | .collect::>(); 95 | 96 | for id in ids { 97 | let packed = packed_output_sets.get_packed(id); 98 | output_file.write_all(packed).unwrap(); 99 | let output_set = OutputSet::from_packed(channels, packed); 100 | 101 | index.insert_new_unchecked_with_abstraction( 102 | output_set.as_ref(), 103 | &output_set.abstraction(), 104 | 0, 105 | ); 106 | } 107 | 108 | log::info!("subsumed_len = {}", index.len()); 109 | } 110 | 111 | drop(output_file); 112 | fs::rename(output_tmp_path, output_path).unwrap(); 113 | } 114 | 115 | pub fn prune_all(input: PathBuf) { 116 | let index = input.join("index.txt"); 117 | let index_file = BufReader::new(File::open(index).unwrap()); 118 | for line in index_file.lines() { 119 | let line = line.unwrap(); 120 | let channels = str::parse::(line.split('_').nth(1).unwrap()).unwrap(); 121 | log::info!("pruning {}", line); 122 | 123 | prune(channels, input.join(line)); 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /checker/snocheck/src/Check.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE Safe #-} 2 | module Check 3 | ( checkProof 4 | ) 5 | where 6 | 7 | import Control.Monad 8 | import Data.Foldable 9 | import Data.Maybe 10 | import Data.Traversable 11 | import Data.IntMap.Strict ( IntMap ) 12 | import qualified Data.IntMap.Strict as IM 13 | import ProofStep 14 | import VectSet ( VectSet ) 15 | import qualified VectSet as VS 16 | import Parallel 17 | 18 | type Result a = Either String a 19 | 20 | type ProofSteps = Int -> Either String ProofStep 21 | 22 | check :: String -> Bool -> Result () 23 | check msg cond = if cond then return () else Left msg 24 | 25 | checkProof :: (Int, Int -> ProofStep) -> Result (Int, Int) 26 | checkProof (stepCount, stepFn) = do 27 | parCheckRange checkStep' 0 stepCount 28 | let lastStep = stepFn $ stepCount - 1 29 | return (VS.channels $ vectSet lastStep, bound lastStep) 30 | where 31 | steps n | 0 <= n && n < stepCount = return $ stepFn n 32 | | otherwise = Left "step id out of bounds" 33 | 34 | checkStep' stepId = case checkStep steps stepId of 35 | Left err -> Left $ "step " ++ show stepId ++ ": " ++ err 36 | x -> x 37 | 38 | parCheckRange :: (Int -> Result ()) -> Int -> Int -> Result () 39 | parCheckRange checkFn low high = if (high - low) < 1000 40 | then traverse_ checkFn [low .. high - 1] 41 | else 42 | let mid = (low + high) `div` 2 43 | a = parCheckRange checkFn low mid 44 | b = parCheckRange checkFn mid high 45 | in par b a >> b 46 | 47 | checkStep :: ProofSteps -> Int -> Result () 48 | checkStep steps stepId = do 49 | step <- steps stepId 50 | case witnesses step of 51 | Huffman pol witnesses -> checkHuffman steps' step pol witnesses 52 | Successors witnesses -> checkSuccessors steps' step witnesses 53 | where 54 | steps' i = do 55 | check "witness step id out of bounds" (i < stepId) 56 | steps i 57 | 58 | checkHuffman :: ProofSteps -> ProofStep -> Bool -> [Maybe Witness] -> Result () 59 | checkHuffman steps step pol witnesses = do 60 | let vs = vectSet step 61 | channels = VS.channels vs 62 | extremalChannels = VS.extremalChannels pol vs 63 | check "wrong number of huffman witnesses" 64 | (length extremalChannels == length witnesses) 65 | bounds <- for (zip extremalChannels witnesses) $ \(c, witness) -> do 66 | let prunedSet = VS.pruneExtremal pol c vs 67 | getBound steps witness prunedSet 68 | check "huffman bound too low" (huffmanBound bounds >= bound step) 69 | return () 70 | 71 | huffmanBound :: [Int] -> Int 72 | huffmanBound = huffmanBound' . IM.fromListWith (+) . map (\k -> (k, 1)) 73 | 74 | huffmanBound' :: IntMap Int -> Int 75 | huffmanBound' queue = case queuePop queue of 76 | Just (item, queue') -> case queuePop queue' of 77 | Nothing -> item 78 | Just (item2, queue'') -> 79 | huffmanBound' $ queuePush (1 + max item item2) queue'' 80 | 81 | queuePop :: IntMap Int -> Maybe (Int, IntMap Int) 82 | queuePop queue = do 83 | ((item, count), rest) <- IM.minViewWithKey queue 84 | let rest' = if count > 1 then IM.insert item (count - 1) rest else rest 85 | return (item, rest') 86 | 87 | queuePush :: Int -> IntMap Int -> IntMap Int 88 | queuePush item = IM.insertWith (+) item 1 89 | 90 | checkSuccessors :: ProofSteps -> ProofStep -> [Maybe Witness] -> Result () 91 | checkSuccessors steps step witnesses = do 92 | let vs = vectSet step 93 | channels = VS.channels vs 94 | successors = catMaybes 95 | [ VS.applyComp i j vs | i <- [0 .. channels - 1], j <- [0 .. i - 1] ] 96 | 97 | check "set might already be sorted" (VS.size vs > 1 + VS.channels vs) 98 | 99 | check "wrong number of successor witnesses" 100 | (length successors == length witnesses) 101 | 102 | for_ (zip successors witnesses) $ \(successor, witness) -> do 103 | b <- getBound steps witness successor 104 | check "successor bound too low" (b + 1 >= bound step) 105 | return () 106 | 107 | return () 108 | 109 | getBound :: ProofSteps -> Maybe Witness -> VectSet -> Result Int 110 | getBound _ Nothing vs | VS.size vs <= 1 + VS.channels vs = return 0 111 | | otherwise = return 1 112 | getBound steps (Just witness) vs = do 113 | witnessStep <- steps $ stepId witness 114 | let witnessSet = vectSet witnessStep 115 | witnessSet' = if invert witness then VS.invert witnessSet else witnessSet 116 | witnessBound = bound witnessStep 117 | permutedWitnessSet = VS.permute (perm witness) witnessSet' 118 | check "witness does not subsume target" (VS.subsumes permutedWitnessSet vs) 119 | return witnessBound 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lower Size Bounds for Sorting Networks 2 | 3 | This is an implementation of a new search procedure for finding lower bounds 4 | for the size of sorting networks. Using this search procedure I was able to 5 | show that sorting 11 inputs requires 35 comparisons, improving upon the 6 | previously best known lower bound of 33 comparisons. This in turn improves the 7 | lower bound for 12 inputs to 39 comparisons. As sorting networks matching these 8 | bounds are already known, both new lower bounds are tight. 9 | 10 | This repository contains the following components: 11 | 12 | * A log generating search procedure, implemented in Rust. 13 | * Pruning and post-processing of search logs, implemented in Rust. 14 | * A formally verified checker for the certificates generated from search 15 | logs, implemented using Haskell and Isabelle/HOL. 16 | 17 | For a full explanation see the [corresponding paper on arXiv][paper]. If you 18 | have any questions feel free to reach out to [me on Twitter][twitter]. 19 | 20 | ## Usage 21 | 22 | To perform the search and verify the result, use: 23 | 24 | ``` 25 | bash search_and_verify.sh CHANNELS DATADIR 26 | ``` 27 | 28 | If you have a proof certificate generated from the search logs, it can be 29 | checked using: 30 | 31 | ``` 32 | bash verify_proof_cert.sh CERTBIN 33 | ``` 34 | 35 | ## System Requirements 36 | 37 | To compile the search procedure, a recent Rust compiler and the Rust package 38 | manager Cargo (e.g. as installed via [rustup]) are 39 | required. 40 | 41 | The formally verified checker uses the ["Haskell Tool Stack"][stack]. 42 | 43 | Checking the formal proof of the verified checker and re-exporting the verified 44 | Haskell code requires [Isabelle2019 or Isabelle2020][isabelle]. 45 | 46 | Performing the search for 9 input channels (replicating the previously best 47 | known result) requires very little resources and is a good way to make sure 48 | that a system is set up correctly. The first time `bash search_and_verify.sh 9 49 | data` is run, it will compile the required Rust and Haskell code. Subsequent 50 | runs will use cached binaries. For 9 input channels the search and verification 51 | require less than 50MB of ram and finishes within 10 seconds on my laptop. 52 | 53 | Performing the search and certificate generation for 11 input channels requires 54 | a bit below 200GB of ram and took below 80 hours on an AMD "EPYC 7401P" 24-core 55 | (48-thread) processor. Available threads will be used automatically. 56 | Distribution across multiple machines is not supported. 57 | 58 | Checking the resulting certificate for 11 input channels is less demanding. It 59 | takes below 8GB and around 3 hours on my laptop. 60 | 61 | ## Certificates 62 | 63 | A 1.2GB compressed certificate for 11 input channels is available for download: 64 | [proof_cert_11.bin.zst][cert11]. The certificate is also archived [at zenodo 65 | (DOI: 10.5281/zenodo.4108365)][cert11-zenodo]. It is compressed using 66 | [Zstandard][zst] and needs to be unpacked before it can be checked. The SHA-256 67 | hash of the uncompressed certificate is 68 | `7fe9f5cd694714bf83da0bcab162a290eb076ad4257265507a74cea8fab85b7e` 69 | 70 | ## Formal Verification 71 | 72 | The certificate checker consists of an unverified part that parses the 73 | certificate and a formally verified part that checks the parsed data. The 74 | formally verified part is extracted from a specification and proof written 75 | using Isabelle/HOL. This part is contained in the subdirectory 76 | `checker/verified`. 77 | 78 | A small patch, `strict_and_parallel.patch`, is applied to the extracted code. It 79 | is easy to manually verify that this patch only affects the evaluation order, 80 | not the result of running the extracted program. This is done to speed up the 81 | checking and to allow the use of multiple threads. Given enough time it is 82 | possible to run the checking without this patch. 83 | 84 | A current snapshot of the extracted code is part of the repository. This allows 85 | running the checker without requiring Isabelle. To update and reverify the 86 | extracted code run `checker/update_extracted_code.sh`. 87 | 88 | Currently the formal proof is without comments and not intended for humans to 89 | read. An exception to this is the definition of sorting networks and lower size 90 | bounds for sorting networks as well as the final lemma 91 | `check_proof_get_bound_spec` showing the correctness of the checker. These are 92 | contained in `Sorting_Network_Bound.thy` and `Checker_Codegen.thy` 93 | respectively. 94 | 95 | As Isabelle/HOL code uses escape sequences for various non-ASCII symbols, it is 96 | best viewed using Isabelle/jEdit. Alternatively a [current snapshot of the 97 | formal proof is available in PDF form][document.pdf]. 98 | 99 | [rustup]: https://rustup.rs/ 100 | [stack]: https:/haskellstack.org 101 | [isabelle]: https://isabelle.in.tum.de/ 102 | [cert11]: https://files.jix.one/sortnetopt/proof_cert_11.bin.zst 103 | [cert11-zenodo]: https://doi.org/10.5281/zenodo.4108365 104 | [zst]: https://facebook.github.io/zstd/ 105 | [document.pdf]: https://files.jix.one/sortnetopt/document.pdf 106 | [twitter]: https://twitter.com/jix_ 107 | [paper]: https://arxiv.org/abs/2012.04400 108 | -------------------------------------------------------------------------------- /src/bin/sortnetopt.rs: -------------------------------------------------------------------------------- 1 | #[global_allocator] 2 | static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; 3 | 4 | use std::path::PathBuf; 5 | 6 | use rustc_hash::FxHashSet as HashSet; 7 | use structopt::StructOpt; 8 | 9 | use sortnetopt::{ 10 | fix::fix_proof, 11 | logging, 12 | output_set::{ 13 | index::{Lower, OutputSetIndex}, 14 | OutputSet, MAX_CHANNELS, 15 | }, 16 | proof::gen_proof, 17 | prune::{prune, prune_all}, 18 | search::Search, 19 | }; 20 | 21 | #[derive(Debug, StructOpt)] 22 | struct Opt { 23 | /// Print peak memory usage 24 | #[structopt(short, long)] 25 | mem_usage: bool, 26 | #[structopt(subcommand)] 27 | command: OptCommand, 28 | } 29 | 30 | #[derive(Debug, StructOpt)] 31 | enum OptCommand { 32 | Search(OptSearch), 33 | Prune(OptPrune), 34 | PruneAll(OptPruneAll), 35 | GenProof(OptGenProof), 36 | FixProof(OptFixProof), 37 | Gnp(OptGnp), 38 | } 39 | 40 | #[derive(Debug, StructOpt)] 41 | struct OptSearch { 42 | /// Number of channels in the sorting network 43 | channels: usize, 44 | #[structopt(parse(from_os_str))] 45 | output: Option, 46 | #[structopt(short, long)] 47 | limit: Option, 48 | #[structopt(short, long)] 49 | prefix: Vec, 50 | } 51 | 52 | #[derive(Debug, StructOpt)] 53 | struct OptPrune { 54 | /// Number of channels in the sorting network 55 | channels: usize, 56 | #[structopt(parse(from_os_str))] 57 | input: PathBuf, 58 | } 59 | 60 | #[derive(Debug, StructOpt)] 61 | struct OptPruneAll { 62 | #[structopt(parse(from_os_str))] 63 | input: PathBuf, 64 | } 65 | 66 | #[derive(Debug, StructOpt)] 67 | struct OptGenProof { 68 | #[structopt(parse(from_os_str))] 69 | input: PathBuf, 70 | } 71 | 72 | #[derive(Debug, StructOpt)] 73 | struct OptFixProof { 74 | #[structopt(parse(from_os_str))] 75 | input: PathBuf, 76 | } 77 | 78 | #[derive(Debug, StructOpt)] 79 | struct OptGnp { 80 | /// Number of channels in the sorting network 81 | channels: usize, 82 | /// Dump index trees as graphviz graph 83 | #[structopt(short, long)] 84 | dump_index: bool, 85 | } 86 | 87 | fn main() { 88 | let opt = Opt::from_args(); 89 | logging::setup(opt.mem_usage); 90 | 91 | match opt.command { 92 | OptCommand::Search(opt) => cmd_search(opt), 93 | OptCommand::Prune(opt) => cmd_prune(opt), 94 | OptCommand::PruneAll(opt) => cmd_prune_all(opt), 95 | OptCommand::GenProof(opt) => cmd_gen_proof(opt), 96 | OptCommand::FixProof(opt) => cmd_fix_proof(opt), 97 | OptCommand::Gnp(opt) => cmd_gnp(opt), 98 | } 99 | } 100 | 101 | fn cmd_search(opt: OptSearch) { 102 | log::info!("options: {:?}", opt); 103 | 104 | let mut initial = OutputSet::all_values(opt.channels); 105 | 106 | for pair in opt.prefix.chunks(2) { 107 | initial.apply_comparator([pair[0], pair[1]]); 108 | } 109 | 110 | log::info!( 111 | "result = {}", 112 | Search::search(initial.as_ref(), opt.limit, opt.output.clone()) 113 | ); 114 | } 115 | 116 | fn cmd_prune(opt: OptPrune) { 117 | log::info!("options: {:?}", opt); 118 | prune(opt.channels, opt.input); 119 | } 120 | 121 | fn cmd_prune_all(opt: OptPruneAll) { 122 | log::info!("options: {:?}", opt); 123 | prune_all(opt.input); 124 | } 125 | 126 | fn cmd_gen_proof(opt: OptGenProof) { 127 | log::info!("options: {:?}", opt); 128 | gen_proof(opt.input); 129 | } 130 | 131 | fn cmd_fix_proof(opt: OptFixProof) { 132 | log::info!("options: {:?}", opt); 133 | fix_proof(opt.input); 134 | } 135 | 136 | fn cmd_gnp(opt: OptGnp) { 137 | log::info!("options: {:?}", opt); 138 | 139 | assert!(opt.channels <= MAX_CHANNELS); 140 | 141 | let initial = OutputSet::all_values(opt.channels); 142 | let abstraction = initial.abstraction(); 143 | 144 | let mut layer = OutputSetIndex::::new(opt.channels); 145 | 146 | layer.insert_with_abstraction(initial.as_ref(), &abstraction, 0); 147 | 148 | let mut layer_count = 0; 149 | 150 | while !layer.is_empty() { 151 | let mut next_layer = OutputSetIndex::::new(opt.channels); 152 | 153 | let mut next_output_sets = HashSet::default(); 154 | 155 | layer.for_each(|output_set: OutputSet<&[bool]>, _abstraction, _value| { 156 | for i in 0..opt.channels { 157 | for j in 0..i { 158 | let mut next_output_set = output_set.to_owned(); 159 | if next_output_set.apply_comparator([i, j]) { 160 | next_output_set.canonicalize(false); 161 | next_output_sets.insert(next_output_set); 162 | } 163 | } 164 | } 165 | 166 | for next_output_set in next_output_sets.drain() { 167 | let abstraction = next_output_set.abstraction(); 168 | next_layer.insert_with_abstraction(next_output_set.as_ref(), &abstraction[..], 0); 169 | } 170 | }); 171 | 172 | layer_count += 1; 173 | 174 | log::info!("layer {} size is {}", layer_count, next_layer.len(),); 175 | 176 | layer = next_layer; 177 | 178 | if opt.dump_index { 179 | layer 180 | .dump_dot(&mut std::io::BufWriter::new( 181 | std::fs::File::create(format!("_layer_{}.dot", layer_count)).unwrap(), 182 | )) 183 | .unwrap(); 184 | } 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /src/thread_pool.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::BTreeMap, 3 | future::Future, 4 | mem::{replace, transmute}, 5 | pin::Pin, 6 | sync::{ 7 | atomic::{AtomicUsize, Ordering}, 8 | Arc, RwLock, Weak, 9 | }, 10 | task::{Context, Poll}, 11 | thread, 12 | }; 13 | 14 | use abort_on_panic::abort_on_panic; 15 | use async_task::{JoinHandle, Task}; 16 | use crossbeam_channel::{unbounded, Sender}; 17 | 18 | type WeakSchedule = Weak>>>; 19 | 20 | #[derive(Default)] 21 | struct Pending { 22 | next_gc: usize, 23 | queue: BTreeMap<(Prio, usize), WeakSchedule>, 24 | } 25 | 26 | pub struct ThreadPool { 27 | queue: Sender>, 28 | pending_queue: Sender<((Prio, usize), WeakSchedule)>, 29 | pending: Arc>>, 30 | pending_id: AtomicUsize, 31 | } 32 | 33 | pub struct Handle { 34 | handle: JoinHandle, 35 | } 36 | 37 | #[derive(Clone)] 38 | pub struct Schedule { 39 | task: Arc>>>, 40 | } 41 | 42 | impl Schedule { 43 | pub fn schedule(&self) -> bool { 44 | if let Some(task) = self.task.write().unwrap().take() { 45 | task.schedule(); 46 | true 47 | } else { 48 | false 49 | } 50 | } 51 | 52 | pub fn is_scheduled(&self) -> bool { 53 | self.task.read().unwrap().is_none() 54 | } 55 | } 56 | 57 | impl Future for Handle { 58 | type Output = T; 59 | 60 | fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll { 61 | Pin::new(&mut self.handle) 62 | .poll(ctx) 63 | .map(|result| result.unwrap()) 64 | } 65 | } 66 | 67 | impl Drop for Handle { 68 | fn drop(&mut self) { 69 | self.handle.cancel(); 70 | } 71 | } 72 | 73 | impl ThreadPool { 74 | pub fn spawn<'a, O>(&'a self, future: Pin + Send + 'a>>) -> Handle 75 | where 76 | O: Send + 'static, 77 | { 78 | let (handle, schedule) = self.spawn_delayed(future); 79 | schedule.schedule(); 80 | handle 81 | } 82 | 83 | pub fn spawn_delayed<'a, O>( 84 | &'a self, 85 | future: Pin + Send + 'a>>, 86 | ) -> (Handle, Schedule) 87 | where 88 | O: Send + 'static, 89 | { 90 | let queue = self.queue.clone(); 91 | 92 | let future = unsafe { 93 | transmute::< 94 | Pin + Send + '_>>, 95 | Pin + Send + 'static>>, 96 | >(future) 97 | }; 98 | 99 | let (task, handle) = async_task::spawn( 100 | future, 101 | { 102 | move |task| { 103 | queue.send(task).unwrap(); 104 | } 105 | }, 106 | (), 107 | ); 108 | ( 109 | Handle { handle }, 110 | Schedule { 111 | task: Arc::new(RwLock::new(Some(task))), 112 | }, 113 | ) 114 | } 115 | 116 | pub fn add_pending(&self, prio: Prio, schedule: &Schedule) { 117 | let id = self.pending_id.fetch_add(1, Ordering::Relaxed); 118 | 119 | // Ignore errors here, to not panic during tear down 120 | let _ = self.pending_queue 121 | .send(((prio, id), Arc::downgrade(&schedule.task))); 122 | } 123 | 124 | pub fn scope(in_scope: impl FnOnce(&ThreadPool) -> T) -> T { 125 | let (sender, receiver) = unbounded::>(); 126 | 127 | let (pending_sender, pending_receiver) = unbounded::<((Prio, usize), WeakSchedule)>(); 128 | 129 | let pool = Self { 130 | queue: sender, 131 | pending_queue: pending_sender, 132 | pending: Arc::new(RwLock::new(Pending { 133 | next_gc: 0, 134 | queue: Default::default(), 135 | })), 136 | pending_id: 0.into(), 137 | }; 138 | 139 | let threads = num_cpus::get().max(1); 140 | 141 | let workers = (0..threads) 142 | .map(|_worker| { 143 | let receiver = receiver.clone(); 144 | let pending_receiver = pending_receiver.clone(); 145 | let pending = pool.pending.clone(); 146 | thread::spawn(move || 'outer: loop { 147 | while let Ok(task) = receiver.try_recv() { 148 | abort_on_panic!("task panicked", { task.run() }); 149 | } 150 | 151 | { 152 | let mut pending_mut = pending.write().unwrap(); 153 | 154 | let queue_limit = pending_receiver.len() * 2; 155 | 156 | let mut counter = 0; 157 | while let Ok((prio, task)) = pending_receiver.try_recv() { 158 | if pending_mut.queue.len() >= pending_mut.next_gc { 159 | let old_queue = replace(&mut pending_mut.queue, Default::default()); 160 | 161 | pending_mut.queue = old_queue 162 | .into_iter() 163 | .filter(|(_prio, schedule)| { 164 | if let Some(schedule) = schedule.upgrade() { 165 | !(Schedule { task: schedule }).is_scheduled() 166 | } else { 167 | false 168 | } 169 | }) 170 | .collect(); 171 | 172 | pending_mut.next_gc = pending_mut.queue.len().max(5000) * 2; 173 | } 174 | 175 | pending_mut.queue.insert(prio, task); 176 | counter += 1; 177 | if counter >= queue_limit { 178 | break; 179 | } 180 | } 181 | 182 | while let Some(&prio) = pending_mut.queue.keys().next() { 183 | let schedule = pending_mut.queue.remove(&prio).unwrap(); 184 | if let Some(schedule) = schedule.upgrade() { 185 | if (Schedule { task: schedule }).schedule() { 186 | continue 'outer; 187 | } 188 | } 189 | } 190 | } 191 | 192 | match receiver.recv_timeout(std::time::Duration::from_millis(10)) { 193 | Ok(task) => { 194 | abort_on_panic!("task panicked", { task.run() }); 195 | } 196 | Err(err) => { 197 | if err.is_disconnected() { 198 | break; 199 | } 200 | } 201 | } 202 | }) 203 | }) 204 | .collect::>(); 205 | 206 | let _guard = scopeguard::guard((), |_| { 207 | for worker in workers { 208 | let _ignored = worker.join(); 209 | } 210 | }); 211 | 212 | let result = in_scope(&pool); 213 | 214 | drop(pool); 215 | 216 | result 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /src/proof.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::BTreeMap, 3 | fs::File, 4 | io::{BufRead, BufReader, BufWriter, Write}, 5 | path::PathBuf, 6 | sync::Arc, 7 | }; 8 | 9 | use byteorder::{LittleEndian, WriteBytesExt}; 10 | use indicatif::{ProgressBar, ProgressStyle}; 11 | use memmap::MmapOptions; 12 | use rayon::prelude::*; 13 | use rustc_hash::FxHashMap as HashMap; 14 | 15 | use crate::{ 16 | huffman::max_plus_1_huffman, 17 | output_set::{ 18 | index::{LowerInvert, OutputSetIndex}, 19 | CVec, OutputSet, 20 | }, 21 | }; 22 | 23 | pub fn gen_proof(input: PathBuf) { 24 | let mut pruned_indexes = BTreeMap::default(); 25 | 26 | let index = input.join("index.txt"); 27 | let index_file = BufReader::new(File::open(index).unwrap()); 28 | for line in index_file.lines() { 29 | let line = line.unwrap(); 30 | let channels = str::parse::(line.split('_').nth(1).unwrap()).unwrap(); 31 | let bound = 32 | str::parse::(line.split('_').nth(2).unwrap().split('.').next().unwrap()).unwrap(); 33 | 34 | log::info!("loading {}_{}", channels, bound); 35 | 36 | let pruned_path = input.join(line).with_extension("pbin"); 37 | 38 | let pruned_file = File::open(pruned_path).unwrap(); 39 | let pruned_data = unsafe { MmapOptions::new().map(&pruned_file).unwrap() }; 40 | 41 | let pruned_index = pruned_indexes 42 | .entry(channels) 43 | .or_insert_with(|| OutputSetIndex::::new(channels)); 44 | 45 | let stride = OutputSet::packed_len_for_channels(channels); 46 | for packed in pruned_data.chunks(stride) { 47 | let output_set = OutputSet::from_packed(channels, packed); 48 | pruned_index.insert_new_unchecked_with_abstraction( 49 | output_set.as_ref(), 50 | &output_set.abstraction(), 51 | bound, 52 | ); 53 | } 54 | } 55 | 56 | for (channels, pruned_index) in pruned_indexes.iter() { 57 | log::info!("{} channels: {}", channels, pruned_index.len()) 58 | } 59 | 60 | let max_channels = *pruned_indexes.keys().rev().next().unwrap(); 61 | let mut gen = GenProof::new(pruned_indexes); 62 | 63 | gen.prove_all(); 64 | 65 | gen.encode_proof(&Arc::new(OutputSet::all_values(max_channels))); 66 | 67 | log::info!("proof using {} steps", gen.step_data.len()); 68 | 69 | let proof_path = input.join("proof.bin"); 70 | let mut proof_file = BufWriter::new(File::create(proof_path).unwrap()); 71 | 72 | gen.write_proof(&mut proof_file); 73 | } 74 | 75 | #[derive(Clone)] 76 | enum ProofStep { 77 | Trivial, 78 | NotSorted, 79 | Huffman { 80 | pol: bool, 81 | witnesses: Vec, Arc)>>, 82 | }, 83 | Successors { 84 | witnesses: Vec, Arc)>>, 85 | }, 86 | } 87 | 88 | struct GenProof { 89 | pruned_indexes: BTreeMap>, 90 | output_sets: HashMap, (u8, Arc)>, 91 | steps: HashMap, ProofStep>, 92 | step_ids: HashMap, usize>, 93 | step_data: Vec>, 94 | } 95 | 96 | impl GenProof { 97 | pub fn new(pruned_indexes: BTreeMap>) -> Self { 98 | let mut output_sets = HashMap::default(); 99 | for pruned_index in pruned_indexes.values() { 100 | pruned_index.for_each(|output_set, _abstraction, bound| { 101 | let shared_output_set = Arc::new(output_set.to_owned()); 102 | output_sets.insert(shared_output_set.clone(), (bound, shared_output_set)); 103 | }) 104 | } 105 | Self { 106 | output_sets, 107 | pruned_indexes, 108 | steps: Default::default(), 109 | step_ids: Default::default(), 110 | step_data: vec![], 111 | } 112 | } 113 | 114 | pub fn prove_all(&mut self) { 115 | let template = "{elapsed_precise} [{wide_bar:.green/blue}] {percent}% {pos}/{len} {eta}"; 116 | let bar = ProgressBar::new(self.output_sets.len() as u64); 117 | bar.set_style( 118 | ProgressStyle::default_bar() 119 | .template(template) 120 | .progress_chars("#>-"), 121 | ); 122 | 123 | self.steps = self 124 | .output_sets 125 | .iter() 126 | .collect::>() 127 | .into_par_iter() 128 | .map(|(output_set, &(bound, _))| { 129 | bar.inc(1); 130 | let proof_step = self 131 | .trivial_step(bound, output_set.as_ref().as_ref()) 132 | .or_else(|| self.huffman_step(bound, output_set.as_ref().as_ref())) 133 | .or_else(|| self.successor_step(bound, output_set.as_ref().as_ref())); 134 | 135 | ( 136 | output_set.clone(), 137 | proof_step.expect("no valid proof step found"), 138 | ) 139 | }) 140 | .collect(); 141 | } 142 | 143 | pub fn encode_proof(&mut self, target: &Arc) -> Option { 144 | if let Some(&id) = self.step_ids.get(target) { 145 | Some(id) 146 | } else { 147 | let mut step_data = vec![]; 148 | 149 | let bound = self.output_sets.get(target).unwrap().0; 150 | 151 | step_data.push(target.channels() as u8); 152 | step_data.push(bound); 153 | step_data.extend(target.packed()); 154 | 155 | let step_witnesses; 156 | match self.steps.get(target).unwrap().clone() { 157 | ProofStep::Trivial | ProofStep::NotSorted => return None, 158 | ProofStep::Huffman { pol, witnesses } => { 159 | step_data.push(pol as u8); 160 | step_witnesses = witnesses; 161 | } 162 | ProofStep::Successors { witnesses } => { 163 | step_data.push(2); 164 | step_witnesses = witnesses; 165 | } 166 | } 167 | 168 | for witness in step_witnesses { 169 | if let Some((invert, perm, id)) = witness.and_then(|witness| { 170 | Some((witness.0, witness.1, self.encode_proof(&witness.2)?)) 171 | }) { 172 | step_data.push(invert as u8); 173 | step_data.extend(perm.iter().map(|&index| index as u8)); 174 | step_data.write_u32::(id as u32).unwrap(); 175 | } else { 176 | step_data.push(2); 177 | } 178 | } 179 | 180 | let id = self.step_data.len(); 181 | 182 | self.step_ids.insert(target.clone(), id); 183 | self.step_data.push(step_data.into_boxed_slice()); 184 | 185 | Some(id) 186 | } 187 | } 188 | 189 | pub fn write_proof(&self, target: &mut impl Write) { 190 | target 191 | .write_u32::(self.step_data.len() as u32) 192 | .unwrap(); 193 | 194 | let mut offset = 4 + (8 + 4) * self.step_data.len() as u64; 195 | 196 | for step in self.step_data.iter() { 197 | target.write_u64::(offset).unwrap(); 198 | target.write_u32::(step.len() as u32).unwrap(); 199 | offset += step.len() as u64; 200 | } 201 | 202 | for step in self.step_data.iter() { 203 | target.write_all(step).unwrap(); 204 | } 205 | } 206 | 207 | fn lookup_witness( 208 | &self, 209 | target: OutputSet<&[bool]>, 210 | ) -> (u8, Option<(bool, CVec, Arc)>) { 211 | let subsuming = self 212 | .pruned_indexes 213 | .get(&target.channels()) 214 | .and_then(|index| { 215 | index.lookup_subsuming_with_abstraction(target, &target.abstraction()) 216 | }); 217 | 218 | if let Some((bound, (invert, perm), output_set)) = subsuming { 219 | ( 220 | bound, 221 | Some(( 222 | invert, 223 | perm, 224 | self.output_sets.get(&output_set).unwrap().1.clone(), 225 | )), 226 | ) 227 | } else { 228 | (if target.is_sorted() { 0 } else { 1 }, None) 229 | } 230 | } 231 | 232 | fn trivial_step(&self, bound: u8, target: OutputSet<&[bool]>) -> Option { 233 | if bound == 0 { 234 | Some(ProofStep::Trivial) 235 | } else if bound == 1 && !target.is_sorted() { 236 | Some(ProofStep::NotSorted) 237 | } else { 238 | None 239 | } 240 | } 241 | 242 | fn huffman_step(&self, bound: u8, target: OutputSet<&[bool]>) -> Option { 243 | let mut extremal_channels = [CVec::new(), CVec::new()]; 244 | 245 | for (pol, pol_channels) in extremal_channels.iter_mut().enumerate() { 246 | for channel in 0..target.channels() { 247 | if target.channel_is_extremal(pol > 0, channel) { 248 | pol_channels.push(channel); 249 | } 250 | } 251 | } 252 | 253 | let mut pols = [0, 1]; 254 | pols.sort_unstable_by_key(|&pol| extremal_channels[pol].len()); 255 | 256 | for &pol in pols.iter() { 257 | let mut bounds = vec![]; 258 | let mut witnesses = vec![]; 259 | let mut pruned_output_sets = vec![]; 260 | for &channel in extremal_channels[pol].iter() { 261 | let mut pruned = OutputSet::all_values(target.channels() - 1); 262 | target.prune_extremal_channel_into(pol > 0, channel, pruned.as_mut()); 263 | let (bound, witness) = self.lookup_witness(pruned.as_ref()); 264 | bounds.push(bound); 265 | witnesses.push(witness); 266 | pruned_output_sets.push(pruned); 267 | } 268 | 269 | let huffman_bound = max_plus_1_huffman(&bounds); 270 | if huffman_bound >= bound { 271 | return Some(ProofStep::Huffman { 272 | pol: pol > 0, 273 | witnesses, 274 | }); 275 | } 276 | } 277 | 278 | None 279 | } 280 | 281 | fn successor_step(&self, bound: u8, target: OutputSet<&[bool]>) -> Option { 282 | let mut witnesses = vec![]; 283 | for i in 0..target.channels() { 284 | for j in 0..i { 285 | let mut successor = target.to_owned(); 286 | if successor.apply_comparator([i, j]) { 287 | let (successor_bound, witness) = self.lookup_witness(successor.as_ref()); 288 | if successor_bound + 1 < bound { 289 | return None; 290 | } 291 | witnesses.push(witness); 292 | } 293 | } 294 | } 295 | 296 | Some(ProofStep::Successors { witnesses }) 297 | } 298 | } 299 | -------------------------------------------------------------------------------- /src/output_set/subsume.rs: -------------------------------------------------------------------------------- 1 | use std::{iter::repeat, mem::replace}; 2 | 3 | use super::{CVec, OutputSet, MAX_CHANNELS}; 4 | 5 | #[derive(Copy, Clone)] 6 | enum UndoAction { 7 | Matching { 8 | channels: [usize; 2], 9 | }, 10 | Swap { 11 | target_channel: usize, 12 | source_channels: [usize; 2], 13 | }, 14 | } 15 | 16 | pub struct Subsume { 17 | channels: usize, 18 | output_sets: [OutputSet; 2], 19 | perms: [CVec; 2], 20 | matching: [CVec; 2], 21 | undo_stack: Vec, 22 | fixed_channels: usize, 23 | buffer: Vec, 24 | } 25 | 26 | impl Subsume { 27 | pub fn new(output_sets: [OutputSet; 2]) -> Self { 28 | assert_eq!(output_sets[0].channels(), output_sets[1].channels()); 29 | let channels = output_sets[0].channels(); 30 | let identity_perm = (0..channels).collect::>(); 31 | let all_mask = (1 << channels) - 1; 32 | let full_matching = repeat(all_mask).take(channels).collect::>(); 33 | Self { 34 | channels, 35 | output_sets, 36 | perms: [identity_perm.clone(), identity_perm], 37 | matching: [full_matching.clone(), full_matching], 38 | undo_stack: vec![], 39 | fixed_channels: 0, 40 | buffer: vec![], 41 | } 42 | } 43 | 44 | pub fn search(&mut self) -> Option> { 45 | loop { 46 | loop { 47 | if self.fixed_channels == self.channels { 48 | if self.output_sets[0].subsumes_unpermuted(self.output_sets[1].as_ref()) { 49 | let mut perm = (0..self.channels).collect::>(); 50 | 51 | for i in 0..self.channels { 52 | perm[self.perms[1][i]] = self.perms[0][i]; 53 | } 54 | 55 | return Some(perm); 56 | } else { 57 | return None; 58 | } 59 | } 60 | if !self.filter_matching() { 61 | return None; 62 | } 63 | if !self.move_unique() { 64 | break; 65 | } 66 | } 67 | 68 | let guess = self.select_guess(); 69 | 70 | let stack_depth = self.undo_stack.len(); 71 | let fixed_channels = self.fixed_channels; 72 | 73 | if self.isolate_matching(guess) { 74 | self.move_unique(); 75 | let result = self.search(); 76 | if result.is_some() { 77 | self.rollback(stack_depth, fixed_channels); 78 | return result; 79 | } 80 | } 81 | 82 | self.rollback(stack_depth, fixed_channels); 83 | 84 | if !self.remove_matching(guess) { 85 | return None; 86 | } 87 | self.move_unique(); 88 | } 89 | } 90 | 91 | fn rollback(&mut self, depth: usize, fixed_channels: usize) { 92 | self.fixed_channels = fixed_channels; 93 | for action in self.undo_stack.drain(depth..).rev() { 94 | match action { 95 | UndoAction::Matching { channels } => { 96 | for side in 0..2 { 97 | let side_a = side; 98 | let side_b = 1 - side; 99 | let channel_a = channels[side_a]; 100 | let channel_b = channels[side_b]; 101 | self.matching[side_a][channel_a] |= 1 << channel_b; 102 | self.matching[side_b][channel_b] |= 1 << channel_a; 103 | } 104 | } 105 | UndoAction::Swap { 106 | target_channel, 107 | source_channels, 108 | } => { 109 | for (&source_channel, (output_set, perm)) in source_channels 110 | .iter() 111 | .zip(self.output_sets.iter_mut().zip(self.perms.iter_mut())) 112 | { 113 | output_set.swap_channels([target_channel, source_channel]); 114 | perm.swap(target_channel, source_channel); 115 | } 116 | } 117 | } 118 | } 119 | } 120 | 121 | fn get_matching(&self, channels: [usize; 2]) -> bool { 122 | self.matching[0][channels[0]] & (1 << channels[1]) != 0 123 | } 124 | 125 | fn remove_matching(&mut self, channels: [usize; 2]) -> bool { 126 | debug_assert!(self.get_matching(channels)); 127 | 128 | for side in 0..2 { 129 | let mask = self.matching[side][channels[side]]; 130 | if mask & (mask - 1) == 0 { 131 | return false; 132 | } 133 | } 134 | 135 | self.undo_stack.push(UndoAction::Matching { channels }); 136 | 137 | for side in 0..2 { 138 | self.matching[side][channels[side]] &= !(1 << channels[1 - side]); 139 | } 140 | 141 | for side in 0..2 { 142 | let mask = self.matching[side][channels[side]]; 143 | if mask & (mask - 1) == 0 { 144 | let other_channel = mask.trailing_zeros() as usize; 145 | 146 | loop { 147 | let other_mask = 148 | self.matching[1 - side][other_channel] & !(1 << channels[side]); 149 | if other_mask == 0 { 150 | break; 151 | } 152 | 153 | let mut channels_rec = [0, 0]; 154 | channels_rec[1 - side] = other_channel; 155 | 156 | let target_channel = other_mask.trailing_zeros() as usize; 157 | channels_rec[side] = target_channel; 158 | 159 | if !self.remove_matching(channels_rec) { 160 | return false; 161 | } 162 | } 163 | } 164 | } 165 | 166 | true 167 | } 168 | 169 | fn isolate_matching(&mut self, channels: [usize; 2]) -> bool { 170 | loop { 171 | let other_mask = self.matching[1][channels[1]] & !(1 << channels[0]); 172 | if other_mask == 0 { 173 | break; 174 | } 175 | let target_channel = other_mask.trailing_zeros() as usize; 176 | 177 | if !self.remove_matching([target_channel, channels[1]]) { 178 | return false; 179 | } 180 | } 181 | true 182 | } 183 | 184 | fn select_guess(&self) -> [usize; 2] { 185 | let mut min_choice = (MAX_CHANNELS + 1, 0, 0); 186 | 187 | for side in 0..2 { 188 | for channel in self.fixed_channels..self.channels { 189 | let matching_channel = self.perms[side][channel]; 190 | let weight = self.matching[side][matching_channel].count_ones() as usize; 191 | let choice = (weight, side, matching_channel); 192 | min_choice = min_choice.min(choice); 193 | } 194 | } 195 | 196 | let (_weight, min_side, min_matching_channel) = min_choice; 197 | 198 | let mut min_other_choice = (MAX_CHANNELS + 1, 0); 199 | 200 | let mut mask = self.matching[min_side][min_matching_channel]; 201 | 202 | while mask != 0 { 203 | let matching_channel = mask.trailing_zeros() as usize; 204 | mask &= mask - 1; 205 | let weight = self.matching[1 - min_side][matching_channel].count_ones() as usize; 206 | let other_choice = (weight, matching_channel); 207 | min_other_choice = min_other_choice.min(other_choice); 208 | } 209 | 210 | let (_weight, min_other_matching_channel) = min_other_choice; 211 | 212 | let mut result = [0, 0]; 213 | result[min_side] = min_matching_channel; 214 | result[1 - min_side] = min_other_matching_channel; 215 | 216 | result 217 | } 218 | 219 | fn filter_matching(&mut self) -> bool { 220 | let abstraction_len = 221 | self.output_sets[0].low_channels_channel_abstraction_len(self.fixed_channels); 222 | 223 | self.buffer.resize( 224 | abstraction_len * (self.channels - self.fixed_channels) * 2, 225 | 0, 226 | ); 227 | 228 | let mut buffer_vec = replace(&mut self.buffer, vec![]); 229 | let mut buffer = &mut buffer_vec[..]; 230 | 231 | let mut abstraction_buffers = [CVec::new(), CVec::new()]; 232 | 233 | for (side, buffers) in abstraction_buffers.iter_mut().enumerate() { 234 | for channel in self.fixed_channels..self.channels { 235 | let (current_buffer, buffer_rest) = buffer.split_at_mut(abstraction_len); 236 | 237 | self.output_sets[side].low_channels_channel_abstraction( 238 | self.fixed_channels, 239 | channel, 240 | current_buffer, 241 | ); 242 | 243 | buffer = buffer_rest; 244 | buffers.push(current_buffer); 245 | } 246 | } 247 | 248 | let mut result = true; 249 | 250 | 'outer: for (channel_lo, buffer_lo) in 251 | (self.fixed_channels..self.channels).zip(abstraction_buffers[0].iter()) 252 | { 253 | let matching_channel_lo = self.perms[0][channel_lo]; 254 | 255 | for (channel_hi, buffer_hi) in 256 | (self.fixed_channels..self.channels).zip(abstraction_buffers[1].iter()) 257 | { 258 | let matching_channel_hi = self.perms[1][channel_hi]; 259 | if self.get_matching([matching_channel_lo, matching_channel_hi]) { 260 | if buffer_lo 261 | .iter() 262 | .zip(buffer_hi.iter()) 263 | .any(|(&lo, &hi)| lo > hi) 264 | { 265 | if !self.remove_matching([matching_channel_lo, matching_channel_hi]) { 266 | result = false; 267 | break 'outer; 268 | } 269 | } 270 | } 271 | } 272 | } 273 | 274 | drop(abstraction_buffers); 275 | 276 | self.buffer = buffer_vec; 277 | result 278 | } 279 | 280 | fn move_unique(&mut self) -> bool { 281 | let mut moved = false; 282 | 283 | for channel_lo in self.fixed_channels..self.channels { 284 | let matching_channel_lo = self.perms[0][channel_lo]; 285 | let mask = self.matching[0][matching_channel_lo]; 286 | if mask & (mask - 1) == 0 { 287 | let matching_channel_hi = mask.trailing_zeros() as usize; 288 | let channel_hi = self.perms[1] 289 | .iter() 290 | .position(|&matching_channel| matching_channel == matching_channel_hi) 291 | .unwrap(); 292 | 293 | self.undo_stack.push(UndoAction::Swap { 294 | target_channel: self.fixed_channels, 295 | source_channels: [channel_lo, channel_hi], 296 | }); 297 | 298 | for (side, &channel) in [channel_lo, channel_hi].iter().enumerate() { 299 | self.output_sets[side].swap_channels([self.fixed_channels, channel]); 300 | self.perms[side].swap(self.fixed_channels, channel); 301 | } 302 | 303 | self.fixed_channels += 1; 304 | moved = true; 305 | } 306 | } 307 | moved 308 | } 309 | } 310 | -------------------------------------------------------------------------------- /src/output_set/canon.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | cmp::Reverse, 3 | mem::{replace, swap}, 4 | }; 5 | 6 | use super::{CVec, OutputSet, Perm}; 7 | 8 | #[derive(Clone)] 9 | struct CandidateData { 10 | perm: Perm, 11 | fingerprint: CVec, 12 | partition: CVec<(u16, u16)>, 13 | partition_score: (usize, usize), 14 | } 15 | 16 | impl CandidateData { 17 | pub fn score(&self) -> (usize, usize, &[u64]) { 18 | ( 19 | self.partition_score.0, 20 | self.partition_score.1, 21 | self.fingerprint.as_ref(), 22 | ) 23 | } 24 | 25 | pub fn sort_partition(&mut self) { 26 | self.partition 27 | .sort_unstable_by_key(|&(part_mask, part_id)| { 28 | (Reverse(part_mask.count_ones()), part_id, part_mask) 29 | }); 30 | 31 | let singleton_count = self 32 | .partition 33 | .iter() 34 | .rev() 35 | .take_while(|&(part_mask, _)| part_mask & (part_mask - 1) == 0) 36 | .count(); 37 | 38 | let max_partition_size = self.partition[0].0.count_ones() as usize; 39 | 40 | self.partition_score = (singleton_count, !max_partition_size); 41 | } 42 | } 43 | 44 | pub struct Canonicalize { 45 | bitmap_len: usize, 46 | channels: usize, 47 | used_channels: usize, 48 | bitmaps: Vec, 49 | data: Vec, 50 | free_list: Vec, 51 | layer: Vec, 52 | next_layer: Vec, 53 | fixed: usize, 54 | buffer: Vec, 55 | temp: OutputSet, 56 | } 57 | 58 | macro_rules! get_mut { 59 | ($s:ident, $i:expr) => { 60 | match $i { 61 | index => OutputSet::from_bitmap( 62 | $s.channels, 63 | &mut $s.bitmaps[index * $s.bitmap_len..(index + 1) * $s.bitmap_len], 64 | ), 65 | } 66 | }; 67 | } 68 | 69 | macro_rules! get { 70 | ($s:ident, $i:expr) => { 71 | match $i { 72 | index => OutputSet::from_bitmap( 73 | $s.channels, 74 | &$s.bitmaps[index * $s.bitmap_len..(index + 1) * $s.bitmap_len], 75 | ), 76 | } 77 | }; 78 | } 79 | 80 | macro_rules! reborrow { 81 | ($s:ident) => {{ 82 | struct Reborrow<'a> { 83 | bitmaps: &'a [bool], 84 | channels: usize, 85 | bitmap_len: usize, 86 | }; 87 | Reborrow { 88 | bitmaps: $s.bitmaps.as_ref(), 89 | channels: $s.channels, 90 | bitmap_len: $s.bitmap_len, 91 | } 92 | }}; 93 | } 94 | 95 | macro_rules! alloc { 96 | ($s:ident, $output_set:expr, $data:expr) => { 97 | match ($output_set, $data) { 98 | (output_set, data) => { 99 | if let Some(index) = $s.free_list.pop() { 100 | $s.data[index] = data; 101 | $s.bitmaps[index * $s.bitmap_len..(index + 1) * $s.bitmap_len] 102 | .copy_from_slice(output_set.bitmap()); 103 | index 104 | } else { 105 | let index = $s.data.len(); 106 | $s.data.push(data); 107 | $s.bitmaps.extend_from_slice(output_set.bitmap()); 108 | index 109 | } 110 | } 111 | } 112 | }; 113 | } 114 | 115 | impl Canonicalize { 116 | pub fn new(mut output_set: OutputSet<&mut [bool]>, inversion: bool) -> Self { 117 | let mut new = Self { 118 | bitmap_len: 1 << output_set.channels(), 119 | channels: output_set.channels(), 120 | used_channels: output_set.channels(), 121 | bitmaps: vec![], 122 | data: vec![], 123 | free_list: vec![], 124 | layer: vec![], 125 | next_layer: vec![], 126 | fixed: 0, 127 | buffer: vec![], 128 | temp: output_set.to_owned(), 129 | }; 130 | 131 | let mut data = CandidateData { 132 | perm: Perm::identity(output_set.channels()), 133 | fingerprint: CVec::new(), 134 | partition: CVec::new(), 135 | partition_score: Default::default(), 136 | }; 137 | 138 | for channel in (0..new.channels).rev() { 139 | if output_set.is_channel_unconstrained(channel) { 140 | new.used_channels -= 1; 141 | output_set.swap_channels([channel, new.used_channels]); 142 | data.perm.perm.swap(channel, new.used_channels); 143 | } 144 | } 145 | 146 | let identity = alloc!(new, output_set.as_ref(), data.clone()); 147 | new.layer.push(identity); 148 | 149 | if inversion { 150 | data.perm.invert = true; 151 | let inverted = alloc!(new, output_set.as_ref(), data); 152 | get_mut!(new, inverted).invert(); 153 | if get!(new, inverted) == get!(new, identity) { 154 | new.free_list.push(inverted); 155 | } else { 156 | new.layer.push(inverted); 157 | } 158 | } 159 | 160 | new 161 | } 162 | 163 | pub fn canonicalize(&mut self) -> (OutputSet<&[bool]>, Perm) { 164 | if self.used_channels > 0 { 165 | self.initialize_partitions(); 166 | 167 | loop { 168 | self.prune_using_fingerprints(); 169 | let mut prune = false; 170 | while self.move_singleton() { 171 | prune = true; 172 | } 173 | if prune { 174 | self.prune(true); 175 | } 176 | if self.fixed == self.used_channels { 177 | break; 178 | } 179 | self.individualize(); 180 | self.prune(false); 181 | if self.fixed == self.used_channels { 182 | break; 183 | } 184 | self.refine(); 185 | } 186 | } 187 | 188 | let index = self.layer[0]; 189 | (get!(self, index), self.data[index].perm.clone()) 190 | } 191 | 192 | fn initialize_partitions(&mut self) { 193 | for &index in self.layer.iter() { 194 | let output_set = get!(self, index); 195 | let mut fingerprints = (0..self.used_channels) 196 | .map(|channel| (output_set.channel_fingerprint(channel), channel)) 197 | .collect::>(); 198 | 199 | fingerprints.sort_unstable(); 200 | 201 | let mut part_fingerprint = fingerprints[0].0; 202 | let mut part_id = 0; 203 | let mut part_mask = 0u16; 204 | 205 | let data = &mut self.data[index]; 206 | 207 | for &(fingerprint, channel) in fingerprints.iter() { 208 | if part_fingerprint != fingerprint { 209 | data.partition.push((part_mask, part_id)); 210 | part_mask = 0; 211 | part_fingerprint = fingerprint; 212 | part_id += 1; 213 | } 214 | data.fingerprint.push(fingerprint); 215 | part_mask |= 1 << channel; 216 | } 217 | 218 | data.partition.push((part_mask, part_id)); 219 | 220 | data.sort_partition(); 221 | } 222 | } 223 | 224 | fn prune_using_fingerprints(&mut self) { 225 | let data = &self.data; 226 | let min_index = self 227 | .layer 228 | .iter() 229 | .cloned() 230 | .min_by(|&a, &b| data[b].score().cmp(&data[a].score())) 231 | .unwrap(); 232 | let min_score = self.data[min_index].clone(); 233 | 234 | let free_list = &mut self.free_list; 235 | 236 | self.layer.retain(|&index| { 237 | if data[index].score() == min_score.score() { 238 | true 239 | } else { 240 | free_list.push(index); 241 | false 242 | } 243 | }); 244 | } 245 | 246 | fn move_singleton(&mut self) -> bool { 247 | for &index in self.layer.iter() { 248 | let data = &mut self.data[index]; 249 | 250 | let source_channel = match data.partition.last() { 251 | None => return false, 252 | Some((part, _)) if part & (part - 1) != 0 => return false, 253 | Some((part, _)) => part.trailing_zeros() as usize, 254 | }; 255 | 256 | data.partition.pop(); 257 | 258 | if source_channel != self.fixed { 259 | get_mut!(self, index).swap_channels([source_channel, self.fixed]); 260 | data.perm.perm.swap(source_channel, self.fixed); 261 | 262 | let source_mask = 1 << source_channel; 263 | let fixed_mask = 1 << self.fixed; 264 | 265 | for (part, _id) in data.partition.iter_mut() { 266 | let fixed_present = *part & fixed_mask != 0; 267 | 268 | *part = (*part & !fixed_mask) | (source_mask * fixed_present as u16); 269 | } 270 | } 271 | } 272 | self.fixed += 1; 273 | true 274 | } 275 | 276 | fn prune(&mut self, recompute_fingerprints: bool) { 277 | if self.layer.len() == 1 { 278 | return; 279 | } 280 | 281 | let reborrow = reborrow!(self); 282 | let free_list = &mut self.free_list; 283 | 284 | if self.fixed == self.used_channels { 285 | let min_output_set_index = *self 286 | .layer 287 | .iter() 288 | .min_by_key(|&index| get!(reborrow, index)) 289 | .unwrap(); 290 | 291 | self.layer.retain(|&index| { 292 | if index == min_output_set_index { 293 | true 294 | } else { 295 | free_list.push(index); 296 | false 297 | } 298 | }); 299 | return; 300 | } 301 | 302 | self.layer.sort_by_key(|&index| get!(reborrow, index)); 303 | self.layer.dedup_by(|&mut test, &mut repr| { 304 | if get!(reborrow, test) == get!(reborrow, repr) { 305 | free_list.push(test); 306 | true 307 | } else { 308 | false 309 | } 310 | }); 311 | 312 | if self.layer.len() == 1 || !recompute_fingerprints { 313 | return; 314 | } 315 | 316 | let mut max_fingerprint = 0; 317 | 318 | for &index in self.layer.iter() { 319 | let fingerprint = 320 | get!(self, index).low_channels_fingerprint(self.fixed, &mut self.buffer); 321 | self.data[index].fingerprint[0] = fingerprint; 322 | max_fingerprint = max_fingerprint.max(fingerprint); 323 | } 324 | 325 | let data = &self.data; 326 | 327 | self.layer.retain(|&index| { 328 | if data[index].fingerprint[0] == max_fingerprint { 329 | true 330 | } else { 331 | free_list.push(index); 332 | false 333 | } 334 | }) 335 | } 336 | 337 | fn individualize(&mut self) { 338 | let mut max_fingerprint = 0; 339 | for index in self.layer.drain(..) { 340 | let mut part_iter = self.data[index].partition.last().unwrap().0; 341 | 342 | while part_iter != 0 { 343 | let source_channel = part_iter.trailing_zeros() as usize; 344 | part_iter = part_iter & (part_iter - 1); 345 | self.temp 346 | .bitmap_mut() 347 | .copy_from_slice(get!(self, index).bitmap()); 348 | 349 | self.temp.swap_channels([source_channel, self.fixed]); 350 | let fingerprint = self 351 | .temp 352 | .low_channels_fingerprint(self.fixed + 1, &mut self.buffer); 353 | 354 | if fingerprint > max_fingerprint { 355 | for free_index in self.next_layer.drain(..) { 356 | self.free_list.push(free_index); 357 | } 358 | max_fingerprint = fingerprint 359 | } else if fingerprint < max_fingerprint { 360 | continue; 361 | } 362 | 363 | let mut data = self.data[index].clone(); 364 | 365 | data.perm.perm.swap(source_channel, self.fixed); 366 | let last_part = data.partition.last_mut().unwrap(); 367 | last_part.0 &= !(1 << source_channel); 368 | 369 | let source_mask = 1 << source_channel; 370 | let fixed_mask = 1 << self.fixed; 371 | 372 | for (part, _id) in data.partition.iter_mut() { 373 | let fixed_present = *part & fixed_mask != 0; 374 | 375 | *part = (*part & !fixed_mask) | (source_mask * fixed_present as u16); 376 | } 377 | 378 | let next_index = alloc!(self, self.temp.as_ref(), data); 379 | self.next_layer.push(next_index); 380 | } 381 | 382 | self.free_list.push(index); 383 | } 384 | 385 | swap(&mut self.layer, &mut self.next_layer); 386 | 387 | self.fixed += 1; 388 | } 389 | 390 | fn refine(&mut self) { 391 | for &index in self.layer.iter() { 392 | let data = &mut self.data[index]; 393 | let output_set = get!(self, index); 394 | 395 | data.fingerprint.clear(); 396 | 397 | let mut part_id = 0; 398 | 399 | for (part, _id) in replace(&mut data.partition, CVec::new()) { 400 | let mut part_iter = part; 401 | 402 | let mut fingerprints = CVec::<(u64, usize)>::new(); 403 | 404 | while part_iter != 0 { 405 | let channel = part_iter.trailing_zeros() as usize; 406 | part_iter = part_iter & (part_iter - 1); 407 | fingerprints.push(( 408 | output_set.low_channels_channel_fingerprint( 409 | self.fixed, 410 | channel, 411 | &mut self.buffer, 412 | ), 413 | channel, 414 | )); 415 | } 416 | 417 | fingerprints.sort_unstable(); 418 | 419 | let mut part_fingerprint = fingerprints[0].0; 420 | let mut part_mask = 0u16; 421 | 422 | for &(fingerprint, channel) in fingerprints.iter() { 423 | if part_fingerprint != fingerprint { 424 | data.partition.push((part_mask, part_id)); 425 | part_mask = 0; 426 | part_fingerprint = fingerprint; 427 | part_id += 1; 428 | } 429 | data.fingerprint.push(fingerprint); 430 | part_mask |= 1 << channel; 431 | } 432 | 433 | data.partition.push((part_mask, part_id)); 434 | 435 | part_id += 1; 436 | } 437 | 438 | data.sort_partition(); 439 | } 440 | } 441 | } 442 | -------------------------------------------------------------------------------- /src/search/states.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::BTreeMap, 3 | hash::{Hash, Hasher}, 4 | mem::{replace, transmute}, 5 | sync::RwLock, 6 | }; 7 | 8 | use arrayref::array_ref; 9 | use futures::{ 10 | channel::oneshot, 11 | future::{FutureExt, Shared}, 12 | }; 13 | use rustc_hash::FxHasher; 14 | 15 | use crate::output_set::OutputSet; 16 | 17 | #[derive(Clone, Copy, PartialEq, Eq, Debug)] 18 | pub struct State { 19 | pub bounds: [u8; 2], 20 | pub huffman_bounds: [u8; 2], 21 | } 22 | 23 | pub struct StateMap { 24 | state_shards: Vec>>, 25 | lock_shards: Vec>>>>, 26 | } 27 | 28 | impl Default for StateMap { 29 | fn default() -> Self { 30 | let threads = num_cpus::get(); 31 | let shards = (threads * threads * 8).next_power_of_two(); 32 | 33 | Self { 34 | state_shards: (0..shards).map(|_| Default::default()).collect(), 35 | lock_shards: (0..shards).map(|_| Default::default()).collect(), 36 | } 37 | } 38 | } 39 | 40 | impl StateMap { 41 | pub fn len(&self) -> usize { 42 | self.state_shards 43 | .iter() 44 | .map(|shard| shard.read().unwrap().len()) 45 | .sum() 46 | } 47 | 48 | pub fn get(&self, output_set: OutputSet<&[bool]>) -> State { 49 | let packed = output_set.packed_pvec(); 50 | 51 | let mut hasher = FxHasher::default(); 52 | packed.hash(&mut hasher); 53 | let hash = hasher.finish(); 54 | let shard = (hash & (self.state_shards.len() - 1) as u64) as usize; 55 | 56 | if let Some(result) = self.state_shards[shard] 57 | .read() 58 | .unwrap() 59 | .get_with_packed(output_set.channels(), &packed) 60 | { 61 | result 62 | } else { 63 | if output_set.is_sorted() { 64 | State { 65 | bounds: [0, 0], 66 | huffman_bounds: [0, 0], 67 | } 68 | } else if output_set.channels() <= 2 { 69 | State { 70 | bounds: [1, 1], 71 | huffman_bounds: [1, 1], 72 | } 73 | } else { 74 | let quadratic_bound = (output_set.channels() * (output_set.channels() - 1)) / 2; 75 | 76 | let known_bounds = [0, 0, 1, 3, 5, 9, 12, 16, 19, 25, 29, 35]; 77 | 78 | let bound = known_bounds 79 | .get(output_set.channels()) 80 | .cloned() 81 | .unwrap_or(quadratic_bound as u8); 82 | 83 | State { 84 | bounds: [1, bound], 85 | huffman_bounds: [1, bound], 86 | } 87 | } 88 | } 89 | } 90 | 91 | pub fn set(&self, output_set: OutputSet<&[bool]>, state: State) { 92 | let packed = output_set.packed_pvec(); 93 | 94 | let mut hasher = FxHasher::default(); 95 | packed.hash(&mut hasher); 96 | let hash = hasher.finish(); 97 | let shard = (hash & (self.state_shards.len() - 1) as u64) as usize; 98 | 99 | self.state_shards[shard].write().unwrap().set_with_packed( 100 | output_set.channels(), 101 | &output_set.packed_pvec(), 102 | state, 103 | ) 104 | } 105 | 106 | pub async fn lock<'a>(&'a self, output_set: OutputSet<&'a [bool]>) -> Option> { 107 | let packed = output_set.packed(); 108 | 109 | let mut hasher = FxHasher::default(); 110 | packed.hash(&mut hasher); 111 | let hash = hasher.finish(); 112 | let shard = (hash & (self.lock_shards.len() - 1) as u64) as usize; 113 | 114 | match self.get_lock(output_set, shard, &packed) { 115 | Ok(lock) => { 116 | lock.await.unwrap(); 117 | None 118 | } 119 | Err(unlock) => Some(StateLock { 120 | states: self, 121 | unlock: Some(unlock), 122 | channels: output_set.channels(), 123 | packed, 124 | shard, 125 | }), 126 | } 127 | } 128 | 129 | fn get_lock<'a>( 130 | &'a self, 131 | output_set: OutputSet<&'a [bool]>, 132 | shard: usize, 133 | packed: &[u8], 134 | ) -> Result>, oneshot::Sender<()>> { 135 | let mut shard_mut = self.lock_shards[shard].write().unwrap(); 136 | 137 | if let Some(existing) = shard_mut.get_with_packed(output_set.channels(), &packed) { 138 | Ok(existing) 139 | } else { 140 | let (unlock, lock) = oneshot::channel(); 141 | shard_mut.set_with_packed(output_set.channels(), &packed, lock.shared()); 142 | Err(unlock) 143 | } 144 | } 145 | 146 | pub fn into_shards(self) -> Vec> { 147 | self.state_shards 148 | .into_iter() 149 | .map(|shard| shard.into_inner().unwrap()) 150 | .collect() 151 | } 152 | } 153 | 154 | pub struct StateLock<'a> { 155 | states: &'a StateMap, 156 | unlock: Option>, 157 | channels: usize, 158 | packed: Vec, 159 | shard: usize, 160 | } 161 | 162 | impl<'a> Drop for StateLock<'a> { 163 | fn drop(&mut self) { 164 | let mut shard_mut = self.states.lock_shards[self.shard].write().unwrap(); 165 | 166 | self.unlock.take().unwrap().send(()).unwrap(); 167 | shard_mut.remove_with_packed(self.channels, &self.packed); 168 | } 169 | } 170 | 171 | pub struct OutputSetMap { 172 | states_3_channels: BTreeMap<[u8; 1 << 0], T>, 173 | states_4_channels: BTreeMap<[u8; 1 << 1], T>, 174 | states_5_channels: BTreeMap<[u8; 1 << 2], T>, 175 | states_6_channels: BTreeMap<[u8; 1 << 3], T>, 176 | states_7_channels: BTreeMap<[u8; 1 << 4], T>, 177 | states_8_channels: BTreeMap<[u8; 1 << 5], T>, 178 | states_9_channels: BTreeMap<[[u8; 1 << 5]; 1 << 1], T>, 179 | states_10_channels: BTreeMap<[[u8; 1 << 5]; 1 << 2], T>, 180 | states_11_channels: BTreeMap<[[u8; 1 << 5]; 1 << 3], T>, 181 | } 182 | 183 | impl Default for OutputSetMap { 184 | fn default() -> Self { 185 | Self { 186 | states_3_channels: Default::default(), 187 | states_4_channels: Default::default(), 188 | states_5_channels: Default::default(), 189 | states_6_channels: Default::default(), 190 | states_7_channels: Default::default(), 191 | states_8_channels: Default::default(), 192 | states_9_channels: Default::default(), 193 | states_10_channels: Default::default(), 194 | states_11_channels: Default::default(), 195 | } 196 | } 197 | } 198 | 199 | impl OutputSetMap { 200 | pub fn len(&self) -> usize { 201 | self.states_3_channels.len() 202 | + self.states_4_channels.len() 203 | + self.states_5_channels.len() 204 | + self.states_6_channels.len() 205 | + self.states_7_channels.len() 206 | + self.states_8_channels.len() 207 | + self.states_9_channels.len() 208 | + self.states_10_channels.len() 209 | + self.states_11_channels.len() 210 | } 211 | 212 | pub fn drain_packed(&mut self) -> impl Iterator, T)> { 213 | let result = replace(&mut self.states_3_channels, Default::default()) 214 | .into_iter() 215 | .map(|(packed, state)| (3, packed.as_ref().to_owned(), state)) 216 | .chain( 217 | replace(&mut self.states_4_channels, Default::default()) 218 | .into_iter() 219 | .map(|(packed, state)| (4, packed.as_ref().to_owned(), state)), 220 | ) 221 | .chain( 222 | replace(&mut self.states_5_channels, Default::default()) 223 | .into_iter() 224 | .map(|(packed, state)| (5, packed.as_ref().to_owned(), state)), 225 | ) 226 | .chain( 227 | replace(&mut self.states_6_channels, Default::default()) 228 | .into_iter() 229 | .map(|(packed, state)| (6, packed.as_ref().to_owned(), state)), 230 | ) 231 | .chain( 232 | replace(&mut self.states_7_channels, Default::default()) 233 | .into_iter() 234 | .map(|(packed, state)| (7, packed.as_ref().to_owned(), state)), 235 | ) 236 | .chain( 237 | replace(&mut self.states_8_channels, Default::default()) 238 | .into_iter() 239 | .map(|(packed, state)| (8, packed.as_ref().to_owned(), state)), 240 | ) 241 | .chain( 242 | replace(&mut self.states_9_channels, Default::default()) 243 | .into_iter() 244 | .map(|(packed, state)| { 245 | ( 246 | 9, 247 | unsafe { transmute::<_, [u8; 1 << 6]>(packed) } 248 | .as_ref() 249 | .to_owned(), 250 | state, 251 | ) 252 | }), 253 | ) 254 | .chain( 255 | replace(&mut self.states_10_channels, Default::default()) 256 | .into_iter() 257 | .map(|(packed, state)| { 258 | ( 259 | 10, 260 | unsafe { transmute::<_, [u8; 1 << 7]>(packed) } 261 | .as_ref() 262 | .to_owned(), 263 | state, 264 | ) 265 | }), 266 | ) 267 | .chain( 268 | replace(&mut self.states_11_channels, Default::default()) 269 | .into_iter() 270 | .map(|(packed, state)| { 271 | ( 272 | 11, 273 | unsafe { transmute::<_, [u8; 1 << 8]>(packed) } 274 | .as_ref() 275 | .to_owned(), 276 | state, 277 | ) 278 | }), 279 | ); 280 | 281 | // Works around vscode's broken syntax highlighting 282 | let _ignored: [(); 0 >> 1] = []; 283 | let _ignored: [(); 0 >> 1] = []; 284 | let _ignored: [(); 0 >> 1] = []; 285 | 286 | result 287 | } 288 | 289 | pub fn get_with_packed(&self, channels: usize, packed: &[u8]) -> Option { 290 | let result = match channels { 291 | 3 => self 292 | .states_3_channels 293 | .get(array_ref!(packed, 0, 1 << 0)) 294 | .cloned(), 295 | 4 => self 296 | .states_4_channels 297 | .get(array_ref!(packed, 0, 1 << 1)) 298 | .cloned(), 299 | 5 => self 300 | .states_5_channels 301 | .get(array_ref!(packed, 0, 1 << 2)) 302 | .cloned(), 303 | 6 => self 304 | .states_6_channels 305 | .get(array_ref!(packed, 0, 1 << 3)) 306 | .cloned(), 307 | 7 => self 308 | .states_7_channels 309 | .get(array_ref!(packed, 0, 1 << 4)) 310 | .cloned(), 311 | 8 => self 312 | .states_8_channels 313 | .get(array_ref!(packed, 0, 1 << 5)) 314 | .cloned(), 315 | 316 | 9 => self 317 | .states_9_channels 318 | .get(unsafe { &transmute::<_, [_; 1 << 1]>(*array_ref!(packed, 0, 1 << 6)) }) 319 | .cloned(), 320 | 10 => self 321 | .states_10_channels 322 | .get(unsafe { &transmute::<_, [_; 1 << 2]>(*array_ref!(packed, 0, 1 << 7)) }) 323 | .cloned(), 324 | 11 => self 325 | .states_11_channels 326 | .get(unsafe { &transmute::<_, [_; 1 << 3]>(*array_ref!(packed, 0, 1 << 8)) }) 327 | .cloned(), 328 | _ => None, 329 | }; 330 | 331 | // Works around vscode's broken syntax highlighting 332 | let _ignored: [(); 0 >> 1] = []; 333 | let _ignored: [(); 0 >> 1] = []; 334 | let _ignored: [(); 0 >> 1] = []; 335 | 336 | result 337 | } 338 | 339 | pub fn set_with_packed(&mut self, channels: usize, packed: &[u8], value: T) { 340 | assert!(channels > 2); 341 | 342 | match channels { 343 | 3 => self 344 | .states_3_channels 345 | .insert(*array_ref!(packed, 0, 1 << 0), value), 346 | 4 => self 347 | .states_4_channels 348 | .insert(*array_ref!(packed, 0, 1 << 1), value), 349 | 5 => self 350 | .states_5_channels 351 | .insert(*array_ref!(packed, 0, 1 << 2), value), 352 | 6 => self 353 | .states_6_channels 354 | .insert(*array_ref!(packed, 0, 1 << 3), value), 355 | 7 => self 356 | .states_7_channels 357 | .insert(*array_ref!(packed, 0, 1 << 4), value), 358 | 8 => self 359 | .states_8_channels 360 | .insert(*array_ref!(packed, 0, 1 << 5), value), 361 | 362 | 9 => self.states_9_channels.insert( 363 | unsafe { transmute::<_, [_; 1 << 1]>(*array_ref!(packed, 0, 1 << 6)) }, 364 | value, 365 | ), 366 | 10 => self.states_10_channels.insert( 367 | unsafe { transmute::<_, [_; 1 << 2]>(*array_ref!(packed, 0, 1 << 7)) }, 368 | value, 369 | ), 370 | 11 => self.states_11_channels.insert( 371 | unsafe { transmute::<_, [_; 1 << 3]>(*array_ref!(packed, 0, 1 << 8)) }, 372 | value, 373 | ), 374 | _ => unreachable!(), 375 | }; 376 | 377 | // Works around vscode's broken syntax highlighting 378 | let _ignored: [(); 0 >> 1] = []; 379 | let _ignored: [(); 0 >> 1] = []; 380 | let _ignored: [(); 0 >> 1] = []; 381 | } 382 | 383 | pub fn remove_with_packed(&mut self, channels: usize, packed: &[u8]) { 384 | assert!(channels > 2); 385 | 386 | match channels { 387 | 3 => self.states_3_channels.remove(array_ref!(packed, 0, 1 << 0)), 388 | 4 => self.states_4_channels.remove(array_ref!(packed, 0, 1 << 1)), 389 | 5 => self.states_5_channels.remove(array_ref!(packed, 0, 1 << 2)), 390 | 6 => self.states_6_channels.remove(array_ref!(packed, 0, 1 << 3)), 391 | 7 => self.states_7_channels.remove(array_ref!(packed, 0, 1 << 4)), 392 | 8 => self.states_8_channels.remove(array_ref!(packed, 0, 1 << 5)), 393 | 394 | 9 => self 395 | .states_9_channels 396 | .remove(unsafe { &transmute::<_, [_; 1 << 1]>(*array_ref!(packed, 0, 1 << 6)) }), 397 | 10 => self 398 | .states_10_channels 399 | .remove(unsafe { &transmute::<_, [_; 1 << 2]>(*array_ref!(packed, 0, 1 << 7)) }), 400 | 11 => self 401 | .states_11_channels 402 | .remove(unsafe { &transmute::<_, [_; 1 << 3]>(*array_ref!(packed, 0, 1 << 8)) }), 403 | _ => unreachable!(), 404 | }; 405 | 406 | // Works around vscode's broken syntax highlighting 407 | let _ignored: [(); 0 >> 1] = []; 408 | let _ignored: [(); 0 >> 1] = []; 409 | let _ignored: [(); 0 >> 1] = []; 410 | } 411 | } 412 | -------------------------------------------------------------------------------- /src/search.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::hash_map::Entry, 3 | fs::{create_dir, File}, 4 | io::{self, BufWriter, ErrorKind, Write}, 5 | path::PathBuf, 6 | pin::Pin, 7 | sync::Arc, 8 | time::{Duration, Instant}, 9 | }; 10 | 11 | use async_std::{prelude::*, task, task::sleep}; 12 | use futures::{pending, poll, task::Poll}; 13 | use rustc_hash::FxHashMap as HashMap; 14 | 15 | mod states; 16 | 17 | use self::states::{State, StateMap}; 18 | use crate::{ 19 | huffman::max_plus_1_huffman, 20 | output_set::{CVec, OutputSet}, 21 | thread_pool::{Handle, Schedule, ThreadPool}, 22 | }; 23 | 24 | pub struct Search { 25 | states: StateMap, 26 | } 27 | 28 | type Prio = (usize, usize, u8); 29 | 30 | impl Search { 31 | pub fn search( 32 | initial: OutputSet<&[bool]>, 33 | limit: Option, 34 | output: Option, 35 | ) -> usize { 36 | let search = Self { 37 | states: StateMap::default(), 38 | }; 39 | 40 | let mut initial = initial.to_owned_bvec(); 41 | 42 | let limit = limit.unwrap_or(usize::max_value()); 43 | 44 | initial.canonicalize(true); 45 | 46 | let final_bound = ThreadPool::scope(|pool| { 47 | let _info_thread = pool.spawn(Box::pin(async { 48 | let mut last_msg = Instant::now(); 49 | loop { 50 | let next_msg = last_msg + Duration::from_secs(10); 51 | let sleep_for = next_msg.saturating_duration_since(Instant::now()); 52 | last_msg = next_msg; 53 | sleep(sleep_for).await; 54 | search.log_stats(); 55 | } 56 | })); 57 | 58 | let main_loop = pool.spawn(Box::pin(async { 59 | let mut state = search.states.get(initial.as_ref()); 60 | loop { 61 | if state.bounds[0] == state.bounds[1] || state.bounds[0] as usize >= limit { 62 | break state.bounds[0] as usize; 63 | } 64 | state = search.improve(pool, 0, state, initial.as_ref()).await; 65 | log::info!("bounds: {:?}", state); 66 | search.log_stats(); 67 | } 68 | })); 69 | 70 | task::block_on(main_loop) 71 | }); 72 | 73 | if let Some(output) = output { 74 | search.dump_states(output).unwrap(); 75 | } 76 | 77 | final_bound 78 | } 79 | 80 | fn dump_states(self, output: PathBuf) -> io::Result<()> { 81 | match create_dir(&output) { 82 | Err(err) if err.kind() == ErrorKind::AlreadyExists => Ok(()), 83 | res => res, 84 | }?; 85 | 86 | let mut index = BufWriter::new(File::create(output.join("index.txt"))?); 87 | 88 | let mut group_files = HashMap::<(usize, u8), BufWriter>::default(); 89 | for mut shard in self.states.into_shards() { 90 | for (channels, packed, state) in shard.drain_packed() { 91 | let group = (channels, state.bounds[0]); 92 | let group_file = match group_files.entry(group) { 93 | Entry::Occupied(entry) => entry.into_mut(), 94 | Entry::Vacant(entry) => { 95 | let group_name = format!("group_{}_{}.bin", group.0, group.1); 96 | writeln!(&mut index, "{}", group_name)?; 97 | let group_file = BufWriter::new(File::create(output.join(group_name))?); 98 | entry.insert(group_file) 99 | } 100 | }; 101 | 102 | group_file.write_all(&packed)?; 103 | } 104 | } 105 | 106 | Ok(()) 107 | } 108 | 109 | fn log_stats(&self) { 110 | log::info!("states: {:?}", self.states.len()); 111 | } 112 | 113 | fn improve_boxed<'a>( 114 | &'a self, 115 | pool: &'a ThreadPool, 116 | level: usize, 117 | previous_state: State, 118 | output_set: OutputSet, 119 | ) -> Pin + Send + 'a>> { 120 | Box::pin(async move { 121 | self.improve(pool, level, previous_state, output_set.as_ref()) 122 | .await 123 | }) 124 | } 125 | 126 | #[allow(unreachable_code)] 127 | async fn improve( 128 | &self, 129 | pool: &ThreadPool, 130 | level: usize, 131 | previous_state: State, 132 | output_set: OutputSet<&[bool]>, 133 | ) -> State { 134 | let _locked = loop { 135 | let state = self.states.get(output_set); 136 | if state != previous_state { 137 | return state; 138 | } 139 | if state.bounds[0] == state.bounds[1] { 140 | return state; 141 | } 142 | 143 | if let Some(locked) = self.states.lock(output_set).await { 144 | break locked; 145 | } 146 | }; 147 | 148 | let mut state = self.states.get(output_set); 149 | 150 | if state != previous_state { 151 | return state; 152 | } 153 | if state.bounds[0] == state.bounds[1] { 154 | return state; 155 | } 156 | 157 | let mut extremal_channels = [CVec::new(), CVec::new()]; 158 | 159 | for (pol, pol_channels) in extremal_channels.iter_mut().enumerate() { 160 | for channel in 0..output_set.channels() { 161 | if output_set.channel_is_extremal(pol > 0, channel) { 162 | pol_channels.push(channel); 163 | } 164 | } 165 | } 166 | 167 | for (pol, pol_channels) in extremal_channels.iter().enumerate() { 168 | if pol_channels.len() == 1 { 169 | let mut pruned_output_set = OutputSet::all_values(output_set.channels() - 1); 170 | output_set.prune_extremal_channel_into( 171 | pol > 0, 172 | pol_channels[0], 173 | pruned_output_set.as_mut(), 174 | ); 175 | 176 | pruned_output_set.canonicalize(true); 177 | 178 | let mut pruned_state = self.states.get(pruned_output_set.as_ref()); 179 | 180 | loop { 181 | if pruned_state.bounds[1] < state.bounds[1] { 182 | state.bounds[1] = pruned_state.bounds[1]; 183 | state.huffman_bounds[1] = pruned_state.bounds[1]; 184 | 185 | self.states.set(output_set.as_ref(), state); 186 | 187 | if state.bounds[0] == state.bounds[1] { 188 | return state; 189 | } 190 | } 191 | 192 | if pruned_state.bounds[0] > state.bounds[0] { 193 | state.bounds[0] = pruned_state.bounds[0]; 194 | state.huffman_bounds[0] = pruned_state.bounds[0]; 195 | 196 | self.states.set(output_set.as_ref(), state); 197 | 198 | return state; 199 | } 200 | 201 | pruned_state = self 202 | .improve_boxed(pool, level, pruned_state, pruned_output_set.clone()) 203 | .await; 204 | } 205 | } 206 | } 207 | 208 | if state.huffman_bounds[1] > state.bounds[0] { 209 | let new_state = self.improve_huffman(pool, level, output_set).await; 210 | if new_state.bounds[0] > state.bounds[0] { 211 | return new_state; 212 | } 213 | state = new_state; 214 | } 215 | 216 | let mut edges = Edges::default(); 217 | 218 | for i in 0..output_set.channels() { 219 | for j in 0..i { 220 | let mut target = output_set.to_owned(); 221 | if target.apply_comparator([i, j]) { 222 | target.canonicalize(true); 223 | 224 | edges.add_edge(self, target); 225 | } 226 | } 227 | } 228 | 229 | loop { 230 | let combined_upper_bound = 231 | edges.states().map(|state| state.bounds[1]).min().unwrap() + 1; 232 | 233 | if combined_upper_bound < state.bounds[1] { 234 | state.bounds[1] = combined_upper_bound; 235 | if combined_upper_bound < state.huffman_bounds[1] { 236 | state.huffman_bounds[1] = combined_upper_bound; 237 | } 238 | self.states.set(output_set, state); 239 | if state.bounds[0] == state.bounds[1] { 240 | return state; 241 | } 242 | } 243 | 244 | let combined_lower_bound = 245 | edges.states().map(|state| state.bounds[0]).min().unwrap() + 1; 246 | 247 | if combined_lower_bound > state.bounds[0] { 248 | let mut state = self.states.get(output_set); 249 | state.bounds[0] = combined_lower_bound; 250 | self.states.set(output_set, state); 251 | 252 | return state; 253 | } 254 | 255 | edges 256 | .improve_next(self, pool, level + 1, Some(state.bounds[0])) 257 | .await; 258 | } 259 | } 260 | 261 | async fn improve_huffman( 262 | &self, 263 | pool: &ThreadPool, 264 | level: usize, 265 | output_set: OutputSet<&[bool]>, 266 | ) -> State { 267 | let mut state = self.states.get(output_set); 268 | 269 | if state.huffman_bounds[1] <= state.bounds[0] { 270 | return state; 271 | } 272 | 273 | let mut extremal_channels = [CVec::new(), CVec::new()]; 274 | 275 | for (pol, pol_channels) in extremal_channels.iter_mut().enumerate() { 276 | for channel in 0..output_set.channels() { 277 | if output_set.channel_is_extremal(pol > 0, channel) { 278 | pol_channels.push(channel); 279 | } 280 | } 281 | } 282 | 283 | let mut edges = Edges::default(); 284 | 285 | let mut pruned_ids = [vec![], vec![]]; 286 | 287 | for (pol, (pol_channels, pol_ids)) in extremal_channels 288 | .iter() 289 | .zip(pruned_ids.iter_mut()) 290 | .enumerate() 291 | { 292 | for &channel in pol_channels.iter() { 293 | let mut target = OutputSet::all_values(output_set.channels() - 1); 294 | output_set.prune_extremal_channel_into(pol > 0, channel, target.as_mut()); 295 | target.canonicalize(true); 296 | 297 | pol_ids.push(edges.add_edge(self, target)); 298 | } 299 | } 300 | 301 | let mut upper_huffman_bounds = [state.huffman_bounds[1], state.huffman_bounds[1]]; 302 | 303 | while state.huffman_bounds[1] > state.bounds[0] { 304 | for (pol, pol_ids) in pruned_ids.iter().enumerate() { 305 | if upper_huffman_bounds[pol] > state.bounds[0] { 306 | let lower_bounds = pol_ids 307 | .iter() 308 | .map(|&id| edges.state(id).bounds[0]) 309 | .collect::>(); 310 | let upper_bounds = pol_ids 311 | .iter() 312 | .map(|&id| edges.state(id).bounds[1]) 313 | .collect::>(); 314 | 315 | let lower_huffman_bound = max_plus_1_huffman(&lower_bounds); 316 | let upper_huffman_bound = max_plus_1_huffman(&upper_bounds); 317 | 318 | if upper_huffman_bound <= state.bounds[0] { 319 | edges.retain_edges(|id| pruned_ids[1 - pol].contains(&id)); 320 | } 321 | 322 | upper_huffman_bounds[pol] = upper_huffman_bound; 323 | 324 | if lower_huffman_bound > state.huffman_bounds[0] { 325 | state.huffman_bounds[0] = lower_huffman_bound; 326 | if lower_huffman_bound > state.bounds[0] { 327 | state.bounds[0] = lower_huffman_bound; 328 | self.states.set(output_set, state); 329 | return state; 330 | } 331 | self.states.set(output_set, state); 332 | } 333 | } 334 | } 335 | 336 | let max_upper_huffman_bound = *upper_huffman_bounds.iter().max().unwrap(); 337 | 338 | if max_upper_huffman_bound < state.huffman_bounds[1] { 339 | state.huffman_bounds[1] = max_upper_huffman_bound; 340 | self.states.set(output_set, state); 341 | } 342 | 343 | edges.improve_next(self, pool, level + 1, None).await; 344 | } 345 | 346 | state 347 | } 348 | } 349 | 350 | #[derive(Default)] 351 | struct Edges { 352 | target_to_id: HashMap, usize>, 353 | targets: Vec, 354 | active_ids: Vec, 355 | } 356 | 357 | struct EdgeTarget { 358 | output_set: Arc, 359 | state: State, 360 | len: usize, 361 | running: Option<(Handle, Schedule)>, 362 | } 363 | 364 | impl Edges { 365 | fn add_edge(&mut self, search: &Search, target: OutputSet) -> usize { 366 | let target = Arc::new(target); 367 | 368 | let targets = &mut self.targets; 369 | let active_ids = &mut self.active_ids; 370 | match self.target_to_id.entry(target.clone()) { 371 | Entry::Occupied(entry) => *entry.get(), 372 | Entry::Vacant(entry) => { 373 | let id = targets.len(); 374 | let state = search.states.get(target.as_ref().as_ref()); 375 | let len = target.len(); 376 | targets.push(EdgeTarget { 377 | output_set: target, 378 | state, 379 | len, 380 | running: None, 381 | }); 382 | if state.bounds[0] != state.bounds[1] { 383 | active_ids.push(id); 384 | } 385 | entry.insert(id); 386 | id 387 | } 388 | } 389 | } 390 | 391 | fn states<'a>(&'a self) -> impl Iterator + 'a { 392 | self.targets.iter().map(|target| target.state) 393 | } 394 | 395 | fn state(&self, id: usize) -> State { 396 | self.targets[id].state 397 | } 398 | 399 | fn retain_edges(&mut self, mut should_retain: impl FnMut(usize) -> bool) { 400 | let targets = &mut self.targets; 401 | self.active_ids.retain(move |&id| { 402 | let retain = should_retain(id); 403 | if !retain { 404 | targets[id].running = None; 405 | } 406 | retain 407 | }) 408 | } 409 | 410 | #[allow(unreachable_code)] 411 | async fn improve_next( 412 | &mut self, 413 | search: &Search, 414 | pool: &ThreadPool, 415 | level: usize, 416 | order: Option, 417 | ) { 418 | let targets = &mut self.targets; 419 | 420 | let mut block = false; 421 | 422 | loop { 423 | let mut finished = false; 424 | 425 | let mut num_running = 0; 426 | 427 | for &id in self.active_ids.iter() { 428 | let target = &mut targets[id]; 429 | 430 | if let Some((handle, schedule)) = target.running.as_mut() { 431 | if let Poll::Ready(new_state) = poll!(handle) { 432 | target.state = new_state; 433 | target.running = None; 434 | finished = true; 435 | } else if schedule.is_scheduled() { 436 | num_running += 1; 437 | } 438 | } 439 | } 440 | 441 | self.active_ids.retain(|&id| { 442 | let state = targets[id].state; 443 | let retain = state.bounds[0] != state.bounds[1]; 444 | if !retain { 445 | assert!(targets[id].running.is_none()); 446 | } 447 | retain 448 | }); 449 | 450 | if finished || self.active_ids.is_empty() { 451 | return; 452 | } 453 | 454 | if block && num_running > 0 { 455 | pending!(); 456 | } 457 | 458 | block = true; 459 | 460 | if let Some(limit) = order { 461 | self.active_ids.sort_by_key(|&id| { 462 | let target = &targets[id]; 463 | 464 | ( 465 | target.state.bounds[0] >= limit, 466 | target.len, 467 | target.state.bounds[0], 468 | target.state.bounds[1], 469 | ) 470 | }); 471 | } else { 472 | self.active_ids.sort_by_key(|&id| { 473 | let target = &targets[id]; 474 | 475 | (target.state.bounds[0], target.len, target.state.bounds[1]) 476 | }); 477 | } 478 | 479 | for (index, &id) in self.active_ids.iter().enumerate() { 480 | let target = &mut targets[id]; 481 | if target.running.is_none() { 482 | block = false; 483 | assert_ne!(target.state.bounds[0], target.state.bounds[1]); 484 | 485 | let (handle, schedule) = pool.spawn_delayed(search.improve_boxed( 486 | pool, 487 | level, 488 | target.state, 489 | target.output_set.as_ref().clone(), 490 | )); 491 | if index != 0 { 492 | if let Some(limit) = order { 493 | if target.state.bounds[0] < limit { 494 | pool.add_pending( 495 | (level, target.len, target.state.bounds[0]), 496 | &schedule, 497 | ); 498 | } 499 | } 500 | } 501 | target.running = Some((handle, schedule)); 502 | } 503 | 504 | if index == 0 { 505 | target.running.as_ref().unwrap().1.schedule(); 506 | } 507 | } 508 | } 509 | } 510 | } 511 | -------------------------------------------------------------------------------- /checker/snocheck/src/Verified/Checker.hs: -------------------------------------------------------------------------------- 1 | -- Generated using update_extracted_code.sh, do not edit -- 2 | {-# LANGUAGE EmptyDataDecls, RankNTypes, ScopedTypeVariables #-} 3 | 4 | module 5 | Verified.Checker(Int(..), Proof_witness(..), Proof_step_witnesses(..), 6 | Proof_step(..), Proof_cert(..), integer_of_int, 7 | check_proof_get_bound) 8 | where { 9 | 10 | import Prelude ((==), (/=), (<), (<=), (>=), (>), (+), (-), (*), (/), (**), 11 | (>>=), (>>), (=<<), (&&), (||), (^), (^^), (.), ($), ($!), (++), (!!), Eq, 12 | error, id, return, not, fst, snd, map, filter, concat, concatMap, reverse, 13 | zip, null, takeWhile, dropWhile, all, any, Integer, negate, abs, divMod, 14 | String, Bool(True, False), Maybe(Nothing, Just)); 15 | import qualified Prelude; 16 | import qualified Parallel; 17 | 18 | newtype Nat = Nat Integer; 19 | 20 | integer_of_nat :: Nat -> Integer; 21 | integer_of_nat (Nat x) = x; 22 | 23 | equal_nat :: Nat -> Nat -> Bool; 24 | equal_nat m n = integer_of_nat m == integer_of_nat n; 25 | 26 | instance Eq Nat where { 27 | a == b = equal_nat a b; 28 | }; 29 | 30 | less_eq_nat :: Nat -> Nat -> Bool; 31 | less_eq_nat m n = integer_of_nat m <= integer_of_nat n; 32 | 33 | class Ord a where { 34 | less_eq :: a -> a -> Bool; 35 | less :: a -> a -> Bool; 36 | }; 37 | 38 | less_nat :: Nat -> Nat -> Bool; 39 | less_nat m n = integer_of_nat m < integer_of_nat n; 40 | 41 | instance Ord Nat where { 42 | less_eq = less_eq_nat; 43 | less = less_nat; 44 | }; 45 | 46 | class (Ord a) => Preorder a where { 47 | }; 48 | 49 | class (Preorder a) => Order a where { 50 | }; 51 | 52 | instance Preorder Nat where { 53 | }; 54 | 55 | instance Order Nat where { 56 | }; 57 | 58 | class (Order a) => Linorder a where { 59 | }; 60 | 61 | instance Linorder Nat where { 62 | }; 63 | 64 | instance Ord Integer where { 65 | less_eq = (\ a b -> a <= b); 66 | less = (\ a b -> a < b); 67 | }; 68 | 69 | newtype Int = Int_of_integer Integer; 70 | 71 | data Num = One | Bit0 Num | Bit1 Num; 72 | 73 | data Set a = Set [a] | Coset [a]; 74 | 75 | data Vect_trie = VtEmpty | VtNode !Bool !Vect_trie !Vect_trie; 76 | 77 | newtype Multiset a = Mset [a]; 78 | 79 | data Proof_witness = ProofWitness Int Bool [Int]; 80 | 81 | data Proof_step_witnesses = HuffmanWitnesses Bool [Maybe Proof_witness] 82 | | SuccessorWitnesses [Maybe Proof_witness]; 83 | 84 | data Proof_step = ProofStep Int [[Bool]] Int Proof_step_witnesses; 85 | 86 | data Proof_cert = ProofCert Int (Int -> Proof_step); 87 | 88 | integer_of_int :: Int -> Integer; 89 | integer_of_int (Int_of_integer k) = k; 90 | 91 | max :: forall a. (Ord a) => a -> a -> a; 92 | max a b = (if less_eq a b then b else a); 93 | 94 | nat :: Int -> Nat; 95 | nat k = Nat (max (0 :: Integer) (integer_of_int k)); 96 | 97 | plus_nat :: Nat -> Nat -> Nat; 98 | plus_nat m n = Nat (integer_of_nat m + integer_of_nat n); 99 | 100 | one_nat :: Nat; 101 | one_nat = Nat (1 :: Integer); 102 | 103 | suc :: Nat -> Nat; 104 | suc n = plus_nat n one_nat; 105 | 106 | minus_nat :: Nat -> Nat -> Nat; 107 | minus_nat m n = Nat (max (0 :: Integer) (integer_of_nat m - integer_of_nat n)); 108 | 109 | zero_nat :: Nat; 110 | zero_nat = Nat (0 :: Integer); 111 | 112 | nth :: forall a. [a] -> Nat -> a; 113 | nth (x : xs) n = 114 | (if equal_nat n zero_nat then x else nth xs (minus_nat n one_nat)); 115 | 116 | upt :: Nat -> Nat -> [Nat]; 117 | upt i j = (if less_nat i j then i : upt (suc i) j else []); 118 | 119 | fold :: forall a b. (a -> b -> b) -> [a] -> b -> b; 120 | fold f (x : xs) s = fold f xs (f x s); 121 | fold f [] s = s; 122 | 123 | par :: forall a b. a -> b -> b; 124 | par = Parallel.par; 125 | 126 | member :: forall a. (Eq a) => [a] -> a -> Bool; 127 | member [] y = False; 128 | member (x : xs) y = x == y || member xs y; 129 | 130 | pred :: Nat -> Nat; 131 | pred = nat_pred_code; 132 | 133 | nat_pred_code :: Nat -> Nat; 134 | nat_pred_code n = 135 | (if equal_nat n zero_nat then pred zero_nat else minus_nat n one_nat); 136 | 137 | the_elem :: forall a. Set a -> a; 138 | the_elem (Set [x]) = x; 139 | 140 | distinct :: forall a. (Eq a) => [a] -> Bool; 141 | distinct [] = True; 142 | distinct (x : xs) = not (member xs x) && distinct xs; 143 | 144 | is_none :: forall a. Maybe a -> Bool; 145 | is_none (Just x) = False; 146 | is_none Nothing = True; 147 | 148 | list_vt_extend :: Vect_trie -> ([Bool] -> [Bool]) -> [[Bool]] -> [[Bool]]; 149 | list_vt_extend VtEmpty el_prefix suffix = suffix; 150 | list_vt_extend (VtNode True lo hi) el_prefix suffix = 151 | el_prefix [] : 152 | list_vt_extend lo (el_prefix . (\ a -> False : a)) 153 | (list_vt_extend hi (el_prefix . (\ a -> True : a)) suffix); 154 | list_vt_extend (VtNode False lo hi) el_prefix suffix = 155 | list_vt_extend lo (el_prefix . (\ a -> False : a)) 156 | (list_vt_extend hi (el_prefix . (\ a -> True : a)) suffix); 157 | 158 | list_vt :: Vect_trie -> [[Bool]]; 159 | list_vt a = list_vt_extend a id []; 160 | 161 | vt_singleton :: [Bool] -> Vect_trie; 162 | vt_singleton [] = VtNode True VtEmpty VtEmpty; 163 | vt_singleton (False : xs) = VtNode False (vt_singleton xs) VtEmpty; 164 | vt_singleton (True : xs) = VtNode False VtEmpty (vt_singleton xs); 165 | 166 | vt_union :: Vect_trie -> Vect_trie -> Vect_trie; 167 | vt_union VtEmpty VtEmpty = VtEmpty; 168 | vt_union (VtNode v va vb) VtEmpty = VtNode v va vb; 169 | vt_union VtEmpty (VtNode v va vb) = VtNode v va vb; 170 | vt_union (VtNode a a_lo a_hi) (VtNode b b_lo b_hi) = 171 | VtNode (a || b) (vt_union a_lo b_lo) (vt_union a_hi b_hi); 172 | 173 | vt_list :: [[Bool]] -> Vect_trie; 174 | vt_list ls = fold (vt_union . vt_singleton) ls VtEmpty; 175 | 176 | gen_length :: forall a. Nat -> [a] -> Nat; 177 | gen_length n (x : xs) = gen_length (suc n) xs; 178 | gen_length n [] = n; 179 | 180 | list_update :: forall a. [a] -> Nat -> a -> [a]; 181 | list_update [] i y = []; 182 | list_update (x : xs) i y = 183 | (if equal_nat i zero_nat then y : xs 184 | else x : list_update xs (minus_nat i one_nat) y); 185 | 186 | witness_step_id :: Proof_witness -> Int; 187 | witness_step_id (ProofWitness x1 x2 x3) = x1; 188 | 189 | witness_invert :: Proof_witness -> Bool; 190 | witness_invert (ProofWitness x1 x2 x3) = x2; 191 | 192 | witness_perm :: Proof_witness -> [Int]; 193 | witness_perm (ProofWitness x1 x2 x3) = x3; 194 | 195 | step_vect_list :: Proof_step -> [[Bool]]; 196 | step_vect_list (ProofStep x1 x2 x3 x4) = x2; 197 | 198 | size_list :: forall a. [a] -> Nat; 199 | size_list = gen_length zero_nat; 200 | 201 | step_bound :: Proof_step -> Int; 202 | step_bound (ProofStep x1 x2 x3 x4) = x3; 203 | 204 | less_eq_int :: Int -> Int -> Bool; 205 | less_eq_int k l = integer_of_int k <= integer_of_int l; 206 | 207 | zero_int :: Int; 208 | zero_int = Int_of_integer (0 :: Integer); 209 | 210 | less_int :: Int -> Int -> Bool; 211 | less_int k l = integer_of_int k < integer_of_int l; 212 | 213 | length_vt :: Vect_trie -> Nat; 214 | length_vt VtEmpty = zero_nat; 215 | length_vt (VtNode True lo hi) = suc (plus_nat (length_vt lo) (length_vt hi)); 216 | length_vt (VtNode False lo hi) = plus_nat (length_vt lo) (length_vt hi); 217 | 218 | is_unsorted_vt :: Nat -> Vect_trie -> Bool; 219 | is_unsorted_vt n a = less_nat (suc n) (length_vt a); 220 | 221 | is_subset_vt :: Vect_trie -> Vect_trie -> Bool; 222 | is_subset_vt VtEmpty a = True; 223 | is_subset_vt (VtNode True a_lo a_hi) VtEmpty = False; 224 | is_subset_vt (VtNode False a_lo a_hi) VtEmpty = 225 | is_subset_vt a_lo VtEmpty && is_subset_vt a_hi VtEmpty; 226 | is_subset_vt (VtNode True a_lo a_hi) (VtNode False b_lo b_hi) = False; 227 | is_subset_vt (VtNode False a_lo a_hi) (VtNode b b_lo b_hi) = 228 | is_subset_vt a_lo b_lo && is_subset_vt a_hi b_hi; 229 | is_subset_vt (VtNode True a_lo a_hi) (VtNode True b_lo b_hi) = 230 | is_subset_vt a_lo b_lo && is_subset_vt a_hi b_hi; 231 | 232 | permute_list_vect :: [Nat] -> [Bool] -> [Bool]; 233 | permute_list_vect ps xs = map (nth xs) ps; 234 | 235 | permute_vt :: [Nat] -> Vect_trie -> Vect_trie; 236 | permute_vt ps a = vt_list (map (permute_list_vect ps) (list_vt a)); 237 | 238 | invert_vt :: Bool -> Vect_trie -> Vect_trie; 239 | invert_vt z a = (if z then vt_list (map (map not) (list_vt a)) else a); 240 | 241 | get_bound :: 242 | (Int -> Proof_step) -> 243 | Int -> Maybe Proof_witness -> Nat -> Vect_trie -> Maybe Nat; 244 | get_bound proof_steps step_limit Nothing width a = 245 | Just (if is_unsorted_vt width a then one_nat else zero_nat); 246 | get_bound proof_steps step_limit (Just witness) width a = 247 | let { 248 | witness_id = witness_step_id witness; 249 | perm = map nat (witness_perm witness); 250 | step = proof_steps witness_id; 251 | b_list = step_vect_list step; 252 | b = vt_list b_list; 253 | ba = permute_vt perm (invert_vt (witness_invert witness) b); 254 | } in (if not (less_eq_int zero_int witness_id && 255 | less_int witness_id step_limit) || 256 | (not (all (\ i -> less_eq_nat zero_nat i && less_nat i width) 257 | perm) || 258 | (not (equal_nat (size_list perm) width) || 259 | (not (distinct perm) || 260 | (not (all (\ xs -> equal_nat (size_list xs) width) b_list) || 261 | not (is_subset_vt ba a))))) 262 | then Nothing else Just (nat (step_bound step))); 263 | 264 | ocmp_list :: Nat -> [(Nat, Nat)]; 265 | ocmp_list n = 266 | concatMap (\ i -> map (\ j -> (j, i)) (upt zero_nat i)) (upt zero_nat n); 267 | 268 | set_mset :: forall a. Multiset a -> Set a; 269 | set_mset (Mset xs) = Set xs; 270 | 271 | the :: forall a. Maybe a -> a; 272 | the (Just x2) = x2; 273 | 274 | step_witnesses :: Proof_step -> Proof_step_witnesses; 275 | step_witnesses (ProofStep x1 x2 x3 x4) = x4; 276 | 277 | step_width :: Proof_step -> Int; 278 | step_width (ProofStep x1 x2 x3 x4) = x1; 279 | 280 | is_redundant_cmp_vt :: (Nat, Nat) -> Vect_trie -> Bool; 281 | is_redundant_cmp_vt (aa, b) a = 282 | let { 283 | vs = list_vt a; 284 | } in not (not (all (\ x -> not (nth x aa && not (nth x b))) vs) && 285 | not (all (\ x -> not (not (nth x aa) && nth x b)) vs)); 286 | 287 | apply_cmp_list :: (Nat, Nat) -> [Bool] -> [Bool]; 288 | apply_cmp_list (a, b) xs = 289 | let { 290 | xa = nth xs a; 291 | xb = nth xs b; 292 | } in list_update (list_update xs a (xa && xb)) b (xa || xb); 293 | 294 | list_all2 :: forall a b. (a -> b -> Bool) -> [a] -> [b] -> Bool; 295 | list_all2 p (x : xs) (y : ys) = p x y && list_all2 p xs ys; 296 | list_all2 p xs [] = null xs; 297 | list_all2 p [] ys = null ys; 298 | 299 | check_successors :: (Int -> Proof_step) -> Int -> Proof_step -> Bool; 300 | check_successors proof_steps step_limit step = 301 | (case step_witnesses step of { 302 | HuffmanWitnesses _ _ -> False; 303 | SuccessorWitnesses witnesses -> 304 | let { 305 | width = nat (step_width step); 306 | bound = nat (step_bound step); 307 | ocmps = ocmp_list width; 308 | a_list = step_vect_list step; 309 | a = vt_list a_list; 310 | nrcmps = filter (\ c -> not (is_redundant_cmp_vt c a)) ocmps; 311 | bs = map (\ c -> vt_list (map (apply_cmp_list c) a_list)) nrcmps; 312 | } in not (equal_nat bound zero_nat) && 313 | is_unsorted_vt width a && 314 | equal_nat (size_list nrcmps) (size_list witnesses) && 315 | all (\ xs -> equal_nat (size_list xs) width) a_list && 316 | list_all2 317 | (\ b w -> 318 | (case get_bound proof_steps step_limit w width b of { 319 | Nothing -> False; 320 | Just ba -> less_eq_nat bound (suc ba); 321 | })) 322 | bs witnesses; 323 | }); 324 | 325 | sucmax :: Nat -> Nat -> Nat; 326 | sucmax a b = suc (max a b); 327 | 328 | sucmax_huffman_step_sorted_list :: [Nat] -> Multiset Nat; 329 | sucmax_huffman_step_sorted_list (a1 : a2 : asa) = Mset (sucmax a1 a2 : asa); 330 | sucmax_huffman_step_sorted_list [] = Mset []; 331 | sucmax_huffman_step_sorted_list [v] = Mset [v]; 332 | 333 | apsnd :: forall a b c. (a -> b) -> (c, a) -> (c, b); 334 | apsnd f (x, y) = (x, f y); 335 | 336 | divmod_integer :: Integer -> Integer -> (Integer, Integer); 337 | divmod_integer k l = 338 | (if k == (0 :: Integer) then ((0 :: Integer), (0 :: Integer)) 339 | else (if (0 :: Integer) < l 340 | then (if (0 :: Integer) < k then divMod (abs k) (abs l) 341 | else (case divMod (abs k) (abs l) of { 342 | (r, s) -> 343 | (if s == (0 :: Integer) 344 | then (negate r, (0 :: Integer)) 345 | else (negate r - (1 :: Integer), l - s)); 346 | })) 347 | else (if l == (0 :: Integer) then ((0 :: Integer), k) 348 | else apsnd negate 349 | (if k < (0 :: Integer) then divMod (abs k) (abs l) 350 | else (case divMod (abs k) (abs l) of { 351 | (r, s) -> 352 | (if s == (0 :: Integer) 353 | then (negate r, (0 :: Integer)) 354 | else (negate r - (1 :: Integer), 355 | negate l - s)); 356 | }))))); 357 | 358 | divide_integer :: Integer -> Integer -> Integer; 359 | divide_integer k l = fst (divmod_integer k l); 360 | 361 | divide_nat :: Nat -> Nat -> Nat; 362 | divide_nat m n = Nat (divide_integer (integer_of_nat m) (integer_of_nat n)); 363 | 364 | part :: forall a b. (Linorder b) => (a -> b) -> b -> [a] -> ([a], ([a], [a])); 365 | part f pivot (x : xs) = 366 | (case part f pivot xs of { 367 | (lts, (eqs, gts)) -> 368 | let { 369 | xa = f x; 370 | } in (if less xa pivot then (x : lts, (eqs, gts)) 371 | else (if less pivot xa then (lts, (eqs, x : gts)) 372 | else (lts, (x : eqs, gts)))); 373 | }); 374 | part f pivot [] = ([], ([], [])); 375 | 376 | nat_of_integer :: Integer -> Nat; 377 | nat_of_integer k = Nat (max (0 :: Integer) k); 378 | 379 | sort_key :: forall a b. (Linorder b) => (a -> b) -> [a] -> [a]; 380 | sort_key f xs = 381 | (case xs of { 382 | [] -> []; 383 | [_] -> xs; 384 | [x, y] -> (if less_eq (f x) (f y) then xs else [y, x]); 385 | _ : _ : _ : _ -> 386 | (case part f 387 | (f (nth xs 388 | (divide_nat (size_list xs) (nat_of_integer (2 :: Integer))))) 389 | xs 390 | of { 391 | (lts, (eqs, gts)) -> sort_key f lts ++ eqs ++ sort_key f gts; 392 | }); 393 | }); 394 | 395 | sorted_list_of_multiset :: forall a. (Linorder a) => Multiset a -> [a]; 396 | sorted_list_of_multiset (Mset xs) = sort_key (\ x -> x) xs; 397 | 398 | size_multiset :: forall a. Multiset a -> Nat; 399 | size_multiset (Mset xs) = size_list xs; 400 | 401 | min :: forall a. (Ord a) => a -> a -> a; 402 | min a b = (if less_eq a b then a else b); 403 | 404 | mina :: forall a. (Linorder a) => Set a -> a; 405 | mina (Set (x : xs)) = fold min xs x; 406 | 407 | bot_set :: forall a. Set a; 408 | bot_set = Set []; 409 | 410 | sucmax_value_bound_huffman :: Multiset Nat -> Nat; 411 | sucmax_value_bound_huffman a = 412 | (if equal_nat (size_multiset a) zero_nat then mina bot_set 413 | else (if equal_nat (minus_nat (size_multiset a) one_nat) zero_nat 414 | then the_elem (set_mset a) 415 | else sucmax_value_bound_huffman 416 | (sucmax_huffman_step_sorted_list 417 | (sorted_list_of_multiset a)))); 418 | 419 | is_member_vt :: [Bool] -> Vect_trie -> Bool; 420 | is_member_vt uu VtEmpty = False; 421 | is_member_vt [] (VtNode a uv uw) = a; 422 | is_member_vt (False : xs) (VtNode ux a_lo uy) = is_member_vt xs a_lo; 423 | is_member_vt (True : xs) (VtNode uz va a_hi) = is_member_vt xs a_hi; 424 | 425 | extremal_channels_vt :: Vect_trie -> Nat -> Bool -> [Nat]; 426 | extremal_channels_vt a n pol = 427 | filter 428 | (\ i -> 429 | is_member_vt (map (\ j -> not (equal_nat j i == pol)) (upt zero_nat n)) a) 430 | (upt zero_nat n); 431 | 432 | prune_extremal_vt :: Bool -> Nat -> Vect_trie -> Vect_trie; 433 | prune_extremal_vt uu uv VtEmpty = VtEmpty; 434 | prune_extremal_vt True i (VtNode uw a_lo ux) = 435 | (if equal_nat i zero_nat then a_lo 436 | else VtNode uw (prune_extremal_vt True (minus_nat i one_nat) a_lo) 437 | (prune_extremal_vt True (minus_nat i one_nat) ux)); 438 | prune_extremal_vt False i (VtNode uy uz a_hi) = 439 | (if equal_nat i zero_nat then a_hi 440 | else VtNode uy (prune_extremal_vt False (minus_nat i one_nat) uz) 441 | (prune_extremal_vt False (minus_nat i one_nat) a_hi)); 442 | 443 | check_huffman :: (Int -> Proof_step) -> Int -> Proof_step -> Bool; 444 | check_huffman proof_steps step_limit step = 445 | (case step_witnesses step of { 446 | HuffmanWitnesses pol witnesses -> 447 | let { 448 | width = nat (step_width step); 449 | widtha = pred width; 450 | bound = nat (step_bound step); 451 | a_list = step_vect_list step; 452 | a = vt_list a_list; 453 | extremal_channels = extremal_channels_vt a width pol; 454 | bs = map (\ c -> prune_extremal_vt pol c a) extremal_channels; 455 | bounds = 456 | map (\ (x, y) -> get_bound proof_steps step_limit y widtha x) 457 | (zip bs witnesses); 458 | huffman_bound = sucmax_value_bound_huffman (Mset (map the bounds)); 459 | } in not (equal_nat width zero_nat) && 460 | all (\ xs -> equal_nat (size_list xs) width) a_list && 461 | not (null witnesses) && 462 | equal_nat (size_list extremal_channels) 463 | (size_list witnesses) && 464 | all (\ b -> not (is_none b)) bounds && 465 | less_eq_nat bound huffman_bound; 466 | SuccessorWitnesses _ -> False; 467 | }); 468 | 469 | check_step :: (Int -> Proof_step) -> Int -> Proof_step -> Bool; 470 | check_step proof_steps step_limit step = 471 | (case step_witnesses step of { 472 | HuffmanWitnesses _ _ -> check_huffman proof_steps step_limit step; 473 | SuccessorWitnesses _ -> check_successors proof_steps step_limit step; 474 | }); 475 | 476 | cert_length :: Proof_cert -> Int; 477 | cert_length (ProofCert x1 x2) = x1; 478 | 479 | cert_step :: Proof_cert -> Int -> Proof_step; 480 | cert_step (ProofCert x1 x2) = x2; 481 | 482 | int_of_nat :: Nat -> Int; 483 | int_of_nat n = Int_of_integer (integer_of_nat n); 484 | 485 | all_interval_nat :: (Nat -> Bool) -> Nat -> Nat -> Bool; 486 | all_interval_nat p i j = less_eq_nat j i || p i && all_interval_nat p (suc i) j; 487 | 488 | par_range_all :: (Nat -> Bool) -> Nat -> Nat -> Bool; 489 | par_range_all f lo n = 490 | (if less_nat n (nat_of_integer (1000 :: Integer)) 491 | then all_interval_nat f lo (plus_nat lo n) 492 | else let { 493 | na = divide_nat n (nat_of_integer (2 :: Integer)); 494 | a = par_range_all f lo na; 495 | b = par_range_all f (plus_nat lo na) (minus_nat n na); 496 | } in par b a && b); 497 | 498 | check_proof :: Proof_cert -> Bool; 499 | check_proof cert = 500 | let { 501 | steps = cert_step cert; 502 | n = cert_length cert; 503 | } in less_eq_int zero_int n && 504 | par_range_all 505 | (\ i -> check_step steps (int_of_nat i) (steps (int_of_nat i))) 506 | zero_nat (nat n); 507 | 508 | one_int :: Int; 509 | one_int = Int_of_integer (1 :: Integer); 510 | 511 | minus_int :: Int -> Int -> Int; 512 | minus_int k l = Int_of_integer (integer_of_int k - integer_of_int l); 513 | 514 | check_proof_get_bound :: Proof_cert -> Maybe (Int, Int); 515 | check_proof_get_bound cert = 516 | (if check_proof cert && less_int zero_int (cert_length cert) 517 | then let { 518 | last_step = cert_step cert (minus_int (cert_length cert) one_int); 519 | } in Just (step_width last_step, step_bound last_step) 520 | else Nothing); 521 | 522 | } 523 | -------------------------------------------------------------------------------- /src/output_set/index.rs: -------------------------------------------------------------------------------- 1 | use std::{cmp::Reverse, iter::repeat, mem::replace}; 2 | 3 | use super::{BVec, CVec, OutputSet}; 4 | 5 | mod tree; 6 | 7 | use tree::{Augmentation, TraversalMut, Tree}; 8 | 9 | pub enum Lower {} 10 | pub enum LowerInvert {} 11 | 12 | pub enum Upper {} 13 | 14 | const TREE_THRESHOLD: usize = 32; 15 | 16 | pub trait IndexDirection { 17 | type Perm; 18 | 19 | fn lookup_dir() -> bool; 20 | 21 | fn can_improve(best_so_far: Option, range: [u8; 2]) -> bool; 22 | fn does_improve(best_so_far: Option, value: u8) -> bool; 23 | 24 | fn can_be_updated(range: [u8; 2], value: u8) -> bool; 25 | fn would_be_updated(candidate_value: u8, lookup_value: u8) -> bool; 26 | 27 | fn test_abstraction_range(candidate_range: &[[u16; 2]], lookup: &[u16]) -> bool; 28 | 29 | fn test_abstraction_range_update(candidate_range: &[[u16; 2]], lookup: &[u16]) -> bool; 30 | 31 | fn test_abstraction(candidate: &[u16], lookup: &[u16]) -> bool; 32 | 33 | fn test_abstraction_update(candidate: &[u16], lookup: &[u16]) -> bool { 34 | Self::test_abstraction(lookup, candidate) 35 | } 36 | 37 | fn test_precise( 38 | candidate: OutputSet<&[bool]>, 39 | candidate_abstraction: &[u16], 40 | lookup: OutputSet<&[bool]>, 41 | lookup_abstraction: &[u16], 42 | ) -> Option; 43 | 44 | fn test_precise_update( 45 | candidate: OutputSet<&[bool]>, 46 | candidate_abstraction: &[u16], 47 | lookup: OutputSet<&[bool]>, 48 | lookup_abstraction: &[u16], 49 | ) -> bool { 50 | Self::test_precise(lookup, lookup_abstraction, candidate, candidate_abstraction).is_some() 51 | } 52 | 53 | fn id_perm(channels: usize) -> Self::Perm; 54 | } 55 | 56 | impl IndexDirection for Lower { 57 | type Perm = CVec; 58 | 59 | fn lookup_dir() -> bool { 60 | false 61 | } 62 | 63 | fn can_improve(best_so_far: Option, range: [u8; 2]) -> bool { 64 | if let Some(best_so_far) = best_so_far { 65 | range[1] > best_so_far 66 | } else { 67 | true 68 | } 69 | } 70 | 71 | fn does_improve(best_so_far: Option, value: u8) -> bool { 72 | if let Some(best_so_far) = best_so_far { 73 | value > best_so_far 74 | } else { 75 | true 76 | } 77 | } 78 | 79 | fn can_be_updated(range: [u8; 2], value: u8) -> bool { 80 | range[0] <= value 81 | } 82 | 83 | fn would_be_updated(candidate_value: u8, lookup_value: u8) -> bool { 84 | candidate_value <= lookup_value 85 | } 86 | 87 | fn test_abstraction_range(candidate_range: &[[u16; 2]], lookup: &[u16]) -> bool { 88 | candidate_range 89 | .iter() 90 | .zip(lookup.iter()) 91 | .all(|(candidate, &lookup)| candidate[0] <= lookup) 92 | } 93 | 94 | fn test_abstraction_range_update(candidate_range: &[[u16; 2]], lookup: &[u16]) -> bool { 95 | candidate_range 96 | .iter() 97 | .zip(lookup.iter()) 98 | .all(|(candidate, &lookup)| candidate[1] >= lookup) 99 | } 100 | 101 | fn test_abstraction(candidate: &[u16], lookup: &[u16]) -> bool { 102 | candidate 103 | .iter() 104 | .zip(lookup.iter()) 105 | .all(|(&candidate, &lookup)| candidate <= lookup) 106 | } 107 | 108 | fn test_precise( 109 | candidate: OutputSet<&[bool]>, 110 | _candidate_abstraction: &[u16], 111 | lookup: OutputSet<&[bool]>, 112 | _lookup_abstraction: &[u16], 113 | ) -> Option { 114 | candidate.subsumes_permuted(lookup) 115 | } 116 | 117 | fn id_perm(channels: usize) -> Self::Perm { 118 | (0..channels).collect() 119 | } 120 | } 121 | 122 | impl IndexDirection for LowerInvert { 123 | type Perm = (bool, CVec); 124 | 125 | fn lookup_dir() -> bool { 126 | false 127 | } 128 | 129 | fn can_improve(best_so_far: Option, range: [u8; 2]) -> bool { 130 | if let Some(best_so_far) = best_so_far { 131 | range[1] > best_so_far 132 | } else { 133 | true 134 | } 135 | } 136 | 137 | fn does_improve(best_so_far: Option, value: u8) -> bool { 138 | if let Some(best_so_far) = best_so_far { 139 | value > best_so_far 140 | } else { 141 | true 142 | } 143 | } 144 | 145 | fn can_be_updated(range: [u8; 2], value: u8) -> bool { 146 | range[0] <= value 147 | } 148 | 149 | fn would_be_updated(candidate_value: u8, lookup_value: u8) -> bool { 150 | candidate_value <= lookup_value 151 | } 152 | 153 | fn test_abstraction_range(candidate_range: &[[u16; 2]], lookup: &[u16]) -> bool { 154 | (0..2).any(|invert| { 155 | let mask = invert * 3; 156 | candidate_range 157 | .iter() 158 | .enumerate() 159 | .all(|(i, candidate)| candidate[0] <= lookup[i ^ mask]) 160 | }) 161 | } 162 | 163 | fn test_abstraction_range_update(candidate_range: &[[u16; 2]], lookup: &[u16]) -> bool { 164 | (0..2).any(|invert| { 165 | let mask = invert * 3; 166 | candidate_range 167 | .iter() 168 | .enumerate() 169 | .all(|(i, candidate)| candidate[1] >= lookup[i ^ mask]) 170 | }) 171 | } 172 | 173 | fn test_abstraction(candidate: &[u16], lookup: &[u16]) -> bool { 174 | (0..2).any(|invert| { 175 | let mask = invert * 3; 176 | candidate 177 | .iter() 178 | .enumerate() 179 | .all(|(i, &candidate)| candidate <= lookup[i ^ mask]) 180 | }) 181 | } 182 | 183 | fn test_precise( 184 | candidate: OutputSet<&[bool]>, 185 | candidate_abstraction: &[u16], 186 | lookup: OutputSet<&[bool]>, 187 | lookup_abstraction: &[u16], 188 | ) -> Option { 189 | if Lower::test_abstraction(candidate_abstraction, lookup_abstraction) { 190 | if let Some(perm) = candidate.subsumes_permuted(lookup) { 191 | return Some((false, perm)); 192 | } 193 | } 194 | let mut inverted = candidate.to_owned(); 195 | inverted.invert(); 196 | if let Some(perm) = inverted.subsumes_permuted(lookup) { 197 | return Some((true, perm)); 198 | } 199 | None 200 | } 201 | 202 | fn id_perm(channels: usize) -> Self::Perm { 203 | (false, (0..channels).collect()) 204 | } 205 | } 206 | 207 | impl IndexDirection for Upper { 208 | type Perm = CVec; 209 | 210 | fn lookup_dir() -> bool { 211 | true 212 | } 213 | 214 | fn can_improve(best_so_far: Option, range: [u8; 2]) -> bool { 215 | if let Some(best_so_far) = best_so_far { 216 | range[0] < best_so_far 217 | } else { 218 | true 219 | } 220 | } 221 | 222 | fn does_improve(best_so_far: Option, value: u8) -> bool { 223 | if let Some(best_so_far) = best_so_far { 224 | value < best_so_far 225 | } else { 226 | true 227 | } 228 | } 229 | 230 | fn can_be_updated(range: [u8; 2], value: u8) -> bool { 231 | range[1] >= value 232 | } 233 | 234 | fn would_be_updated(candidate_value: u8, lookup_value: u8) -> bool { 235 | candidate_value >= lookup_value 236 | } 237 | 238 | fn test_abstraction_range(candidate_range: &[[u16; 2]], lookup: &[u16]) -> bool { 239 | candidate_range 240 | .iter() 241 | .zip(lookup.iter()) 242 | .all(|(candidate, &lookup)| candidate[1] >= lookup) 243 | } 244 | 245 | fn test_abstraction_range_update(candidate_range: &[[u16; 2]], lookup: &[u16]) -> bool { 246 | candidate_range 247 | .iter() 248 | .zip(lookup.iter()) 249 | .all(|(candidate, &lookup)| candidate[0] <= lookup) 250 | } 251 | 252 | fn test_abstraction(candidate: &[u16], lookup: &[u16]) -> bool { 253 | candidate 254 | .iter() 255 | .zip(lookup.iter()) 256 | .all(|(&candidate, &lookup)| candidate >= lookup) 257 | } 258 | 259 | fn test_precise( 260 | candidate: OutputSet<&[bool]>, 261 | _candidate_abstraction: &[u16], 262 | lookup: OutputSet<&[bool]>, 263 | _lookup_abstraction: &[u16], 264 | ) -> Option { 265 | lookup.subsumes_permuted(candidate) 266 | } 267 | 268 | fn id_perm(channels: usize) -> Self::Perm { 269 | (0..channels).collect() 270 | } 271 | } 272 | 273 | pub struct OutputSetIndex { 274 | direction: std::marker::PhantomData, 275 | channels: usize, 276 | trees: Vec, 277 | point_dim: usize, 278 | points: Vec, 279 | packed_dim: usize, 280 | packed: Vec, 281 | values: Vec, 282 | } 283 | 284 | impl OutputSetIndex { 285 | pub fn new(channels: usize) -> Self { 286 | Self { 287 | direction: std::marker::PhantomData, 288 | channels, 289 | trees: vec![], 290 | point_dim: OutputSet::abstraction_len_for_channels(channels), 291 | points: vec![], 292 | packed_dim: OutputSet::packed_len_for_channels(channels), 293 | packed: vec![], 294 | values: vec![], 295 | } 296 | } 297 | 298 | pub fn lookup_subsuming_with_abstraction( 299 | &self, 300 | output_set: OutputSet<&[bool]>, 301 | abstraction: &[u16], 302 | ) -> Option<(u8, Dir::Perm, OutputSet)> { 303 | let mut best_so_far = None; 304 | 305 | let mut bitmap = repeat(false).take(1 << self.channels).collect::>(); 306 | let mut candidate_output_set = OutputSet::from_bitmap(self.channels, &mut bitmap[..]); 307 | 308 | let mut node_filter = |best_so_far: &mut Option<(u8, Dir::Perm, OutputSet)>, 309 | augmentation: &Augmentation, 310 | ranges: &[[u16; 2]]| 311 | -> bool { 312 | Dir::can_improve(best_so_far.as_ref().map(|x| x.0), augmentation.value_range) 313 | && Dir::test_abstraction_range(ranges, abstraction) 314 | }; 315 | 316 | let mut action = |best_so_far: &mut Option<(u8, Dir::Perm, OutputSet)>, 317 | candidate_abstraction: &[u16], 318 | packed_candidate: &[u8], 319 | value: u8| 320 | -> bool { 321 | if !Dir::does_improve(best_so_far.as_ref().map(|x| x.0), value) { 322 | return true; 323 | } 324 | if !Dir::test_abstraction(candidate_abstraction, abstraction) { 325 | return true; 326 | } 327 | candidate_output_set.unpack_from_slice(packed_candidate); 328 | let perm = if candidate_output_set == output_set { 329 | Dir::id_perm(output_set.channels()) 330 | } else { 331 | let perm = Dir::test_precise( 332 | candidate_output_set.as_ref(), 333 | candidate_abstraction, 334 | output_set, 335 | abstraction, 336 | ); 337 | if let Some(perm) = perm { 338 | perm 339 | } else { 340 | return true; 341 | } 342 | }; 343 | 344 | *best_so_far = Some((value, perm, candidate_output_set.to_owned())); 345 | 346 | true 347 | }; 348 | 349 | for tree in self.trees.iter() { 350 | best_so_far = tree.traverse( 351 | best_so_far, 352 | Dir::lookup_dir(), 353 | &mut node_filter, 354 | &mut action, 355 | ); 356 | } 357 | 358 | for (index, &value) in self.values.iter().enumerate() { 359 | action( 360 | &mut best_so_far, 361 | &self.points[index * self.point_dim..][..self.point_dim], 362 | &self.packed[index * self.packed_dim..][..self.packed_dim], 363 | value, 364 | ); 365 | } 366 | 367 | best_so_far 368 | } 369 | 370 | pub fn lookup_with_abstraction( 371 | &self, 372 | output_set: OutputSet<&[bool]>, 373 | abstraction: &[u16], 374 | ) -> Option { 375 | let mut best_so_far = None; 376 | 377 | let mut bitmap = repeat(false).take(1 << self.channels).collect::>(); 378 | let mut candidate_output_set = OutputSet::from_bitmap(self.channels, &mut bitmap[..]); 379 | 380 | let mut node_filter = |best_so_far: &mut Option, 381 | augmentation: &Augmentation, 382 | ranges: &[[u16; 2]]| 383 | -> bool { 384 | Dir::can_improve(*best_so_far, augmentation.value_range) 385 | && Dir::test_abstraction_range(ranges, abstraction) 386 | }; 387 | 388 | let mut action = |best_so_far: &mut Option, 389 | candidate_abstraction: &[u16], 390 | packed_candidate: &[u8], 391 | value: u8| 392 | -> bool { 393 | if !Dir::does_improve(*best_so_far, value) { 394 | return true; 395 | } 396 | if !Dir::test_abstraction(candidate_abstraction, abstraction) { 397 | return true; 398 | } 399 | candidate_output_set.unpack_from_slice(packed_candidate); 400 | if candidate_output_set != output_set { 401 | if Dir::test_precise( 402 | candidate_output_set.as_ref(), 403 | candidate_abstraction, 404 | output_set, 405 | abstraction, 406 | ) 407 | .is_none() 408 | { 409 | return true; 410 | } 411 | } 412 | 413 | *best_so_far = Some(value); 414 | 415 | true 416 | }; 417 | 418 | for tree in self.trees.iter() { 419 | best_so_far = tree.traverse( 420 | best_so_far, 421 | Dir::lookup_dir(), 422 | &mut node_filter, 423 | &mut action, 424 | ); 425 | } 426 | 427 | for (index, &value) in self.values.iter().enumerate() { 428 | action( 429 | &mut best_so_far, 430 | &self.points[index * self.point_dim..][..self.point_dim], 431 | &self.packed[index * self.packed_dim..][..self.packed_dim], 432 | value, 433 | ); 434 | } 435 | 436 | best_so_far 437 | } 438 | 439 | pub fn insert_new_unchecked_with_abstraction( 440 | &mut self, 441 | output_set: OutputSet<&[bool]>, 442 | abstraction: &[u16], 443 | value: u8, 444 | ) { 445 | let old_size = self.packed.len(); 446 | self.packed.resize(old_size + self.packed_dim, 0); 447 | output_set.pack_into_slice(&mut self.packed[old_size..]); 448 | 449 | self.points.extend_from_slice(abstraction); 450 | self.values.push(value); 451 | 452 | if self.values.len() >= TREE_THRESHOLD { 453 | self.trees.sort_by_key(|tree| Reverse(tree.len())); 454 | 455 | while let Some(tree) = self.trees.pop() { 456 | if tree.len() > self.values.len() { 457 | self.trees.push(tree); 458 | break; 459 | } 460 | tree.traverse( 461 | (), 462 | Dir::lookup_dir(), 463 | |_, _, _| true, 464 | |_, point, packed, value| { 465 | self.points.extend_from_slice(point); 466 | self.packed.extend_from_slice(packed); 467 | self.values.push(value); 468 | true 469 | }, 470 | ); 471 | } 472 | 473 | let tree = Tree::new( 474 | self.point_dim, 475 | replace(&mut self.points, vec![]), 476 | self.packed_dim, 477 | replace(&mut self.packed, vec![]), 478 | replace(&mut self.values, vec![]), 479 | ); 480 | 481 | self.trees.push(tree); 482 | } 483 | } 484 | 485 | pub fn insert_with_abstraction( 486 | &mut self, 487 | output_set: OutputSet<&[bool]>, 488 | abstraction: &[u16], 489 | value: u8, 490 | ) -> u8 { 491 | let best_so_far = self.lookup_with_abstraction(output_set, abstraction); 492 | 493 | let mut bitmap = repeat(false).take(1 << self.channels).collect::>(); 494 | let mut candidate_output_set = OutputSet::from_bitmap(self.channels, &mut bitmap[..]); 495 | 496 | let mut updated_in_place = false; 497 | 498 | if !Dir::does_improve(best_so_far, value) { 499 | return best_so_far.unwrap(); 500 | } 501 | 502 | let mut node_filter = 503 | |_: &mut (), augmentation: &Augmentation, ranges: &[[u16; 2]]| -> bool { 504 | Dir::can_be_updated(augmentation.value_range, value) 505 | && Dir::test_abstraction_range_update(ranges, abstraction) 506 | }; 507 | 508 | let mut action = |_: &mut (), 509 | candidate_abstraction: &[u16], 510 | packed_candidate: &[u8], 511 | candidate_value: &mut u8| 512 | -> TraversalMut { 513 | if !Dir::would_be_updated(*candidate_value, value) { 514 | return TraversalMut::Retain; 515 | } 516 | if !Dir::test_abstraction_update(candidate_abstraction, abstraction) { 517 | return TraversalMut::Retain; 518 | } 519 | candidate_output_set.unpack_from_slice(packed_candidate); 520 | if candidate_output_set == output_set { 521 | assert!(!updated_in_place); 522 | *candidate_value = value; 523 | updated_in_place = true; 524 | return TraversalMut::Retain; 525 | } 526 | if !Dir::test_precise_update( 527 | candidate_output_set.as_ref(), 528 | candidate_abstraction, 529 | output_set, 530 | abstraction, 531 | ) { 532 | return TraversalMut::Retain; 533 | } 534 | 535 | TraversalMut::Remove 536 | }; 537 | 538 | for tree in self.trees.iter_mut() { 539 | tree.traverse_mut((), Dir::lookup_dir(), &mut node_filter, &mut action); 540 | } 541 | 542 | self.trees.retain(|tree| !tree.is_empty()); 543 | 544 | let mut index = 0; 545 | 546 | while index < self.values.len() { 547 | let action_result = action( 548 | &mut (), 549 | &self.points[index * self.point_dim..][..self.point_dim], 550 | &self.packed[index * self.packed_dim..][..self.packed_dim], 551 | &mut self.values[index], 552 | ); 553 | 554 | if action_result == TraversalMut::Remove { 555 | self.values.swap_remove(index); 556 | 557 | if index != self.values.len() { 558 | let (keep, last) = self.points.split_at_mut(self.values.len() * self.point_dim); 559 | keep[index * self.point_dim..][..self.point_dim].copy_from_slice(last); 560 | 561 | let (keep, last) = self 562 | .packed 563 | .split_at_mut(self.values.len() * self.packed_dim); 564 | keep[index * self.packed_dim..][..self.packed_dim].copy_from_slice(last); 565 | } 566 | self.points.truncate(self.values.len() * self.point_dim); 567 | self.packed.truncate(self.values.len() * self.packed_dim); 568 | } else { 569 | index += 1; 570 | } 571 | } 572 | 573 | if updated_in_place { 574 | return value; 575 | } 576 | 577 | self.insert_new_unchecked_with_abstraction(output_set, abstraction, value); 578 | value 579 | } 580 | 581 | pub fn for_each(&self, mut action: impl FnMut(OutputSet<&[bool]>, &[u16], u8)) { 582 | let mut bitmap = repeat(false).take(1 << self.channels).collect::>(); 583 | let mut output_set = OutputSet::from_bitmap(self.channels, &mut bitmap[..]); 584 | 585 | let mut action = |_: &mut (), abstraction: &[u16], packed: &[u8], value: u8| -> bool { 586 | output_set.unpack_from_slice(packed); 587 | 588 | action(output_set.as_ref(), abstraction, value); 589 | true 590 | }; 591 | 592 | for tree in self.trees.iter() { 593 | tree.traverse((), Dir::lookup_dir(), |_, _, _| true, &mut action); 594 | } 595 | 596 | for (index, &value) in self.values.iter().enumerate() { 597 | action( 598 | &mut (), 599 | &self.points[index * self.point_dim..][..self.point_dim], 600 | &self.packed[index * self.packed_dim..][..self.packed_dim], 601 | value, 602 | ); 603 | } 604 | } 605 | 606 | pub fn is_empty(&self) -> bool { 607 | self.values.is_empty() && self.trees.is_empty() 608 | } 609 | 610 | pub fn len(&self) -> usize { 611 | self.values.len() + self.trees.iter().map(|tree| tree.len()).sum::() 612 | } 613 | 614 | pub fn dump_dot(&self, output: &mut impl std::io::Write) -> std::io::Result<()> { 615 | writeln!(output, "digraph {{")?; 616 | for (tree_id, tree) in self.trees.iter().enumerate() { 617 | tree.dump_dot(tree_id, output)?; 618 | } 619 | writeln!(output, "}}")?; 620 | Ok(()) 621 | } 622 | } 623 | -------------------------------------------------------------------------------- /src/output_set/index/tree.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Range; 2 | 3 | const SPLIT_THRESHOLD: usize = 16; 4 | 5 | pub struct Tree { 6 | point_dim: usize, 7 | points: Vec, 8 | packed_dim: usize, 9 | packed: Vec, 10 | values: Vec, 11 | ranges: Vec<[u16; 2]>, 12 | nodes: Vec, 13 | } 14 | 15 | #[derive(Clone)] 16 | struct Node { 17 | augmentation: Augmentation, 18 | parent: usize, 19 | is_right_child: bool, 20 | links: Links, 21 | } 22 | 23 | #[derive(Clone)] 24 | enum Links { 25 | Leaf { values: Range }, 26 | Inner { children: [usize; 2] }, 27 | } 28 | 29 | #[derive(Clone, Eq, PartialEq)] 30 | pub struct Augmentation { 31 | pub size: usize, 32 | pub value_range: [u8; 2], 33 | } 34 | 35 | #[derive(Copy, Clone)] 36 | enum TraversalState { 37 | Done, 38 | Node { id: usize }, 39 | Point { id: usize, index: usize }, 40 | } 41 | 42 | #[allow(dead_code)] 43 | #[derive(Copy, Clone, Eq, PartialEq, Debug)] 44 | pub enum TraversalMut { 45 | Break, 46 | Retain, 47 | Remove, 48 | } 49 | 50 | impl Augmentation { 51 | pub fn new(_point_dim: usize, _points: &[u16], values: &[u8]) -> Self { 52 | let mut min_value = values[0]; 53 | let mut max_value = values[0]; 54 | for &value in &values[1..] { 55 | min_value = min_value.min(value); 56 | max_value = max_value.max(value); 57 | } 58 | Augmentation { 59 | size: values.len(), 60 | value_range: [min_value, max_value], 61 | } 62 | } 63 | 64 | pub fn update_leaf(&mut self, _point_dim: usize, _points: &[u16], values: &[u8]) -> bool { 65 | let mut min_value = values[0]; 66 | let mut max_value = values[0]; 67 | for &value in &values[1..] { 68 | min_value = min_value.min(value); 69 | max_value = max_value.max(value); 70 | } 71 | 72 | let new_augmentation = Augmentation { 73 | size: values.len(), 74 | value_range: [min_value, max_value], 75 | }; 76 | 77 | let updated = new_augmentation != *self; 78 | 79 | *self = new_augmentation; 80 | 81 | updated 82 | } 83 | 84 | pub fn update_inner(&mut self, children: [&Self; 2]) -> bool { 85 | let value_range_min = children[0].value_range[0].min(children[1].value_range[0]); 86 | let value_range_max = children[0].value_range[1].max(children[1].value_range[1]); 87 | let value_range = [value_range_min, value_range_max]; 88 | let size = children[0].size + children[1].size; 89 | 90 | let updated = (self.value_range != value_range) | (self.size != size); 91 | 92 | self.value_range = value_range; 93 | self.size = size; 94 | 95 | updated 96 | } 97 | } 98 | 99 | impl Tree { 100 | pub fn new( 101 | point_dim: usize, 102 | points: Vec, 103 | packed_dim: usize, 104 | packed: Vec, 105 | values: Vec, 106 | ) -> Self { 107 | assert_eq!( 108 | values.len() * point_dim, 109 | points.len(), 110 | "number of points != number of values" 111 | ); 112 | 113 | assert_eq!( 114 | values.len() * packed_dim, 115 | packed.len(), 116 | "number of points != number of packed values" 117 | ); 118 | 119 | let (nodes, mut ranges); 120 | 121 | if values.is_empty() { 122 | nodes = vec![]; 123 | ranges = vec![]; 124 | } else { 125 | nodes = vec![Node { 126 | augmentation: Augmentation::new(point_dim, &points, &values), 127 | links: Links::Leaf { 128 | values: 0..values.len(), 129 | }, 130 | parent: 0, 131 | is_right_child: true, 132 | }]; 133 | 134 | ranges = vec![[0; 2]; point_dim]; 135 | 136 | compute_ranges(&mut ranges, &points); 137 | }; 138 | 139 | let mut result = Self { 140 | point_dim, 141 | points, 142 | packed_dim, 143 | packed, 144 | values, 145 | ranges, 146 | nodes, 147 | }; 148 | 149 | if !result.values.is_empty() { 150 | result.build_tree(0); 151 | } 152 | 153 | result 154 | } 155 | 156 | pub fn is_empty(&self) -> bool { 157 | self.nodes.is_empty() 158 | } 159 | 160 | pub fn len(&self) -> usize { 161 | if self.nodes.is_empty() { 162 | 0 163 | } else { 164 | self.nodes[0].augmentation.size 165 | } 166 | } 167 | 168 | fn build_tree(&mut self, mut node_id: usize) { 169 | loop { 170 | match self.nodes[node_id].links.clone() { 171 | Links::Leaf { values } => { 172 | if values.len() > SPLIT_THRESHOLD { 173 | let ranges = &self.ranges[node_id * self.point_dim..][..self.point_dim]; 174 | 175 | let (split_dim, (split_range, split_low)) = ranges 176 | .iter() 177 | .map(|&[low, high]| (high - low, low)) 178 | .enumerate() 179 | .max_by_key(|&(_index, (range, _low))| range) 180 | .unwrap(); 181 | 182 | let split_at = split_low + (split_range + 1) / 2; 183 | 184 | let child_ranges = self.partition(values.clone(), split_dim, split_at); 185 | 186 | if child_ranges.iter().all(|range| range.start != range.end) { 187 | let mut children = [0; 2]; 188 | 189 | for (child_nr, (child_range, child_id)) in 190 | child_ranges.iter().zip(children.iter_mut()).enumerate() 191 | { 192 | *child_id = 193 | self.new_child(node_id, child_nr > 0, child_range.clone()); 194 | } 195 | 196 | self.nodes[node_id].links = Links::Inner { children }; 197 | 198 | node_id = children[0]; 199 | continue; 200 | } 201 | } 202 | while self.nodes[node_id].is_right_child { 203 | if node_id == 0 { 204 | return; 205 | } 206 | node_id = self.nodes[node_id].parent; 207 | } 208 | let parent_node = &self.nodes[self.nodes[node_id].parent]; 209 | let children = match &parent_node.links { 210 | Links::Inner { children } => children, 211 | _ => unreachable!(), 212 | }; 213 | node_id = children[1]; 214 | } 215 | _ => (), 216 | } 217 | } 218 | } 219 | 220 | fn new_child(&mut self, parent_id: usize, is_right_child: bool, values: Range) -> usize { 221 | let ranges_len = self.ranges.len(); 222 | self.ranges.resize(ranges_len + self.point_dim, [0; 2]); 223 | 224 | compute_ranges( 225 | &mut self.ranges[ranges_len..], 226 | &self.points[values.start * self.point_dim..values.end * self.point_dim], 227 | ); 228 | 229 | let node_id = self.nodes.len(); 230 | 231 | self.nodes.push(Node { 232 | augmentation: Augmentation::new( 233 | self.point_dim, 234 | &self.points[values.start * self.point_dim..values.end * self.point_dim], 235 | &self.values[values.clone()], 236 | ), 237 | links: Links::Leaf { values }, 238 | parent: parent_id, 239 | is_right_child, 240 | }); 241 | 242 | node_id 243 | } 244 | 245 | fn partition(&mut self, values: Range, dim: usize, at: u16) -> [Range; 2] { 246 | if values.len() == 0 { 247 | return [values.clone(), values]; 248 | } 249 | 250 | let mut left = values.start; 251 | let mut right = values.end - 1; 252 | 253 | loop { 254 | while left < values.end && self.points[left * self.point_dim + dim] < at { 255 | left += 1; 256 | } 257 | while right > values.start && self.points[right * self.point_dim + dim] >= at { 258 | right -= 1; 259 | } 260 | if left >= right { 261 | break; 262 | } 263 | self.swap_values(left, right); 264 | } 265 | 266 | [values.start..left, left..values.end] 267 | } 268 | 269 | fn swap_values(&mut self, left: usize, right: usize) { 270 | if left == right { 271 | return; 272 | } 273 | 274 | let left_ptr = self.points[left * self.point_dim..][..self.point_dim].as_mut_ptr(); 275 | let right_ptr = self.points[right * self.point_dim..][..self.point_dim].as_mut_ptr(); 276 | 277 | unsafe { 278 | std::ptr::swap_nonoverlapping(left_ptr, right_ptr, self.point_dim); 279 | } 280 | 281 | let left_ptr = self.packed[left * self.packed_dim..][..self.packed_dim].as_mut_ptr(); 282 | let right_ptr = self.packed[right * self.packed_dim..][..self.packed_dim].as_mut_ptr(); 283 | 284 | unsafe { 285 | std::ptr::swap_nonoverlapping(left_ptr, right_ptr, self.packed_dim); 286 | } 287 | 288 | self.values.swap(left, right); 289 | } 290 | 291 | fn move_node_ranges(&mut self, dst: usize, src: usize) { 292 | if dst == src { 293 | return; 294 | } 295 | 296 | let dst_ptr = self.ranges[dst * self.point_dim..][..self.point_dim].as_mut_ptr(); 297 | let src_ptr = self.ranges[src * self.point_dim..][..self.point_dim].as_ptr(); 298 | 299 | unsafe { 300 | std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, self.point_dim); 301 | } 302 | } 303 | 304 | fn traversal_root(&self) -> TraversalState { 305 | if self.nodes.is_empty() { 306 | TraversalState::Done 307 | } else { 308 | TraversalState::Node { id: 0 } 309 | } 310 | } 311 | 312 | fn traversal_enter(&self, left_to_right: bool, node_id: usize) -> TraversalState { 313 | match &self.nodes[node_id].links { 314 | Links::Leaf { values } => TraversalState::Point { 315 | id: node_id, 316 | index: values.start, 317 | }, 318 | Links::Inner { children } => TraversalState::Node { 319 | id: children[!left_to_right as usize], 320 | }, 321 | } 322 | } 323 | 324 | fn traversal_skip(&self, left_to_right: bool, mut node_id: usize) -> TraversalState { 325 | loop { 326 | if node_id == 0 { 327 | return TraversalState::Done; 328 | } else { 329 | let node = &self.nodes[node_id]; 330 | if node.is_right_child == left_to_right { 331 | node_id = node.parent; 332 | continue; 333 | } 334 | let parent_node = &self.nodes[node.parent]; 335 | let children = match &parent_node.links { 336 | Links::Inner { children } => children, 337 | _ => unreachable!(), 338 | }; 339 | return TraversalState::Node { 340 | id: children[left_to_right as usize], 341 | }; 342 | } 343 | } 344 | } 345 | 346 | fn traversal_next( 347 | &self, 348 | left_to_right: bool, 349 | node_id: usize, 350 | mut index: usize, 351 | ) -> TraversalState { 352 | let node = &self.nodes[node_id]; 353 | let values = match &node.links { 354 | Links::Leaf { values } => values, 355 | _ => unreachable!(), 356 | }; 357 | index += 1; 358 | if index == values.end { 359 | self.traversal_skip(left_to_right, node_id) 360 | } else { 361 | TraversalState::Point { id: node_id, index } 362 | } 363 | } 364 | 365 | fn traversal_remove( 366 | &mut self, 367 | left_to_right: bool, 368 | node_id: usize, 369 | index: usize, 370 | ) -> TraversalState { 371 | let node = &mut self.nodes[node_id]; 372 | let values = match &mut node.links { 373 | Links::Leaf { values } => values, 374 | _ => unreachable!(), 375 | }; 376 | 377 | values.end -= 1; 378 | if values.start == values.end { 379 | self.traversal_unlink(left_to_right, node_id) 380 | } else { 381 | let last_index = values.end; 382 | 383 | self.swap_values(index, last_index); 384 | 385 | self.update_augmentation(node_id); 386 | 387 | if index == last_index { 388 | self.traversal_skip(left_to_right, node_id) 389 | } else { 390 | TraversalState::Point { id: node_id, index } 391 | } 392 | } 393 | } 394 | 395 | fn traversal_unlink(&mut self, left_to_right: bool, node_id: usize) -> TraversalState { 396 | if node_id == 0 { 397 | self.nodes.clear(); 398 | TraversalState::Done 399 | } else { 400 | let node = &self.nodes[node_id]; 401 | let parent = node.parent; 402 | let is_right_child = node.is_right_child; 403 | 404 | let parent_node = &self.nodes[parent]; 405 | let children = match &parent_node.links { 406 | Links::Inner { children } => children, 407 | _ => unreachable!(), 408 | }; 409 | 410 | let sibling = children[!is_right_child as usize]; 411 | 412 | assert_ne!(sibling, node_id); 413 | 414 | let moved_node = self.nodes[sibling].clone(); 415 | 416 | match &moved_node.links { 417 | Links::Inner { children } => { 418 | for &child in children.iter() { 419 | self.nodes[child].parent = parent; 420 | } 421 | } 422 | _ => (), 423 | } 424 | 425 | let parent_node = &mut self.nodes[parent]; 426 | 427 | parent_node.links = moved_node.links; 428 | parent_node.augmentation = moved_node.augmentation; 429 | 430 | let parent_parent = parent_node.parent; 431 | 432 | self.move_node_ranges(parent, sibling); 433 | 434 | if parent != 0 { 435 | self.update_augmentation(parent_parent) 436 | } 437 | 438 | if is_right_child != left_to_right { 439 | TraversalState::Node { id: parent } 440 | } else { 441 | self.traversal_skip(left_to_right, parent) 442 | } 443 | } 444 | } 445 | 446 | fn update_augmentation(&mut self, mut node_id: usize) { 447 | loop { 448 | let node = &mut self.nodes[node_id]; 449 | let parent = node.parent; 450 | let updated = match &node.links { 451 | Links::Inner { children } => { 452 | let children = children.clone(); 453 | 454 | assert_ne!(children[0], children[1]); 455 | assert_ne!(node_id, children[0]); 456 | assert_ne!(node_id, children[1]); 457 | 458 | let target = &mut self.nodes[node_id].augmentation as *mut Augmentation; 459 | let child_0 = &self.nodes[children[0]].augmentation as *const Augmentation; 460 | let child_1 = &self.nodes[children[1]].augmentation as *const Augmentation; 461 | 462 | unsafe { (&mut *target).update_inner([&*child_0, &*child_1]) } 463 | } 464 | Links::Leaf { values } => node.augmentation.update_leaf( 465 | self.point_dim, 466 | &self.points[values.start * self.point_dim..values.end * self.point_dim], 467 | &self.values[values.clone()], 468 | ), 469 | }; 470 | 471 | if updated && node_id != 0 { 472 | node_id = parent; 473 | } else { 474 | break; 475 | } 476 | } 477 | } 478 | 479 | pub fn dump_dot(&self, tree: usize, output: &mut impl std::io::Write) -> std::io::Result<()> { 480 | let mut state = self.traversal_root(); 481 | loop { 482 | match state { 483 | TraversalState::Done => { 484 | return Ok(()); 485 | } 486 | TraversalState::Node { id } => match self.nodes[id].links { 487 | Links::Inner { children } => { 488 | writeln!(output, " n{}_{} -> n{}_{};", tree, id, tree, children[0])?; 489 | writeln!(output, " n{}_{} -> n{}_{};", tree, id, tree, children[1])?; 490 | state = self.traversal_enter(true, id); 491 | } 492 | Links::Leaf { .. } => { 493 | state = self.traversal_skip(true, id); 494 | } 495 | }, 496 | TraversalState::Point { .. } => { 497 | unreachable!(); 498 | } 499 | } 500 | } 501 | } 502 | 503 | pub fn traverse( 504 | &self, 505 | mut user_state: S, 506 | left_to_right: bool, 507 | mut node_filter: N, 508 | mut action: A, 509 | ) -> S 510 | where 511 | N: FnMut(&mut S, &Augmentation, &[[u16; 2]]) -> bool, 512 | A: FnMut(&mut S, &[u16], &[u8], u8) -> bool, 513 | { 514 | let mut state = self.traversal_root(); 515 | 516 | loop { 517 | match state { 518 | TraversalState::Done => return user_state, 519 | TraversalState::Node { id } => { 520 | let enter = node_filter( 521 | &mut user_state, 522 | &self.nodes[id].augmentation, 523 | &self.ranges[id * self.point_dim..][..self.point_dim], 524 | ); 525 | 526 | if enter { 527 | state = self.traversal_enter(left_to_right, id); 528 | } else { 529 | state = self.traversal_skip(left_to_right, id); 530 | } 531 | } 532 | TraversalState::Point { id, index } => { 533 | let will_continue = action( 534 | &mut user_state, 535 | &self.points[index * self.point_dim..][..self.point_dim], 536 | &self.packed[index * self.packed_dim..][..self.packed_dim], 537 | self.values[index], 538 | ); 539 | 540 | if will_continue { 541 | state = self.traversal_next(left_to_right, id, index); 542 | } else { 543 | return user_state; 544 | } 545 | } 546 | } 547 | } 548 | } 549 | 550 | pub fn traverse_mut( 551 | &mut self, 552 | mut user_state: S, 553 | left_to_right: bool, 554 | mut node_filter: N, 555 | mut action: A, 556 | ) -> S 557 | where 558 | N: FnMut(&mut S, &Augmentation, &[[u16; 2]]) -> bool, 559 | A: FnMut(&mut S, &[u16], &[u8], &mut u8) -> TraversalMut, 560 | { 561 | let mut state = self.traversal_root(); 562 | 563 | loop { 564 | match state { 565 | TraversalState::Done => return user_state, 566 | TraversalState::Node { id } => { 567 | let enter = node_filter( 568 | &mut user_state, 569 | &self.nodes[id].augmentation, 570 | &self.ranges[id * self.point_dim..][..self.point_dim], 571 | ); 572 | 573 | if enter { 574 | state = self.traversal_enter(left_to_right, id); 575 | } else { 576 | state = self.traversal_skip(left_to_right, id); 577 | } 578 | } 579 | TraversalState::Point { id, index } => { 580 | let old_value = self.values[index]; 581 | let action_result = action( 582 | &mut user_state, 583 | &self.points[index * self.point_dim..][..self.point_dim], 584 | &self.packed[index * self.packed_dim..][..self.packed_dim], 585 | &mut self.values[index], 586 | ); 587 | 588 | match action_result { 589 | TraversalMut::Break => return user_state, 590 | TraversalMut::Remove => { 591 | state = self.traversal_remove(left_to_right, id, index) 592 | } 593 | TraversalMut::Retain => { 594 | if self.values[index] != old_value { 595 | self.update_augmentation(id); 596 | } 597 | state = self.traversal_next(left_to_right, id, index); 598 | } 599 | } 600 | } 601 | } 602 | } 603 | } 604 | } 605 | 606 | fn compute_ranges(ranges: &mut [[u16; 2]], points: &[u16]) { 607 | for (range, &component) in ranges.iter_mut().zip(points.iter()) { 608 | *range = [component; 2]; 609 | } 610 | 611 | for point in points.chunks(ranges.len()).skip(1) { 612 | for (range, &component) in ranges.iter_mut().zip(point.iter()) { 613 | range[0] = range[0].min(component); 614 | range[1] = range[1].max(component); 615 | } 616 | } 617 | } 618 | -------------------------------------------------------------------------------- /src/output_set.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | cmp::Ordering, 3 | hash::{Hash, Hasher}, 4 | iter::repeat, 5 | mem::replace, 6 | }; 7 | 8 | use arrayvec::ArrayVec; 9 | use rustc_hash::FxHasher; 10 | 11 | const PAIR_ABSTRACTION_GROUPS: usize = 4; 12 | 13 | pub const MAX_CHANNELS: usize = 11; 14 | pub const MAX_BITMAP_SIZE: usize = 1 << MAX_CHANNELS; 15 | pub const MAX_PACKED_SIZE: usize = 1 << (MAX_CHANNELS - 3); 16 | pub const MAX_ABSTRACTION_SIZE: usize = MAX_CHANNELS * (MAX_CHANNELS - 1) * PAIR_ABSTRACTION_GROUPS; 17 | 18 | pub type CVec = ArrayVec<[T; MAX_CHANNELS]>; 19 | pub type WVec = ArrayVec<[T; MAX_CHANNELS + 1]>; 20 | pub type BVec = ArrayVec<[T; MAX_BITMAP_SIZE]>; 21 | pub type PVec = ArrayVec<[T; MAX_PACKED_SIZE]>; 22 | pub type AVec = ArrayVec<[T; 512]>; 23 | 24 | pub mod index; 25 | 26 | mod canon; 27 | mod subsume; 28 | 29 | #[derive(Copy, Clone, Hash, Debug)] 30 | pub struct OutputSet> { 31 | channels: usize, 32 | bitmap: Bitmap, 33 | } 34 | 35 | impl PartialEq> for OutputSet 36 | where 37 | BitmapA: AsRef<[bool]>, 38 | BitmapB: AsRef<[bool]>, 39 | { 40 | fn eq(&self, other: &OutputSet) -> bool { 41 | self.bitmap() == other.bitmap() 42 | } 43 | } 44 | 45 | impl Eq for OutputSet where Bitmap: AsRef<[bool]> {} 46 | 47 | impl PartialOrd> for OutputSet 48 | where 49 | BitmapA: AsRef<[bool]>, 50 | BitmapB: AsRef<[bool]>, 51 | { 52 | fn partial_cmp(&self, other: &OutputSet) -> Option { 53 | Some(self.bitmap().cmp(other.bitmap())) 54 | } 55 | } 56 | 57 | impl Ord for OutputSet 58 | where 59 | Bitmap: AsRef<[bool]>, 60 | { 61 | fn cmp(&self, other: &Self) -> Ordering { 62 | self.partial_cmp(other).unwrap() 63 | } 64 | } 65 | 66 | impl OutputSet { 67 | pub fn all_values(channels: usize) -> Self { 68 | debug_assert!(channels <= MAX_CHANNELS); 69 | Self { 70 | channels, 71 | bitmap: vec![true; 1 << channels], 72 | } 73 | } 74 | 75 | pub fn all_values_bvec(channels: usize) -> OutputSet> { 76 | debug_assert!(channels <= MAX_CHANNELS); 77 | OutputSet { 78 | channels, 79 | bitmap: repeat(true).take(1 << channels).collect(), 80 | } 81 | } 82 | 83 | pub fn packed_len_for_channels(channels: usize) -> usize { 84 | ((1 << channels) + 7) / 8 85 | } 86 | 87 | pub fn abstraction_len_for_channels(channels: usize) -> usize { 88 | channels * (channels - 1) * PAIR_ABSTRACTION_GROUPS 89 | } 90 | 91 | pub fn from_packed(channels: usize, packed: &[u8]) -> Self { 92 | let mut result = Self::all_values(channels); 93 | result.unpack_from_slice(packed); 94 | result 95 | } 96 | } 97 | 98 | impl OutputSet 99 | where 100 | Bitmap: AsRef<[bool]>, 101 | { 102 | pub fn channels(&self) -> usize { 103 | self.channels 104 | } 105 | 106 | pub fn from_bitmap(channels: usize, bitmap: Bitmap) -> Self { 107 | debug_assert!(channels <= MAX_CHANNELS); 108 | debug_assert_eq!(bitmap.as_ref().len(), 1 << channels); 109 | Self { channels, bitmap } 110 | } 111 | 112 | pub fn into_bitmap(self) -> Bitmap { 113 | self.bitmap 114 | } 115 | 116 | pub fn bitmap(&self) -> &[bool] { 117 | self.bitmap.as_ref() 118 | } 119 | 120 | pub fn as_ref(&self) -> OutputSet<&[bool]> { 121 | OutputSet { 122 | channels: self.channels, 123 | bitmap: self.bitmap(), 124 | } 125 | } 126 | 127 | pub fn to_owned(&self) -> OutputSet> { 128 | OutputSet { 129 | channels: self.channels, 130 | bitmap: self.bitmap().to_owned(), 131 | } 132 | } 133 | 134 | pub fn to_owned_bvec(&self) -> OutputSet> { 135 | let mut bitmap = BVec::new(); 136 | bitmap.try_extend_from_slice(self.bitmap()).unwrap(); 137 | OutputSet { 138 | channels: self.channels, 139 | bitmap, 140 | } 141 | } 142 | 143 | pub fn len(&self) -> usize { 144 | self.bitmap().iter().map(|&present| present as usize).sum() 145 | } 146 | 147 | pub fn weight_histogram(&self) -> WVec { 148 | let mut histogram = repeat(0).take(self.channels + 1).collect::>(); 149 | 150 | for (index, &present) in self.bitmap().iter().enumerate() { 151 | histogram[index.count_ones() as usize] += present as usize; 152 | } 153 | 154 | histogram 155 | } 156 | 157 | pub fn is_sorted(&self) -> bool { 158 | if self.channels() < 2 { 159 | return true; 160 | } 161 | 162 | let mut already_present: WVec = 163 | repeat(false).take(self.channels + 1).collect::>(); 164 | 165 | for (index, &present) in self.bitmap().iter().enumerate() { 166 | let weight = index.count_ones() as usize; 167 | if already_present[weight] & present { 168 | return false; 169 | } else { 170 | already_present[weight] |= present; 171 | } 172 | } 173 | 174 | true 175 | } 176 | 177 | pub fn is_channel_unconstrained(&self, channel: usize) -> bool { 178 | let bitmap = self.bitmap(); 179 | let mask = 1 << channel; 180 | 181 | let mut index = mask; 182 | let size = bitmap.len(); 183 | 184 | while index < size { 185 | if bitmap[index] != bitmap[index ^ mask] { 186 | return false; 187 | } 188 | index = (index + 1) | mask; 189 | } 190 | 191 | true 192 | } 193 | 194 | pub fn channel_fingerprint(&self, channel: usize) -> u64 { 195 | let bitmap = self.bitmap(); 196 | let all_mask = bitmap.len() - 1; 197 | let mask = 1 << channel; 198 | let mut fingerprint = 0; 199 | fingerprint += bitmap[mask] as u64; 200 | fingerprint += bitmap[mask ^ all_mask] as u64 * 2; 201 | 202 | let mut index = mask; 203 | let size = bitmap.len(); 204 | 205 | while index < size { 206 | fingerprint += bitmap[index] as u64 * 4; 207 | index = (index + 1) | mask; 208 | } 209 | 210 | fingerprint 211 | } 212 | 213 | pub fn low_channels_fingerprint(&self, low_channels: usize, buffer: &mut Vec) -> u64 { 214 | let weights = self.channels() + 1 - low_channels; 215 | let low_indices = 1 << low_channels; 216 | 217 | let low_mask = low_indices - 1; 218 | 219 | buffer.clear(); 220 | buffer.resize(weights * low_indices, 0); 221 | 222 | for (index, &present) in self.bitmap().iter().enumerate() { 223 | let low_index = index & low_mask; 224 | let high_weight = (index & !low_mask).count_ones() as usize; 225 | buffer[high_weight + weights * low_index] += present as usize; 226 | } 227 | 228 | let mut hasher = FxHasher::default(); 229 | buffer.hash(&mut hasher); 230 | hasher.finish() 231 | } 232 | 233 | pub fn low_channels_channel_fingerprint( 234 | &self, 235 | low_channels: usize, 236 | channel: usize, 237 | buffer: &mut Vec, 238 | ) -> u64 { 239 | debug_assert!(channel >= low_channels); 240 | 241 | let bitmap = self.bitmap(); 242 | let weights = self.channels() - low_channels; 243 | let low_indices = 1 << low_channels; 244 | 245 | let mask = 1 << channel; 246 | 247 | let low_mask = low_indices - 1; 248 | let high_mask = !(low_mask | mask); 249 | 250 | buffer.clear(); 251 | buffer.resize(weights * low_indices, 0); 252 | 253 | let mut index = mask; 254 | let size = bitmap.len(); 255 | 256 | while index < size { 257 | let low_index = index & low_mask; 258 | let high_weight = (index & high_mask).count_ones() as usize; 259 | buffer[high_weight + weights * low_index] += bitmap[index] as usize; 260 | index = (index + 1) | mask; 261 | } 262 | 263 | let mut hasher = FxHasher::default(); 264 | buffer.hash(&mut hasher); 265 | hasher.finish() 266 | } 267 | 268 | pub fn low_channels_channel_abstraction_len(&self, low_channels: usize) -> usize { 269 | let weights = self.channels() - low_channels; 270 | let low_indices = 1 << low_channels; 271 | weights * low_indices * 3 272 | } 273 | 274 | pub fn low_channels_channel_abstraction( 275 | &self, 276 | low_channels: usize, 277 | channel: usize, 278 | buffer: &mut [usize], 279 | ) { 280 | debug_assert!(channel >= low_channels); 281 | 282 | let bitmap = self.bitmap(); 283 | 284 | let weights = self.channels() - low_channels; 285 | let low_indices = 1 << low_channels; 286 | 287 | let mask = 1 << channel; 288 | 289 | buffer.iter_mut().for_each(|value| *value = 0); 290 | 291 | let low_mask = low_indices - 1; 292 | let high_mask = !(low_mask | mask); 293 | 294 | let mut index = mask; 295 | let size = bitmap.len(); 296 | 297 | while index < size { 298 | let low_index = index & low_mask; 299 | let high_weight = (index & high_mask).count_ones() as usize; 300 | let value_hi = bitmap[index]; 301 | let value_lo = bitmap[index ^ mask]; 302 | buffer[0 + 3 * (high_weight + weights * low_index)] += value_lo as usize; 303 | buffer[1 + 3 * (high_weight + weights * low_index)] += value_hi as usize; 304 | buffer[2 + 3 * (high_weight + weights * low_index)] += (value_lo & value_hi) as usize; 305 | index = (index + 1) | mask; 306 | } 307 | } 308 | 309 | pub fn channel_abstraction(&self, channel: usize) -> [CVec; 3] { 310 | let bitmap = self.bitmap(); 311 | 312 | let weights = self.channels(); 313 | let mut abstraction = [ 314 | repeat(0).take(weights).collect::>(), 315 | repeat(0).take(weights).collect::>(), 316 | repeat(0).take(weights).collect::>(), 317 | ]; 318 | 319 | let mask = 1 << channel; 320 | 321 | let mut index = mask; 322 | let size = bitmap.len(); 323 | 324 | while index < size { 325 | let weight = (index & !mask).count_ones() as usize; 326 | let value_hi = bitmap[index]; 327 | let value_lo = bitmap[index ^ mask]; 328 | abstraction[0][weight] += value_lo as usize; 329 | abstraction[1][weight] += value_hi as usize; 330 | abstraction[2][weight] += (value_lo & value_hi) as usize; 331 | index = (index + 1) | mask; 332 | } 333 | 334 | abstraction 335 | } 336 | 337 | pub fn channel_pair_abstraction( 338 | &self, 339 | channel_pair: [usize; 2], 340 | ) -> [usize; PAIR_ABSTRACTION_GROUPS] { 341 | let bitmap = self.bitmap(); 342 | 343 | let mut abstraction = [0; PAIR_ABSTRACTION_GROUPS]; 344 | 345 | let mask_0 = 1 << channel_pair[0]; 346 | let mask_1 = 1 << channel_pair[1]; 347 | 348 | let mask = mask_0 | mask_1; 349 | 350 | let mut index = mask; 351 | let size = bitmap.len(); 352 | 353 | while index < size { 354 | let value_0 = bitmap[index]; 355 | let value_1 = bitmap[index ^ mask_0]; 356 | let value_2 = bitmap[index ^ mask_1]; 357 | let value_3 = bitmap[index ^ mask]; 358 | 359 | abstraction[0] += value_0 as usize; 360 | abstraction[1] += value_1 as usize; 361 | abstraction[2] += value_2 as usize; 362 | abstraction[3] += value_3 as usize; 363 | 364 | index = (index + 1) | mask; 365 | } 366 | 367 | abstraction 368 | } 369 | 370 | pub fn subsumes_unpermuted(&self, other: OutputSet>) -> bool { 371 | self.bitmap() 372 | .iter() 373 | .zip(other.bitmap().iter()) 374 | .all(|(&my_value, &other_value)| !my_value | other_value) 375 | } 376 | 377 | pub fn subsumes_permuted(&self, other: OutputSet>) -> Option> { 378 | subsume::Subsume::new([self.to_owned(), other.to_owned()]).search() 379 | } 380 | 381 | pub fn packed_len(&self) -> usize { 382 | (self.bitmap().len() + 7) / 8 383 | } 384 | 385 | pub fn pack_into_slice(&self, slice: &mut [u8]) { 386 | let bitmap = self.bitmap(); 387 | 388 | let mut byte_chunks = bitmap.chunks_exact(8); 389 | let mut target_bytes = slice.iter_mut(); 390 | 391 | for (byte_chunk, target_byte) in (&mut byte_chunks).zip(&mut target_bytes) { 392 | unsafe { 393 | *target_byte = (*byte_chunk.get_unchecked(0) as u8) 394 | | ((*byte_chunk.get_unchecked(1) as u8) << 1) 395 | | ((*byte_chunk.get_unchecked(2) as u8) << 2) 396 | | ((*byte_chunk.get_unchecked(3) as u8) << 3) 397 | | ((*byte_chunk.get_unchecked(4) as u8) << 4) 398 | | ((*byte_chunk.get_unchecked(5) as u8) << 5) 399 | | ((*byte_chunk.get_unchecked(6) as u8) << 6) 400 | | ((*byte_chunk.get_unchecked(7) as u8) << 7); 401 | } 402 | } 403 | 404 | let remainder = byte_chunks.remainder(); 405 | if !remainder.is_empty() { 406 | let target_byte = target_bytes.next().unwrap(); 407 | *target_byte = (remainder.get(0).cloned().unwrap_or(false) as u8) 408 | | ((remainder.get(1).cloned().unwrap_or(false) as u8) << 1) 409 | | ((remainder.get(2).cloned().unwrap_or(false) as u8) << 2) 410 | | ((remainder.get(3).cloned().unwrap_or(false) as u8) << 3) 411 | | ((remainder.get(4).cloned().unwrap_or(false) as u8) << 4) 412 | | ((remainder.get(5).cloned().unwrap_or(false) as u8) << 5) 413 | | ((remainder.get(6).cloned().unwrap_or(false) as u8) << 6) 414 | | ((remainder.get(7).cloned().unwrap_or(false) as u8) << 7); 415 | } 416 | } 417 | 418 | pub fn packed_pvec(&self) -> PVec { 419 | let mut result = repeat(0).take(self.packed_len()).collect::>(); 420 | 421 | self.pack_into_slice(&mut result[..]); 422 | 423 | result 424 | } 425 | 426 | pub fn packed(&self) -> Vec { 427 | let mut result = repeat(0).take(self.packed_len()).collect::>(); 428 | 429 | self.pack_into_slice(&mut result[..]); 430 | 431 | result 432 | } 433 | 434 | pub fn abstraction_len(&self) -> usize { 435 | OutputSet::abstraction_len_for_channels(self.channels()) 436 | } 437 | 438 | pub fn write_abstraction_into(&self, abstraction: &mut [u16]) { 439 | assert_eq!(abstraction.len(), self.abstraction_len()); 440 | 441 | for channel in 0..self.channels() { 442 | let mut groups = <[CVec; PAIR_ABSTRACTION_GROUPS]>::default(); 443 | for other_channel in 0..self.channels() { 444 | if other_channel == channel { 445 | continue; 446 | } 447 | for (&abstraction_value, group_values) in self 448 | .channel_pair_abstraction([channel, other_channel]) 449 | .iter() 450 | .zip(groups.iter_mut()) 451 | { 452 | group_values.push(abstraction_value); 453 | } 454 | } 455 | for group_values in groups.iter_mut() { 456 | group_values.sort_unstable_by(|a, b| b.cmp(a)); 457 | } 458 | for (group, group_values) in groups.iter().enumerate() { 459 | for (index, &abstraction_value) in group_values.iter().enumerate() { 460 | abstraction 461 | [channel + self.channels() * (group + PAIR_ABSTRACTION_GROUPS * index)] = 462 | abstraction_value as u16; 463 | } 464 | } 465 | } 466 | 467 | for chunk in abstraction.chunks_mut(self.channels()) { 468 | chunk.sort_unstable(); 469 | } 470 | 471 | for chunk in abstraction.chunks_mut(self.channels() * 4) { 472 | let mut tmp = ArrayVec::<[u16; 64]>::new(); 473 | tmp.try_extend_from_slice(chunk).unwrap(); 474 | 475 | for i in 0..4 { 476 | for j in 0..self.channels() { 477 | chunk[i + j * 4] = tmp[j + i * self.channels()]; 478 | } 479 | } 480 | } 481 | } 482 | 483 | pub fn abstraction(&self) -> AVec { 484 | let mut result = repeat(0).take(self.abstraction_len()).collect::>(); 485 | 486 | self.write_abstraction_into(&mut result[..]); 487 | 488 | result 489 | } 490 | 491 | pub fn channel_is_extremal(&self, polarity: bool, channel: usize) -> bool { 492 | let bitmap = self.bitmap(); 493 | let all_mask = bitmap.len() - 1; 494 | let mask = 1 << channel; 495 | let index = mask ^ (all_mask * polarity as usize); 496 | 497 | bitmap[index] 498 | } 499 | 500 | pub fn prune_extremal_channel_into( 501 | &self, 502 | polarity: bool, 503 | channel: usize, 504 | mut target: OutputSet<&mut [bool]>, 505 | ) { 506 | assert_eq!(target.channels() + 1, self.channels()); 507 | 508 | let target_bitmap = target.bitmap_mut(); 509 | 510 | for value in target_bitmap.iter_mut() { 511 | *value = false; 512 | } 513 | 514 | let mut queue = BVec::::new(); 515 | 516 | let bitmap = self.bitmap(); 517 | let all_mask = bitmap.len() - 1; 518 | let mask = 1 << channel; 519 | let flip = all_mask * polarity as usize; 520 | let index = mask ^ flip; 521 | 522 | assert!(bitmap[index]); 523 | 524 | let new_all_mask = all_mask >> 1; 525 | 526 | let mask_low = mask - 1; 527 | let mask_high = new_all_mask ^ mask_low; 528 | 529 | let new_flip = new_all_mask * polarity as usize; 530 | 531 | queue.push(mask as u16); 532 | target_bitmap[new_flip] = true; 533 | 534 | while let Some(index) = queue.pop() { 535 | let index = index as usize; 536 | let mut bit = 1; 537 | for _ in 0..self.channels() { 538 | let next = index | bit; 539 | 540 | let next_target = ((next & mask_low) | ((next >> 1) & mask_high)) ^ new_flip; 541 | 542 | if bitmap[next ^ flip] & !target_bitmap[next_target] { 543 | target_bitmap[next_target] = true; 544 | queue.push(next as u16); 545 | } 546 | 547 | bit <<= 1; 548 | } 549 | } 550 | } 551 | } 552 | 553 | impl OutputSet 554 | where 555 | Bitmap: AsMut<[bool]> + AsRef<[bool]>, 556 | { 557 | pub fn as_mut(&mut self) -> OutputSet<&mut [bool]> { 558 | OutputSet { 559 | channels: self.channels, 560 | bitmap: self.bitmap_mut(), 561 | } 562 | } 563 | 564 | pub fn bitmap_mut(&mut self) -> &mut [bool] { 565 | self.bitmap.as_mut() 566 | } 567 | 568 | pub fn apply_comparator(&mut self, channels: [usize; 2]) -> bool { 569 | debug_assert_ne!(channels[0], channels[1]); 570 | debug_assert!(channels[0] < self.channels); 571 | debug_assert!(channels[1] < self.channels); 572 | 573 | let mask_0 = 1 << channels[0]; 574 | let mask_1 = 1 << channels[1]; 575 | 576 | let comparator_mask = mask_0 | mask_1; 577 | 578 | let mut index = comparator_mask; 579 | 580 | let bitmap = self.bitmap_mut(); 581 | 582 | let size = bitmap.len(); 583 | 584 | let mut out_of_order_present = false; 585 | let mut in_order_present = false; 586 | 587 | while index < size { 588 | let out_of_order = replace(&mut bitmap[index ^ mask_0], false); 589 | out_of_order_present |= out_of_order; 590 | let in_order = &mut bitmap[index ^ mask_1]; 591 | in_order_present |= *in_order; 592 | *in_order |= out_of_order; 593 | index = (index + 1) | comparator_mask; 594 | } 595 | 596 | out_of_order_present & in_order_present 597 | } 598 | 599 | pub fn swap_channels(&mut self, channels: [usize; 2]) { 600 | debug_assert!(channels[0] < self.channels); 601 | debug_assert!(channels[1] < self.channels); 602 | if channels[0] == channels[1] { 603 | return; 604 | } 605 | 606 | let mask_0 = 1 << channels[0]; 607 | let mask_1 = 1 << channels[1]; 608 | 609 | let comparator_mask = mask_0 | mask_1; 610 | 611 | let mut index = comparator_mask; 612 | 613 | let bitmap = self.bitmap_mut(); 614 | 615 | let size = bitmap.len(); 616 | 617 | while index < size { 618 | bitmap.swap(index ^ mask_0, index ^ mask_1); 619 | index = (index + 1) | comparator_mask; 620 | } 621 | } 622 | 623 | pub fn invert(&mut self) { 624 | self.bitmap_mut().reverse() 625 | } 626 | 627 | pub fn canonicalize(&mut self, inversion: bool) -> Perm { 628 | if self.channels() == 0 { 629 | return Perm::identity(0); 630 | } 631 | 632 | let mut canonicalize = canon::Canonicalize::new(self.as_mut(), inversion); 633 | 634 | let (result, perm) = canonicalize.canonicalize(); 635 | 636 | self.bitmap_mut().copy_from_slice(result.bitmap()); 637 | 638 | perm 639 | } 640 | 641 | pub fn unpack_from_slice(&mut self, slice: &[u8]) { 642 | let bitmap = self.bitmap_mut(); 643 | 644 | let mut byte_chunks = bitmap.chunks_exact_mut(8); 645 | let mut source_bytes = slice.iter(); 646 | 647 | for (byte_chunk, source_byte) in (&mut byte_chunks).zip(&mut source_bytes) { 648 | unsafe { 649 | *byte_chunk.get_unchecked_mut(0) = source_byte & (1 << 0) != 0; 650 | *byte_chunk.get_unchecked_mut(1) = source_byte & (1 << 1) != 0; 651 | *byte_chunk.get_unchecked_mut(2) = source_byte & (1 << 2) != 0; 652 | *byte_chunk.get_unchecked_mut(3) = source_byte & (1 << 3) != 0; 653 | *byte_chunk.get_unchecked_mut(4) = source_byte & (1 << 4) != 0; 654 | *byte_chunk.get_unchecked_mut(5) = source_byte & (1 << 5) != 0; 655 | *byte_chunk.get_unchecked_mut(6) = source_byte & (1 << 6) != 0; 656 | *byte_chunk.get_unchecked_mut(7) = source_byte & (1 << 7) != 0; 657 | } 658 | } 659 | 660 | let remainder = byte_chunks.into_remainder(); 661 | if !remainder.is_empty() { 662 | let source_byte = source_bytes.next().unwrap(); 663 | for (i, target_bit) in remainder.iter_mut().enumerate() { 664 | *target_bit = source_byte & (1 << i) != 0 665 | } 666 | } 667 | } 668 | 669 | pub fn copy_from(&mut self, other: OutputSet>) { 670 | self.bitmap_mut().copy_from_slice(other.bitmap()) 671 | } 672 | } 673 | 674 | #[derive(Clone, Eq, PartialEq, Debug)] 675 | pub struct Perm { 676 | pub invert: bool, 677 | pub perm: CVec, 678 | } 679 | 680 | impl Perm { 681 | fn identity(channels: usize) -> Self { 682 | Self { 683 | invert: false, 684 | perm: (0..channels).collect(), 685 | } 686 | } 687 | } 688 | 689 | #[cfg(test)] 690 | mod test { 691 | use super::*; 692 | 693 | #[rustfmt::skip] 694 | static SORT_11: &[[usize; 2]] = &[ 695 | [0, 9], [1, 6], [2, 4], [3, 7], [5, 8], 696 | [0, 1], [3, 5], [4, 10], [6, 9], [7, 8], 697 | [1, 3], [2, 5], [4, 7], [8, 10], 698 | [0, 4], [1, 2], [3, 7], [5, 9], [6, 8], 699 | [0, 1], [2, 6], [4, 5], [7, 8], [9, 10], 700 | [2, 4], [3, 6], [5, 7], [8, 9], 701 | [1, 2], [3, 4], [5, 6], [7, 8], 702 | [2, 3], [4, 5], [6, 7], 703 | ]; 704 | 705 | #[test] 706 | fn sort_11_sorts() { 707 | crate::logging::setup(false); 708 | 709 | let mut output_set = OutputSet::all_values(11); 710 | 711 | for (i, &comparator) in SORT_11.iter().enumerate() { 712 | assert!(!output_set.is_sorted()); 713 | output_set.apply_comparator(comparator); 714 | log::info!( 715 | "step {}: histogram = {:?}", 716 | i, 717 | output_set.weight_histogram() 718 | ); 719 | } 720 | 721 | assert!(output_set.is_sorted()); 722 | } 723 | 724 | #[test] 725 | fn sort_11_pack_unpack() { 726 | crate::logging::setup(false); 727 | 728 | let mut output_set = OutputSet::all_values(11); 729 | 730 | for (i, &comparator) in SORT_11.iter().enumerate() { 731 | assert!(!output_set.is_sorted()); 732 | output_set.apply_comparator(comparator); 733 | let unpacked = OutputSet::from_packed(11, &output_set.packed()); 734 | assert_eq!(output_set, unpacked); 735 | 736 | log::info!( 737 | "step {}: histogram = {:?}", 738 | i, 739 | output_set.weight_histogram() 740 | ); 741 | } 742 | 743 | assert!(output_set.is_sorted()); 744 | } 745 | 746 | #[test] 747 | fn sort_canonicalize() { 748 | crate::logging::setup(false); 749 | 750 | for &limit in &[1, 4, 8, 16, 30] { 751 | let mut output_set = OutputSet::all_values(11); 752 | for (i, &comparator) in SORT_11.iter().enumerate() { 753 | assert!(!output_set.is_sorted()); 754 | output_set.apply_comparator(comparator); 755 | let mut canonical = output_set.clone(); 756 | canonical.canonicalize(true); 757 | 758 | let mut canonical_2 = output_set.clone(); 759 | 760 | if i & 1 != 0 { 761 | canonical_2.invert(); 762 | } 763 | 764 | for &pair in SORT_11[..limit].iter() { 765 | canonical_2.swap_channels(pair); 766 | } 767 | canonical_2.canonicalize(true); 768 | 769 | assert_eq!(canonical, canonical_2); 770 | 771 | log::info!( 772 | "step {}: histogram = {:?}", 773 | i, 774 | output_set.weight_histogram(), 775 | ); 776 | } 777 | 778 | assert!(output_set.is_sorted()); 779 | } 780 | } 781 | 782 | #[test] 783 | fn sort_subsume_permuted() { 784 | crate::logging::setup(false); 785 | 786 | for &limit in &[1, 4, 8, 16, 30] { 787 | let mut output_set = OutputSet::all_values(11); 788 | for (i, &comparator) in SORT_11.iter().enumerate() { 789 | assert!(!output_set.is_sorted()); 790 | let previous_output_set = output_set.clone(); 791 | output_set.apply_comparator(comparator); 792 | 793 | let mut permuted = output_set.clone(); 794 | 795 | for &pair in SORT_11[..limit].iter() { 796 | permuted.swap_channels(pair); 797 | } 798 | 799 | assert!(permuted.subsumes_permuted(output_set.as_ref()).is_some()); 800 | assert!(output_set.subsumes_permuted(permuted.as_ref()).is_some()); 801 | assert!(previous_output_set 802 | .subsumes_permuted(permuted.as_ref()) 803 | .is_none()); 804 | let strict_progress = permuted.subsumes_permuted(previous_output_set.as_ref()); 805 | 806 | log::info!( 807 | "step {}: histogram = {:?} strict progress = {:?}", 808 | i, 809 | output_set.weight_histogram(), 810 | strict_progress 811 | ); 812 | } 813 | assert!(output_set.is_sorted()); 814 | } 815 | } 816 | 817 | #[test] 818 | fn prune_extremal_channel() { 819 | let mut output_set_large = OutputSet::all_values(11); 820 | let mut output_set_small = OutputSet::all_values(10); 821 | let mut output_set_ref = OutputSet::all_values(10); 822 | 823 | output_set_large.prune_extremal_channel_into(false, 0, output_set_small.as_mut()); 824 | assert_eq!(output_set_small, output_set_ref); 825 | 826 | output_set_large.apply_comparator([0, 1]); 827 | output_set_large.prune_extremal_channel_into(false, 0, output_set_small.as_mut()); 828 | assert_eq!(output_set_small, output_set_ref); 829 | 830 | output_set_large.apply_comparator([2, 3]); 831 | output_set_ref.apply_comparator([1, 2]); 832 | output_set_large.prune_extremal_channel_into(false, 0, output_set_small.as_mut()); 833 | assert_eq!(output_set_small, output_set_ref); 834 | 835 | let mut output_set_large = OutputSet::all_values(11); 836 | let mut output_set_small = OutputSet::all_values(10); 837 | let mut output_set_ref = OutputSet::all_values(10); 838 | 839 | output_set_large.prune_extremal_channel_into(true, 1, output_set_small.as_mut()); 840 | assert_eq!(output_set_small, output_set_ref); 841 | 842 | output_set_large.apply_comparator([0, 1]); 843 | output_set_large.prune_extremal_channel_into(true, 1, output_set_small.as_mut()); 844 | assert_eq!(output_set_small, output_set_ref); 845 | 846 | output_set_large.apply_comparator([2, 3]); 847 | output_set_ref.apply_comparator([1, 2]); 848 | output_set_large.prune_extremal_channel_into(true, 1, output_set_small.as_mut()); 849 | assert_eq!(output_set_small, output_set_ref); 850 | } 851 | } 852 | --------------------------------------------------------------------------------