├── src ├── dune ├── mcavl.ml ├── mcavl.mli ├── pure.ml ├── mcmap.ml ├── core.ml ├── mcset.ml └── s.ml ├── tests ├── tests.ml ├── qclin_set.ml ├── dune ├── qc_map.ml ├── qc_set.ml ├── test_mc.ml ├── bench.ml ├── test_seq.ml └── test_view.ml ├── .gitignore ├── mcavl.opam.template ├── .ocamlformat ├── dune-project ├── LICENSE ├── mcavl.opam └── README.md /src/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (public_name mcavl)) 3 | -------------------------------------------------------------------------------- /src/mcavl.ml: -------------------------------------------------------------------------------- 1 | module type Ordered = S.Ordered 2 | 3 | module Set = Mcset.Make 4 | module Map = Mcmap.Make 5 | -------------------------------------------------------------------------------- /tests/tests.ml: -------------------------------------------------------------------------------- 1 | let () = 2 | Alcotest.run "Mcavl" 3 | [ "Sequential", Test_seq.tests 4 | ; "View", Test_view.tests 5 | ; "Domains", Test_mc.tests ] 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.annot 2 | *.cmo 3 | *.cma 4 | *.cmi 5 | *.a 6 | *.o 7 | *.cmx 8 | *.cmxs 9 | *.cmxa 10 | 11 | .merlin 12 | *.install 13 | *.coverage 14 | *.sw[lmnop] 15 | 16 | _build/ 17 | _doc/ 18 | _coverage/ 19 | _opam/ 20 | -------------------------------------------------------------------------------- /mcavl.opam.template: -------------------------------------------------------------------------------- 1 | pin-depends: [ 2 | [ "lin.dev" "git+https://github.com/jmid/multicoretests.git#opam-pkg-rename" ] 3 | [ "qcheck-stm.dev" "git+https://github.com/jmid/multicoretests.git#opam-pkg-rename" ] 4 | [ "multicoretests-util.dev" "git+https://github.com/jmid/multicoretests.git#opam-pkg-rename" ] 5 | ] 6 | -------------------------------------------------------------------------------- /.ocamlformat: -------------------------------------------------------------------------------- 1 | version = 0.24.1 2 | profile = ocamlformat 3 | let-binding-spacing = compact 4 | sequence-style = separator 5 | doc-comments = after-when-possible 6 | exp-grouping = preserve 7 | break-cases = toplevel 8 | cases-exp-indent = 4 9 | cases-matching-exp-indent = normal 10 | if-then-else = keyword-first 11 | parens-tuple = multi-line-only 12 | -------------------------------------------------------------------------------- /dune-project: -------------------------------------------------------------------------------- 1 | (lang dune 3.3) 2 | 3 | (name mcavl) 4 | 5 | (generate_opam_files true) 6 | 7 | (source (github art-w/mcavl)) 8 | 9 | (authors "Arthur Wendling") 10 | 11 | (maintainers "art.wendling@gmail.com") 12 | 13 | (documentation "https://art-w.github.io/mcavl/mcavl/Mcavl") 14 | 15 | (license MIT) 16 | 17 | (package 18 | (name mcavl) 19 | (synopsis "Lock-free Sets and Maps for OCaml multicore") 20 | (depends 21 | (ocaml (>= 5.0.0)) 22 | dune 23 | (alcotest :with-test) 24 | (domainslib :with-test) 25 | (lin :with-test) 26 | (qcheck-stm :with-test) 27 | (multicoretests-util :with-test) 28 | )) 29 | -------------------------------------------------------------------------------- /tests/qclin_set.ml: -------------------------------------------------------------------------------- 1 | module Set_sig = struct 2 | module S = Mcavl.Set (Int) 3 | 4 | type t = S.t 5 | 6 | let init () = S.empty () 7 | 8 | let cleanup _ = () 9 | 10 | open Lin_api 11 | 12 | let nat = nat_small 13 | 14 | let api = 15 | [ val_ "S.add" S.add (nat @-> t @-> returning unit) 16 | ; val_ "S.remove" S.remove (nat @-> t @-> returning bool) 17 | ; val_ "S.mem" S.mem (nat @-> t @-> returning bool) 18 | ; val_ "S.cardinal" S.cardinal (t @-> returning int) ] 19 | end 20 | 21 | module HT = Lin_api.Make (Set_sig) ;; 22 | 23 | QCheck_base_runner.run_tests_main 24 | [HT.lin_test `Domain ~count:10_000 ~name:"Mcavl.Set"] 25 | -------------------------------------------------------------------------------- /tests/dune: -------------------------------------------------------------------------------- 1 | (executable 2 | (name tests) 3 | (modules test_seq test_view test_mc tests) 4 | (libraries mcavl alcotest)) 5 | 6 | (executable 7 | (name bench) 8 | (modules bench) 9 | (libraries mcavl domainslib unix)) 10 | 11 | (executable 12 | (name qc_set) 13 | (modules qc_set) 14 | (libraries mcavl qcheck qcheck-stm) 15 | (preprocess 16 | (pps ppx_deriving.show))) 17 | 18 | (executable 19 | (name qclin_set) 20 | (modules qclin_set) 21 | (libraries mcavl qcheck lin) 22 | (preprocess 23 | (pps ppx_deriving.show))) 24 | 25 | (executable 26 | (name qc_map) 27 | (modules qc_map) 28 | (libraries mcavl qcheck qcheck-stm) 29 | (preprocess 30 | (pps ppx_deriving.show))) 31 | -------------------------------------------------------------------------------- /src/mcavl.mli: -------------------------------------------------------------------------------- 1 | (** An imperative, mutable, implementation of totally ordered Sets and Maps 2 | datastructures, with the following properties: 3 | 4 | - {b Thread-safe:} concurrent updates from multiple domains are possible. 5 | These collections can typically be used to synchronize multiple workers 6 | towards a common goal. 7 | - {b Lock-free:} concurrent modifications are able to make progress without 8 | waiting for others to finish. 9 | - {b Linearizable:} concurrent operations appears to take place sequentially 10 | on a linearized timeline, providing a coherent collection of elements at 11 | all time. It is furthermore possible to take an {b O(1)} snapshot/copy 12 | to observe its collection of elements at a given point in time. 13 | *) 14 | 15 | module type Ordered = S.Ordered 16 | 17 | module Set (Ord : Ordered) : S.Set(Ord).S 18 | 19 | module Map (Ord : Ordered) : S.Map(Ord).S 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Arthur Wendling 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mcavl.opam: -------------------------------------------------------------------------------- 1 | # This file is generated by dune, edit dune-project instead 2 | opam-version: "2.0" 3 | synopsis: "Lock-free Sets and Maps for OCaml multicore" 4 | maintainer: ["art.wendling@gmail.com"] 5 | authors: ["Arthur Wendling"] 6 | license: "MIT" 7 | homepage: "https://github.com/art-w/mcavl" 8 | doc: "https://art-w.github.io/mcavl/mcavl/Mcavl" 9 | bug-reports: "https://github.com/art-w/mcavl/issues" 10 | depends: [ 11 | "ocaml" {>= "5.0.0"} 12 | "dune" {>= "3.3"} 13 | "alcotest" {with-test} 14 | "domainslib" {with-test} 15 | "lin" {with-test} 16 | "qcheck-stm" {with-test} 17 | "multicoretests-util" {with-test} 18 | "odoc" {with-doc} 19 | ] 20 | build: [ 21 | ["dune" "subst"] {dev} 22 | [ 23 | "dune" 24 | "build" 25 | "-p" 26 | name 27 | "-j" 28 | jobs 29 | "@install" 30 | "@runtest" {with-test} 31 | "@doc" {with-doc} 32 | ] 33 | ] 34 | dev-repo: "git+https://github.com/art-w/mcavl.git" 35 | pin-depends: [ 36 | [ "lin.dev" "git+https://github.com/jmid/multicoretests.git#opam-pkg-rename" ] 37 | [ "qcheck-stm.dev" "git+https://github.com/jmid/multicoretests.git#opam-pkg-rename" ] 38 | [ "multicoretests-util.dev" "git+https://github.com/jmid/multicoretests.git#opam-pkg-rename" ] 39 | ] 40 | -------------------------------------------------------------------------------- /tests/qc_map.ml: -------------------------------------------------------------------------------- 1 | module M = Map.Make (Int) 2 | module T = Mcavl.Map (Int) 3 | open QCheck 4 | open STM 5 | 6 | module Conf = struct 7 | type cmd = 8 | | Add of int * int 9 | | Remove of int 10 | | Find_opt of int 11 | | Copy_find_opt of int 12 | | Cardinal 13 | [@@deriving show {with_path= false}] 14 | 15 | type state = int M.t 16 | 17 | type sut = int T.t 18 | 19 | let arb_cmd _ = 20 | let int_gen = Gen.int_bound 10 in 21 | QCheck.make ~print:show_cmd 22 | (Gen.oneof 23 | [ Gen.map2 (fun i j -> Add (i, j)) int_gen Gen.nat 24 | ; Gen.map (fun i -> Remove i) int_gen 25 | ; Gen.map (fun i -> Find_opt i) int_gen 26 | ; Gen.map (fun i -> Copy_find_opt i) int_gen 27 | ; Gen.return Cardinal ] ) 28 | 29 | let init_state = M.empty 30 | 31 | let init_sut () = T.empty () 32 | 33 | let cleanup _ = () 34 | 35 | let next_state c s = 36 | match c with 37 | | Add (k, v) -> M.add k v s 38 | | Remove k -> M.remove k s 39 | | Find_opt _ | Copy_find_opt _ | Cardinal -> s 40 | 41 | let run c r = 42 | match c with 43 | | Add (k, v) -> Res (unit, T.add k v r) 44 | | Remove k -> Res (bool, T.remove k r) 45 | | Find_opt k -> Res (option int, T.find_opt k r) 46 | | Copy_find_opt k -> Res (option int, T.find_opt k (T.copy r)) 47 | | Cardinal -> Res (int, T.cardinal r) 48 | 49 | let precond _ _ = true 50 | 51 | let postcond c s res = 52 | match c, res with 53 | | Add _, Res ((Unit, _), _) -> true 54 | | Remove k, Res ((Bool, _), found) -> found = M.mem k s 55 | | (Find_opt k | Copy_find_opt k), Res ((Option Int, _), (m : int option)) -> 56 | m = M.find_opt k s 57 | | Cardinal, Res ((Int, _), m) -> m = M.cardinal s 58 | | _ -> assert false 59 | end 60 | 61 | module CT = STM.Make (Conf) 62 | 63 | let () = 64 | QCheck_runner.run_tests_main 65 | [ CT.agree_test ~count:10_000 ~name:"seq" 66 | ; CT.agree_test_par ~count:10_000 ~name:"par" ] 67 | -------------------------------------------------------------------------------- /tests/qc_set.ml: -------------------------------------------------------------------------------- 1 | module S = Set.Make (Int) 2 | module T = Mcavl.Set (Int) 3 | open QCheck 4 | open STM 5 | 6 | module Conf = struct 7 | type cmd = 8 | | Add of int 9 | | Remove of int 10 | | Mem of int 11 | | Cardinal 12 | | Min_elt_opt 13 | | Choose_opt 14 | [@@deriving show {with_path= false}] 15 | 16 | type state = S.t 17 | 18 | type sut = T.t 19 | 20 | let arb_cmd _ = 21 | let int_gen = Gen.int_bound 10 in 22 | QCheck.make ~print:show_cmd 23 | (Gen.oneof 24 | [ Gen.map (fun i -> Add i) int_gen 25 | ; Gen.map (fun i -> Remove i) int_gen 26 | ; Gen.map (fun i -> Mem i) int_gen 27 | ; Gen.return Cardinal 28 | ; Gen.return Choose_opt 29 | ; Gen.return Min_elt_opt ] ) 30 | 31 | let init_state = S.empty 32 | 33 | let init_sut () = T.empty () 34 | 35 | let cleanup _ = () 36 | 37 | let next_state c s = 38 | match c with 39 | | Add i -> S.add i s 40 | | Remove i -> S.remove i s 41 | | Mem _ | Cardinal | Choose_opt | Min_elt_opt -> s 42 | 43 | let run c r = 44 | match c with 45 | | Add i -> Res (unit, T.add i r) 46 | | Remove i -> Res (bool, T.remove i r) 47 | | Mem i -> Res (bool, T.mem i r) 48 | | Cardinal -> Res (int, T.cardinal r) 49 | | Choose_opt -> Res (option int, T.choose_opt r) 50 | | Min_elt_opt -> Res (option int, T.min_elt_opt r) 51 | 52 | let precond _ _ = true 53 | 54 | let postcond c s res = 55 | match c, res with 56 | | Add _, Res ((Unit, _), _) -> true 57 | | Remove i, Res ((Bool, _), found) -> found = S.mem i s 58 | | Mem i, Res ((Bool, _), m) -> m = S.mem i s 59 | | Cardinal, Res ((Int, _), m) -> m = S.cardinal s 60 | | Choose_opt, Res ((Option Int, _), None) -> S.is_empty s 61 | | Choose_opt, Res ((Option Int, _), Some x) -> S.mem x s 62 | | Min_elt_opt, Res ((Option Int, _), m) -> m = S.min_elt_opt s 63 | | _ -> assert false 64 | end 65 | 66 | module CT = STM.Make (Conf) 67 | 68 | let () = 69 | QCheck_runner.run_tests_main 70 | [ CT.agree_test ~count:10_000 ~name:"seq" 71 | ; CT.agree_test_par ~count:10_000 ~name:"par" ] 72 | -------------------------------------------------------------------------------- /tests/test_mc.ml: -------------------------------------------------------------------------------- 1 | module S = Mcavl.Set (Int) 2 | 3 | let test_all () = 4 | let t = S.empty () in 5 | for i = 1 to 1000 do 6 | S.add ((4 * i) + 2) t 7 | done ; 8 | let started = Array.init 4 (fun _ -> Atomic.make false) in 9 | let started_other = Atomic.make false in 10 | let await ~d i = 11 | if i = 100 12 | then begin 13 | Atomic.set started.(d) true ; 14 | while not (Atomic.get started_other) do 15 | Domain.cpu_relax () 16 | done 17 | end 18 | in 19 | let copy = ref t in 20 | Array.iter Domain.join 21 | [| Domain.spawn (fun () -> 22 | for i = 1 to 1000 do 23 | S.add (4 * i) t ; 24 | await ~d:0 i 25 | done ) 26 | ; Domain.spawn (fun () -> 27 | for i = 1 to 1000 do 28 | S.add (4 * i) t ; 29 | await ~d:1 i 30 | done ) 31 | ; Domain.spawn (fun () -> 32 | for i = 1 to 1000 do 33 | S.add ((4 * i) + 1) t ; 34 | await ~d:2 i 35 | done ) 36 | ; Domain.spawn (fun () -> 37 | for i = 1 to 1000 do 38 | assert (S.remove ((4 * i) + 2) t) ; 39 | await ~d:3 i 40 | done ) 41 | ; Domain.spawn (fun () -> 42 | while not (Array.for_all Atomic.get started) do 43 | Domain.cpu_relax () 44 | done ; 45 | Atomic.set started_other true ; 46 | let t' = S.copy t in 47 | Array.iter Domain.join 48 | [| Domain.spawn (fun () -> 49 | let ok = ref true in 50 | assert (S.mem 4 t') ; 51 | for i = 1 to 1000 do 52 | let rem = S.remove ((4 * i) + 0) t' in 53 | if !ok then ok := rem else assert (not rem) 54 | done ; 55 | assert (not !ok) ) 56 | ; Domain.spawn (fun () -> 57 | let ok = ref true in 58 | assert (S.mem 5 t') ; 59 | for i = 1 to 1000 do 60 | let rem = S.remove ((4 * i) + 1) t' in 61 | if !ok then ok := rem else assert (not rem) 62 | done ; 63 | assert (not !ok) ) 64 | ; Domain.spawn (fun () -> 65 | let ok = ref false in 66 | for i = 1 to 1000 do 67 | let rem = S.remove ((4 * i) + 2) t' in 68 | if not !ok then ok := rem else assert rem 69 | done ; 70 | assert !ok ) 71 | ; Domain.spawn (fun () -> 72 | for i = 1 to 1000 do 73 | S.add ((4 * i) + 3) t' 74 | done ) |] ; 75 | copy := t' ) |] ; 76 | Alcotest.(check int) "cardinal" 2000 (S.cardinal t) ; 77 | Alcotest.(check int) "cardinal copy" 1000 (S.cardinal !copy) 78 | 79 | let test_all () = 80 | for _ = 0 to 100 do 81 | test_all () 82 | done 83 | 84 | let tests = 85 | let open Alcotest in 86 | [test_case "all" `Quick test_all] 87 | -------------------------------------------------------------------------------- /tests/bench.ml: -------------------------------------------------------------------------------- 1 | let nb = try int_of_string Sys.argv.(1) with _ -> 100_000 2 | 3 | let max_domains = try int_of_string Sys.argv.(2) with _ -> 8 4 | 5 | module T = Domainslib.Task 6 | 7 | module type SET = sig 8 | type elt = int 9 | 10 | type t 11 | 12 | val empty : unit -> t 13 | 14 | val copy : t -> t 15 | 16 | val add : elt -> t -> unit 17 | 18 | val remove : elt -> t -> bool 19 | 20 | val cardinal : t -> int 21 | end 22 | 23 | module Mcset_int = Mcavl.Set (Int) 24 | 25 | module Naive = struct 26 | module S = Set.Make (Int) 27 | 28 | type elt = S.elt 29 | 30 | type t = S.t Atomic.t 31 | 32 | let empty () = Atomic.make S.empty 33 | 34 | let copy t = 35 | let s = Atomic.get t in 36 | Atomic.make s 37 | 38 | let rec add x t = 39 | let s = Atomic.get t in 40 | let s' = S.add x s in 41 | if Atomic.compare_and_set t s s' then () else add x t 42 | 43 | let rec remove x t = 44 | let s = Atomic.get t in 45 | let s' = S.remove x s in 46 | if Atomic.compare_and_set t s s' then S.mem x s else remove x t 47 | 48 | let cardinal t = S.cardinal (Atomic.get t) 49 | end 50 | 51 | module Test (Config : sig 52 | val pool : T.pool 53 | 54 | val nb_threads : int 55 | 56 | module S : SET 57 | end) = 58 | struct 59 | module S = Config.S 60 | 61 | let () = Printf.printf "%i%!" Config.nb_threads 62 | 63 | let bench ~init fn = 64 | let last = ref None in 65 | let metrics = 66 | Array.init 11 67 | @@ fun _ -> 68 | let input = init () in 69 | let t0 = Unix.gettimeofday () in 70 | let r = fn input in 71 | let t1 = Unix.gettimeofday () in 72 | last := Some r ; 73 | t1 -. t0 74 | in 75 | Array.sort Float.compare metrics ; 76 | let median = metrics.(Array.length metrics / 2) in 77 | Printf.printf "\t%f%!" (1000.0 *. median) ; 78 | match !last with 79 | | Some result -> result 80 | | None -> assert false 81 | 82 | let iter start finish body = T.parallel_for Config.pool ~start ~finish ~body 83 | 84 | let t_full = 85 | let t = 86 | bench 87 | ~init:(fun () -> S.empty ()) 88 | (fun t -> 89 | iter 1 nb (fun i -> S.add i t) ; 90 | t ) 91 | in 92 | assert (S.cardinal t = nb) ; 93 | t 94 | 95 | let () = 96 | bench 97 | ~init:(fun () -> 98 | let t = S.empty () in 99 | iter 1 nb (fun i -> S.add i t) ; 100 | t ) 101 | (fun t -> 102 | iter 1 nb (fun i -> assert (S.remove i t)) ; 103 | assert (S.cardinal t = 0) ) 104 | 105 | let () = 106 | bench 107 | ~init:(fun () -> S.copy t_full) 108 | (fun t -> 109 | iter 1 nb (fun i -> assert (S.remove i t)) ; 110 | assert (S.cardinal t = 0) ) ; 111 | assert (S.cardinal t_full = nb) 112 | end 113 | 114 | let run domains (module Impl : SET) = 115 | let module Config = struct 116 | let pool = T.setup_pool ~num_additional_domains:domains () 117 | 118 | let nb_threads = domains + 1 119 | 120 | module S = Impl 121 | end in 122 | T.run Config.pool (fun () -> 123 | let module Run = Test (Config) in 124 | () ) ; 125 | T.teardown_pool Config.pool ; 126 | Printf.printf "\n%!" ; 127 | () 128 | 129 | let () = 130 | Printf.printf "Mcset.t\n%!" ; 131 | Printf.printf "CPU\tADD\tREMOVE\tCOPY REMOVE\n%!" ; 132 | for domains = 0 to max_domains - 1 do 133 | run domains (module Mcset_int) 134 | done ; 135 | Printf.printf "\n%!" 136 | 137 | let () = 138 | Printf.printf "Stdlib.Set.t Atomic.t\n%!" ; 139 | Printf.printf "CPU\tADD\tREMOVE\tCOPY REMOVE\n%!" ; 140 | for domains = 0 to max_domains - 1 do 141 | run domains (module Naive) 142 | done ; 143 | Printf.printf "\n%!" 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Lock-free Sets and Maps for OCaml multicore: [**online documentation**](https://art-w.github.io/mcavl/mcavl/Mcavl) 2 | 3 | - The queries and updates are linearizable, such that you can pretend that they happened in a specific order when debugging your algorithms (even though their execution was interleaved.) The `copy` operation provides an `O(1)` snapshot of the collection to observe a coherent copy-on-write view of the elements in the linearized timeline: 4 | 5 | ![Linearized timeline](https://art-w.github.io/mcavl/linearized.png) 6 | 7 | - The algorithms are a crossover between the imperative and purely functional AVL tree. Atomic updates are done in place such that other threads get immediate access... but rebalancing may need to update multiple nodes at once and so it creates new nodes to swap them in one atomic update: 8 | 9 | ![Atomic rebalancing](https://art-w.github.io/mcavl/balance.png) 10 | 11 | Since the old nodes are still valid, this doesn't impact concurrent threads that could be traversing the old branches. The replaced nodes are marked "dead" to avoid loosing concurrent removes of their corresponding values (which must now be done on the new rebalanced nodes). 12 | 13 | - The binary tree rebalancing is optimistic (!) The AVL tree can end up locally imbalanced when a node rebalancing was disturbed by nearby concurrent updates, either because it observed the wrong subtree's heights or because it was unable to complete the rebalancing without loosing progress from another thread. This seems to work well in practice, because this type of contention happens near the leaves: The higher-up the tree a node is, the less frequently it will be rebalanced and the less likely contention is to happen. Further operations on an imbalanced subtree will also generate new opportunities to repair it. 14 | 15 | - The `add` operation always insert new elements at the leaves, and all other operations are careful to not drop a leaf that could concurrently welcome an insertion. 16 | 17 | - The `remove` function requires a bit more care to play nice with concurrent threads. The standard procedure would "teleport" a leaf value to replace the removed one... but this risks hiding the teleported element from concurrent traversals that would then believe it didn't exist in the set. So the remove operation first mark the element as removed, then performs safe rotations to push that node towards the leaves where it can finally be cleared. 18 | 19 | - The `copy`/`snapshot` functionality is in itself trivial, as it only signals that the original tree is now read-only and that further modifications have to perform copy-on-write on the visited nodes. However, a concurrent `add` or `remove` is now in a difficult situation as it is unclear if their effect happened before or after the copy... and choosing wrong breaks linearizability. The solution is to first signal their intent, then check that no other threads have indicated a causality violation before actually committing their effect. Time traveling is fine if no one can catch us! 20 | 21 | - [QuickCheck State-Machine Testing](https://github.com/jmid/multicoretests) was instrumental in discovering issues and validating the final design. 22 | 23 | - Given enough cores, performances are reasonable! In presence of contention, a concurrent operation will help other threads finish their work in order to avoid having to wait for them. There should be little busy looping to make progress: The worst case is a concurrent `copy` that requires the operation to restart from scratch on the new copy-on-write root. The expected cases are contention by inserting new elements at the exact same leaf (which only requires a local retry at that spot), or when encountering "dead" nodes (where it doesn't take long to walk back up the tree to discover the freshly rebalanced path.) 24 | 25 | ![Add 1m elements](https://art-w.github.io/mcavl/test_add.png) ![Remove 1m elements](https://art-w.github.io/mcavl/test_remove.png) 26 | -------------------------------------------------------------------------------- /src/pure.ml: -------------------------------------------------------------------------------- 1 | module Make (E : S.Ordered_poly) = struct 2 | type 'a elt = 'a E.t 3 | 4 | type 'a t = 'a r Atomic.t 5 | 6 | and 'a r = 'a s Atomic.t 7 | 8 | and 'a s = 9 | | Leaf of 'a state 10 | | Node of 'a state * int * 'a r * 'a elt * 'a r 11 | | Copy of 'a r 12 | 13 | and 'a state = 14 | | Alive 15 | | Dead 16 | | Read_only 17 | | Removing 18 | | Balancing_left of 'a s * 'a s 19 | | Balancing_left_center of 'a s * 'a s * 'a s * 'a s 20 | | Balancing_right of 'a s * 'a s 21 | | Balancing_right_center of 'a s * 'a s * 'a s * 'a s 22 | | Attempt_add of attempt * 'a t * 'a r 23 | | Attempt_remove of attempt * 'a t * 'a r 24 | | Attempt_replace of attempt * 'a t * 'a r * 'a elt 25 | 26 | and attempt = attempt_state Atomic.t 27 | 28 | and attempt_state = Unknown | Success | Failure 29 | 30 | module Pure = struct 31 | let empty ~s = Atomic.make (Leaf s) 32 | 33 | let singleton ~s elt = 34 | let leaf = Leaf s in 35 | Atomic.make (Node (s, 1, Atomic.make leaf, elt, Atomic.make leaf)) 36 | 37 | let rec height ~get t = s_height ~get (get t) 38 | 39 | and s_height ~get = function 40 | | Leaf _ -> 0 41 | | Node (_, h, _, _, _) -> h 42 | | Copy t -> height ~get t 43 | 44 | let create ~s ~h left pivot right = 45 | Atomic.make (Node (s, h, left, pivot, right)) 46 | 47 | let balance ~s ~get left pivot right = 48 | let s_left = get left in 49 | let s_right = get right in 50 | let hl = s_height ~get s_left in 51 | let hr = s_height ~get s_right in 52 | if hl > hr + 2 53 | then begin 54 | match s_left with 55 | | Node (_, _, left_left, left_pivot, left_right) -> 56 | let hll = height ~get left_left in 57 | let s_left_right = get left_right in 58 | let hlr = s_height ~get s_left_right in 59 | if hll >= hlr 60 | then 61 | let h = 1 + max hlr hr in 62 | create ~s 63 | ~h:(1 + max hll h) 64 | left_left left_pivot 65 | (create ~s ~h left_right pivot right) 66 | else begin 67 | match s_left_right with 68 | | Node (_, _, center_left, center_pivot, center_right) -> 69 | let hlc = 1 + max hll (height ~get center_left) in 70 | let hcr = 1 + max (height ~get center_right) hr in 71 | let h = 1 + max hlc hcr in 72 | create ~s ~h 73 | (create ~s ~h:hlc left_left left_pivot center_left) 74 | center_pivot 75 | (create ~s ~h:hcr center_right pivot right) 76 | | _ -> assert false 77 | end 78 | | _ -> assert false 79 | end 80 | else if hr > hl + 2 81 | then begin 82 | match s_right with 83 | | Node (_, _, right_left, right_pivot, right_right) -> 84 | let hrr = height ~get right_right in 85 | let s_right_left = get right_left in 86 | let hrl = s_height ~get s_right_left in 87 | if hrr >= hrl 88 | then 89 | let h = 1 + max hl hrl in 90 | create ~s 91 | ~h:(1 + max h hrr) 92 | (create ~s ~h left pivot right_left) 93 | right_pivot right_right 94 | else begin 95 | match get right_left with 96 | | Node (_, _, center_left, center_pivot, center_right) -> 97 | let hlc = 1 + max hl (height ~get center_left) in 98 | let hcr = 1 + max (height ~get center_right) hrr in 99 | let h = 1 + max hlc hcr in 100 | create ~s ~h 101 | (create ~s ~h:hlc left pivot center_left) 102 | center_pivot 103 | (create ~s ~h:hcr center_right right_pivot right_right) 104 | | _ -> assert false 105 | end 106 | | _ -> assert false 107 | end 108 | else 109 | let height = 1 + max hl hr in 110 | Atomic.make (Node (s, height, left, pivot, right)) 111 | 112 | let rec add ~s ~get x t = 113 | match get t with 114 | | Leaf _ -> singleton ~s x 115 | | Node (_, _, left, pivot, right) -> begin 116 | match E.compare x pivot with 117 | | 0 -> t 118 | | c when c < 0 -> 119 | let left' = add ~s ~get x left in 120 | if left' == left then t else balance ~s ~get left' pivot right 121 | | _ -> 122 | let right' = add ~s ~get x right in 123 | if right' == right then t else balance ~s ~get left pivot right' 124 | end 125 | | _ -> assert false 126 | 127 | let rec add_or_replace ~s ~get x t = 128 | match get t with 129 | | Leaf _ -> singleton ~s x 130 | | Node (_, height, left, pivot, right) -> begin 131 | match E.compare x pivot with 132 | | 0 when x == pivot -> t 133 | | 0 -> Atomic.make (Node (s, height, left, x, right)) 134 | | c when c < 0 -> 135 | let left' = add_or_replace ~s ~get x left in 136 | if left' == left then t else balance ~s ~get left' pivot right 137 | | _ -> 138 | let right' = add_or_replace ~s ~get x right in 139 | if right' == right then t else balance ~s ~get left pivot right' 140 | end 141 | | _ -> assert false 142 | 143 | let of_list ~s lst = 144 | List.fold_left 145 | (fun t x -> add_or_replace ~s ~get:Atomic.get x t) 146 | (empty ~s) lst 147 | 148 | let of_seq ~s seq = 149 | Seq.fold_left 150 | (fun t x -> add_or_replace ~s ~get:Atomic.get x t) 151 | (empty ~s) seq 152 | end 153 | end 154 | -------------------------------------------------------------------------------- /tests/test_seq.ml: -------------------------------------------------------------------------------- 1 | module S = Mcavl.Set (Int) 2 | 3 | let shuffle a = 4 | let n = Array.length a in 5 | let a = Array.copy a in 6 | for i = n - 1 downto 1 do 7 | let k = Random.int (i + 1) in 8 | let x = a.(k) in 9 | a.(k) <- a.(i) ; 10 | a.(i) <- x 11 | done 12 | 13 | let shuffle_list lst = 14 | let a = Array.of_list lst in 15 | shuffle a ; Array.to_list a 16 | 17 | let test_empty () = 18 | let t = S.empty () in 19 | Alcotest.(check int) "cardinal" 0 (S.cardinal t) 20 | 21 | let test_singleton () = 22 | let t = S.singleton 42 in 23 | Alcotest.(check int) "singleton" 1 (S.cardinal t) ; 24 | Alcotest.(check bool) "member" true (S.mem 42 t) ; 25 | () 26 | 27 | let test_shuffle () = 28 | let lst = List.init 1000 (fun i -> 2 * i) in 29 | let t = S.of_list @@ shuffle_list lst in 30 | Alcotest.(check int) "cardinal" 1000 (S.cardinal t) ; 31 | Alcotest.(check bool) "is_empty" false (S.is_empty t) ; 32 | let lst' = S.to_list t in 33 | Alcotest.(check (list int)) "iso" lst lst' ; 34 | Alcotest.(check bool) "mem" true (List.for_all (fun i -> S.mem i t) lst) ; 35 | Alcotest.(check bool) 36 | "not mem" false 37 | (List.exists (fun i -> S.mem i t) @@ List.map (fun i -> i + 1) lst) ; 38 | let elt = S.choose t in 39 | Alcotest.(check bool) "choose mem" true (S.mem elt t) ; 40 | let elt = S.min_elt t in 41 | Alcotest.(check bool) "min_elt mem" true (S.mem elt t) ; 42 | Alcotest.(check bool) 43 | "min_elt smallest" true 44 | (S.for_all (fun e -> elt <= e) t) ; 45 | let elt = S.max_elt t in 46 | Alcotest.(check bool) "max_elt mem" true (S.mem elt t) ; 47 | Alcotest.(check bool) "max_elt largest" true (S.for_all (fun e -> elt >= e) t) ; 48 | Alcotest.(check int) "find" 42 (S.find 42 t) ; 49 | Alcotest.(check int) "find_first" 68 (S.find_first (fun x -> x > 66) t) ; 50 | Alcotest.(check int) "find_first bound" 0 (S.find_first (fun x -> x > -10) t) ; 51 | Alcotest.(check (option int)) 52 | "find_first_opt missing" None 53 | (S.find_first_opt (fun x -> x > 99999) t) ; 54 | Alcotest.(check int) "find_last" 522 (S.find_last (fun x -> x < 523) t) ; 55 | Alcotest.(check int) 56 | "find_last bound" (S.max_elt t) 57 | (S.find_last (fun x -> x < 9999) t) ; 58 | Alcotest.(check (option int)) 59 | "find_last_opt missing" None 60 | (S.find_last_opt (fun x -> x < 0) t) ; 61 | List.iter 62 | (fun i -> 63 | Alcotest.(check bool) "mem before" true (S.mem i t) ; 64 | Alcotest.(check bool) "remove" true (S.remove i t) ; 65 | Alcotest.(check bool) "mem after" false (S.mem i t) ) 66 | lst ; 67 | Alcotest.(check int) "cardinal after" 0 (S.cardinal t) 68 | 69 | let test_iter () = 70 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 71 | let lst = ref [] in 72 | S.iter (fun i -> lst := i :: !lst) t ; 73 | Alcotest.(check (list int)) 74 | "order" 75 | (List.init 1000 (fun i -> i)) 76 | (List.rev !lst) 77 | 78 | let test_fold () = 79 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 80 | let lst = S.fold (fun i lst -> i :: lst) t [] in 81 | Alcotest.(check (list int)) 82 | "order" 83 | (List.init 1000 (fun i -> i)) 84 | (List.rev lst) 85 | 86 | let test_seq () = 87 | let t = S.of_seq @@ List.to_seq @@ List.init 1000 (fun i -> i) in 88 | Alcotest.(check int) "cardinal" 1000 (S.cardinal t) ; 89 | let seq = S.to_seq t in 90 | let lst = Seq.fold_left (fun lst i -> i :: lst) [] seq in 91 | Alcotest.(check (list int)) 92 | "to_seq" 93 | (List.init 1000 (fun i -> i)) 94 | (List.rev lst) ; 95 | let seq = S.to_rev_seq t in 96 | let lst = Seq.fold_left (fun lst i -> i :: lst) [] seq in 97 | Alcotest.(check (list int)) "to_rev_seq" (List.init 1000 (fun i -> i)) lst ; 98 | let seq = S.to_seq_from 42 t in 99 | let lst = Seq.fold_left (fun lst i -> i :: lst) [] seq in 100 | Alcotest.(check (list int)) 101 | "order" 102 | (List.filter (fun i -> i >= 42) (List.init 1000 (fun i -> i))) 103 | (List.rev lst) 104 | 105 | let test_exists () = 106 | Alcotest.(check bool) "is_empty" true (S.is_empty (S.empty ())) ; 107 | Alcotest.(check bool) 108 | "not exists empty" false 109 | (S.exists (fun _ -> assert false) (S.empty ())) ; 110 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 111 | let lst = ref [] in 112 | Alcotest.(check bool) 113 | "exists" true 114 | (S.exists 115 | (fun i -> 116 | lst := i :: !lst ; 117 | i = 789 ) 118 | t ) ; 119 | Alcotest.(check bool) "exists shortcut" true (List.length !lst < 999) ; 120 | let lst = ref [] in 121 | Alcotest.(check bool) 122 | "not exists" false 123 | (S.exists 124 | (fun i -> 125 | lst := i :: !lst ; 126 | i = -1 ) 127 | t ) ; 128 | Alcotest.(check (list int)) 129 | "not exists no shortcut" 130 | (List.init 1000 (fun i -> i)) 131 | (List.sort Int.compare !lst) 132 | 133 | let test_for_all () = 134 | Alcotest.(check bool) 135 | "for_all empty" true 136 | (S.for_all (fun _ -> assert false) (S.empty ())) ; 137 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 138 | let lst = ref [] in 139 | Alcotest.(check bool) 140 | "for_all" true 141 | (S.for_all 142 | (fun i -> 143 | lst := i :: !lst ; 144 | i >= 0 ) 145 | t ) ; 146 | Alcotest.(check (list int)) 147 | "for_all no shortcut" 148 | (List.init 1000 (fun i -> i)) 149 | (List.sort Int.compare !lst) ; 150 | let lst = ref [] in 151 | Alcotest.(check bool) 152 | "not for_all" false 153 | (S.for_all 154 | (fun i -> 155 | lst := i :: !lst ; 156 | i = 0 ) 157 | t ) ; 158 | Alcotest.(check int) "for_all shortcut" 1 (List.length !lst) 159 | 160 | let test_view () = 161 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 162 | let v = S.snapshot t in 163 | for i = 500 to 1999 do 164 | S.add i t 165 | done ; 166 | Alcotest.(check int) "cardinal" 1000 (S.View.cardinal v) ; 167 | Alcotest.(check int) "cardinal" 2000 (S.cardinal t) ; 168 | let t = S.of_view v in 169 | for i = -500 to 499 do 170 | S.add i t 171 | done ; 172 | for i = -500 to 499 do 173 | Alcotest.(check bool) "remove" true (S.remove i t) 174 | done ; 175 | Alcotest.(check int) "cardinal" 1000 (S.View.cardinal v) ; 176 | Alcotest.(check int) "cardinal" 500 (S.cardinal t) 177 | 178 | let tests = 179 | let open Alcotest in 180 | [ test_case "empty" `Quick test_empty 181 | ; test_case "singleton" `Quick test_singleton 182 | ; test_case "iter" `Quick test_iter 183 | ; test_case "fold" `Quick test_fold 184 | ; test_case "seq" `Quick test_seq 185 | ; test_case "exists" `Quick test_exists 186 | ; test_case "for_all" `Quick test_for_all 187 | ; test_case "shuffle" `Quick test_shuffle 188 | ; test_case "view" `Quick test_view ] 189 | -------------------------------------------------------------------------------- /src/mcmap.ml: -------------------------------------------------------------------------------- 1 | module Make_poly (Ord : S.Ordered_poly) = struct 2 | type 'a key = 'a Ord.t 3 | 4 | type 'a kv = K of 'a key | Kv of 'a key * 'a 5 | 6 | module Kv = struct 7 | type 'a t = 'a kv 8 | 9 | let make (k, v) = Kv (k, v) 10 | 11 | let key = function 12 | | K k -> k 13 | | Kv (k, _) -> k 14 | 15 | let value = function 16 | | K _ -> invalid_arg "Kv.value missing" 17 | | Kv (_, v) -> v 18 | 19 | let value_opt = function 20 | | Some (Kv (_, v)) -> Some v 21 | | None -> None 22 | | Some (K _) -> invalid_arg "Kv.value_opt missing" 23 | 24 | let binding = function 25 | | K _ -> invalid_arg "Kv.binding missing" 26 | | Kv (k, v) -> k, v 27 | 28 | let binding_opt = function 29 | | Some (Kv (k, v)) -> Some (k, v) 30 | | None -> None 31 | | Some (K _) -> invalid_arg "Kv.value_opt missing" 32 | 33 | let apply f kv = 34 | let k, v = binding kv in 35 | f k v 36 | 37 | let compare a b = Ord.compare (key a) (key b) 38 | end 39 | 40 | module S = Mcset.Make_poly (Kv) 41 | 42 | type 'a t = 'a S.t 43 | 44 | let empty = S.empty 45 | 46 | let singleton k v = S.singleton (Kv (k, v)) 47 | 48 | let add k v t = S.add_or_replace (Kv (k, v)) t 49 | 50 | let remove k t = S.remove (K k) t 51 | 52 | let cardinal = S.cardinal 53 | 54 | let is_empty = S.is_empty 55 | 56 | let mem k t = S.mem (K k) t 57 | 58 | let min_binding t = Kv.binding (S.min_elt t) 59 | 60 | let min_binding_opt t = Kv.binding_opt (S.min_elt_opt t) 61 | 62 | let max_binding t = Kv.binding (S.max_elt t) 63 | 64 | let max_binding_opt t = Kv.binding_opt (S.max_elt_opt t) 65 | 66 | let find k t = Kv.value (S.find (K k) t) 67 | 68 | let find_opt k t = Kv.value_opt (S.find_opt (K k) t) 69 | 70 | let find_first f t = Kv.binding (S.find_first (fun kv -> f (Kv.key kv)) t) 71 | 72 | let find_first_opt f t = 73 | Kv.binding_opt (S.find_first_opt (fun kv -> f (Kv.key kv)) t) 74 | 75 | let find_last f t = Kv.binding (S.find_last (fun kv -> f (Kv.key kv)) t) 76 | 77 | let find_last_opt f t = 78 | Kv.binding_opt (S.find_last_opt (fun kv -> f (Kv.key kv)) t) 79 | 80 | let choose t = Kv.binding (S.choose t) 81 | 82 | let choose_opt t = Kv.binding_opt (S.choose_opt t) 83 | 84 | let fold f t init = S.fold (fun kv acc -> Kv.apply f kv acc) t init 85 | 86 | let iter f t = S.iter (fun kv -> Kv.apply f kv) t 87 | 88 | let for_all f t = S.for_all (fun kv -> Kv.apply f kv) t 89 | 90 | let exists f t = S.exists (fun kv -> Kv.apply f kv) t 91 | 92 | let bindings t = List.map Kv.binding (S.elements t) 93 | 94 | let to_list t = List.map Kv.binding (S.to_list t) 95 | 96 | let of_list lst = S.of_list (List.map Kv.make lst) 97 | 98 | let to_seq t = Seq.map Kv.binding (S.to_seq t) 99 | 100 | let of_seq seq = S.of_seq (Seq.map Kv.make seq) 101 | 102 | let to_rev_seq t = Seq.map Kv.binding (S.to_rev_seq t) 103 | 104 | let to_seq_from k t = Seq.map Kv.binding (S.to_seq_from (K k) t) 105 | 106 | module View = struct 107 | module S = S.View 108 | 109 | type 'a key = 'a Ord.t 110 | 111 | type 'a t = 'a S.t 112 | 113 | let empty = S.empty 114 | 115 | let singleton k v = S.singleton (Kv (k, v)) 116 | 117 | let add k v t = S.add_or_replace (Kv (k, v)) t 118 | 119 | let remove k t = S.remove (K k) t 120 | 121 | let cardinal = S.cardinal 122 | 123 | let is_empty = S.is_empty 124 | 125 | let mem k t = S.mem (K k) t 126 | 127 | let min_binding t = Kv.binding (S.min_elt t) 128 | 129 | let min_binding_opt t = Kv.binding_opt (S.min_elt_opt t) 130 | 131 | let max_binding t = Kv.binding (S.max_elt t) 132 | 133 | let max_binding_opt t = Kv.binding_opt (S.max_elt_opt t) 134 | 135 | let find k t = Kv.value (S.find (K k) t) 136 | 137 | let find_opt k t = Kv.value_opt (S.find_opt (K k) t) 138 | 139 | let find_first f t = Kv.binding (S.find_first (fun kv -> f (Kv.key kv)) t) 140 | 141 | let find_first_opt f t = 142 | Kv.binding_opt (S.find_first_opt (fun kv -> f (Kv.key kv)) t) 143 | 144 | let find_last f t = Kv.binding (S.find_last (fun kv -> f (Kv.key kv)) t) 145 | 146 | let find_last_opt f t = 147 | Kv.binding_opt (S.find_last_opt (fun kv -> f (Kv.key kv)) t) 148 | 149 | let choose t = Kv.binding (S.choose t) 150 | 151 | let choose_opt t = Kv.binding_opt (S.choose_opt t) 152 | 153 | let fold f t init = S.fold (fun kv acc -> Kv.apply f kv acc) t init 154 | 155 | let iter f t = S.iter (fun kv -> Kv.apply f kv) t 156 | 157 | let for_all f t = S.for_all (fun kv -> Kv.apply f kv) t 158 | 159 | let exists f t = S.exists (fun kv -> Kv.apply f kv) t 160 | 161 | let bindings t = List.map Kv.binding (S.elements t) 162 | 163 | let to_list t = List.map Kv.binding (S.to_list t) 164 | 165 | let of_list lst = S.of_list (List.map Kv.make lst) 166 | 167 | let to_seq t = Seq.map Kv.binding (S.to_seq t) 168 | 169 | let of_seq seq = S.of_seq (Seq.map Kv.make seq) 170 | 171 | let to_rev_seq t = Seq.map Kv.binding (S.to_rev_seq t) 172 | 173 | let to_seq_from k t = Seq.map Kv.binding (S.to_seq_from (K k) t) 174 | 175 | let union = S.union 176 | 177 | let inter = S.inter 178 | 179 | let diff = S.diff 180 | 181 | let map f t = 182 | S.map 183 | (fun kv -> 184 | let k, v = Kv.binding kv in 185 | let v' = f v in 186 | if v == v' then kv else Kv (k, v') ) 187 | t 188 | 189 | let mapi f t = 190 | S.map 191 | (fun kv -> 192 | let k, v = Kv.binding kv in 193 | let v' = f k v in 194 | if v == v' then kv else Kv (k, v') ) 195 | t 196 | 197 | let filter f t = S.filter (fun kv -> Kv.apply f kv) t 198 | 199 | let filter_map f t = 200 | S.filter_map 201 | (fun kv -> 202 | let k, v = Kv.binding kv in 203 | match f k v with 204 | | None -> None 205 | | Some v -> Some (Kv (k, v)) ) 206 | t 207 | 208 | let partition f t = S.partition (fun kv -> Kv.apply f kv) t 209 | 210 | let split k t = 211 | let k = K k in 212 | let smaller, found, larger = S.split k t in 213 | let found = if found then Some (Kv.value (S.find k t)) else None in 214 | smaller, found, larger 215 | 216 | let pop_min t = 217 | let kv, t = S.pop_min t in 218 | let k, v = Kv.binding kv in 219 | k, v, t 220 | 221 | let pop_min_opt t = 222 | match S.pop_min_opt t with 223 | | None -> None 224 | | Some (kv, t) -> 225 | let k, v = Kv.binding kv in 226 | Some (k, v, t) 227 | 228 | let pop_max t = 229 | let kv, t = S.pop_max t in 230 | let k, v = Kv.binding kv in 231 | k, v, t 232 | 233 | let pop_max_opt t = 234 | match S.pop_max_opt t with 235 | | None -> None 236 | | Some (kv, t) -> 237 | let k, v = Kv.binding kv in 238 | Some (k, v, t) 239 | end 240 | 241 | let copy = S.copy 242 | 243 | let snapshot = S.snapshot 244 | 245 | let to_view = S.to_view 246 | 247 | let of_view = S.of_view 248 | end 249 | 250 | module Make (Ord : S.Ordered) = struct 251 | module Ord_poly = struct 252 | type 'a t = Ord.t 253 | 254 | let compare = Ord.compare 255 | end 256 | 257 | module I = Make_poly (Ord_poly) 258 | 259 | module View = struct 260 | type key = Ord.t 261 | 262 | type 'a t = 'a I.View.t 263 | 264 | include ( 265 | I.View : 266 | S.View_map_poly(Ord_poly).S with type _ key := key and type 'a t := 'a t ) 267 | end 268 | 269 | type key = Ord.t 270 | 271 | include (I : S.Map(Ord).S with type key := key and module View := View) 272 | end 273 | -------------------------------------------------------------------------------- /tests/test_view.ml: -------------------------------------------------------------------------------- 1 | module MS = Mcavl.Set (Int) 2 | module S = MS.View 3 | 4 | let shuffle a = 5 | let n = Array.length a in 6 | let a = Array.copy a in 7 | for i = n - 1 downto 1 do 8 | let k = Random.int (i + 1) in 9 | let x = a.(k) in 10 | a.(k) <- a.(i) ; 11 | a.(i) <- x 12 | done 13 | 14 | let shuffle_list lst = 15 | let a = Array.of_list lst in 16 | shuffle a ; Array.to_list a 17 | 18 | let test_empty () = 19 | let t = S.empty in 20 | Alcotest.(check int) "cardinal" 0 (S.cardinal t) 21 | 22 | let test_singleton () = 23 | let t = S.singleton 42 in 24 | Alcotest.(check int) "singleton" 1 (S.cardinal t) ; 25 | Alcotest.(check bool) "member" true (S.mem 42 t) ; 26 | () 27 | 28 | let test_shuffle () = 29 | let lst = List.init 1000 (fun i -> 2 * i) in 30 | let t = S.of_list @@ shuffle_list lst in 31 | Alcotest.(check int) "cardinal" 1000 (S.cardinal t) ; 32 | Alcotest.(check bool) "is_empty" false (S.is_empty t) ; 33 | let lst' = S.to_list t in 34 | Alcotest.(check (list int)) "iso" lst lst' ; 35 | let elt = S.choose t in 36 | Alcotest.(check bool) "choose mem" true (S.mem elt t) ; 37 | let elt = S.min_elt t in 38 | Alcotest.(check bool) "min_elt mem" true (S.mem elt t) ; 39 | Alcotest.(check bool) 40 | "min_elt smallest" true 41 | (S.for_all (fun e -> elt <= e) t) ; 42 | let elt = S.max_elt t in 43 | Alcotest.(check bool) "max_elt mem" true (S.mem elt t) ; 44 | Alcotest.(check bool) "max_elt largest" true (S.for_all (fun e -> elt >= e) t) ; 45 | Alcotest.(check int) "find" 42 (S.find 42 t) ; 46 | Alcotest.(check int) "find_first" 68 (S.find_first (fun x -> x > 66) t) ; 47 | Alcotest.(check (option int)) 48 | "find_first_opt missing" None 49 | (S.find_first_opt (fun x -> x > 99999) t) ; 50 | Alcotest.(check int) "find_last" 522 (S.find_last (fun x -> x < 523) t) ; 51 | Alcotest.(check int) 52 | "find_last bound" (S.max_elt t) 53 | (S.find_last (fun x -> x < 9999) t) ; 54 | Alcotest.(check (option int)) 55 | "find_last_opt missing" None 56 | (S.find_last_opt (fun x -> x < 0) t) ; 57 | Alcotest.(check bool) "mem" true (List.for_all (fun i -> S.mem i t) lst) ; 58 | Alcotest.(check bool) 59 | "not mem" false 60 | (List.exists (fun i -> S.mem i t) @@ List.map (fun i -> i + 1) lst) ; 61 | let t' = 62 | List.fold_left 63 | (fun t i -> 64 | let x, t' = S.pop_min t in 65 | Alcotest.(check int) "pop_min" i x ; 66 | let card = S.cardinal t - 1 in 67 | Alcotest.(check int) "pop_min cardinal" card (S.cardinal t') ; 68 | let t'' = S.remove i t in 69 | Alcotest.(check bool) "remove not equal" false (t == t'') ; 70 | Alcotest.(check int) "remove cardinal" card (S.cardinal t'') ; 71 | t'' ) 72 | t lst 73 | in 74 | Alcotest.(check bool) "is_empty" true (S.is_empty t') ; 75 | Alcotest.(check bool) "not mem 43" false (S.mem 43 t) ; 76 | let t' = S.remove 43 t in 77 | Alcotest.(check bool) "remove physeq" true (t' == t) ; 78 | Alcotest.(check bool) "mem 78" true (S.mem 78 t) ; 79 | let t' = S.add 78 t in 80 | Alcotest.(check bool) "add physeq" true (t' == t) ; 81 | let lst', t' = 82 | List.fold_left 83 | (fun (acc, t) _ -> 84 | let x, t = S.pop_max t in 85 | x :: acc, t ) 86 | ([], t) lst 87 | in 88 | Alcotest.(check bool) "pop_max empty" true (S.is_empty t') ; 89 | Alcotest.(check (list int)) "pop_max sorted" lst lst' 90 | 91 | let test_union () = 92 | let t1 = S.of_list @@ List.init 1000 (fun i -> 3 * i) in 93 | let t2 = S.of_list @@ List.init 1000 (fun i -> (3 * i) + 1) in 94 | let t3 = S.of_list @@ List.init 1000 (fun i -> 100 + i) in 95 | let t12 = S.union t1 t2 in 96 | Alcotest.(check bool) 97 | "union t1 t2: subset t1" true 98 | (S.for_all (fun e -> S.mem e t12) t1) ; 99 | Alcotest.(check bool) 100 | "union t1 t2: subset t2" true 101 | (S.for_all (fun e -> S.mem e t12) t2) ; 102 | let t123 = S.union t12 t3 in 103 | Alcotest.(check bool) 104 | "union t12 t3: subset t12" true 105 | (S.for_all (fun e -> S.mem e t123) t12) ; 106 | Alcotest.(check bool) 107 | "union t12 t3: subset t3" true 108 | (S.for_all (fun e -> S.mem e t123) t3) 109 | 110 | let test_inter () = 111 | let t1 = S.of_list @@ List.init 1000 (fun i -> 3 * i) in 112 | let t2 = S.of_list @@ List.init 1000 (fun i -> (3 * i) + 1) in 113 | let t3 = S.of_list @@ List.init 1000 (fun i -> 100 + i) in 114 | let t12 = S.inter t1 t2 in 115 | Alcotest.(check bool) "inter empty" true (S.is_empty t12) ; 116 | let t13 = S.inter t1 t3 in 117 | Alcotest.(check bool) 118 | "inter t1 t3: t1" true 119 | (S.for_all (fun e -> S.mem e t13 = S.mem e t3) t1) ; 120 | Alcotest.(check bool) 121 | "inter t1 t3: t3" true 122 | (S.for_all (fun e -> S.mem e t13 = S.mem e t1) t3) 123 | 124 | let test_diff () = 125 | let t1 = S.of_list @@ List.init 1000 (fun i -> 3 * i) in 126 | let t3 = S.of_list @@ List.init 1000 (fun i -> 100 + i) in 127 | let t13 = S.diff t1 t3 in 128 | Alcotest.(check bool) 129 | "diff t1 t3: t1" true 130 | (S.for_all (fun e -> S.mem e t13 = not (S.mem e t3)) t1) ; 131 | Alcotest.(check bool) 132 | "diff t1 t3: t3" true 133 | (S.for_all (fun e -> not (S.mem e t13)) t3) 134 | 135 | let test_split () = 136 | let t = S.of_list @@ List.init 1000 (fun i -> 2 * i) in 137 | let smaller, found, larger = S.split 789 t in 138 | Alcotest.(check bool) "split: not found" false found ; 139 | Alcotest.(check int) "split: cardinal smaller" 395 (S.cardinal smaller) ; 140 | Alcotest.(check bool) 141 | "split: is smaller" true 142 | (S.for_all (fun e -> e < 789) smaller) ; 143 | Alcotest.(check int) "split: cardinal larger" 605 (S.cardinal larger) ; 144 | Alcotest.(check bool) 145 | "split: is larger" true 146 | (S.for_all (fun e -> e > 789) larger) ; 147 | Alcotest.(check bool) 148 | "split: found" true 149 | (S.for_all (fun e -> S.mem e (if e < 789 then smaller else larger)) t) ; 150 | let smaller, found, larger = S.split 100 t in 151 | Alcotest.(check bool) "split: found" true found ; 152 | Alcotest.(check bool) "split: not smaller" false (S.mem 100 smaller) ; 153 | Alcotest.(check bool) "split: not larger" false (S.mem 100 larger) 154 | 155 | let test_map () = 156 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 157 | let t' = S.map (fun x -> x) t in 158 | Alcotest.(check bool) "map physeq" true (t == t') ; 159 | let t' = S.map (fun x -> if x > 500 then -x else x) t in 160 | Alcotest.(check bool) 161 | "map partial physeq" true 162 | (S.for_all (fun x -> S.mem (if x > 500 then -x else x) t') t) ; 163 | let t' = S.map (fun x -> -x) t in 164 | Alcotest.(check bool) 165 | "map worst-case" true 166 | (S.for_all (fun x -> S.mem (-x) t') t) ; 167 | let t' = S.map (fun x -> x + 1) t in 168 | Alcotest.(check bool) 169 | "map in-order" true 170 | (S.for_all (fun x -> S.mem (x + 1) t') t) 171 | 172 | let test_filter () = 173 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 174 | let t' = S.filter (fun _ -> true) t in 175 | Alcotest.(check bool) "filter physeq" true (t == t') ; 176 | let t' = S.filter (fun _ -> false) t in 177 | Alcotest.(check bool) "filter all" true (S.is_empty t') ; 178 | let t' = S.filter (fun x -> x > 123) t in 179 | Alcotest.(check bool) 180 | "filter some half" true 181 | (S.for_all (fun x -> x > 123 = S.mem x t') t) ; 182 | let t' = S.filter (fun x -> x mod 2 = 0) t in 183 | Alcotest.(check bool) 184 | "filter some even" true 185 | (S.for_all (fun x -> x mod 2 = 0 = S.mem x t') t) 186 | 187 | let test_partition () = 188 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 189 | let t', tn = S.partition (fun _ -> true) t in 190 | Alcotest.(check bool) "partition physeq" true (t == t') ; 191 | Alcotest.(check bool) "partition is_empty" true (S.is_empty tn) ; 192 | let t', tn = S.partition (fun _ -> false) t in 193 | Alcotest.(check bool) "partition all is_empty" true (S.is_empty t') ; 194 | Alcotest.(check bool) "partition all physeq" true (t == tn) ; 195 | let t', tn = S.partition (fun x -> x > 123) t in 196 | Alcotest.(check bool) 197 | "partition some half" true 198 | (S.for_all (fun x -> x > 123 = S.mem x t' && x <= 123 = S.mem x tn) t) ; 199 | let t', tn = S.partition (fun x -> x mod 2 = 0) t in 200 | Alcotest.(check bool) 201 | "partition some even" true 202 | (S.for_all 203 | (fun x -> x mod 2 = 0 = S.mem x t' && x mod 2 <> 0 = S.mem x tn) 204 | t ) 205 | 206 | let test_comparisons () = 207 | let t1 = S.of_list @@ List.init 1000 (fun i -> 3 * i) in 208 | let t2 = S.of_list @@ List.init 1000 (fun i -> (3 * i) + 1) in 209 | Alcotest.(check bool) "not equal" false (S.equal t1 t2) ; 210 | Alcotest.(check bool) "smaller" true (S.compare t1 t2 < 0) ; 211 | Alcotest.(check bool) "larger" true (S.compare t2 t1 > 0) ; 212 | Alcotest.(check bool) "disjoint" true (S.disjoint t1 t2) ; 213 | let t3 = S.of_list @@ List.init 1000 (fun i -> i) in 214 | Alcotest.(check bool) "not equal" false (S.equal t1 t3) ; 215 | Alcotest.(check bool) "larger" true (S.compare t1 t3 > 0) ; 216 | Alcotest.(check bool) "smaller" true (S.compare t3 t1 < 0) ; 217 | Alcotest.(check bool) "not disjoint" false (S.disjoint t1 t3) ; 218 | let t4 = S.diff (S.union t1 t2) t2 in 219 | Alcotest.(check bool) "different structure" true (t1 <> t4) ; 220 | Alcotest.(check bool) "equal" true (S.equal t1 t4) ; 221 | Alcotest.(check bool) "not disjoint" false (S.disjoint t1 t4) 222 | 223 | let test_filter_map () = 224 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 225 | let t' = S.filter_map (fun x -> Some x) t in 226 | Alcotest.(check bool) "filter_map physeq" true (t == t') ; 227 | let t' = S.filter_map (fun _ -> None) t in 228 | Alcotest.(check bool) "filter_map all" true (S.is_empty t') ; 229 | let t' = S.filter_map (fun x -> if x > 123 then None else Some x) t in 230 | Alcotest.(check bool) 231 | "filter_map some half" true 232 | (S.for_all (fun x -> x <= 123 = S.mem x t') t) ; 233 | let t' = S.filter_map (fun x -> if x mod 2 = 0 then Some (-x) else None) t in 234 | Alcotest.(check bool) 235 | "filter_map some even" true 236 | (S.for_all (fun x -> x mod 2 = 0 = S.mem (-x) t') t) 237 | 238 | let test_iter () = 239 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 240 | let lst = ref [] in 241 | S.iter (fun i -> lst := i :: !lst) t ; 242 | Alcotest.(check (list int)) 243 | "order" 244 | (List.init 1000 (fun i -> i)) 245 | (List.rev !lst) 246 | 247 | let test_fold () = 248 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 249 | let lst = S.fold (fun i lst -> i :: lst) t [] in 250 | Alcotest.(check (list int)) 251 | "order" 252 | (List.init 1000 (fun i -> i)) 253 | (List.rev lst) 254 | 255 | let test_seq () = 256 | let t = S.of_seq @@ List.to_seq @@ List.init 1000 (fun i -> i) in 257 | Alcotest.(check int) "cardinal" 1000 (S.cardinal t) ; 258 | let seq = S.to_seq t in 259 | let lst = Seq.fold_left (fun lst i -> i :: lst) [] seq in 260 | Alcotest.(check (list int)) 261 | "to_seq" 262 | (List.init 1000 (fun i -> i)) 263 | (List.rev lst) ; 264 | let seq = S.to_rev_seq t in 265 | let lst = Seq.fold_left (fun lst i -> i :: lst) [] seq in 266 | Alcotest.(check (list int)) "to_rev_seq" (List.init 1000 (fun i -> i)) lst ; 267 | let seq = S.to_seq_from 42 t in 268 | let lst = Seq.fold_left (fun lst i -> i :: lst) [] seq in 269 | Alcotest.(check (list int)) 270 | "order" 271 | (List.filter (fun i -> i >= 42) (List.init 1000 (fun i -> i))) 272 | (List.rev lst) 273 | 274 | let test_exists () = 275 | Alcotest.(check bool) "is_empty" true (S.is_empty S.empty) ; 276 | Alcotest.(check bool) 277 | "not exists empty" false 278 | (S.exists (fun _ -> assert false) S.empty) ; 279 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 280 | let lst = ref [] in 281 | Alcotest.(check bool) 282 | "exists" true 283 | (S.exists 284 | (fun i -> 285 | lst := i :: !lst ; 286 | i = 789 ) 287 | t ) ; 288 | Alcotest.(check bool) "exists shortcut" true (List.length !lst < 999) ; 289 | let lst = ref [] in 290 | Alcotest.(check bool) 291 | "not exists" false 292 | (S.exists 293 | (fun i -> 294 | lst := i :: !lst ; 295 | i = -1 ) 296 | t ) ; 297 | Alcotest.(check (list int)) 298 | "not exists no shortcut" 299 | (List.init 1000 (fun i -> i)) 300 | (List.sort Int.compare !lst) 301 | 302 | let test_for_all () = 303 | Alcotest.(check bool) 304 | "for_all empty" true 305 | (S.for_all (fun _ -> assert false) S.empty) ; 306 | let t = S.of_list @@ List.init 1000 (fun i -> i) in 307 | let lst = ref [] in 308 | Alcotest.(check bool) 309 | "for_all" true 310 | (S.for_all 311 | (fun i -> 312 | lst := i :: !lst ; 313 | i >= 0 ) 314 | t ) ; 315 | Alcotest.(check (list int)) 316 | "for_all no shortcut" 317 | (List.init 1000 (fun i -> i)) 318 | (List.sort Int.compare !lst) ; 319 | let lst = ref [] in 320 | Alcotest.(check bool) 321 | "not for_all" false 322 | (S.for_all 323 | (fun i -> 324 | lst := i :: !lst ; 325 | i = 0 ) 326 | t ) ; 327 | Alcotest.(check int) "for_all shortcut" 1 (List.length !lst) 328 | 329 | let tests = 330 | let open Alcotest in 331 | [ test_case "empty" `Quick test_empty 332 | ; test_case "singleton" `Quick test_singleton 333 | ; test_case "union" `Quick test_union 334 | ; test_case "inter" `Quick test_inter 335 | ; test_case "diff" `Quick test_diff 336 | ; test_case "split" `Quick test_split 337 | ; test_case "map" `Quick test_map 338 | ; test_case "filter" `Quick test_filter 339 | ; test_case "partition" `Quick test_partition 340 | ; test_case "filter_map" `Quick test_filter_map 341 | ; test_case "comparisons" `Quick test_comparisons 342 | ; test_case "iter" `Quick test_iter 343 | ; test_case "fold" `Quick test_fold 344 | ; test_case "seq" `Quick test_seq 345 | ; test_case "exists" `Quick test_exists 346 | ; test_case "for_all" `Quick test_for_all 347 | ; test_case "shuffle" `Quick test_shuffle ] 348 | -------------------------------------------------------------------------------- /src/core.ml: -------------------------------------------------------------------------------- 1 | module Make (E : S.Ordered_poly) = struct 2 | include Pure.Make (E) 3 | 4 | type result = Ok | Overflow | Looking_for_life | Retry 5 | 6 | let empty () = Atomic.make (Pure.empty ~s:Alive) 7 | 8 | let singleton elt = Atomic.make (Pure.singleton ~s:Alive elt) 9 | 10 | let rec height t = s_height (Atomic.get t) 11 | 12 | and s_height = function 13 | | Leaf _ -> 0 14 | | Node (_, h, _, _, _) -> h 15 | | Copy t -> height t 16 | 17 | let rec create left pivot right = Atomic.make (s_create left pivot right) 18 | 19 | and s_create left pivot right = 20 | let hl, hr = height left, height right in 21 | let h = if hl >= hr then hl + 1 else hr + 1 in 22 | Node (Alive, h, left, pivot, right) 23 | 24 | let rec fixup t = 25 | let s = Atomic.get t in 26 | match s with 27 | | Leaf _ -> Ok 28 | | Node ((Alive | Read_only | Dead), _, _, _, _) -> Ok 29 | | Node (Removing, _, _, _, _) -> 30 | let res = push_down t s in 31 | begin 32 | match balance t with 33 | | Ok -> res 34 | | other -> other 35 | end 36 | | Node ((Attempt_add _ | Attempt_remove _ | Attempt_replace _), _, _, _, _) 37 | -> 38 | finalize_op t s ; fixup t 39 | | Copy t' -> 40 | let repr = make_copy t' in 41 | let (_ : bool) = Atomic.compare_and_set t s repr in 42 | fixup t 43 | | _ -> balance t 44 | 45 | and finalize_op t s = 46 | match s with 47 | | Node (Attempt_add (attempt, root, expected), height, left, pivot, right) 48 | -> begin 49 | match Atomic.get attempt with 50 | | Unknown -> 51 | let result = 52 | if Atomic.get root == expected then Success else Failure 53 | in 54 | let (_ : bool) = Atomic.compare_and_set attempt Unknown result in 55 | finalize_op t s 56 | | Success -> 57 | let repr = Node (Alive, height, left, pivot, right) in 58 | let (_ : bool) = Atomic.compare_and_set t s repr in 59 | () 60 | | Failure -> 61 | let repr = Leaf Alive in 62 | let (_ : bool) = Atomic.compare_and_set t s repr in 63 | () 64 | end 65 | | Node (Attempt_remove (attempt, root, expected), height, left, pivot, right) 66 | -> begin 67 | match Atomic.get attempt with 68 | | Unknown -> 69 | let result = 70 | if Atomic.get root == expected then Success else Failure 71 | in 72 | let (_ : bool) = Atomic.compare_and_set attempt Unknown result in 73 | finalize_op t s 74 | | Success -> 75 | let repr = Node (Removing, height, left, pivot, right) in 76 | let (_ : bool) = Atomic.compare_and_set t s repr in 77 | () 78 | | Failure -> 79 | let repr = Node (Alive, height, left, pivot, right) in 80 | let (_ : bool) = Atomic.compare_and_set t s repr in 81 | () 82 | end 83 | | Node 84 | ( Attempt_replace (attempt, root, expected, new_pivot) 85 | , height 86 | , left 87 | , pivot 88 | , right ) -> begin 89 | match Atomic.get attempt with 90 | | Unknown -> 91 | let result = 92 | if Atomic.get root == expected then Success else Failure 93 | in 94 | let (_ : bool) = Atomic.compare_and_set attempt Unknown result in 95 | finalize_op t s 96 | | Success -> 97 | let repr = Node (Alive, height, left, new_pivot, right) in 98 | let (_ : bool) = Atomic.compare_and_set t s repr in 99 | () 100 | | Failure -> 101 | let repr = Node (Alive, height, left, pivot, right) in 102 | let (_ : bool) = Atomic.compare_and_set t s repr in 103 | () 104 | end 105 | | _ -> () 106 | 107 | and get_stable t = 108 | let s = Atomic.get t in 109 | match s with 110 | | Leaf Alive | Node (Alive, _, _, _, _) -> s 111 | | Leaf Dead | Node (Dead, _, _, _, _) -> s 112 | | Leaf Read_only | Node (Read_only, _, _, _, _) -> s 113 | | _ -> 114 | let (_ : result) = fixup t in 115 | get_stable t 116 | 117 | and ensure_read_only t = 118 | let s = Atomic.get t in 119 | match s with 120 | | Leaf Read_only -> s 121 | | Node (Read_only, _, _, _, _) -> s 122 | | Node (Dead, _, _, _, _) -> assert false 123 | | Leaf _ -> 124 | let ro = Leaf Read_only in 125 | if Atomic.compare_and_set t s ro then ro else ensure_read_only t 126 | | Node (Alive, h, left, pivot, right) -> 127 | let ro = Node (Read_only, h, left, pivot, right) in 128 | if Atomic.compare_and_set t s ro then ro else ensure_read_only t 129 | | Copy t' -> 130 | let ro = ensure_read_only t' in 131 | if Atomic.compare_and_set t s ro then ro else ensure_read_only t 132 | | _ -> 133 | let (_ : result) = fixup t in 134 | ensure_read_only t 135 | 136 | and make_copy t = 137 | let s = ensure_read_only t in 138 | match s with 139 | | Leaf Read_only -> Leaf Alive 140 | | Node (Read_only, h, left, pivot, right) -> 141 | Node (Alive, h, Atomic.make (Copy left), pivot, Atomic.make (Copy right)) 142 | | _ -> assert false 143 | 144 | and push_down t s = 145 | match s with 146 | | _ when s != Atomic.get t -> Ok 147 | | Node (Removing, _, left, pivot_to_remove, right) -> 148 | let s_left = Atomic.get left in 149 | begin 150 | match s_left with 151 | | Leaf Dead -> push_down_left_locked t s right 152 | | Leaf Alive -> 153 | let s_left_dead = Leaf Dead in 154 | if Atomic.compare_and_set left s_left s_left_dead 155 | then push_down_left_locked t s right 156 | else push_down t s 157 | | Node (Removing, _, _, _, _) -> 158 | let result = push_down left s_left in 159 | begin 160 | match push_down t s with 161 | | Ok -> result 162 | | r -> r 163 | end 164 | | Node (Dead, _, left_left, left_pivot, left_right) -> 165 | let expected_height = 166 | 1 + max (height left_right) (height right) 167 | in 168 | let s_t_new = 169 | Node 170 | (Removing, expected_height, left_right, pivot_to_remove, right) 171 | in 172 | let t_new = Atomic.make s_t_new in 173 | let expected_height = 174 | 1 + max expected_height (height left_left) 175 | in 176 | let s_new = 177 | Node (Alive, expected_height, left_left, left_pivot, t_new) 178 | in 179 | if Atomic.compare_and_set t s s_new 180 | then begin 181 | match push_down t_new s_t_new with 182 | | Ok -> Ok 183 | | Overflow -> balance t 184 | | Retry | Looking_for_life -> assert false 185 | end 186 | else Ok 187 | | Node (Alive, left_height, left_left, left_pivot, left_right) -> 188 | let s_left_dead = 189 | Node (Dead, left_height, left_left, left_pivot, left_right) 190 | in 191 | let (_ : bool) = Atomic.compare_and_set left s_left s_left_dead in 192 | push_down t s 193 | | _ -> 194 | let (_ : result) = fixup left in 195 | push_down t s 196 | end 197 | | _ -> Ok 198 | 199 | and push_down_left_locked t s right = 200 | if Atomic.get t != s 201 | then Ok 202 | else begin 203 | let s_right = Atomic.get right in 204 | match s_right with 205 | | Leaf Dead -> 206 | if Atomic.compare_and_set t s (Leaf Alive) then Overflow else Ok 207 | | Leaf Alive -> 208 | let s_right_dead = Leaf Dead in 209 | if Atomic.compare_and_set right s_right s_right_dead 210 | then if Atomic.compare_and_set t s s_right then Overflow else Ok 211 | else push_down_left_locked t s right 212 | | Node (Dead, right_height, right_left, right_pivot, right_right) -> 213 | let s_right_alive = 214 | Node (Alive, right_height, right_left, right_pivot, right_right) 215 | in 216 | if Atomic.compare_and_set t s s_right_alive then Overflow else Ok 217 | | Node (Alive, right_height, right_left, right_pivot, right_right) -> 218 | let s_right_dead = 219 | Node (Dead, right_height, right_left, right_pivot, right_right) 220 | in 221 | if Atomic.compare_and_set right s_right s_right_dead 222 | then if Atomic.compare_and_set t s s_right then Overflow else Ok 223 | else push_down_left_locked t s right 224 | | _ -> 225 | let result = fixup right in 226 | begin 227 | match push_down_left_locked t s right with 228 | | Ok -> result 229 | | other -> other 230 | end 231 | end 232 | 233 | and balance t = 234 | let s = Atomic.get t in 235 | match s with 236 | | Node (Balancing_left (s_left, s_left_dead), _, left, pivot, right) -> 237 | if Atomic.compare_and_set left s_left s_left_dead 238 | || Atomic.get left == s_left_dead 239 | then begin 240 | match s_left with 241 | | Node (Alive, _, a, b, c) -> 242 | let repr = s_create a b (create c pivot right) in 243 | if Atomic.compare_and_set t s repr then Overflow else Ok 244 | | _ -> assert false 245 | end 246 | else begin 247 | let (_ : bool) = 248 | Atomic.compare_and_set t s (s_create left pivot right) 249 | in 250 | Ok 251 | end 252 | | Node 253 | ( Balancing_left_center (s_left, s_left_dead, s_center, s_center_dead) 254 | , _ 255 | , left 256 | , pivot 257 | , right ) -> 258 | if Atomic.compare_and_set left s_left s_left_dead 259 | || Atomic.get left == s_left_dead 260 | then begin 261 | match s_left with 262 | | Node (Alive, _, a, b, c) -> 263 | if Atomic.compare_and_set c s_center s_center_dead 264 | || Atomic.get c == s_center_dead 265 | then begin 266 | match s_center with 267 | | Node (Alive, _, c, d, e) -> 268 | let repr = 269 | s_create (create a b c) d (create e pivot right) 270 | in 271 | if Atomic.compare_and_set t s repr then Overflow else Ok 272 | | _ -> assert false 273 | end 274 | else begin 275 | let s_left_new = s_create a b c in 276 | let (_ : bool) = 277 | Atomic.compare_and_set left s_left_dead s_left_new 278 | in 279 | let (_ : bool) = 280 | Atomic.compare_and_set t s (s_create left pivot right) 281 | in 282 | Ok 283 | end 284 | | _ -> assert false 285 | end 286 | else begin 287 | let (_ : bool) = 288 | Atomic.compare_and_set t s (s_create left pivot right) 289 | in 290 | Ok 291 | end 292 | | Node (Balancing_right (s_right, s_right_dead), _, left, pivot, right) -> 293 | if Atomic.compare_and_set right s_right s_right_dead 294 | || Atomic.get right == s_right_dead 295 | then begin 296 | match s_right with 297 | | Node (Alive, _, a, b, c) -> 298 | let repr = s_create (create left pivot a) b c in 299 | if Atomic.compare_and_set t s repr then Overflow else Ok 300 | | _ -> assert false 301 | end 302 | else begin 303 | let (_ : bool) = 304 | Atomic.compare_and_set t s (s_create left pivot right) 305 | in 306 | Ok 307 | end 308 | | Node 309 | ( Balancing_right_center (s_right, s_right_dead, s_center, s_center_dead) 310 | , _ 311 | , left 312 | , pivot 313 | , right ) -> 314 | if Atomic.compare_and_set right s_right s_right_dead 315 | || Atomic.get right == s_right_dead 316 | then begin 317 | match s_right with 318 | | Node (Alive, _, x, y, z) -> 319 | if Atomic.compare_and_set x s_center s_center_dead 320 | || Atomic.get x == s_center_dead 321 | then begin 322 | match s_center with 323 | | Node (Alive, _, v, w, x) -> 324 | let repr = 325 | s_create (create left pivot v) w (create x y z) 326 | in 327 | if Atomic.compare_and_set t s repr then Overflow else Ok 328 | | _ -> assert false 329 | end 330 | else begin 331 | let s_right_new = s_create x y z in 332 | let (_ : bool) = 333 | Atomic.compare_and_set right s_right_dead s_right_new 334 | in 335 | let (_ : bool) = 336 | Atomic.compare_and_set t s (s_create left pivot right) 337 | in 338 | Ok 339 | end 340 | | _ -> assert false 341 | end 342 | else begin 343 | let (_ : bool) = 344 | Atomic.compare_and_set t s (s_create left pivot right) 345 | in 346 | Ok 347 | end 348 | | Node (Alive, ht, left, pivot, right) -> 349 | let s_left = get_stable left in 350 | let s_right = get_stable right in 351 | let hl = s_height s_left in 352 | let hr = s_height s_right in 353 | if hl > hr + 2 354 | then begin 355 | match s_left with 356 | | Node (Alive, hl, left_left, left_pivot, left_right) -> 357 | let hll = height left_left in 358 | let s_left_right = get_stable left_right in 359 | if hll >= s_height s_left_right 360 | then begin 361 | let s_left_dead = 362 | Node (Dead, hl, left_left, left_pivot, left_right) 363 | in 364 | let plan = Balancing_left (s_left, s_left_dead) in 365 | if Atomic.compare_and_set t s 366 | (Node (plan, ht, left, pivot, right)) 367 | then balance t 368 | else Ok 369 | end 370 | else begin 371 | match s_left_right with 372 | | Node (Alive, hlr, a, b, c) -> 373 | let s_left_right_dead = Node (Dead, hlr, a, b, c) in 374 | let s_left_dead = 375 | Node (Dead, hl, left_left, left_pivot, left_right) 376 | in 377 | let plan = 378 | Balancing_left_center 379 | (s_left, s_left_dead, s_left_right, s_left_right_dead) 380 | in 381 | if Atomic.compare_and_set t s 382 | (Node (plan, ht, left, pivot, right)) 383 | then balance t 384 | else Ok 385 | | _ -> Ok 386 | end 387 | | _ -> Ok 388 | end 389 | else if hr > hl + 2 390 | then begin 391 | match s_right with 392 | | Node (Alive, hr, right_left, right_pivot, right_right) -> 393 | let hrr = height right_right in 394 | let s_right_left = get_stable right_left in 395 | if hrr >= s_height s_right_left 396 | then begin 397 | let s_right_dead = 398 | Node (Dead, hr, right_left, right_pivot, right_right) 399 | in 400 | let plan = Balancing_right (s_right, s_right_dead) in 401 | if Atomic.compare_and_set t s 402 | (Node (plan, ht, left, pivot, right)) 403 | then balance t 404 | else Ok 405 | end 406 | else begin 407 | match s_right_left with 408 | | Node (Alive, hrl, a, b, c) -> 409 | let s_right_left_dead = Node (Dead, hrl, a, b, c) in 410 | let s_right_dead = 411 | Node (Dead, hr, right_left, right_pivot, right_right) 412 | in 413 | let plan = 414 | Balancing_right_center 415 | (s_right, s_right_dead, s_right_left, s_right_left_dead) 416 | in 417 | if Atomic.compare_and_set t s 418 | (Node (plan, ht, left, pivot, right)) 419 | then balance t 420 | else Ok 421 | | _ -> Ok 422 | end 423 | | _ -> Ok 424 | end 425 | else begin 426 | let expected_height = 1 + max hl hr in 427 | if ht = expected_height 428 | then Ok 429 | else begin 430 | let repr = Node (Alive, expected_height, left, pivot, right) in 431 | if Atomic.compare_and_set t s repr then Overflow else Ok 432 | end 433 | end 434 | | _ -> Ok 435 | end 436 | -------------------------------------------------------------------------------- /src/mcset.ml: -------------------------------------------------------------------------------- 1 | module Make_poly (E : S.Ordered_poly) = struct 2 | include Core.Make (E) 3 | 4 | let rec add ~gen x t = 5 | let s = Atomic.get t in 6 | match s with 7 | | Leaf Dead | Node (Dead, _, _, _, _) -> Looking_for_life 8 | | Leaf Read_only | Node (Read_only, _, _, _, _) -> Retry 9 | | Leaf Alive -> 10 | let root, expected = gen in 11 | if Atomic.get root != expected 12 | then Retry 13 | else begin 14 | let attempt = Atomic.make Unknown in 15 | let state = Attempt_add (attempt, root, expected) in 16 | let repr = Node (state, 1, Atomic.make s, x, Atomic.make s) in 17 | if Atomic.compare_and_set t s repr 18 | then begin 19 | finalize_op t repr ; 20 | match Atomic.get attempt with 21 | | Success -> Overflow 22 | | Failure -> Retry 23 | | Unknown -> assert false 24 | end 25 | else add ~gen x t 26 | end 27 | | Node (Alive, _, left, pivot, right) -> begin 28 | match E.compare x pivot with 29 | | 0 -> Ok 30 | | c when c < 0 -> begin 31 | match add ~gen x left with 32 | | Ok -> Ok 33 | | Retry -> Retry 34 | | Overflow -> balance t 35 | | Looking_for_life -> add ~gen x t 36 | end 37 | | _ -> begin 38 | match add ~gen x right with 39 | | Ok -> Ok 40 | | Retry -> Retry 41 | | Overflow -> balance t 42 | | Looking_for_life -> add ~gen x t 43 | end 44 | end 45 | | _ -> 46 | let res = fixup t in 47 | begin 48 | match add ~gen x t with 49 | | Ok -> res 50 | | other -> other 51 | end 52 | 53 | let rec add_retry x t = 54 | let root = Atomic.get t in 55 | match add ~gen:(t, root) x root with 56 | | Ok | Overflow -> () 57 | | Retry -> add_retry x t 58 | | Looking_for_life -> assert false 59 | 60 | let add x t = add_retry x t 61 | 62 | let rec add_or_replace ~gen x t = 63 | let s = Atomic.get t in 64 | match s with 65 | | Leaf Dead | Node (Dead, _, _, _, _) -> Looking_for_life 66 | | Leaf Read_only | Node (Read_only, _, _, _, _) -> Retry 67 | | Leaf Alive -> 68 | let root, expected = gen in 69 | if Atomic.get root != expected 70 | then Retry 71 | else begin 72 | let attempt = Atomic.make Unknown in 73 | let state = Attempt_add (attempt, root, expected) in 74 | let repr = Node (state, 1, Atomic.make s, x, Atomic.make s) in 75 | if Atomic.compare_and_set t s repr 76 | then begin 77 | finalize_op t repr ; 78 | match Atomic.get attempt with 79 | | Success -> Overflow 80 | | Failure -> Retry 81 | | Unknown -> assert false 82 | end 83 | else add_or_replace ~gen x t 84 | end 85 | | Node (Alive, height, left, pivot, right) -> begin 86 | match E.compare x pivot with 87 | | 0 when pivot == x -> Ok 88 | | 0 -> 89 | let root, expected = gen in 90 | if Atomic.get root != expected 91 | then Retry 92 | else begin 93 | let attempt = Atomic.make Unknown in 94 | let state = Attempt_replace (attempt, root, expected, x) in 95 | let repr = Node (state, height, left, x, right) in 96 | if Atomic.compare_and_set t s repr 97 | then begin 98 | finalize_op t repr ; 99 | match Atomic.get attempt with 100 | | Success -> Ok 101 | | Failure -> Retry 102 | | Unknown -> assert false 103 | end 104 | else add_or_replace ~gen x t 105 | end 106 | | c when c < 0 -> begin 107 | match add_or_replace ~gen x left with 108 | | Ok -> Ok 109 | | Retry -> Retry 110 | | Overflow -> balance t 111 | | Looking_for_life -> add_or_replace ~gen x t 112 | end 113 | | _ -> begin 114 | match add_or_replace ~gen x right with 115 | | Ok -> Ok 116 | | Retry -> Retry 117 | | Overflow -> balance t 118 | | Looking_for_life -> add_or_replace ~gen x t 119 | end 120 | end 121 | | _ -> 122 | let res = fixup t in 123 | begin 124 | match add_or_replace ~gen x t with 125 | | Ok -> res 126 | | other -> other 127 | end 128 | 129 | let rec add_or_replace_retry x t = 130 | let root = Atomic.get t in 131 | match add_or_replace ~gen:(t, root) x root with 132 | | Ok | Overflow -> () 133 | | Retry -> add_or_replace_retry x t 134 | | Looking_for_life -> assert false 135 | 136 | let add_or_replace x t = add_or_replace_retry x t 137 | 138 | let rec remove ~gen elt t = 139 | let s = Atomic.get t in 140 | match s with 141 | | Leaf Alive -> Ok, false (* not found *) 142 | | Leaf Read_only | Node (Read_only, _, _, _, _) -> Retry, false 143 | | Leaf Dead | Node (Dead, _, _, _, _) -> Looking_for_life, false 144 | | Node (Alive, h, left, pivot, right) -> begin 145 | match E.compare elt pivot with 146 | | 0 -> 147 | let root, expected = gen in 148 | if Atomic.get root != expected 149 | then Retry, false 150 | else begin 151 | let attempt = Atomic.make Unknown in 152 | let state = Attempt_remove (attempt, root, expected) in 153 | let s_removing = Node (state, h, left, pivot, right) in 154 | if Atomic.compare_and_set t s s_removing 155 | then begin 156 | finalize_op t s_removing ; 157 | match Atomic.get attempt with 158 | | Success -> fixup t, true 159 | | Failure -> Retry, false 160 | | Unknown -> assert false 161 | end 162 | else remove ~gen elt t 163 | end 164 | | c when c < 0 -> begin 165 | match remove ~gen elt left with 166 | | Ok, found -> Ok, found 167 | | Retry, found -> Retry, found 168 | | Overflow, found -> balance t, found 169 | | Looking_for_life, _ -> remove ~gen elt t 170 | end 171 | | _ -> begin 172 | match remove ~gen elt right with 173 | | Ok, found -> Ok, found 174 | | Retry, found -> Retry, found 175 | | Overflow, found -> balance t, found 176 | | Looking_for_life, _ -> remove ~gen elt t 177 | end 178 | end 179 | | _ -> 180 | let res = fixup t in 181 | begin 182 | match remove ~gen elt t with 183 | | Ok, found -> res, found 184 | | other, found -> other, found 185 | end 186 | 187 | let rec remove_retry elt t = 188 | let root = Atomic.get t in 189 | match remove ~gen:(t, root) elt root with 190 | | Ok, found | Overflow, found -> found 191 | | Retry, _ -> remove_retry elt t 192 | | Looking_for_life, _ -> assert false 193 | 194 | let remove elt t = remove_retry elt t 195 | 196 | let rec is_empty t = 197 | match Atomic.get t with 198 | | Leaf _ -> true 199 | | Node ((Alive | Dead | Read_only), _, _, _, _) -> false 200 | | _ -> 201 | let (_ : result) = fixup t in 202 | is_empty t 203 | 204 | let is_empty t = is_empty (Atomic.get t) 205 | 206 | let rec choose_opt t = 207 | match Atomic.get t with 208 | | Leaf _ -> None 209 | | Node ((Alive | Dead | Read_only), _, _, pivot, _) -> Some pivot 210 | | _ -> 211 | let (_ : result) = fixup t in 212 | choose_opt t 213 | 214 | let choose_opt t = choose_opt (Atomic.get t) 215 | 216 | let choose t = 217 | match choose_opt t with 218 | | Some x -> x 219 | | None -> raise Not_found 220 | 221 | let rec mem x t = 222 | match Atomic.get t with 223 | | Leaf _ -> false 224 | | Node ((Alive | Dead | Read_only), _, left, pivot, right) -> begin 225 | match E.compare x pivot with 226 | | 0 -> true 227 | | c when c < 0 -> mem x left 228 | | _ -> mem x right 229 | end 230 | | _ -> 231 | let (_ : result) = fixup t in 232 | mem x t 233 | 234 | let mem x t = mem x (Atomic.get t) 235 | 236 | let rec find_opt x t = 237 | match Atomic.get t with 238 | | Leaf _ -> None 239 | | Node ((Alive | Dead | Read_only), _, left, pivot, right) -> begin 240 | match E.compare x pivot with 241 | | 0 -> Some pivot 242 | | c when c < 0 -> find_opt x left 243 | | _ -> find_opt x right 244 | end 245 | | _ -> 246 | let (_ : result) = fixup t in 247 | find_opt x t 248 | 249 | let find_opt x t = find_opt x (Atomic.get t) 250 | 251 | let find x t = 252 | match find_opt x t with 253 | | Some x -> x 254 | | None -> raise Not_found 255 | 256 | let rec find_first_opt f t = 257 | match Atomic.get t with 258 | | Leaf _ -> None 259 | | Node ((Alive | Dead | Read_only), _, left, pivot, right) -> 260 | if f pivot 261 | then 262 | match find_first_opt f left with 263 | | None -> Some pivot 264 | | some -> some 265 | else find_first_opt f right 266 | | _ -> 267 | let (_ : result) = fixup t in 268 | find_first_opt f t 269 | 270 | let find_first_opt f t = find_first_opt f (Atomic.get t) 271 | 272 | let find_first f t = 273 | match find_first_opt f t with 274 | | Some x -> x 275 | | None -> raise Not_found 276 | 277 | let rec find_last_opt f t = 278 | match Atomic.get t with 279 | | Leaf _ -> None 280 | | Node ((Alive | Dead | Read_only), _, left, pivot, right) -> 281 | if f pivot 282 | then 283 | match find_last_opt f right with 284 | | None -> Some pivot 285 | | some -> some 286 | else find_last_opt f left 287 | | _ -> 288 | let (_ : result) = fixup t in 289 | find_last_opt f t 290 | 291 | let find_last_opt f t = find_last_opt f (Atomic.get t) 292 | 293 | let find_last f t = 294 | match find_last_opt f t with 295 | | Some x -> x 296 | | None -> raise Not_found 297 | 298 | let rec min_elt_opt t = 299 | match Atomic.get t with 300 | | Leaf _ -> None 301 | | Node ((Alive | Dead | Read_only), _, left, pivot, _) -> begin 302 | match min_elt_opt left with 303 | | None -> Some pivot 304 | | some -> some 305 | end 306 | | _ -> 307 | let (_ : result) = fixup t in 308 | min_elt_opt t 309 | 310 | let min_elt_opt t = min_elt_opt (Atomic.get t) 311 | 312 | let min_elt t = 313 | match min_elt_opt t with 314 | | Some x -> x 315 | | None -> raise Not_found 316 | 317 | let rec max_elt_opt t = 318 | match Atomic.get t with 319 | | Leaf _ -> None 320 | | Node ((Alive | Dead | Read_only), _, _, pivot, right) -> begin 321 | match max_elt_opt right with 322 | | None -> Some pivot 323 | | some -> some 324 | end 325 | | _ -> 326 | let (_ : result) = fixup t in 327 | max_elt_opt t 328 | 329 | let max_elt_opt t = max_elt_opt (Atomic.get t) 330 | 331 | let max_elt t = 332 | match max_elt_opt t with 333 | | Some x -> x 334 | | None -> raise Not_found 335 | 336 | let rec snapshot t = 337 | let root = Atomic.get t in 338 | if Atomic.compare_and_set t root (Atomic.make (Copy root)) 339 | then root 340 | else snapshot t 341 | 342 | let to_view = snapshot 343 | 344 | let of_view r = Atomic.make (Atomic.make (Copy r)) 345 | 346 | let copy t = of_view (to_view t) 347 | 348 | module View = struct 349 | type 'a elt = 'a E.t 350 | 351 | type 'a t = 'a r 352 | 353 | let empty_fresh () = Pure.empty ~s:Read_only 354 | 355 | let empty = Obj.magic (empty_fresh ()) (* weak! *) 356 | 357 | let singleton elt = Pure.singleton ~s:Read_only elt 358 | 359 | let of_list lst = Pure.of_list ~s:Read_only lst 360 | 361 | let of_seq lst = Pure.of_seq ~s:Read_only lst 362 | 363 | let add elt t = Pure.add ~s:Read_only ~get:ensure_read_only elt t 364 | 365 | let add_or_replace elt t = 366 | Pure.add_or_replace ~s:Read_only ~get:ensure_read_only elt t 367 | 368 | let balance left pivot right = 369 | Pure.balance ~s:Read_only ~get:ensure_read_only left pivot right 370 | 371 | let rec pop_min_opt t = 372 | match ensure_read_only t with 373 | | Leaf Read_only -> None 374 | | Node (Read_only, _, left, pivot, right) -> begin 375 | match pop_min_opt left with 376 | | None -> Some (pivot, right) 377 | | Some (min, left) -> Some (min, balance left pivot right) 378 | end 379 | | _ -> assert false 380 | 381 | let pop_min t = 382 | match pop_min_opt t with 383 | | Some x -> x 384 | | None -> raise Not_found 385 | 386 | let rec pop_max_opt t = 387 | match ensure_read_only t with 388 | | Leaf Read_only -> None 389 | | Node (Read_only, _, left, pivot, right) -> begin 390 | match pop_max_opt right with 391 | | None -> Some (pivot, left) 392 | | Some (max, right) -> Some (max, balance left pivot right) 393 | end 394 | | _ -> assert false 395 | 396 | let pop_max t = 397 | match pop_max_opt t with 398 | | Some x -> x 399 | | None -> raise Not_found 400 | 401 | let rec add_min elt t = 402 | match ensure_read_only t with 403 | | Leaf Read_only -> singleton elt 404 | | Node (Read_only, _, left, pivot, right) -> 405 | balance (add_min elt left) pivot right 406 | | _ -> assert false 407 | 408 | let rec add_max elt t = 409 | match ensure_read_only t with 410 | | Leaf Read_only -> singleton elt 411 | | Node (Read_only, _, left, pivot, right) -> 412 | balance left pivot (add_max elt right) 413 | | _ -> assert false 414 | 415 | let rec join left pivot right = 416 | match ensure_read_only left, ensure_read_only right with 417 | | Leaf Read_only, Leaf Read_only -> singleton pivot 418 | | Leaf Read_only, _ -> add_min pivot right 419 | | _, Leaf Read_only -> add_max pivot left 420 | | Node (Read_only, lh, ll, lp, lr), Node (Read_only, rh, rl, rp, rr) -> 421 | if lh > rh + 2 422 | then balance ll lp (join lr pivot right) 423 | else if rh > lh + 2 424 | then balance (join left pivot rl) rp rr 425 | else create left pivot right 426 | | _ -> assert false 427 | 428 | let append left right = 429 | match pop_min_opt right with 430 | | None -> left 431 | | Some (pivot, right) -> join left pivot right 432 | 433 | let rec remove elt t = 434 | match ensure_read_only t with 435 | | Leaf Read_only -> t 436 | | Node (Read_only, _, left, pivot, right) -> begin 437 | match E.compare elt pivot with 438 | | 0 -> append left right 439 | | c when c < 0 -> 440 | let left' = remove elt left in 441 | if left == left' then t else balance left' pivot right 442 | | _ -> 443 | let right' = remove elt right in 444 | if right == right' then t else balance left pivot right' 445 | end 446 | | _ -> assert false 447 | 448 | let rec split elt t = 449 | match ensure_read_only t with 450 | | Leaf Read_only -> t, false, t 451 | | Node (Read_only, _, left, pivot, right) -> begin 452 | match E.compare elt pivot with 453 | | 0 -> left, true, right 454 | | c when c < 0 -> 455 | let l, found, r = split elt left in 456 | l, found, join r pivot right 457 | | _ -> 458 | let l, found, r = split elt right in 459 | join left pivot l, found, r 460 | end 461 | | _ -> assert false 462 | 463 | let rec union t1 t2 = 464 | match ensure_read_only t1, ensure_read_only t2 with 465 | | Leaf Read_only, _ -> t2 466 | | _, Leaf Read_only -> t1 467 | | Node (Read_only, h1, l1, p1, r1), Node (Read_only, h2, _, _, _) 468 | when h1 >= h2 -> 469 | let l2, _, r2 = split p1 t2 in 470 | join (union l1 l2) p1 (union r1 r2) 471 | | Node (Read_only, _, _, _, _), Node (Read_only, _, l2, p2, r2) -> 472 | let l1, _, r1 = split p2 t1 in 473 | join (union l1 l2) p2 (union r1 r2) 474 | | _ -> assert false 475 | 476 | let rec inter t1 t2 = 477 | match ensure_read_only t1, ensure_read_only t2 with 478 | | Leaf Read_only, _ -> t1 479 | | _, Leaf Read_only -> t2 480 | | Node (Read_only, h1, l1, p1, r1), Node (Read_only, h2, _, _, _) 481 | when h1 >= h2 -> 482 | let l2, found, r2 = split p1 t2 in 483 | let l12, r12 = inter l1 l2, inter r1 r2 in 484 | if found then join l12 p1 r12 else append l12 r12 485 | | Node (Read_only, _, _, _, _), Node (Read_only, _, l2, p2, r2) -> 486 | let l1, found, r1 = split p2 t1 in 487 | let l12, r12 = inter l1 l2, inter r1 r2 in 488 | if found then join l12 p2 r12 else append l12 r12 489 | | _ -> assert false 490 | 491 | let rec diff t1 t2 = 492 | match ensure_read_only t1, ensure_read_only t2 with 493 | | Leaf Read_only, _ | _, Leaf Read_only -> t1 494 | | Node (Read_only, _, l1, p1, r1), _ -> 495 | let l2, found, r2 = split p1 t2 in 496 | let l12, r12 = diff l1 l2, diff r1 r2 in 497 | if found then append l12 r12 else join l12 p1 r12 498 | | _ -> assert false 499 | 500 | let rec mem x t = 501 | match ensure_read_only t with 502 | | Leaf Read_only -> false 503 | | Node (Read_only, _, left, pivot, right) -> begin 504 | match E.compare x pivot with 505 | | 0 -> true 506 | | c when c < 0 -> mem x left 507 | | _ -> mem x right 508 | end 509 | | _ -> assert false 510 | 511 | let rec find_opt x t = 512 | match ensure_read_only t with 513 | | Leaf Read_only -> None 514 | | Node (Read_only, _, left, pivot, right) -> begin 515 | match E.compare x pivot with 516 | | 0 -> Some pivot 517 | | c when c < 0 -> find_opt x left 518 | | _ -> find_opt x right 519 | end 520 | | _ -> assert false 521 | 522 | let find x t = 523 | match find_opt x t with 524 | | Some x -> x 525 | | None -> raise Not_found 526 | 527 | let rec find_first_opt f t = 528 | match ensure_read_only t with 529 | | Leaf Read_only -> None 530 | | Node (Read_only, _, left, pivot, right) -> 531 | if f pivot 532 | then 533 | match find_first_opt f left with 534 | | None -> Some pivot 535 | | some -> some 536 | else find_first_opt f right 537 | | _ -> 538 | let (_ : result) = fixup t in 539 | find_first_opt f t 540 | 541 | let find_first f t = 542 | match find_first_opt f t with 543 | | Some x -> x 544 | | None -> raise Not_found 545 | 546 | let rec find_last_opt f t = 547 | match ensure_read_only t with 548 | | Leaf Read_only -> None 549 | | Node (Read_only, _, left, pivot, right) -> 550 | if f pivot 551 | then 552 | match find_last_opt f right with 553 | | None -> Some pivot 554 | | some -> some 555 | else find_last_opt f left 556 | | _ -> 557 | let (_ : result) = fixup t in 558 | find_last_opt f t 559 | 560 | let find_last f t = 561 | match find_last_opt f t with 562 | | Some x -> x 563 | | None -> raise Not_found 564 | 565 | let is_empty t = 566 | match ensure_read_only t with 567 | | Leaf Read_only -> true 568 | | Node (Read_only, _, _, _, _) -> false 569 | | _ -> assert false 570 | 571 | let choose_opt t = 572 | match ensure_read_only t with 573 | | Leaf Read_only -> None 574 | | Node (Read_only, _, _, pivot, _) -> Some pivot 575 | | _ -> assert false 576 | 577 | let choose t = 578 | match choose_opt t with 579 | | Some x -> x 580 | | None -> raise Not_found 581 | 582 | let rec min_elt_opt t = 583 | match ensure_read_only t with 584 | | Leaf Read_only -> None 585 | | Node (Read_only, _, left, pivot, _) -> begin 586 | match min_elt_opt left with 587 | | None -> Some pivot 588 | | some -> some 589 | end 590 | | _ -> assert false 591 | 592 | let min_elt t = 593 | match min_elt_opt t with 594 | | Some x -> x 595 | | None -> raise Not_found 596 | 597 | let rec max_elt_opt t = 598 | match ensure_read_only t with 599 | | Leaf Read_only -> None 600 | | Node (Read_only, _, _, pivot, right) -> begin 601 | match max_elt_opt right with 602 | | None -> Some pivot 603 | | some -> some 604 | end 605 | | _ -> assert false 606 | 607 | let max_elt t = 608 | match max_elt_opt t with 609 | | Some x -> x 610 | | None -> raise Not_found 611 | 612 | let rec cardinal acc t = 613 | match ensure_read_only t with 614 | | Leaf Read_only -> acc 615 | | Node (Read_only, _, left, _, right) -> 616 | let acc = acc + 1 in 617 | let acc = cardinal acc left in 618 | cardinal acc right 619 | | _ -> assert false 620 | 621 | let cardinal t = cardinal 0 t 622 | 623 | let rec fold f t acc = 624 | match ensure_read_only t with 625 | | Leaf Read_only -> acc 626 | | Node (Read_only, _, left, pivot, right) -> 627 | let acc = fold f left acc in 628 | let acc = f pivot acc in 629 | fold f right acc 630 | | _ -> assert false 631 | 632 | let rec iter f t = 633 | match ensure_read_only t with 634 | | Leaf Read_only -> () 635 | | Node (Read_only, _, left, pivot, right) -> 636 | iter f left ; f pivot ; iter f right 637 | | _ -> assert false 638 | 639 | let rec for_all f t = 640 | match ensure_read_only t with 641 | | Leaf Read_only -> true 642 | | Node (Read_only, _, left, pivot, right) -> 643 | f pivot && for_all f left && for_all f right 644 | | _ -> assert false 645 | 646 | let rec exists f t = 647 | match ensure_read_only t with 648 | | Leaf Read_only -> false 649 | | Node (Read_only, _, left, pivot, right) -> 650 | f pivot || exists f left || exists f right 651 | | _ -> assert false 652 | 653 | let elements t = List.rev (fold (fun x xs -> x :: xs) t []) 654 | 655 | let to_list = elements 656 | 657 | let rec to_seq t k = 658 | match ensure_read_only t with 659 | | Leaf Read_only -> k 660 | | Node (Read_only, _, left, pivot, right) -> 661 | to_seq left (fun () -> Seq.Cons (pivot, to_seq right k)) 662 | | _ -> assert false 663 | 664 | let rec to_seq_from elt t k = 665 | match ensure_read_only t with 666 | | Leaf Read_only -> k 667 | | Node (Read_only, _, left, pivot, right) -> begin 668 | match E.compare elt pivot with 669 | | 0 -> fun () -> Seq.Cons (pivot, to_seq right k) 670 | | c when c < 0 -> 671 | to_seq_from elt left (fun () -> Seq.Cons (pivot, to_seq right k)) 672 | | _ -> to_seq_from elt right k 673 | end 674 | | _ -> assert false 675 | 676 | let to_seq_from elt t = to_seq_from elt t (fun () -> Seq.Nil) 677 | 678 | let to_seq t = to_seq t (fun () -> Seq.Nil) 679 | 680 | let rec to_rev_seq t k = 681 | match ensure_read_only t with 682 | | Leaf Read_only -> k 683 | | Node (Read_only, _, left, pivot, right) -> 684 | to_rev_seq right (fun () -> Seq.Cons (pivot, to_rev_seq left k)) 685 | | _ -> assert false 686 | 687 | let to_rev_seq t = to_rev_seq t (fun () -> Seq.Nil) 688 | 689 | let fast_union left pivot right = 690 | match max_elt_opt left, min_elt_opt right with 691 | | None, None -> singleton pivot 692 | | Some max, None when E.compare max pivot < 0 -> add_max pivot left 693 | | None, Some min when E.compare pivot min < 0 -> add_min pivot right 694 | | Some max, Some min 695 | when E.compare max pivot < 0 && E.compare pivot min < 0 -> 696 | join left pivot right 697 | | _ -> union left (add pivot right) 698 | 699 | let fast_append left right = 700 | match max_elt_opt left, min_elt_opt right with 701 | | _, None -> left 702 | | None, _ -> right 703 | | Some max, Some min when E.compare max min < 0 -> append left right 704 | | _ -> union left right 705 | 706 | let rec map f t = 707 | match ensure_read_only t with 708 | | Leaf Read_only -> t 709 | | Node (Read_only, _, left, pivot, right) -> 710 | let left' = map f left in 711 | let pivot' = f pivot in 712 | let right' = map f right in 713 | if left == left' && pivot == pivot' && right == right' 714 | then t 715 | else fast_union left' pivot' right' 716 | | _ -> assert false 717 | 718 | let rec filter f t = 719 | match ensure_read_only t with 720 | | Leaf Read_only -> t 721 | | Node (Read_only, _, left, pivot, right) -> 722 | let left' = filter f left in 723 | let keep = f pivot in 724 | let right' = filter f right in 725 | if left == left' && keep && right == right' 726 | then t 727 | else if keep 728 | then join left' pivot right' 729 | else append left' right' 730 | | _ -> assert false 731 | 732 | let rec filter_map f t = 733 | match ensure_read_only t with 734 | | Leaf Read_only -> t 735 | | Node (Read_only, _, left, pivot, right) -> 736 | let left' = filter_map f left in 737 | let keep = f pivot in 738 | let right' = filter_map f right in 739 | begin 740 | match keep with 741 | | None -> fast_append left' right' 742 | | Some pivot' -> 743 | if left == left' && pivot == pivot' && right == right' 744 | then t 745 | else fast_union left' pivot' right' 746 | end 747 | | _ -> assert false 748 | 749 | let rec partition f t = 750 | match ensure_read_only t with 751 | | Leaf Read_only -> t, t 752 | | Node (Read_only, _, left, pivot, right) -> 753 | let left', not_left' = partition f left in 754 | let keep = f pivot in 755 | let right', not_right' = partition f right in 756 | if left == left' && keep && right == right' 757 | then t, empty_fresh () 758 | else if left == not_left' && (not keep) && right == not_right' 759 | then empty_fresh (), t 760 | else if keep 761 | then join left' pivot right', append not_left' not_right' 762 | else append left' right', join not_left' pivot not_right' 763 | | _ -> assert false 764 | 765 | let compare t1 t2 = Seq.compare E.compare (to_seq t1) (to_seq t2) 766 | 767 | let equal t1 t2 = compare t1 t2 = 0 768 | 769 | let rec disjoint t1 t2 = 770 | match ensure_read_only t1, ensure_read_only t2 with 771 | | Leaf Read_only, _ | _, Leaf Read_only -> true 772 | | Node (Read_only, h1, l1, p1, r1), Node (Read_only, h2, l2, p2, r2) -> 773 | if h1 > h2 774 | then 775 | let l2, found, r2 = split p1 t2 in 776 | (not found) && disjoint l1 l2 && disjoint r1 r2 777 | else 778 | let l1, found, r1 = split p2 t1 in 779 | (not found) && disjoint l1 l2 && disjoint r1 r2 780 | | _ -> assert false 781 | end 782 | 783 | let cardinal t = View.cardinal (snapshot t) 784 | 785 | let fold f t acc = View.fold f (snapshot t) acc 786 | 787 | let iter f t = View.iter f (snapshot t) 788 | 789 | let for_all f t = View.for_all f (snapshot t) 790 | 791 | let exists f t = View.exists f (snapshot t) 792 | 793 | let elements t = View.elements (snapshot t) 794 | 795 | let to_list = elements 796 | 797 | let of_list lst = Atomic.make (Pure.of_list ~s:Alive lst) 798 | 799 | let of_seq seq = Atomic.make (Pure.of_seq ~s:Alive seq) 800 | 801 | let to_seq t = View.to_seq (snapshot t) 802 | 803 | let to_rev_seq t = View.to_rev_seq (snapshot t) 804 | 805 | let to_seq_from elt t = View.to_seq_from elt (snapshot t) 806 | end 807 | 808 | module Make (Ord : S.Ordered) = struct 809 | module Ord_poly = struct 810 | type _ t = Ord.t 811 | 812 | let compare = Ord.compare 813 | end 814 | 815 | module I = Make_poly (Ord_poly) 816 | 817 | type void = | 818 | 819 | type elt = Ord.t 820 | 821 | type t = void I.t 822 | 823 | module View = struct 824 | type elt = Ord.t 825 | 826 | type t = void I.View.t 827 | 828 | include ( 829 | I.View : S.View_poly(Ord_poly).S with type _ elt := elt and type _ t := t ) 830 | end 831 | 832 | include ( 833 | I : 834 | S.Set_poly(Ord_poly).S 835 | with type _ elt := elt 836 | and type _ t := t 837 | and type _ View.elt := View.elt 838 | and type _ View.t := View.t 839 | and module View := View ) 840 | end 841 | -------------------------------------------------------------------------------- /src/s.ml: -------------------------------------------------------------------------------- 1 | module type QUERY = sig 2 | type 'a elt 3 | 4 | type 'a t 5 | 6 | (** {1 Queries} *) 7 | 8 | val is_empty : 'a t -> bool 9 | (** [is_empty t] returns [true] when the set [t] contains no elements, 10 | [false] if it has at least one member. {b O(1)} *) 11 | 12 | val mem : 'a elt -> 'a t -> bool 13 | (** [mem x t] returns [true] if the element [x] belongs to the set [t], 14 | [false] otherwise. {b O(logN)} *) 15 | 16 | val min_elt : 'a t -> 'a elt 17 | (** [min_elt t] returns the smallest element of the set [t], 18 | or raises [Not_found] if the set is empty. {b O(logN)} *) 19 | 20 | val min_elt_opt : 'a t -> 'a elt option 21 | (** [min_elt_opt t] returns the smallest element of the set [t], 22 | or [None] if the set is empty. {b O(logN)} *) 23 | 24 | val max_elt : 'a t -> 'a elt 25 | (** [max_elt t] returns the largest element of the set [t], 26 | or raises [Not_found] if the set is empty. {b O(logN)} *) 27 | 28 | val max_elt_opt : 'a t -> 'a elt option 29 | (** [max_elt_opt t] returns the largest element of the set [t], 30 | or [None] if the set is empty. {b O(logN)} *) 31 | 32 | val choose : 'a t -> 'a elt 33 | (** [choose t] returns an arbitrary element of the set [t], 34 | or raises [Not_found] if the set is empty. {b O(1)} *) 35 | 36 | val choose_opt : 'a t -> 'a elt option 37 | (** [choose_opt t] returns an arbitrary element of the set [t], 38 | or [None] if the set is empty. {b O(1)} *) 39 | 40 | val find : 'a elt -> 'a t -> 'a elt 41 | (** [find x t] returns the element of the set that compares equal to [x], 42 | or raises [Not_found] if no such element exists. {b O(logN)} *) 43 | 44 | val find_opt : 'a elt -> 'a t -> 'a elt option 45 | (** [find_opt x t] returns the element of the set that compares equal to [x], 46 | or [None] if no such element exists. {b O(logN)} *) 47 | 48 | val find_first : ('a elt -> bool) -> 'a t -> 'a elt 49 | (** [find_first predicate t] returns the smallest element of the set [t] 50 | that satisfies that monotonically increasing [predicate], 51 | or raises [Not_found] if no such element exists. {b O(logN)} *) 52 | 53 | val find_first_opt : ('a elt -> bool) -> 'a t -> 'a elt option 54 | (** [find_first_opt predicate t] returns the smallest element of the set [t] 55 | that satisfies that monotonically increasing [predicate], 56 | or [None] if no such element exists. {b O(logN)} *) 57 | 58 | val find_last : ('a elt -> bool) -> 'a t -> 'a elt 59 | (** [find_last predicate t] returns the largest element of the set [t] 60 | that satisfies that monotonically decreasing [predicate], 61 | or raises [Not_found] if no such element exists. {b O(logN)} *) 62 | 63 | val find_last_opt : ('a elt -> bool) -> 'a t -> 'a elt option 64 | (** [find_last_opt predicate t] returns the smallest element of the set [t] 65 | that satisfies that monotonically increasing [predicate], 66 | or [None] if no such element exists. {b O(logN)} *) 67 | end 68 | 69 | module type ITER = sig 70 | type 'a elt 71 | 72 | type 'a t 73 | 74 | val cardinal : 'a t -> int 75 | (** [cardinal t] returns the number of unique elements in the set [t]. 76 | {b O(N)} *) 77 | 78 | val fold : ('a elt -> 'b -> 'b) -> 'a t -> 'b -> 'b 79 | (** [fold f t init] computes [f x0 init |> f x1 |> ... |> f xN], 80 | where [x0] ... [xN] are the elements of the set [t] in increasing order. 81 | Returns [init] if the set [t] was empty. *) 82 | 83 | val iter : ('a elt -> unit) -> 'a t -> unit 84 | (** [iter f t] calls [f] on all the elements of the set [t] in increasing 85 | order. *) 86 | 87 | val for_all : ('a elt -> bool) -> 'a t -> bool 88 | (** [for_all predicate t] returns [true] when all the elements of the set [t] 89 | satisfies the [predicate]. No ordering is guaranteed and the function 90 | will exit early if it finds an invalid element. *) 91 | 92 | val exists : ('a elt -> bool) -> 'a t -> bool 93 | (** [exists predicate t] returns [true] when at least one element of the set 94 | [t] satisfies the [predicate]. No ordering is guaranteed and the function 95 | will exit early if it finds a valid element. *) 96 | 97 | val elements : 'a t -> 'a elt list 98 | (** [elements t] is the list of all the elements of [t] in increasing 99 | order. *) 100 | 101 | val to_list : 'a t -> 'a elt list 102 | (** Same as [elements]. *) 103 | 104 | val of_list : 'a elt list -> 'a t 105 | (** [of_list lst] is the set containing all the elements of the list [lst]. *) 106 | 107 | val to_seq : 'a t -> 'a elt Seq.t 108 | (** [to_seq t] is the sequence containing all the elements of the set [t] in 109 | increasing order. {b O(1)} creation then {b O(1)} amortized for each 110 | consumed element of the sequence (with {b O(logN) worst-case}) *) 111 | 112 | val to_rev_seq : 'a t -> 'a elt Seq.t 113 | (** [to_rev_seq t] is the sequence containing all the elements of the set [t] 114 | in decreasing order. *) 115 | 116 | val to_seq_from : 'a elt -> 'a t -> 'a elt Seq.t 117 | (** [to_seq_from x t] is the sequence containing all the elements 118 | of the set [t] that are larger than or equal to [x], 119 | in increasing order. *) 120 | 121 | val of_seq : 'a elt Seq.t -> 'a t 122 | (** [of_set seq] is the set containing all the elements 123 | of the sequence [seq]. *) 124 | end 125 | 126 | module type Ordered_poly = sig 127 | (** Totally ordered polymorphic type. *) 128 | 129 | type 'a t 130 | (** The type of comparable polymorphic elements. *) 131 | 132 | val compare : 'a t -> 'b t -> int 133 | (** [compare a b] must return [0] if [a] equals [b], a negative number 134 | if [a] is stricly less than [b] and a positive number if [a] is strictly 135 | larger than [b]. *) 136 | end 137 | 138 | module type Ordered = sig 139 | (** Totally ordered type. *) 140 | 141 | type t 142 | (** The type of comparable elements. *) 143 | 144 | include Ordered_poly with type _ t := t 145 | end 146 | 147 | module View_poly (Ord : Ordered_poly) = struct 148 | module type S = sig 149 | (** A read-only, non-mutable view of a polymorphic set (with more 150 | operations, similar to the purely functional {! Stdlib.Set} 151 | interface) *) 152 | 153 | type 'a elt = 'a Ord.t 154 | (** The type of polymorphic set elements. *) 155 | 156 | type 'a t 157 | (** The type of read-only polymorphic sets. *) 158 | 159 | val empty : 'a t 160 | (** The empty set. *) 161 | 162 | val singleton : 'a elt -> 'a t 163 | (** [singleton x] returns a set containing only the element [x]. {b O(1)} *) 164 | 165 | val add : 'a elt -> 'a t -> 'a t 166 | (** [add x t] returns a set containing the element [x] and all the 167 | elements of [t]. If [x] was already in the set [t], then the result 168 | is physically equal to [t]. {b O(logN)} *) 169 | 170 | val remove : 'a elt -> 'a t -> 'a t 171 | (** [remove x t] returns a set containing the elements of [t] without [x]. 172 | If [x] was not a member of the set [t], then the result is physically 173 | equal to [t]. {b O(logN)} *) 174 | 175 | val union : 'a t -> 'a t -> 'a t 176 | (** [union t1 t2] returns a set containing all the elements 177 | of [t1] and [t2]. {b O(N)} worst-case *) 178 | 179 | val inter : 'a t -> 'a t -> 'a t 180 | (** [inter t1 t2] returns a set containing the shared elements 181 | of [t1] and [t2]. {b O(N)} worst-case *) 182 | 183 | val diff : 'a t -> 'a t -> 'a t 184 | (** [diff t1 t2] returns a set containing the elements of [t1] that are 185 | not members of the set [t2]. {b O(N)} worst-case *) 186 | 187 | val map : ('a elt -> 'a elt) -> 'a t -> 'a t 188 | (** [map f t] returns a set containing the elements [f x0], [f x1], ..., [f xN] 189 | where [x0] ... [xN] are all the elements of the set [t]. 190 | - The elements are passed to [f] in increasing order. 191 | - The result is physically equal to [t] if [f] always returned a physically 192 | equal element. {b O(NlogN)} worst-case 193 | *) 194 | 195 | val filter : ('a elt -> bool) -> 'a t -> 'a t 196 | (** [filter predicate t] returns the subset of elements of the set [t] that 197 | satistifies the [predicate] (called in increasing order). The resulting 198 | set is physical equal to [t] if no element was rejected. {b O(N)} *) 199 | 200 | val filter_map : ('a elt -> 'a elt option) -> 'a t -> 'a t 201 | (** [filter_map predicate t] returns a set containing the [Some] elements 202 | of [f x0], [f x1], ..., [f xN] where [x0] ... [xN] are all the elements 203 | of the set [t]. 204 | - The elements are passed to [f] in increasing order. 205 | - The result is physically equal to [t] if [f] always returned [Some] 206 | physically equal element. {b O(NlogN)} worst-case 207 | *) 208 | 209 | val partition : ('a elt -> bool) -> 'a t -> 'a t * 'a t 210 | (** [partiton predicate t] returns two sets, the first one 211 | containing all the elements of the set [t] that satisfies [predicate], 212 | while the second contains all the rejected ones. 213 | - The elements are passed to [f] in increasing order. 214 | - The first set is physically equal to [t] if [f] always returned 215 | [true] 216 | - The second set is physically equal to [t] if [f] always returned 217 | [false]. {b O(N)} 218 | *) 219 | 220 | val split : 'a elt -> 'a t -> 'a t * bool * 'a t 221 | (** [split x t] returns a triple [(smaller, found, larger)] such that: 222 | - [smaller] is the subset of elements strictly smaller than [x] 223 | - [larger] is the subset of elements strictly larger than [x] 224 | - [found] is [true] if [x] is a member of the set [t], [false] otherwise. 225 | {b O(logN)} 226 | *) 227 | 228 | val pop_min : 'a t -> 'a elt * 'a t 229 | (** [pop_min t] returns the smallest element and the other elements 230 | of the set [t], or raises [Not_found] if the set [t] is empty. 231 | {b O(logN)} *) 232 | 233 | val pop_min_opt : 'a t -> ('a elt * 'a t) option 234 | (** [pop_min_opt t] returns the smallest element and the other elements 235 | of the set [t], or [None] if the set [t] is empty. {b O(logN)} *) 236 | 237 | val pop_max : 'a t -> 'a elt * 'a t 238 | (** [pop_max t] returns the largest element and the other elements 239 | of the set [t], or raises [Not_found] if the set [t] is empty. 240 | {b O(logN)} *) 241 | 242 | val pop_max_opt : 'a t -> ('a elt * 'a t) option 243 | (** [pop_max_opt t] returns the largest element and the other elements 244 | of the set [t], or [None] if the set [t] is empty. *) 245 | 246 | (** {1 Comparisons} *) 247 | 248 | val equal : 'a t -> 'a t -> bool 249 | (** [equal t1 t2] returns [true] if the sets [t1] and [t2] contain the same 250 | elements. *) 251 | 252 | val disjoint : 'a t -> 'a t -> bool 253 | (** [disjoint t1 t2] returns [true] if the sets [t1] and [t2] have 254 | no elements in common. *) 255 | 256 | val compare : 'a t -> 'a t -> int 257 | (** [compare t1 t2] is a total order, suitable for building sets of sets. *) 258 | 259 | (** @inline *) 260 | include QUERY with type 'a elt := 'a elt and type 'a t := 'a t 261 | 262 | (** {1 Iterators} *) 263 | 264 | (** @inline *) 265 | include ITER with type 'a elt := 'a elt and type 'a t := 'a t 266 | end 267 | end 268 | 269 | module Set_poly (Ord : Ordered_poly) = struct 270 | module type S = sig 271 | (** Thread-safe mutable set structure given 272 | a totally ordered polymorphic type. *) 273 | 274 | type 'a elt = 'a Ord.t 275 | (** The type of set elements. *) 276 | 277 | type 'a t 278 | (** The type of mutable sets. *) 279 | 280 | val empty : unit -> 'a t 281 | (** [empty ()] returns a new empty set. {b O(1)} *) 282 | 283 | val singleton : 'a elt -> 'a t 284 | (** [singleton x] returns a new set containing only the element [x]. 285 | {b O(1)} *) 286 | 287 | val add : 'a elt -> 'a t -> unit 288 | (** [add x t] inserts the element [x] into the set [t]. If [x] was already 289 | a member, then the set is unchanged. {b O(logN)} *) 290 | 291 | val remove : 'a elt -> 'a t -> bool 292 | (** [remove x t] removes the element [x] from the set [t]. It returns 293 | a boolean indicating if the removal was successful. If [false], the 294 | element [x] was already not a member and the set is unchanged. 295 | {b O(logN)} *) 296 | 297 | (** @inline *) 298 | include QUERY with type 'a elt := 'a elt and type 'a t := 'a t 299 | 300 | (** {1 Snapshots} 301 | 302 | Concurrent modifications of a set are linearizable. The snapshot/copy 303 | provides a coherent view of the elements of a set along this linearized 304 | timeline. 305 | 306 | Further updates to the original set (or its copies) will trigger a minimal 307 | copy-on-write of the internal substructures of the set. This doesn't impact 308 | the time complexity of any operations, but induces a corresponding memory 309 | complexity for copying the modified subparts once. 310 | *) 311 | 312 | val copy : 'a t -> 'a t 313 | (** [copy t] returns an independently mutable copy of the set [t]. Further 314 | modifications of the set [t] will not affect its copies (and vice-versa.) 315 | {b O(1)} *) 316 | 317 | module View : View_poly(Ord).S 318 | 319 | val snapshot : 'a t -> 'a View.t 320 | (** [snapshot t] returns a read-only view of the elements of the set [t]. 321 | {b O(1)} *) 322 | 323 | val to_view : 'a t -> 'a View.t 324 | (** Same as [snapshot]. *) 325 | 326 | val of_view : 'a View.t -> 'a t 327 | (** [of_view v] returns a new mutable set containing all the elements 328 | of the view [v]. {b O(1)} *) 329 | 330 | (** {1 Iterators} 331 | 332 | The following functions all proceed on a coherent {! snapshot} of their 333 | set [t] argument, created at the start of their execution. Concurrent 334 | modifications of the original set [t] during their execution will not 335 | be observed by these traversals. 336 | *) 337 | 338 | (** @inline *) 339 | include ITER with type 'a elt := 'a elt and type 'a t := 'a t 340 | end 341 | end 342 | 343 | module Set (Ord : Ordered) = struct 344 | module Ord_poly = struct 345 | type _ t = Ord.t 346 | 347 | let compare = Ord.compare 348 | end 349 | 350 | module type Sig = Set_poly(Ord_poly).S 351 | 352 | module type S = sig 353 | (** Thread-safe mutable set structure given a totally ordered type. *) 354 | 355 | type elt = Ord.t 356 | (** The type of set elements. *) 357 | 358 | type t 359 | (** The type of mutable sets. *) 360 | 361 | module View : sig 362 | (** A read-only, non-mutable view of a set (with more operations, 363 | compatible with the purely functional {! Stdlib.Set} interface) *) 364 | 365 | type elt = Ord.t 366 | (** The type of set elements. *) 367 | 368 | type t 369 | (** The type of read-only sets. *) 370 | 371 | include View_poly(Ord_poly).S with type _ elt := elt and type _ t := t 372 | end 373 | 374 | include 375 | Sig 376 | with type _ elt := elt 377 | and type _ t := t 378 | and type _ View.elt := View.elt 379 | and type _ View.t := View.t 380 | and module View := View 381 | end 382 | end 383 | 384 | (******************************************************************************) 385 | 386 | module type QUERY_MAP = sig 387 | type 'a key 388 | 389 | type 'a t 390 | 391 | (** {1 Queries} *) 392 | 393 | val is_empty : 'a t -> bool 394 | (** [is_empty t] returns [true] when the map [t] contains no bindings, 395 | [false] if it has at least one. {b O(1)} *) 396 | 397 | val mem : 'a key -> 'a t -> bool 398 | (** [mem k t] returns [true] if the key [k] is bound in the map [t], 399 | [false] otherwise. {b O(logN)} *) 400 | 401 | val min_binding : 'a t -> 'a key * 'a 402 | (** [min_binding t] returns the binding with the smallest key 403 | in the map [t], or raises [Not_found] if the map is empty. 404 | {b O(logN)} *) 405 | 406 | val min_binding_opt : 'a t -> ('a key * 'a) option 407 | (** [min_binding t] returns the binding with the smallest key 408 | in the map [t], or [None] if the map is empty. {b O(logN)} *) 409 | 410 | val max_binding : 'a t -> 'a key * 'a 411 | (** [max_binding t] returns the binding with the largest key 412 | in the map [t], or raises [Not_found] if the map is empty. 413 | {b O(logN)} *) 414 | 415 | val max_binding_opt : 'a t -> ('a key * 'a) option 416 | (** [max_binding_opt t] returns the binding with the largest key 417 | in the map [t], or [None] if the map is empty. {b O(logN)} *) 418 | 419 | val choose : 'a t -> 'a key * 'a 420 | (** [choose t] returns an arbitrary binding of the map [t], 421 | or raises [Not_found] if the map is empty. {b O(1)} *) 422 | 423 | val choose_opt : 'a t -> ('a key * 'a) option 424 | (** [choose_opt t] returns an arbitrary binding of the map [t], 425 | or [None] if the map is empty. {b O(1)} *) 426 | 427 | val find : 'a key -> 'a t -> 'a 428 | (** [find k t] returns the value associated with the key [k] in the map [t], 429 | or raises [Not_found] if no such binding existed. {b O(logN)} *) 430 | 431 | val find_opt : 'a key -> 'a t -> 'a option 432 | (** [find_opt k t] returns the value associated with the key [k] 433 | in the map [t], or returns [None] if no such binding existed. 434 | {b O(logN)} *) 435 | 436 | val find_first : ('a key -> bool) -> 'a t -> 'a key * 'a 437 | (** [find_first predicate t] returns the smallest binding of the map [t] 438 | that satisfies that monotonically increasing [predicate], 439 | or raises [Not_found] if no such binding exists. {b O(logN)} *) 440 | 441 | val find_first_opt : ('a key -> bool) -> 'a t -> ('a key * 'a) option 442 | (** [find_first_opt predicate t] returns the smallest binding of the map [t] 443 | that satisfies that monotonically increasing [predicate], 444 | or [None] if no such binding exists. {b O(logN)} *) 445 | 446 | val find_last : ('a key -> bool) -> 'a t -> 'a key * 'a 447 | (** [find_last predicate t] returns the largest binding of the map [t] 448 | that satisfies that monotonically increasing [predicate], 449 | or raises [Not_found] if no such binding exists. {b O(logN)} *) 450 | 451 | val find_last_opt : ('a key -> bool) -> 'a t -> ('a key * 'a) option 452 | (** [find_last_opt predicate t] returns the largest binding of the map [t] 453 | that satisfies that monotonically increasing [predicate], 454 | or [None] if no such binding exists. {b O(logN)} *) 455 | end 456 | 457 | module type ITER_MAP = sig 458 | type 'a key 459 | 460 | type 'a t 461 | 462 | val cardinal : 'a t -> int 463 | (** [cardinal t] returns the number of bindings in the map [t]. {b O(N)} *) 464 | 465 | val fold : ('a key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b 466 | (** [fold f t init] computes [f k0 v0 init |> f k1 v1 |> ... |> f kN vN], 467 | where [k0,v0] ... [kN,vN] are the bindings of the map [t] in increasing order. 468 | Returns [init] if the map [t] was empty. *) 469 | 470 | val iter : ('a key -> 'a -> unit) -> 'a t -> unit 471 | (** [iter f t] calls [f] on all the bindings of the set [t] in increasing 472 | order. *) 473 | 474 | val for_all : ('a key -> 'a -> bool) -> 'a t -> bool 475 | (** [for_all predicate t] returns [true] when all the bindings of the map [t] 476 | satisfies the [predicate]. No ordering is guaranteed and the function 477 | will exit early if it finds an invalid binding. *) 478 | 479 | val exists : ('a key -> 'a -> bool) -> 'a t -> bool 480 | (** [exists predicate t] returns [true] when at least one binding of the map 481 | [t] satisfies the [predicate]. No ordering is guaranteed and the function 482 | will exit early if it finds a valid binding. *) 483 | 484 | val bindings : 'a t -> ('a key * 'a) list 485 | (** [bindings t] is the list of all the bindings of [t] in increasing 486 | order. *) 487 | 488 | val to_list : 'a t -> ('a key * 'a) list 489 | (** Same as [bindings]. *) 490 | 491 | val of_list : ('a key * 'a) list -> 'a t 492 | (** [of_list lst] is the map containing all the bindings of the list [lst]. *) 493 | 494 | val to_seq : 'a t -> ('a key * 'a) Seq.t 495 | (** [to_seq t] is the sequence containing all the bindings of the map [t] in 496 | increasing order. {b O(1)} creation then {b O(1)} amortized for each 497 | consumed binding of the sequence (with {b O(logN) worst-case}) *) 498 | 499 | val to_rev_seq : 'a t -> ('a key * 'a) Seq.t 500 | (** [to_rev_seq t] is the sequence containing all the bindings of the map [t] 501 | in decreasing order. *) 502 | 503 | val to_seq_from : 'a key -> 'a t -> ('a key * 'a) Seq.t 504 | (** [to_seq_from k t] is the sequence containing all the bindings 505 | of the set [t] whose key is larger than or equal to [k], 506 | in increasing order. *) 507 | 508 | val of_seq : ('a key * 'a) Seq.t -> 'a t 509 | (** [of_set seq] is the map containing all the bindings 510 | of the sequence [seq]. *) 511 | end 512 | 513 | module View_map_poly (Ord : Ordered_poly) = struct 514 | module type S = sig 515 | (** A read-only, non-mutable view of a polymorphic set (with more 516 | operations, similar to the purely functional {! Stdlib.Set} 517 | interface) *) 518 | 519 | type 'a key = 'a Ord.t 520 | (** The type of map keys. *) 521 | 522 | type 'a t 523 | (** The type of read-only maps from ['a key] to ['a] values. *) 524 | 525 | val empty : 'a t 526 | (** The empty map. *) 527 | 528 | val singleton : 'a key -> 'a -> 'a t 529 | (** [singleton k v] returns a map containing only the binding [k] for [v]. 530 | {b O(1)} *) 531 | 532 | val add : 'a key -> 'a -> 'a t -> 'a t 533 | (** [add k v t] returns a map containing the binding [k] for [v] and all the 534 | bindings of [t]. If [k] was already bound in the map [t], then its 535 | previous value is replaced by [v]. {b O(logN)} *) 536 | 537 | val remove : 'a key -> 'a t -> 'a t 538 | (** [remove k t] returns a map containing the bindings of [t] without [k]. 539 | If [k] was not bound in the map [t], then the result is physically 540 | equal to [t]. {b O(logN)} *) 541 | 542 | val union : 'a t -> 'a t -> 'a t 543 | (** [union t1 t2] returns a map containing all the bindings 544 | of [t1] and [t2]. {b O(N)} worst-case *) 545 | 546 | val inter : 'a t -> 'a t -> 'a t 547 | (** [inter t1 t2] returns a map containing the shared bindings 548 | of [t1] and [t2]. {b O(N)} worst-case *) 549 | 550 | val diff : 'a t -> 'a t -> 'a t 551 | (** [diff t1 t2] returns a map containing the bindings of [t1] whose 552 | keys are not bound in the map [t2]. {b O(N)} worst-case *) 553 | 554 | val map : ('a -> 'a) -> 'a t -> 'a t 555 | (** [map f t] returns a map where each bound value in the map [t] has been 556 | replaced by its mapping by [f]: 557 | - The values are passed to [f] in increasing key order. 558 | - The result is physically equal to [t] if [f] always returned a physically 559 | equal element. {b O(NlogN)} worst-case 560 | *) 561 | 562 | val mapi : ('a key -> 'a -> 'a) -> 'a t -> 'a t 563 | (** [mapi f t] returns a map where each bound value in the map [t] has been 564 | replaced by its mapping by [f]: 565 | - The values are passed to [f] in increasing key order. 566 | - The result is physically equal to [t] if [f] always returned a physically 567 | equal element. {b O(NlogN)} worst-case 568 | *) 569 | 570 | val filter : ('a key -> 'a -> bool) -> 'a t -> 'a t 571 | (** [filter predicate t] returns the subset of bindings of the map [t] that 572 | satistifies the [predicate] (called in increasing key order). The resulting 573 | map is physical equal to [t] if no binding was rejected. {b O(N)} *) 574 | 575 | val filter_map : ('a key -> 'a -> 'a option) -> 'a t -> 'a t 576 | (** [filter_map predicate t] returns a map containing the [Some] bindings 577 | of [f x0], [f x1], ..., [f xN] where [x0] ... [xN] are all the elements 578 | of the map [t]. 579 | - The bindings are passed to [f] in increasing order. 580 | - The result is physically equal to [t] if [f] always returned [Some] 581 | physically equal value. {b O(NlogN)} worst-case 582 | *) 583 | 584 | val partition : ('a key -> 'a -> bool) -> 'a t -> 'a t * 'a t 585 | (** [partiton predicate t] returns two maps, the first one 586 | containing all the bindings of the map [t] that satisfies [predicate], 587 | while the second contains all the rejected ones. 588 | - The bindings are passed to [f] in increasing order. 589 | - The first map is physically equal to [t] if [f] always returned 590 | [true] 591 | - The second map is physically equal to [t] if [f] always returned 592 | [false]. {b O(N)} 593 | *) 594 | 595 | val split : 'a key -> 'a t -> 'a t * 'a option * 'a t 596 | (** [split k t] returns a triple [(smaller, found, larger)] such that: 597 | - [smaller] is the subset of bindings whose key is strictly smaller than [k] 598 | - [larger] is the subset of bindings whose key is strictly larger than [k] 599 | - [found] is [Some v] if [k] is bound to [v] in the map [t], [None] otherwise. 600 | {b O(logN)} 601 | *) 602 | 603 | val pop_min : 'a t -> 'a key * 'a * 'a t 604 | (** [pop_min t] returns the smallest binding and the other bindings 605 | of the map [t], or raises [Not_found] if the map [t] is empty. 606 | {b O(logN)} *) 607 | 608 | val pop_min_opt : 'a t -> ('a key * 'a * 'a t) option 609 | (** [pop_min_opt t] returns the smallest binding and the other bindings 610 | of the map [t], or [None] if the map [t] is empty. {b O(logN)} *) 611 | 612 | val pop_max : 'a t -> 'a key * 'a * 'a t 613 | (** [pop_max t] returns the largest binding and the other bindings 614 | of the map [t], or raises [Not_found] if the map [t] is empty. 615 | {b O(logN)} *) 616 | 617 | val pop_max_opt : 'a t -> ('a key * 'a * 'a t) option 618 | (** [pop_max_opt t] returns the largest binding and the other bindings 619 | of the map [t], or [None] if the map [t] is empty. *) 620 | 621 | (** @inline *) 622 | include QUERY_MAP with type 'a key := 'a key and type 'a t := 'a t 623 | 624 | (** {1 Iterators} *) 625 | 626 | (** @inline *) 627 | include ITER_MAP with type 'a key := 'a key and type 'a t := 'a t 628 | end 629 | end 630 | 631 | module Map_poly (Ord : Ordered_poly) = struct 632 | module type S = sig 633 | (** Thread-safe mutable map structure given 634 | a totally ordered polymorphic type. *) 635 | 636 | type 'a key = 'a Ord.t 637 | (** The type of the map keys. *) 638 | 639 | type 'a t 640 | (** The type of maps from type ['a key] to ['a] values. *) 641 | 642 | val empty : unit -> 'a t 643 | (** [empty ()] returns a new empty map. {b O(1)} *) 644 | 645 | val singleton : 'a key -> 'a -> 'a t 646 | (** [singleton k v] returns a new map containing only 647 | the binding [k] for [v]. {b O(1)} *) 648 | 649 | val add : 'a key -> 'a -> 'a t -> unit 650 | (** [add k v t] adds the binding [k] for [v] into the map [t]. 651 | If the key [k] was already bound, the previously bound value 652 | is replaced by [v]. {b O(logN)} *) 653 | 654 | val remove : 'a key -> 'a t -> bool 655 | (** [remove k t] deletes any binding of the key [k] in the map [t]. 656 | Returns [true] if the binding was removed, or [false] if no binding 657 | existed for this key. {b O(logN)} *) 658 | 659 | include QUERY_MAP with type 'a key := 'a key and type 'a t := 'a t 660 | 661 | (** {1 Snapshots} 662 | 663 | Concurrent modifications of a map are linearizable. The snapshot/copy 664 | provides a coherent view of the bindings of a map along this linearized 665 | timeline. 666 | 667 | Further updates to the original map (or its copies) will trigger a minimal 668 | copy-on-write of the internal substructures of the map. This doesn't impact 669 | the time complexity of any operations, but induces a corresponding memory 670 | complexity for copying the modified subparts once. 671 | *) 672 | 673 | val copy : 'a t -> 'a t 674 | (** [copy t] returns an independently mutable copy of the map [t]. Further 675 | modifications of the map [t] will not affect its copies (and vice-versa.) 676 | {b O(1)} *) 677 | 678 | module View : View_map_poly(Ord).S 679 | (** A read-only, non-mutable view of a map (with more operations, 680 | similar to the purely functional {! Stdlib.Map} interface) *) 681 | 682 | val snapshot : 'a t -> 'a View.t 683 | (** [snapshot t] returns a read-only view of the bindings of the map [t]. 684 | {b O(1)} *) 685 | 686 | val to_view : 'a t -> 'a View.t 687 | (** Same as [snapshot]. *) 688 | 689 | val of_view : 'a View.t -> 'a t 690 | (** [of_view v] returns a new mutable map containing all the bindings 691 | of the view [v]. {b O(1)} *) 692 | 693 | (** {1 Iterators} 694 | 695 | The following functions all proceed on a coherent {! snapshot} of their 696 | map [t] argument, created at the start of their execution. Concurrent 697 | modifications of the original map [t] during their execution will not 698 | be observed by these traversals. 699 | *) 700 | 701 | (** @inline *) 702 | include ITER_MAP with type 'a key := 'a key and type 'a t := 'a t 703 | end 704 | end 705 | 706 | module Map (Ord : Ordered) = struct 707 | module Ord_poly = struct 708 | type _ t = Ord.t 709 | 710 | let compare = Ord.compare 711 | end 712 | 713 | module type Sig = Map_poly(Ord_poly).S 714 | 715 | module type S = sig 716 | (** Thread-safe mutable map structure given a totally ordered type. *) 717 | 718 | type key = Ord.t 719 | (** The type of the map keys. *) 720 | 721 | type 'a t 722 | (** The type of maps from type [key] to values of type ['a]. *) 723 | 724 | module View : sig 725 | (** A read-only, non-mutable view of a map (with more operations, 726 | compatible with the purely functional {! Stdlib.Map} interface) *) 727 | 728 | type key = Ord.t 729 | (** The type of the map keys. *) 730 | 731 | type 'a t 732 | (** The type of maps from type [key] to values of type ['a]. *) 733 | 734 | include 735 | View_map_poly(Ord_poly).S with type _ key := key and type 'a t := 'a t 736 | end 737 | 738 | include 739 | Sig 740 | with type _ key := key 741 | and type 'a t := 'a t 742 | and type _ View.key := View.key 743 | and module View := View 744 | end 745 | end 746 | --------------------------------------------------------------------------------