├── .gitignore ├── .merlin ├── .travis.yml ├── CHANGES.md ├── LICENSE.md ├── Makefile ├── README.md ├── bench ├── bench.ml └── dune ├── diet.opam ├── dune-project ├── fuzz ├── dune └── fuzz.ml ├── lib ├── diet.ml ├── diet.mli └── dune └── lib_test ├── dune └── test.ml /.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | *.install 3 | .merlin 4 | -------------------------------------------------------------------------------- /.merlin: -------------------------------------------------------------------------------- 1 | PKG astring cmdliner cstruct logs lwt mirage-block mirage-block-unix ppx_sexp_conv ppx_tools ppx_type_conv 2 | PKG io-page io-page.unix logs.fmt result sexplib 3 | PKG ezjsonm mirage-block-ramdisk nbd ounit 4 | S lib 5 | S lib_test 6 | B _build/** 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: c 2 | install: 3 | - wget https://raw.githubusercontent.com/ocaml/ocaml-travisci-skeleton/master/.travis-opam.sh 4 | - wget https://raw.githubusercontent.com/simonjbeaumont/ocaml-travis-coveralls/master/travis-coveralls.sh 5 | script: bash -ex .travis-opam.sh 6 | sudo: required 7 | dist: trusty 8 | env: 9 | global: 10 | - PACKAGE="diet" OCAML_VERSION=4.06 11 | - PACKAGE="diet" OCAML_VERSION=4.07 12 | - PACKAGE="diet" OCAML_VERSION=4.08 13 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | ## v0.4 (2019-07-16) 2 | 3 | - support OCaml 4.08 deprecations by using Stdlib (@avsm) 4 | 5 | ## v0.3 (2019-03-07) 6 | - switch to `dune-release` (@avsm) 7 | - run the tests as well as build them (@avsm) 8 | - update metadata to opam 2.0 format (@avsm) 9 | 10 | ## 0.2 (2018-10-04): 11 | - Add an `iter` function (#5 from @g2p) 12 | - Build via `dune` (was `jbuilder`) 13 | 14 | ## 0.1 (2018-06-07): 15 | - Initial split of the DIET code from mirage/ocaml-qcow 16 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | (* 2 | * Copyright (c) 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | * 16 | *) 17 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | .PHONY: build clean test fuzz doc 3 | 4 | build: 5 | dune build 6 | 7 | test: 8 | dune runtest --force 9 | 10 | fuzz: 11 | dune build fuzz/fuzz.exe 12 | mkdir -p test/input 13 | echo abcd > test/input/case 14 | afl-fuzz -i test/input -o output ./_build/default/fuzz/fuzz.exe @@ 15 | 16 | doc: 17 | dune build @doc 18 | open _build/default/_doc/_html/diet/Diet/module-type-INTERVAL_SET/index.html || echo 'Try pointing your browser at _build/default/_doc/_html/index.html' 19 | 20 | clean: 21 | dune clean 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Ocaml Discrete Interval Encoding Trees 2 | ====================================== 3 | 4 | [![Build Status](https://travis-ci.org/mirage/ocaml-diet.svg?branch=master)](https://travis-ci.org/mirage/ocaml-diet) [![Coverage Status](https://coveralls.io/repos/github/mirage/ocaml-diet/badge.svg?branch=master)](https://coveralls.io/github/mirage/ocaml-diet?branch=master) 5 | 6 | Please read [the API documentation](https://mirage.github.io/ocaml-diet/). 7 | 8 | This is based on the 9 | [Functional Pearls: Diets for Fat Sets](https://web.engr.oregonstate.edu/~erwig/papers/Diet_JFP98.pdf) 10 | by Martin Erwig. 11 | -------------------------------------------------------------------------------- /bench/bench.ml: -------------------------------------------------------------------------------- 1 | open Core 2 | open Core_bench.Std 3 | 4 | module IntDiet = Diet.Make(struct 5 | type t = int 6 | let compare (x: t) (y: t) = Pervasives.compare x y 7 | let zero = 0 8 | let succ x = x + 1 9 | let pred x = x - 1 10 | let add x y = x + y 11 | let sub x y = x - y 12 | let to_string = string_of_int 13 | end) 14 | 15 | let state = Random.State.make_self_init () 16 | 17 | let fisher_yates_shuffle a = 18 | for i = Array.length a-1 downto 1 do 19 | let j = Random.State.int state (i + 1) in 20 | let tmp = a.(i) in 21 | a.(i) <- a.(j); 22 | a.(j) <- tmp; 23 | done 24 | 25 | let gen_array size = 26 | let gen_interval i = 27 | let length = Random.State.int state 8 in 28 | IntDiet.Interval.make (10 * i) (10 * i + length) 29 | in 30 | Array.init size ~f:gen_interval 31 | 32 | let diet_from_array arr = 33 | Array.fold 34 | ~init:IntDiet.empty 35 | ~f:(fun diet intvl -> IntDiet.add intvl diet) 36 | arr 37 | 38 | let gen_equal_diets n = 39 | let intervals = gen_array n in 40 | let regular = diet_from_array intervals in 41 | fisher_yates_shuffle intervals; 42 | let shuffled = diet_from_array intervals in 43 | regular, shuffled 44 | 45 | let gen_non_equal_diets n = 46 | let one = diet_from_array @@ gen_array n in 47 | let other = diet_from_array @@ gen_array n in 48 | one, other 49 | 50 | let create_indexed_with_initialization ~name ~args f = 51 | Bench.Test.create_group ~name @@ 52 | List.map args 53 | ~f: 54 | (fun size -> 55 | let name = Printf.sprintf "size %d" size in 56 | Bench.Test.create_with_initialization ~name (f size) 57 | ) 58 | 59 | let () = 60 | Command.run 61 | (Bench.make_command 62 | [ create_indexed_with_initialization ~name:"Equal" ~args:[10; 100; 1000] 63 | (fun size `init -> 64 | let d, d' = gen_equal_diets size in 65 | (fun () -> IntDiet.equal d d')) 66 | ; create_indexed_with_initialization ~name:"Not equal" ~args:[10; 100; 1000] 67 | (fun size `init -> 68 | let d, d' = gen_non_equal_diets size in 69 | (fun () -> ignore @@ IntDiet.equal d d')) 70 | ]) 71 | -------------------------------------------------------------------------------- /bench/dune: -------------------------------------------------------------------------------- 1 | (executable 2 | (name bench) 3 | (libraries 4 | core 5 | core_bench 6 | diet 7 | ) 8 | ) 9 | 10 | (alias 11 | (name bench) 12 | (action (run ./bench.exe)) 13 | ) 14 | -------------------------------------------------------------------------------- /diet.opam: -------------------------------------------------------------------------------- 1 | opam-version: "2.0" 2 | maintainer: "dave@recoil.org" 3 | authors: "David Scott" 4 | license: "ISC" 5 | homepage: "https://github.com/mirage/ocaml-diet" 6 | doc: "https://mirage.github.io/ocaml-diet/" 7 | bug-reports: "https://github.com/mirage/ocaml-diet/issues" 8 | depends: [ 9 | "ocaml" {>= "4.03.0"} 10 | "dune" 11 | "stdlib-shims" 12 | "ounit" {with-test} 13 | ] 14 | build: [ 15 | ["dune" "subst"] {pinned} 16 | ["dune" "build" "-p" name "-j" jobs] 17 | ["dune" "runtest" "-p" name "-j" jobs] {with-test} 18 | ] 19 | dev-repo: "git+https://github.com/mirage/ocaml-diet.git" 20 | synopsis: "Discrete Interval Encoding Trees" 21 | description: """ 22 | This data structure is based on the 23 | [Functional Pearls: Diets for Fat Sets](https://web.engr.oregonstate.edu/~erwig/papers/Diet_JFP98.pdf) 24 | by Martin Erwig.""" 25 | -------------------------------------------------------------------------------- /dune-project: -------------------------------------------------------------------------------- 1 | (lang dune 1.0) 2 | (name diet) 3 | -------------------------------------------------------------------------------- /fuzz/dune: -------------------------------------------------------------------------------- 1 | (executable 2 | (name fuzz) 3 | (libraries diet crowbar)) 4 | -------------------------------------------------------------------------------- /fuzz/fuzz.ml: -------------------------------------------------------------------------------- 1 | (* 2 | * Copyright (C) 2018 Docker Inc 3 | * 4 | * Permission to use, copy, modify, and/or distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH 9 | * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY 10 | * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, 11 | * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM 12 | * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR 13 | * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR 14 | * PERFORMANCE OF THIS SOFTWARE. 15 | *) 16 | 17 | open Crowbar 18 | 19 | module Int = struct 20 | type t = int 21 | let compare (x: t) (y: t) = Pervasives.compare x y 22 | let zero = 0 23 | let succ x = x + 1 24 | let pred x = x - 1 25 | let add x y = x + y 26 | let sub x y = x - y 27 | let to_string = string_of_int 28 | end 29 | module IntDiet = Diet.Make(Int) 30 | 31 | (* Avoid max_int because the library needs to call succ, and we don't want it 32 | to wrap. *) 33 | let positive_int = range (max_int - 1) 34 | 35 | let interval = map [ positive_int; positive_int ] (fun a b -> 36 | if a <= b then IntDiet.Interval.make a b else IntDiet.Interval.make b a 37 | ) 38 | 39 | let diet = fix (fun diet -> 40 | choose [ 41 | const IntDiet.empty; 42 | map [ interval; diet ] (fun interval diet -> IntDiet.add interval diet); 43 | ] 44 | ) 45 | 46 | let pp_diet = IntDiet.pp 47 | let diet = with_printer pp_diet diet 48 | 49 | (* FIXME: add equals / compare to the diet signature *) 50 | let eq a b = 51 | let intervals t = IntDiet.fold (fun x acc -> x :: acc) t [] |> List.rev in 52 | intervals a = (intervals b) 53 | 54 | let shuffle_a st a = 55 | for i = Array.length a-1 downto 1 do 56 | let j = Random.State.int st (i+1) in 57 | let tmp = a.(i) in 58 | a.(i) <- a.(j); 59 | a.(j) <- tmp; 60 | done 61 | 62 | let check_equality interval_list rng_state = 63 | let state = Random.State.make (Array.of_list rng_state) in 64 | let interval_array = Array.of_list interval_list in 65 | let diet_of_array array = 66 | Array.fold_left (fun diet interval -> IntDiet.add interval diet) IntDiet.empty array 67 | in 68 | let diet1 = diet_of_array interval_array in 69 | shuffle_a state interval_array; 70 | let diet2 = diet_of_array interval_array in 71 | check (IntDiet.equal diet2 diet1) 72 | 73 | let () = 74 | add_test ~name:"union is commutative" [diet; diet] 75 | (fun d1 d2 -> 76 | check_eq ~pp:pp_diet ~eq IntDiet.(union d1 d2) IntDiet.(union d2 d1)); 77 | add_test ~name:"union is associative" [diet; diet; diet] 78 | (fun d1 d2 d3 -> 79 | check_eq ~pp:pp_diet ~eq IntDiet.(union d1 (union d2 d3)) IntDiet.(union (union d1 d2) d3)); 80 | add_test ~name:"intersection is commutative" [diet; diet] 81 | (fun d1 d2 -> 82 | check_eq ~pp:pp_diet ~eq IntDiet.(inter d1 d2) IntDiet.(inter d2 d1)); 83 | add_test ~name:"intersection is associative" [diet; diet; diet] 84 | (fun d1 d2 d3 -> 85 | check_eq ~pp:pp_diet ~eq IntDiet.(inter d1 (inter d2 d3)) IntDiet.(inter (inter d1 d2) d3)); 86 | add_test ~name:"distributive 1" [diet; diet; diet] 87 | (fun d1 d2 d3 -> 88 | check_eq ~pp:pp_diet ~eq IntDiet.(union d1 (inter d2 d3)) IntDiet.(inter (union d1 d2) (union d1 d3))); 89 | add_test ~name:"distributive 2" [diet; diet; diet] 90 | (fun d1 d2 d3 -> 91 | check_eq ~pp:pp_diet ~eq IntDiet.(inter d1 (union d2 d3)) IntDiet.(union (inter d1 d2) (inter d1 d3))); 92 | add_test ~name:"equality" [list1 interval; list1 int] check_equality; 93 | () 94 | -------------------------------------------------------------------------------- /lib/diet.ml: -------------------------------------------------------------------------------- 1 | (* 2 | * Copyright (C) 2016 David Scott 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | * 16 | *) 17 | 18 | module type ELT = sig 19 | type t 20 | val compare: t -> t -> int 21 | val zero: t 22 | val pred: t -> t 23 | val succ: t -> t 24 | val sub: t -> t -> t 25 | val add: t -> t -> t 26 | val to_string: t -> string 27 | end 28 | 29 | module type INTERVAL_SET = sig 30 | type elt 31 | type interval 32 | module Interval: sig 33 | val make: elt -> elt -> interval 34 | val x: interval -> elt 35 | val y: interval -> elt 36 | end 37 | type t 38 | val equal: t -> t -> bool 39 | val compare: t -> t -> int 40 | val pp: Format.formatter -> t -> unit 41 | val empty: t 42 | val is_empty: t -> bool 43 | val cardinal: t -> elt 44 | val mem: elt -> t -> bool 45 | val fold: (interval -> 'a -> 'a) -> t -> 'a -> 'a 46 | val fold_individual: (elt -> 'a -> 'a) -> t -> 'a -> 'a 47 | val iter: (interval -> unit) -> t -> unit 48 | val add: interval -> t -> t 49 | val remove: interval -> t -> t 50 | val min_elt: t -> interval 51 | val max_elt: t -> interval 52 | val choose: t -> interval 53 | val take: t -> elt -> (t * t) option 54 | val union: t -> t -> t 55 | val diff: t -> t -> t 56 | val inter: t -> t -> t 57 | val find_next_gap: elt -> t -> elt 58 | val check_invariants : t -> (unit, string) result 59 | val height : t -> int 60 | end 61 | 62 | module Make(Elt: ELT) = struct 63 | type elt = Elt.t 64 | 65 | module Elt = struct 66 | include Elt 67 | let ( - ) = sub 68 | let ( + ) = add 69 | end 70 | 71 | type interval = elt * elt 72 | 73 | module Interval = struct 74 | let make x y = 75 | if x > y then invalid_arg "Interval.make"; 76 | x, y 77 | let x = fst 78 | let y = snd 79 | end 80 | 81 | let ( > ) x y = Elt.compare x y > 0 82 | let ( >= ) x y = Elt.compare x y >= 0 83 | let ( < ) x y = Elt.compare x y < 0 84 | let ( <= ) x y = Elt.compare x y <= 0 85 | let eq x y = Elt.compare x y = 0 86 | let succ, pred = Elt.succ, Elt.pred 87 | 88 | type t = 89 | | Empty 90 | | Node: node -> t 91 | and node = { x: elt; y: elt; l: t; r: t; h: int; cardinal: elt } 92 | 93 | let rec cons_enum t enum = 94 | match t with 95 | | Empty -> enum 96 | | Node ({l; _} as node) -> cons_enum l (node::enum) 97 | 98 | let compare_with_invariant {x; y; _} {x = x'; y = y'; _} = 99 | if eq x x' && eq y y' then 0 100 | else if y < x' then -1 101 | else 1 102 | 103 | let rec compare_aux enum enum' = 104 | match enum, enum' with 105 | | [], [] -> 0 106 | | [], _ -> -1 107 | | _, [] -> 1 108 | | node::enum, node'::enum' -> 109 | (match compare_with_invariant node node' with 110 | | 0 -> compare_aux (cons_enum node.r enum) (cons_enum node'.r enum') 111 | | c -> c) 112 | 113 | let compare t t' = compare_aux (cons_enum t []) (cons_enum t' []) 114 | 115 | let equal t t' = compare t t' = 0 116 | 117 | let rec pp fmt = function 118 | | Empty -> Format.fprintf fmt "Empty" 119 | | Node n -> pp_node fmt n 120 | 121 | and pp_node fmt {x ;y; l; r; h; cardinal } = 122 | Format.pp_open_vbox fmt 0; 123 | Format.fprintf fmt "x: %s@," (Elt.to_string x); 124 | Format.fprintf fmt "y: %s@," (Elt.to_string y); 125 | Format.fprintf fmt "l:@[@\n%a@]@," pp l; 126 | Format.fprintf fmt "r:@[@\n%a@]@," pp r; 127 | Format.fprintf fmt "h: %d@," h; 128 | Format.fprintf fmt "cardinal: %s" (Elt.to_string cardinal); 129 | Format.pp_close_box fmt () 130 | 131 | let height = function 132 | | Empty -> 0 133 | | Node n -> n.h 134 | 135 | let cardinal = function 136 | | Empty -> Elt.zero 137 | | Node n -> n.cardinal 138 | 139 | let create x y l r = 140 | let h = max (height l) (height r) + 1 in 141 | let cardinal = Elt.(succ (y - x) + (cardinal l) + (cardinal r)) in 142 | Node { x; y; l; r; h; cardinal } 143 | 144 | let rec node x y l r = 145 | let hl = height l and hr = height r in 146 | let open Stdlib in 147 | if hl > hr + 2 then begin 148 | match l with 149 | | Empty -> assert false 150 | | Node { x = lx; y = ly; l = ll; r = lr; _ } -> 151 | if height ll >= (height lr) 152 | then node lx ly ll (node x y lr r) 153 | else match lr with 154 | | Empty -> assert false 155 | | Node { x = lrx; y = lry; l = lrl; r = lrr; _ } -> 156 | node lrx lry (node lx ly ll lrl) (node x y lrr r) 157 | end else if hr > hl + 2 then begin 158 | match r with 159 | | Empty -> assert false 160 | | Node { x = rx; y = ry; l = rl; r = rr; _ } -> 161 | if height rr >= height rl 162 | then node rx ry (node x y l rl) rr 163 | else match rl with 164 | | Empty -> assert false 165 | | Node { x = rlx; y = rly; l = rll; r = rlr; _ } -> 166 | node rlx rly (node x y l rll) (node rx ry rlr rr) 167 | end else create x y l r 168 | 169 | let depth tree = 170 | let rec depth tree k = match tree with 171 | | Empty -> k 0 172 | | Node n -> 173 | depth n.l (fun dl -> 174 | depth n.r (fun dr -> 175 | k (1 + (max dl dr)))) 176 | in depth tree (fun d -> d) 177 | 178 | module Invariant = struct 179 | 180 | let (>>=) xr f = 181 | match xr with 182 | | Ok x -> f x 183 | | e -> e 184 | 185 | let ensure b msg t = 186 | if b then 187 | Ok () 188 | else 189 | Error (Format.asprintf "%s: %a" msg pp t) 190 | 191 | let rec on_every_node d f = 192 | match d with 193 | | Empty -> Ok () 194 | | Node n -> 195 | f n d >>= fun () -> 196 | on_every_node n.l f >>= fun () -> 197 | on_every_node n.r f 198 | 199 | (* The pairs (x, y) in each interval are ordered such that x <= y *) 200 | let ordered { x; y; _ } = 201 | ensure 202 | (x <= y) 203 | "Pairs within each interval should be ordered" 204 | 205 | (* The intervals don't overlap *) 206 | let no_overlap { x; y; l; r; _ } n = 207 | let error = "Intervals should be ordered without overlap" in 208 | begin match l with 209 | | Empty -> Ok () 210 | | Node left -> 211 | ensure (left.y < x) error n 212 | end >>= fun () -> 213 | begin match r with 214 | | Empty -> Ok () 215 | | Node right -> 216 | ensure (right.x > y) error n 217 | end 218 | 219 | let no_adjacent { x; y; l; r; _ } n = 220 | let error = "Intervals should not be adjacent" in 221 | begin match l with 222 | | Empty -> Ok () 223 | | Node left -> 224 | ensure (Elt.succ left.y < x) error n 225 | end >>= fun () -> 226 | begin match r with 227 | | Empty -> Ok () 228 | | Node right -> 229 | ensure (Elt.pred right.x > y) error n 230 | end 231 | 232 | let node_height n = 233 | n.h 234 | 235 | let node_depth n = 236 | depth (Node n) 237 | 238 | (* The height is being stored correctly *) 239 | let height_equals_depth n = 240 | ensure 241 | (node_height n = node_depth n) 242 | "The height is not being maintained correctly" 243 | 244 | let balanced { l; r; _ } = 245 | let diff = height l - (height r) in 246 | let open Stdlib in 247 | ensure 248 | (-2 <= diff && diff <= 2) 249 | "The tree has become imbalanced" 250 | 251 | let check_cardinal { x; y; l; r; cardinal = c; _ } = 252 | ensure 253 | Elt.((c - cardinal l - cardinal r - y + x) = succ zero) 254 | "The cardinal value stored in the node is wrong" 255 | 256 | let check t = 257 | on_every_node t ordered >>= fun () -> 258 | on_every_node t no_overlap >>= fun () -> 259 | on_every_node t height_equals_depth >>= fun () -> 260 | on_every_node t balanced >>= fun () -> 261 | on_every_node t check_cardinal >>= fun () -> 262 | on_every_node t no_adjacent 263 | end 264 | 265 | let empty = Empty 266 | 267 | let is_empty = function 268 | | Empty -> true 269 | | _ -> false 270 | 271 | let rec mem elt = function 272 | | Empty -> false 273 | | Node n -> 274 | (* consider this interval *) 275 | (elt >= n.x && elt <= n.y) 276 | || 277 | (* or search left or search right *) 278 | (if elt < n.x then mem elt n.l else mem elt n.r) 279 | 280 | let rec min_elt = function 281 | | Empty -> raise Not_found 282 | | Node { x; y; l = Empty; _ } -> x, y 283 | | Node { l; _ } -> min_elt l 284 | 285 | let rec max_elt = function 286 | | Empty -> raise Not_found 287 | | Node { x; y; r = Empty; _ } -> x, y 288 | | Node { r; _ } -> max_elt r 289 | 290 | let choose = function 291 | | Empty -> raise Not_found 292 | | Node { x; y; _ } -> x, y 293 | 294 | (* fold over the maximal contiguous intervals *) 295 | let rec fold f t acc = match t with 296 | | Empty -> acc 297 | | Node n -> 298 | let acc = fold f n.l acc in 299 | let acc = f (n.x, n.y) acc in 300 | fold f n.r acc 301 | 302 | (* fold over individual elements *) 303 | let fold_individual f t acc = 304 | let range (from, upto) acc = 305 | let rec loop acc x = 306 | if eq x (succ upto) then acc else loop (f x acc) (succ x) in 307 | loop acc from in 308 | fold range t acc 309 | 310 | (* iterate over maximal contiguous intervals *) 311 | let iter f t = 312 | let f' itl () = 313 | f itl in 314 | fold f' t () 315 | 316 | (* return (x, y, l) where (x, y) is the maximal interval and [l] is 317 | the rest of the tree on the left (whose intervals are all smaller). *) 318 | let rec splitMax = function 319 | | { x; y; l; r = Empty; _} -> x, y, l 320 | | { r = Node r; _ } as n -> 321 | let u, v, r' = splitMax r in 322 | u, v, node n.x n.y n.l r' 323 | 324 | (* return (x, y, r) where (x, y) is the minimal interval and [r] is 325 | the rest of the tree on the right (whose intervals are all larger) *) 326 | let rec splitMin = function 327 | | { x; y; l = Empty; r; _} -> x, y, r 328 | | { l = Node l; _ } as n -> 329 | let u, v, l' = splitMin l in 330 | u, v, node n.x n.y l' n.r 331 | 332 | let addL = function 333 | | { l = Empty; _ } as n -> n 334 | | { l = Node l; _ } as n -> 335 | (* we might have to merge the new element with the maximal interval from 336 | the left *) 337 | let x', y', l' = splitMax l in 338 | if eq (succ y') n.x then { n with x = x'; l = l' } else n 339 | 340 | let addR = function 341 | | { r = Empty; _ } as n -> n 342 | | { r = Node r; _ } as n -> 343 | (* we might have to merge the new element with the minimal interval on 344 | the right *) 345 | let x', y', r' = splitMin r in 346 | if eq (succ n.y) x' then { n with y = y'; r = r' } else n 347 | 348 | let rec add (x, y) t = 349 | if y < x then invalid_arg "interval reversed"; 350 | match t with 351 | | Empty -> node x y Empty Empty 352 | (* completely to the left *) 353 | | Node n when y < (Elt.pred n.x) -> 354 | let l = add (x, y) n.l in 355 | node n.x n.y l n.r 356 | (* completely to the right *) 357 | | Node n when (Elt.succ n.y) < x -> 358 | let r = add (x, y) n.r in 359 | node n.x n.y n.l r 360 | (* overlap on the left only *) 361 | | Node n when x < n.x && y <= n.y -> 362 | let l = add (x, pred n.x) n.l in 363 | let n = addL { n with l } in 364 | node n.x n.y n.l n.r 365 | (* overlap on the right only *) 366 | | Node n when y > n.y && x >= n.x -> 367 | let r = add (succ n.y, y) n.r in 368 | let n = addR { n with r } in 369 | node n.x n.y n.l n.r 370 | (* overlap on both sides *) 371 | | Node n when x < n.x && y > n.y -> 372 | let l = add (x, pred n.x) n.l in 373 | let r = add (succ n.y, y) n.r in 374 | let n = addL { (addR { n with r }) with l } in 375 | node n.x n.y n.l n.r 376 | (* completely within *) 377 | | Node n -> Node n 378 | 379 | let union a b = 380 | let a' = cardinal a and b' = cardinal b in 381 | if a' > b' 382 | then fold add b a 383 | else fold add a b 384 | 385 | let merge l r = match l, r with 386 | | l, Empty -> l 387 | | Empty, r -> r 388 | | Node l, r -> 389 | let x, y, l' = splitMax l in 390 | node x y l' r 391 | 392 | let rec remove (x, y) t = 393 | if y < x then invalid_arg "interval reversed"; 394 | match t with 395 | | Empty -> Empty 396 | (* completely to the left *) 397 | | Node n when y < n.x -> 398 | let l = remove (x, y) n.l in 399 | node n.x n.y l n.r 400 | (* completely to the right *) 401 | | Node n when n.y < x -> 402 | let r = remove (x, y) n.r in 403 | node n.x n.y n.l r 404 | (* overlap on the left only *) 405 | | Node n when x < n.x && y < n.y -> 406 | let n' = node (succ y) n.y n.l n.r in 407 | remove (x, pred n.x) n' 408 | (* overlap on the right only *) 409 | | Node n when y > n.y && x > n.x -> 410 | let n' = node n.x (pred x) n.l n.r in 411 | remove (succ n.y, y) n' 412 | (* overlap on both sides *) 413 | | Node n when x <= n.x && y >= n.y -> 414 | let l = remove (x, n.x) n.l in 415 | let r = remove (n.y, y) n.r in 416 | merge l r 417 | (* completely within *) 418 | | Node n when eq y n.y -> 419 | node n.x (pred x) n.l n.r 420 | | Node n when eq x n.x -> 421 | node (succ y) n.y n.l n.r 422 | | Node n -> 423 | assert (n.x <= pred x); 424 | assert (succ y <= n.y); 425 | let r = node (succ y) n.y Empty n.r in 426 | node n.x (pred x) n.l r 427 | 428 | let diff a b = fold remove b a 429 | 430 | let inter a b = diff a (diff a b) 431 | 432 | let rec find_next_gap from = function 433 | | Empty -> from 434 | | Node n -> 435 | (* consider this interval *) 436 | if (from >= n.x && from <= n.y) then 437 | succ n.y 438 | (* or search left *) 439 | else if from < n.x then 440 | find_next_gap from n.l 441 | (* or search right *) 442 | else 443 | find_next_gap from n.r 444 | 445 | let take t n = 446 | let rec loop acc free n = 447 | if n = Elt.zero 448 | then Some (acc, free) 449 | else begin 450 | match ( 451 | try 452 | let i = choose free in 453 | let x, y = Interval.(x i, y i) in 454 | let len = Elt.(succ @@ y - x) in 455 | let will_use = if Stdlib.(Elt.compare n len < 0) then n else len in 456 | let i' = Interval.make x Elt.(pred @@ x + will_use) in 457 | Some ((add i' acc), (remove i' free), Elt.(n - will_use)) 458 | with 459 | | Not_found -> None 460 | ) with 461 | | Some (acc', free', n') -> loop acc' free' n' 462 | | None -> None 463 | end in 464 | loop empty t n 465 | 466 | let check_invariants = Invariant.check 467 | end 468 | 469 | module Int_elt = struct 470 | type t = int 471 | let compare a b = compare (a:int) b 472 | let zero = 0 473 | let pred = pred 474 | let succ = succ 475 | let sub = (-) 476 | let add = (+) 477 | let to_string = string_of_int 478 | end 479 | 480 | module Int = Make(Int_elt) 481 | module Int64 = Make(Int64) 482 | -------------------------------------------------------------------------------- /lib/diet.mli: -------------------------------------------------------------------------------- 1 | (* 2 | * Copyright (C) 2016 David Scott 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | * 16 | *) 17 | 18 | module type ELT = sig 19 | type t 20 | (** The type of the set elements. *) 21 | 22 | include Set.OrderedType with type t := t 23 | 24 | val zero: t 25 | (** The zeroth element *) 26 | 27 | val pred: t -> t 28 | (** Predecessor of an element *) 29 | 30 | val succ: t -> t 31 | (** Successor of an element *) 32 | 33 | val sub: t -> t -> t 34 | (** [sub a b] returns [a] - [b] *) 35 | 36 | val add: t -> t -> t 37 | (** [add a b] returns [a] + [b] *) 38 | 39 | val to_string: t -> string 40 | (** Display an element. *) 41 | end 42 | 43 | module type INTERVAL_SET = sig 44 | type elt 45 | (** The type of the set elements *) 46 | 47 | type interval 48 | (** An interval: a range (x, y) of set values where all the elements from 49 | x to y inclusive are in the set *) 50 | 51 | module Interval: sig 52 | val make: elt -> elt -> interval 53 | (** [make first last] construct an interval describing all the elements from 54 | [first] to [last] inclusive. *) 55 | 56 | val x: interval -> elt 57 | (** the starting element of the interval *) 58 | 59 | val y: interval -> elt 60 | (** the ending element of the interval *) 61 | end 62 | 63 | type t 64 | (** The type of sets *) 65 | 66 | val equal : t -> t -> bool 67 | (** Equality over sets *) 68 | 69 | val compare : t -> t -> int 70 | (** Comparison over sets *) 71 | 72 | val pp: Format.formatter -> t -> unit 73 | (** Pretty-print a set *) 74 | 75 | val empty: t 76 | (** The empty set *) 77 | 78 | val is_empty: t -> bool 79 | (** Test whether a set is empty or not *) 80 | 81 | val cardinal: t -> elt 82 | (** [cardinal t] is the number of elements in the set [t] *) 83 | 84 | val mem: elt -> t -> bool 85 | (** [mem elt t] tests whether [elt] is in set [t] *) 86 | 87 | val fold: (interval -> 'a -> 'a) -> t -> 'a -> 'a 88 | (** [fold f t acc] folds [f] across all the intervals in [t] *) 89 | 90 | val fold_individual: (elt -> 'a -> 'a) -> t -> 'a -> 'a 91 | (** [fold_individual f t acc] folds [f] across all the individual elements of [t] *) 92 | 93 | val iter: (interval -> unit) -> t -> unit 94 | (** [iter f t] iterates [f] across all the intervals in [t] *) 95 | 96 | val add: interval -> t -> t 97 | (** [add interval t] returns the set consisting of [t] plus [interval] *) 98 | 99 | val remove: interval -> t -> t 100 | (** [remove interval t] returns the set consisting of [t] minus [interval] *) 101 | 102 | val min_elt: t -> interval 103 | (** [min_elt t] returns the smallest (in terms of the ordering) interval in 104 | [t], or raises [Not_found] if the set is empty. *) 105 | 106 | val max_elt: t -> interval 107 | (** [max_elt t] returns the largest (in terms of the ordering) interval in 108 | [t], or raises [Not_found] if the set is empty. *) 109 | 110 | val choose: t -> interval 111 | (** [choose t] returns one interval, or raises Not_found if the set is empty *) 112 | 113 | val take: t -> elt -> (t * t) option 114 | (** [take n] returns [Some a, b] where [cardinal a = n] and [diff t a = b] 115 | or [None] if [cardinal t < n] *) 116 | 117 | val union: t -> t -> t 118 | (** set union *) 119 | 120 | val diff: t -> t -> t 121 | (** set difference *) 122 | 123 | val inter: t -> t -> t 124 | (** set intersection *) 125 | 126 | val find_next_gap: elt -> t -> elt 127 | (** [find_next_gap from t] returns the next element that's 128 | absent in set [t] and greater than or equal to [from] **) 129 | 130 | (**/**) 131 | 132 | val check_invariants : t -> (unit, string) result 133 | (** [check_invariants t] returns [Ok ()] if the underlying invariants hold, or 134 | an error message. *) 135 | 136 | val height : t -> int 137 | (** [height t] return the height of the corresponding tree. *) 138 | end 139 | 140 | 141 | module Make(Elt: ELT): INTERVAL_SET with type elt = Elt.t 142 | 143 | module Int : INTERVAL_SET with type elt = int 144 | module Int64 : INTERVAL_SET with type elt = int64 145 | -------------------------------------------------------------------------------- /lib/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name diet) 3 | (public_name diet) 4 | (libraries stdlib-shims)) 5 | -------------------------------------------------------------------------------- /lib_test/dune: -------------------------------------------------------------------------------- 1 | (tests 2 | (names test) 3 | (libraries diet oUnit)) 4 | -------------------------------------------------------------------------------- /lib_test/test.ml: -------------------------------------------------------------------------------- 1 | (* 2 | * Copyright (C) 2013 Citrix Inc 3 | * 4 | * Permission to use, copy, modify, and/or distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH 9 | * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY 10 | * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, 11 | * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM 12 | * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR 13 | * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR 14 | * PERFORMANCE OF THIS SOFTWARE. 15 | *) 16 | open OUnit2 17 | 18 | module Int_comparable = struct 19 | type t = int 20 | let compare (x: t) (y: t) = Stdlib.compare x y 21 | end 22 | 23 | module IntDiet = struct 24 | include Diet.Int 25 | 26 | let add (x, y) t = 27 | add (Interval.make x y) t 28 | 29 | let remove (x, y) t = 30 | remove (Interval.make x y) t 31 | 32 | let elements t = fold_individual (fun x acc -> x :: acc) t [] |> List.rev 33 | end 34 | 35 | module IntSet = Set.Make(Int_comparable) 36 | 37 | let test_printer ctxt = 38 | let open IntDiet in 39 | let t = add (1, 2) @@ add (4, 5) empty in 40 | let got = Format.asprintf "%a" pp t in 41 | let expected = String.trim {| 42 | x: 4 43 | y: 5 44 | l: 45 | x: 1 46 | y: 2 47 | l: 48 | Empty 49 | r: 50 | Empty 51 | h: 1 52 | cardinal: 2 53 | r: 54 | Empty 55 | h: 2 56 | cardinal: 4|} 57 | in 58 | assert_equal ~ctxt ~printer:(fun s -> s) ~cmp:String.equal expected got 59 | 60 | let test_find_next_gap ctxt = 61 | let open IntDiet in 62 | let set = add (9, 9) @@ add (5, 7) empty in 63 | let test n ~expected = 64 | let got = find_next_gap n set in 65 | assert_equal ~ctxt ~printer:string_of_int expected got 66 | in 67 | test 0 ~expected:0; 68 | test 5 ~expected:8; 69 | test 9 ~expected:10; 70 | for i = 0 to 12 do 71 | let e = find_next_gap i set in 72 | assert (e >= i); 73 | assert (not @@ mem e set); 74 | assert (e == i || mem i set); 75 | assert (find_next_gap e set = e) 76 | done 77 | 78 | 79 | let check_invariants_ok diet = 80 | let expected = Ok () in 81 | let got = IntDiet.check_invariants diet in 82 | let printer = function 83 | | Ok () -> "no error" 84 | | Error e -> e 85 | in 86 | assert_equal ~printer expected got 87 | 88 | let make_random n m = 89 | let rec loop set diet = function 90 | | 0 -> set, diet 91 | | m -> 92 | let r = Random.int n in 93 | let r' = Random.int (n - r) + r in 94 | let add = Random.bool () in 95 | let rec range from upto = 96 | if from > upto then [] else from :: (range (from + 1) upto) in 97 | let set = List.fold_left (fun set elt -> (if add then IntSet.add else IntSet.remove) elt set) set (range r r') in 98 | let diet' = (if add then IntDiet.add else IntDiet.remove) (r, r') diet in 99 | check_invariants_ok diet'; 100 | loop set diet' (m - 1) in 101 | loop IntSet.empty IntDiet.empty m 102 | 103 | let show_list show l = 104 | Printf.sprintf "[%s]" (String.concat "; " (List.map show l)) 105 | 106 | let assert_equal_int_list ?msg ~ctxt expected got = 107 | let printer = show_list string_of_int in 108 | assert_equal ?msg ~ctxt ~printer expected got 109 | 110 | let check_equals ?msg ~ctxt set diet = 111 | assert_equal_int_list 112 | ?msg 113 | ~ctxt 114 | (IntSet.elements set) 115 | (IntDiet.elements diet) 116 | 117 | let test_operators ops ctxt = 118 | for _ = 1 to 100 do 119 | let set1, diet1 = make_random 1000 1000 in 120 | let set2, diet2 = make_random 1000 1000 in 121 | check_equals ~ctxt set1 diet1; 122 | List.iter (fun (op_name, set_op, diet_op) -> 123 | let msg = "When checking " ^ op_name in 124 | let set3 = set_op set1 set2 in 125 | let diet3 = diet_op diet1 diet2 in 126 | check_equals ~msg ~ctxt set3 diet3 127 | ) ops 128 | done 129 | 130 | let test_depth ctxt = 131 | let n = 0x100000 in 132 | let init = IntDiet.add (0, n) IntDiet.empty in 133 | (* take away every other block *) 134 | let rec sub m acc = 135 | if m <= 0 then acc 136 | else sub (m - 2) (IntDiet.remove (m, m) acc) in 137 | let set = sub n init in 138 | let d = IntDiet.height set in 139 | let bound = int_of_float (log (float_of_int n) /. (log 2.)) + 1 in 140 | assert_bool "Depth lower than bound" (d <= bound); 141 | let set = sub (n - 1) set in 142 | let got = IntDiet.height set in 143 | let expected = 1 in 144 | assert_equal ~ctxt ~printer:string_of_int expected got 145 | 146 | let test_add_1 ctxt = 147 | let open IntDiet in 148 | assert_equal_int_list ~ctxt 149 | [3; 4] 150 | (elements @@ add (3, 4) @@ add (3, 3) empty) 151 | 152 | let test_remove_1 ctxt = 153 | let open IntDiet in 154 | assert_equal_int_list ~ctxt 155 | [8] 156 | (elements @@ remove (6, 7) @@ add (7, 8) empty) 157 | 158 | let test_remove_2 ctxt = 159 | let open IntDiet in 160 | assert_equal_int_list ~ctxt 161 | [5; 6] 162 | (elements @@ diff (add (9, 9) @@ add (5, 7) empty) (add (7, 9) empty)) 163 | 164 | let test_adjacent_1 _ctxt = 165 | let open IntDiet in 166 | let set = add (9, 9) @@ add (8, 8) empty in 167 | check_invariants_ok set 168 | 169 | let test_equal = 170 | let open IntDiet in 171 | let make l = List.fold_left (fun diet intvl -> add intvl diet) empty l in 172 | let test ~nodes ~nodes' ~expected ctxt = 173 | let diet = make nodes in 174 | let diet' = make nodes' in 175 | assert_equal ~ctxt expected (IntDiet.equal diet diet') 176 | in 177 | [ "Empty" >:: test ~nodes:[] ~nodes':[] ~expected:true 178 | ; "Single node" >:: test ~nodes:[(1, 2)] ~nodes':[(1, 2)] ~expected:true 179 | ; "Two nodes swapped" >:: test ~nodes:[(1, 2); (4, 5)] ~nodes':[(4, 5); (1, 2)] ~expected:true 180 | ; "Swapped nodes 1" >:: test 181 | ~nodes:[(7, 8); (1, 2); (10, 11); (4, 5); (13, 14)] 182 | ~nodes':[(7, 8); (4, 5); (13, 14); (1, 2); (10, 11)] 183 | ~expected:true 184 | ; "Swapped nodes 2" >:: test 185 | ~nodes:[(4, 5); (1, 2); (10, 11); (7, 8)] 186 | ~nodes':[(7, 8); (4, 5); (10, 11); (1, 2)] 187 | ~expected:true 188 | ; "Swapped nodes 3" >:: test 189 | ~nodes:[(7, 8); (4, 5); (1, 2)] 190 | ~nodes':[(1, 2); (7, 8); (4, 5)] 191 | ~expected:true 192 | ; "Non-empty and empty" >:: test ~nodes:[(1, 2)] ~nodes':[] ~expected:false 193 | ; "Different roots" >:: test ~nodes:[(1, 2)] ~nodes':[(4, 5)] ~expected:false 194 | ] 195 | 196 | let suite = 197 | "diet" >::: 198 | [ "adding an element to the right" >:: test_add_1 199 | ; "removing an element on the left" >:: test_remove_1 200 | ; "removing an elements from two intervals" >:: test_remove_2 201 | ; "test adjacent intervals are coalesced" >:: test_adjacent_1 202 | ; "logarithmic depth" >:: test_depth 203 | ; "operators" >:: test_operators 204 | [ ("union", IntSet.union, IntDiet.union) 205 | ; ("diff", IntSet.diff, IntDiet.diff) 206 | ; ("intersection", IntSet.inter, IntDiet.inter) 207 | ] 208 | ; "finding the next gap" >:: test_find_next_gap 209 | ; "printer" >:: test_printer 210 | ; "equality" >::: test_equal 211 | ] 212 | 213 | let () = run_test_tt_main suite 214 | --------------------------------------------------------------------------------