├── .merlin ├── META.in ├── u_time.ml ├── u_test_main.ml ├── README.md ├── .gitignore ├── u_actionid.ml ├── u_controlid.ml ├── u_id.ml ├── u_perf.ml ├── u_random.ml ├── u_info.ml ├── .ocamlinit ├── u_tests.ml ├── u_lazy.ml ├── u_cycle.mli ├── u_float.ml ├── u_normal.ml ├── u_action.ml ├── .ocp-indent ├── u_recent_acts.ml ├── u_test.ml ├── u_loop.ml ├── u_system.ml ├── u_log.mli ├── u_log.ml ├── u_set.ml ├── Makefile ├── u_set.mli ├── u_obs.mli ├── u_cycle.ml ├── u_permanent_id.ml ├── u_recent.ml ├── u_stat.ml ├── u_control.ml ├── u_obs.ml ├── u_exp.ml ├── u_learn.ml └── u_eval.ml /.merlin: -------------------------------------------------------------------------------- 1 | PKG unix moving-percentile 2 | S . 3 | B . 4 | -------------------------------------------------------------------------------- /META.in: -------------------------------------------------------------------------------- 1 | archive(byte) = "unitron.cma" 2 | archive(native) = "unitron.cmxa" 3 | -------------------------------------------------------------------------------- /u_time.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Discrete time, whose origin is normally 0. 3 | *) 4 | type t = int 5 | -------------------------------------------------------------------------------- /u_test_main.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Tests + demo program 3 | *) 4 | 5 | let main () = 6 | if not (U_tests.run ()) then 7 | exit 1 8 | 9 | let () = main () 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # unitron 2 | 3 | Somewhat like a perceptron using unary input instead of binary. 4 | 5 | # Installation 6 | 7 | See https://github.com/mjambon/mjambon-opam-repo 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | META 3 | 4 | *.annot 5 | *.cmt 6 | *.cmti 7 | *.cmo 8 | *.cma 9 | *.cmi 10 | *.a 11 | *.o 12 | *.cmx 13 | *.cmxs 14 | *.cmxa 15 | 16 | u_test 17 | *.log 18 | 19 | out 20 | -------------------------------------------------------------------------------- /u_actionid.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Identifiers of actions, specified as strings and translated into ints. 3 | See u_permanent_id.ml 4 | *) 5 | 6 | include U_permanent_id.Make (struct 7 | let name = "action ID" 8 | end) 9 | -------------------------------------------------------------------------------- /u_controlid.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Identifiers of controls, specified as strings and translated into ints. 3 | See u_permanent_id.ml 4 | *) 5 | 6 | include U_permanent_id.Make (struct 7 | let name = "control ID" 8 | end) 9 | -------------------------------------------------------------------------------- /u_id.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Generic ID generator suitable for permanent resource identifiers. 3 | 4 | It internally uses ints for efficiency, but maintains a table 5 | for translating those ints back to their original strings. 6 | *) 7 | 8 | let tbl = Hashtbl.create 1000 9 | 10 | -------------------------------------------------------------------------------- /u_perf.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Real time measurements. 3 | *) 4 | 5 | let time f = 6 | let t1 = Unix.gettimeofday () in 7 | let result = f () in 8 | let t2 = Unix.gettimeofday () in 9 | let dt = t2 -. t1 in 10 | result, dt 11 | 12 | let print_time name f = 13 | let result, dt = time f in 14 | U_log.logf "time %s: %.6f" name dt; 15 | result 16 | -------------------------------------------------------------------------------- /u_random.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Utilities for random number generation. 3 | *) 4 | 5 | (* 6 | Return true with probability `proba`. 7 | *) 8 | let pick proba = 9 | Random.float 1. < proba 10 | 11 | (* 12 | Pick a value following a normal distribution. 13 | *) 14 | let normal ?(mean = 0.) ?(stdev = 1.) () = 15 | let x = U_normal.pick () in 16 | mean +. stdev *. x 17 | -------------------------------------------------------------------------------- /u_info.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Informational metrics, used to produced stats. 3 | *) 4 | type t = { 5 | goal: float; 6 | pos_contrib: float; 7 | neg_contrib: float; 8 | pos_contrib_count: int; 9 | neg_contrib_count: int; 10 | } 11 | 12 | let dummy = { 13 | goal = nan; 14 | pos_contrib = 0.; 15 | neg_contrib = 0.; 16 | pos_contrib_count = 0; 17 | neg_contrib_count = 0; 18 | } 19 | -------------------------------------------------------------------------------- /.ocamlinit: -------------------------------------------------------------------------------- 1 | (* -*- tuareg -*- 2 | 3 | Launch utop from the same directory for these commands to be executed 4 | on startup. 5 | 6 | *) 7 | 8 | #use "topfind";; 9 | #require "unix";; 10 | #require "moving-percentile";; 11 | #require "unitron";; 12 | 13 | Printexc.record_backtrace true;; 14 | Sys.set_signal Sys.sigpipe Sys.Signal_ignore;; 15 | open Printf;; 16 | printf "Loaded .ocamlinit for unitron.\n%!";; 17 | -------------------------------------------------------------------------------- /u_tests.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Test suite. 3 | *) 4 | 5 | let tests = U_test.flatten [ 6 | (* Unit tests of individual modules and subsystems *) 7 | "U_stat", U_stat.tests; 8 | "U_permanent_id", U_permanent_id.tests; 9 | "U_loop", U_loop.tests; 10 | "U_recent", U_recent.tests; 11 | 12 | (* Evaluation of the behavior of system for several scenarios *) 13 | "U_eval", U_eval.tests; 14 | ] 15 | 16 | let run () = 17 | U_test.run_tests tests 18 | -------------------------------------------------------------------------------- /u_lazy.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Produce a lazy update function for a given time step. 3 | It allows calling it multiple times without worrying 4 | about unnecessary recomputations. 5 | *) 6 | let get (refresh : U_time.t -> 'a) : U_time.t -> 'a = 7 | let last_updated = ref max_int in 8 | let cached_result = ref None in 9 | fun t -> 10 | if t <> !last_updated then ( 11 | let result = refresh t in 12 | cached_result := Some result; 13 | last_updated := t; 14 | result 15 | ) 16 | else 17 | match !cached_result with 18 | | None -> assert false 19 | | Some x -> x 20 | -------------------------------------------------------------------------------- /u_cycle.mli: -------------------------------------------------------------------------------- 1 | (* 2 | Run a system forever or for an number of steps, 3 | with some useful logging. 4 | 5 | max_iter: number of iterations; the default is to loop forever. 6 | before_step, after_step: functions to call at the beginning and at the 7 | end of each iteration. 8 | if `after_step` returns `false`, iterations 9 | will stop. 10 | *) 11 | val loop : 12 | ?inner_log_mode: U_log.mode -> 13 | ?max_iter: U_time.t -> 14 | ?before_step: (U_time.t -> unit) -> 15 | ?after_step: (U_time.t -> bool) -> 16 | U_system.t -> unit 17 | -------------------------------------------------------------------------------- /u_float.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Utilities operating on floats 3 | *) 4 | 5 | let is_finite x = 6 | match classify_float x with 7 | | FP_infinite 8 | | FP_nan -> false 9 | | FP_normal 10 | | FP_subnormal 11 | | FP_zero -> true 12 | 13 | let minf l f = List.fold_left (fun acc x -> min acc (f x)) 0. l 14 | let min l = List.fold_left min infinity l 15 | 16 | let maxf l f = List.fold_left (fun acc x -> max acc (f x)) 0. l 17 | let max l = List.fold_left max neg_infinity l 18 | 19 | let sumf l f = List.fold_left (fun acc x -> acc +. f x) 0. l 20 | let sum l = List.fold_left (+.) 0. l 21 | 22 | let default ~if_nan x = 23 | if (x <> x) then 24 | if_nan 25 | else 26 | x 27 | -------------------------------------------------------------------------------- /u_normal.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Box–Muller transform 3 | 4 | See https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform 5 | *) 6 | 7 | (* 8 | Return a pair of independent random numbers following 9 | the standard normal distribution (mean = 0, stdev = 1). 10 | *) 11 | let rec pick_pair () = 12 | let u = Random.float 2. -. 1. in 13 | let v = Random.float 2. -. 1. in 14 | let s = u *. u +. v *. v in 15 | if s = 0. || s >= 1. then 16 | pick_pair () 17 | else 18 | let z0 = u *. sqrt (-. 2. *. log s /. s) in 19 | let z1 = v *. sqrt (-. 2. *. log s /. s) in 20 | z0, z1 21 | 22 | (* 23 | Return a single sample from the standard normal distribution. 24 | *) 25 | let pick () = 26 | fst (pick_pair ()) 27 | -------------------------------------------------------------------------------- /u_action.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Library for managing actions. 3 | An action is a named function that has arbitrary effects on the world. 4 | Actions are triggered by controls. Several controls may trigger the 5 | same action at the same time. Each action runs at most once at a given 6 | time step. 7 | *) 8 | 9 | type t = { 10 | id: U_actionid.t; 11 | func: unit -> unit; 12 | } 13 | 14 | let create id func = 15 | { id; func } 16 | 17 | let create_set () = 18 | U_set.create (fun x -> x.id) 19 | 20 | let add actionid func set = 21 | U_set.add set (create actionid func) 22 | 23 | let get set id = 24 | match U_set.get set id with 25 | | None -> failwith ("Invalid action ID " ^ U_actionid.to_string id) 26 | | Some x -> x 27 | -------------------------------------------------------------------------------- /.ocp-indent: -------------------------------------------------------------------------------- 1 | # See https://github.com/OCamlPro/ocp-indent/blob/master/.ocp-indent for more 2 | 3 | # Indent for clauses inside a pattern-match (after the arrow): 4 | # match foo with 5 | # | _ -> 6 | # ^^^^bar 7 | # the default is 2, which aligns the pattern and the expression 8 | match_clause = 4 9 | 10 | # When nesting expressions on the same line, their indentation are in 11 | # some cases stacked, so that it remains correct if you close them one 12 | # at a line. This may lead to large indents in complex code though, so 13 | # this parameter can be used to set a maximum value. Note that it only 14 | # affects indentation after function arrows and opening parens at end 15 | # of line. 16 | # 17 | # for example (left: `none`; right: `4`) 18 | # let f = g (h (i (fun x -> # let f = g (h (i (fun x -> 19 | # x) # x) 20 | # ) # ) 21 | # ) # ) 22 | max_indent = 2 23 | -------------------------------------------------------------------------------- /u_recent_acts.ml: -------------------------------------------------------------------------------- 1 | (* 2 | An act is an instance of an action, i.e. a pair (time, action). 3 | We keep track of recent acts for reinforcement purposes. 4 | *) 5 | 6 | type t = { 7 | recent: U_controlid.t U_set.set U_recent.t; 8 | } 9 | 10 | let create window_length = 11 | let recent = U_recent.init window_length (fun age -> U_set.create_set ()) in 12 | { recent } 13 | 14 | let step x = 15 | U_recent.step x.recent (fun old -> U_set.clear old; old) 16 | 17 | let get_latest x = 18 | U_recent.get_latest x.recent 19 | 20 | let add x controlid = 21 | let latest = get_latest x in 22 | U_set.add latest controlid 23 | 24 | (* 25 | Fold over all the pairs (age, control ID) 26 | *) 27 | let fold x acc0 f = 28 | U_recent.fold x.recent acc0 (fun age set acc -> 29 | U_set.fold set acc (fun controlid acc -> 30 | f age controlid acc 31 | ) 32 | ) 33 | 34 | let iter x f = 35 | fold x () (fun age controlid () -> f age controlid) 36 | 37 | let to_list x = 38 | fold x [] (fun age controlid acc -> (age, controlid) :: acc) 39 | -------------------------------------------------------------------------------- /u_test.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Library for running tests. 3 | *) 4 | 5 | open Printf 6 | 7 | let flatten ll = 8 | List.flatten ( 9 | List.map (fun (prefix, l) -> 10 | List.map (fun (name, f) -> 11 | prefix ^ " > " ^ name, f 12 | ) l 13 | ) ll 14 | ) 15 | 16 | let run_test (name, f) = 17 | eprintf "test %s: START\n%!" name; 18 | let success = 19 | try f () 20 | with e -> 21 | eprintf "Exception %s\n%!" (U_log.string_of_exn e); 22 | false 23 | in 24 | eprintf "test %s: %s\n%!" name (if success then "OK" else "ERROR"); 25 | name, success 26 | 27 | let print_result (name, success) = 28 | eprintf "%-30s %s\n" name (if success then "OK" else "ERROR") 29 | 30 | let print_summary passed total = 31 | eprintf "Tests passed: %i/%i\n%!" 32 | passed total 33 | 34 | let run_tests tests = 35 | let results = List.map run_test tests in 36 | List.iter print_result results; 37 | let passed = List.length (List.filter snd results) in 38 | let total = List.length results in 39 | print_summary passed total; 40 | passed = total 41 | -------------------------------------------------------------------------------- /u_loop.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Generic main loop. 3 | *) 4 | 5 | let should_start opt_max_iter t = 6 | match opt_max_iter with 7 | | None -> true 8 | | Some n -> t < n 9 | 10 | let run ?(inner_log_mode = `Skip) ?max_iter f = 11 | let iter_count = ref 0 in 12 | let rec loop continue t = 13 | U_log.set_time t; 14 | let ok = continue && should_start max_iter t in 15 | if ok then ( 16 | let continue = f t in 17 | U_log.flush (); 18 | incr iter_count; 19 | loop continue (t + 1) 20 | ) 21 | in 22 | let (), dt = 23 | U_perf.time (fun () -> 24 | U_log.set_mode inner_log_mode; 25 | loop true 0; 26 | U_log.clear_time (); 27 | U_log.set_mode `Full 28 | ) 29 | in 30 | let step_duration = dt /. float !iter_count in 31 | U_log.logf "total time: %.6f s" dt; 32 | U_log.logf "step duration: %.2g ms, %.2g KHz" 33 | (1e3 *. step_duration) (1. /. (1e3 *. step_duration)); 34 | U_log.flush () 35 | 36 | let test () = 37 | let acc = ref [] in 38 | let f t = 39 | acc := t :: !acc; 40 | true 41 | in 42 | run ~max_iter:3 f; 43 | assert (!acc = List.rev [0; 1; 2]); 44 | true 45 | 46 | let tests = [ 47 | "main", test; 48 | ] 49 | -------------------------------------------------------------------------------- /u_system.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Initialization of a complete system. 3 | 4 | At each cycle t, the system is fed some input consisting in a set of 5 | control IDs. Each action activated by at least one control is executed. 6 | *) 7 | 8 | type time = int 9 | 10 | type t = { 11 | (* Parameters *) 12 | window_length: int; 13 | (* Reinforcement time window *) 14 | 15 | goal_function : time -> float; 16 | 17 | read_active_controls : time -> (U_controlid.t -> unit) -> unit; 18 | (* `read_active_controls t add` is in charge of registering 19 | active controls using the provided `add` function. *) 20 | 21 | get_control : U_controlid.t -> U_control.t; 22 | get_action : U_actionid.t -> U_action.t; 23 | recent_acts : U_recent_acts.t; 24 | 25 | observables : U_obs.state; 26 | } 27 | 28 | let create 29 | ~window_length 30 | ~goal_function 31 | ~read_active_controls 32 | ~get_control 33 | ~get_action = 34 | let recent_acts = U_recent_acts.create window_length in 35 | let observables = U_obs.create () in 36 | { 37 | window_length; 38 | goal_function; 39 | read_active_controls; 40 | get_control; 41 | get_action; 42 | recent_acts; 43 | observables; 44 | } 45 | -------------------------------------------------------------------------------- /u_log.mli: -------------------------------------------------------------------------------- 1 | (* 2 | Logging. 3 | *) 4 | 5 | val debug : bool 6 | (* to be used as: 7 | if debug then 8 | logf "..." ...; 9 | *) 10 | 11 | val string_of_exn : exn -> string 12 | (* convert an exception into a readable multiline string including 13 | a backtrace. *) 14 | 15 | val log : string -> unit 16 | (* logging function, writes to stderr. 17 | Warning: it does not automatically flush the buffered output. *) 18 | 19 | val logf : ('a, unit, string, unit) format4 -> 'a 20 | (* printf-like logging function, writes to stderr. 21 | Warning: it does not automatically flush the buffered output. *) 22 | 23 | val clear_time : unit -> unit 24 | val set_time : int -> unit 25 | (* set the time to be displayed by each call to `log` or `logf`, 26 | meant to be an iteration number rather than real time. *) 27 | 28 | type mode = [ `Full | `Skip | `Off ] 29 | 30 | val set_mode : mode -> unit 31 | (* set the logging mode or "level". 32 | `Skip` will result in `logf` printing only if the current time 33 | has a single leading nonzero digit followed by zeroes. 34 | The goal of the skip mode is to produce less and less output 35 | as the number of steps increases. *) 36 | 37 | val flush : unit -> unit 38 | (* flush buffered log output. *) 39 | -------------------------------------------------------------------------------- /u_log.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Logging 3 | *) 4 | 5 | open Printf 6 | 7 | let debug = false 8 | 9 | let string_of_exn e = 10 | sprintf "%s\n%s" 11 | (Printexc.to_string e) 12 | (Printexc.get_backtrace ()) 13 | 14 | let () = Printexc.record_backtrace true 15 | 16 | let out_channel = stderr 17 | 18 | let flush () = 19 | Pervasives.flush out_channel 20 | 21 | let time = ref None 22 | 23 | let set_time t = 24 | time := Some t 25 | 26 | let clear_time () = 27 | time := None 28 | 29 | let log s = 30 | match !time with 31 | | Some t -> 32 | fprintf out_channel "[%i] %s\n" 33 | t s 34 | | None -> 35 | fprintf out_channel "[] %s\n" 36 | s 37 | 38 | let rec has_only_one_nonzero_digit n = 39 | if n < 0 then 40 | invalid_arg "has_only_one_nonzero_digit: negative value" 41 | else if n < 10 then 42 | true 43 | else if n mod 10 = 0 then 44 | has_only_one_nonzero_digit (n / 10) 45 | else 46 | false 47 | 48 | type mode = [ `Full | `Skip | `Off ] 49 | 50 | let mode = ref (`Full : mode) 51 | 52 | let set_mode x = 53 | mode := x 54 | 55 | let should_skip () = 56 | match !time with 57 | | None -> 58 | true 59 | | Some t -> 60 | not (has_only_one_nonzero_digit t) 61 | 62 | let logf msgf = 63 | let print = 64 | match !mode with 65 | | `Skip when should_skip () -> 66 | (fun s -> ()) 67 | | `Full | `Skip -> 68 | log 69 | | `Off -> 70 | (fun s -> ()) 71 | in 72 | Printf.kprintf print msgf 73 | -------------------------------------------------------------------------------- /u_set.ml: -------------------------------------------------------------------------------- 1 | (* 2 | A mutable collection of unique elements identified by a key. 3 | *) 4 | 5 | type ('k, 'v) t = { 6 | get_key: 'v -> 'k; 7 | (* Obtain an object's key. The key must be comparable and hashable, 8 | see the documentation for Hashtbl. *) 9 | 10 | tbl: ('k, 'v) Hashtbl.t; 11 | } 12 | 13 | type 'a set = ('a, 'a) t 14 | 15 | let create get_key = 16 | let tbl = Hashtbl.create 100 in 17 | { get_key; tbl } 18 | 19 | let create_set () = 20 | create (fun x -> x) 21 | 22 | let get x k = 23 | try Some (Hashtbl.find x.tbl k) 24 | with Not_found -> None 25 | 26 | let add x v = 27 | let k = x.get_key v in 28 | Hashtbl.replace x.tbl k v 29 | 30 | let remove x v = 31 | Hashtbl.remove x.tbl (x.get_key v) 32 | 33 | let remove_key x k = 34 | Hashtbl.remove x.tbl k 35 | 36 | (* 37 | Get and remove. 38 | *) 39 | let pop x k = 40 | let opt = get x k in 41 | if opt <> None then 42 | remove_key x k; 43 | opt 44 | 45 | let to_list x = 46 | Hashtbl.fold (fun k v acc -> v :: acc) x.tbl [] 47 | 48 | let of_list l = 49 | let x = create (fun v -> v) in 50 | List.iter (add x) l; 51 | x 52 | 53 | let of_list_full l get_key = 54 | let x = create get_key in 55 | List.iter (add x) l; 56 | x 57 | 58 | let iter x f = 59 | Hashtbl.iter (fun k v -> f v) x.tbl 60 | 61 | let sort x = 62 | let l = to_list x in 63 | let get_key = x.get_key in 64 | List.sort (fun a b -> compare (get_key a) (get_key b)) l 65 | 66 | let iter_ordered x f = 67 | List.iter f (sort x) 68 | 69 | let fold x acc f = 70 | Hashtbl.fold (fun k v acc -> f v acc) x.tbl acc 71 | 72 | let clear x = 73 | Hashtbl.clear x.tbl 74 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: default build test utop install uninstall reinstall clean 2 | 3 | default: build 4 | 5 | # Launch utop with a suitable value for OCAMLPATH 6 | utop: 7 | OCAMLPATH=$$(dirname `pwd`):$$OCAMLPATH utop 8 | 9 | PACKAGES = unix moving-percentile 10 | 11 | LIBSOURCES = \ 12 | u_time.ml \ 13 | u_log.mli u_log.ml \ 14 | u_test.ml \ 15 | u_perf.ml \ 16 | u_set.mli u_set.ml \ 17 | u_float.ml \ 18 | u_normal.ml \ 19 | u_stat.ml \ 20 | u_random.ml \ 21 | u_lazy.ml \ 22 | \ 23 | u_permanent_id.ml \ 24 | u_controlid.ml \ 25 | u_actionid.ml \ 26 | u_control.ml \ 27 | u_action.ml \ 28 | u_recent.ml \ 29 | u_recent_acts.ml \ 30 | u_loop.ml \ 31 | u_info.ml \ 32 | u_obs.mli u_obs.ml \ 33 | u_system.ml \ 34 | u_learn.ml \ 35 | u_cycle.mli u_cycle.ml \ 36 | \ 37 | u_exp.ml \ 38 | u_eval.ml \ 39 | \ 40 | u_tests.ml 41 | 42 | build: META 43 | ocamlfind ocamlc -a -o unitron.cma -bin-annot -package "$(PACKAGES)" \ 44 | $(LIBSOURCES) 45 | ocamlfind ocamlopt -a -o unitron.cmxa -bin-annot -package "$(PACKAGES)" \ 46 | $(LIBSOURCES) 47 | ocamlfind ocamlopt -o u_test -bin-annot -linkpkg -package "$(PACKAGES)" \ 48 | unitron.cmxa u_test_main.ml 49 | 50 | test: build 51 | time -p \ 52 | ./u_test 2>&1 | tee test.log | stdbuf -o 0 grep '^>' | tee summary.log 53 | 54 | META: META.in 55 | echo 'requires = "$(PACKAGES)"' > META 56 | cat META.in >> META 57 | 58 | install: META 59 | ocamlfind install unitron META \ 60 | `ls *.cm[ioxa] *.cmx[as] *.o *.a *.mli | grep -F -v '_main.'` 61 | 62 | uninstall: 63 | ocamlfind remove unitron 64 | 65 | reinstall: 66 | $(MAKE) uninstall; $(MAKE) install 67 | 68 | clean: 69 | rm -f *~ *.cm[ioxat] *.cmti *.cmx[as] *.o *.a *.annot META \ 70 | u_test test.log summary.log 71 | -------------------------------------------------------------------------------- /u_set.mli: -------------------------------------------------------------------------------- 1 | (* 2 | A mutable collection of unique elements identified by a key. 3 | *) 4 | 5 | (* 6 | A map from keys to values. 7 | Each key is associated with a unique value. 8 | *) 9 | type ('k, 'v) t = { 10 | get_key: 'v -> 'k; 11 | (* Obtain an object's key. The key must be comparable and hashable, 12 | see the documentation for Hashtbl. *) 13 | 14 | tbl: ('k, 'v) Hashtbl.t; 15 | } 16 | 17 | (* 18 | A map in which key and value are equal. 19 | *) 20 | type 'a set = ('a, 'a) t 21 | 22 | val create : ('v -> 'k) -> ('k, 'v) t 23 | (* Create an empty map. *) 24 | 25 | val create_set : unit -> 'a set 26 | (* Create an empty set. *) 27 | 28 | val get : ('k, 'v) t -> 'k -> 'v option 29 | (* Get an element from its key. *) 30 | 31 | val add : ('k, 'v) t -> 'v -> unit 32 | (* Add an element. *) 33 | 34 | val remove : ('k, 'v) t -> 'v -> unit 35 | (* Remove an element. *) 36 | 37 | val remove_key : ('k, 'v) t -> 'k -> unit 38 | (* Remove an element from its key. *) 39 | 40 | val pop : ('k, 'v) t -> 'k -> 'v option 41 | (* Get an element, then remove it. *) 42 | 43 | val to_list : ('k, 'v) t -> 'v list 44 | (* List the elements. *) 45 | 46 | val sort : ('k, 'v) t -> 'v list 47 | (* List and sort the elements based on their keys. *) 48 | 49 | val of_list : 'a list -> 'a set 50 | (* Create a set from a list of elements. Elements must be suitable keys. *) 51 | 52 | val of_list_full : 'v list -> ('v -> 'k) -> ('k, 'v) t 53 | (* Create a map from a list of elements. *) 54 | 55 | val iter : ('k, 'v) t -> ('v -> unit) -> unit 56 | (* Iterate over the elements in no particular order. *) 57 | 58 | val iter_ordered : ('k, 'v) t -> ('v -> unit) -> unit 59 | (* Iterate over the elements in the order defined by the keys, 60 | which involves a sorting step. *) 61 | 62 | val fold : ('k, 'v) t -> 'acc -> ('v -> 'acc -> 'acc) -> 'acc 63 | (* Fold over the elements in no particular order. *) 64 | 65 | val clear : ('k, 'v) t -> unit 66 | (* Remove all the elements. *) 67 | -------------------------------------------------------------------------------- /u_obs.mli: -------------------------------------------------------------------------------- 1 | (* 2 | Various values, updated at each cycle, meant to be used 3 | by as input by an IO module. 4 | *) 5 | 6 | (* 7 | instant = latest value (identity: instant signal = signal) 8 | recent = moving average over a short window 9 | average = moving average over a long window 10 | normalized = signal translated and scaled such that mean = 0 and stdev = 1, 11 | using moving average and moving variance estimated over 12 | a long window. 13 | goal = feedback obtained at each cycle, that we try to predict 14 | prediction = prediction of the goal function 15 | delta = prediction - goal 16 | contribution = one term in the prediction produced at a given time 17 | (the number of contributions changes over time) 18 | *) 19 | type t = private { 20 | normalized_goal : float; 21 | (* normalized goal *) 22 | recent_normalized_goal : float; 23 | (* recent normalized goal *) 24 | recent_delta : float; 25 | (* recent |delta| / average |delta| *) 26 | activity : float; 27 | (* activity = 28 | recent number of contributions / average number of contributions *) 29 | recent_pos_contrib : float; 30 | (* recent positive contribution / average positive contribution *) 31 | recent_neg_contrib : float; 32 | (* recent negative contribution / average negative contribution *) 33 | } 34 | 35 | type state 36 | type time = int 37 | 38 | val create : unit -> state 39 | (* Create an initial state at time = -1 with meaningless values. 40 | `get` may not be called before an `update` takes place. *) 41 | 42 | val update : state -> time -> U_info.t -> unit 43 | (* Add data for a new timestep. Timesteps must be consecutive. 44 | Skipping or repeating is not allowed. *) 45 | 46 | val get : state -> time -> t 47 | (* Get the observable data for the current timestep. 48 | Fails if the specified timestep differs. 49 | May be called multiple times for the same timestep. *) 50 | 51 | val to_string : t -> string 52 | (* Produce a string representation for logging and debugging. *) 53 | -------------------------------------------------------------------------------- /u_cycle.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Definition of one cycle (step) corresponding to one time increment. 3 | *) 4 | 5 | open U_log 6 | open U_system 7 | 8 | let register_active_controls x t = 9 | let recent_acts = x.recent_acts in 10 | U_recent_acts.step recent_acts; 11 | let add controlid = U_recent_acts.add recent_acts controlid in 12 | x.read_active_controls t add 13 | 14 | (* 15 | Run each action at most once. 16 | An action can be triggered by multiple controls but runs at most once 17 | per cycle. 18 | *) 19 | let run_actions x t = 20 | let actionids = 21 | let controlids = U_recent_acts.get_latest x.recent_acts in 22 | U_set.fold controlids (U_set.create_set ()) (fun controlid acc -> 23 | let control = x.get_control controlid in 24 | U_set.add acc U_control.(control.actionid); 25 | acc 26 | ) 27 | in 28 | U_set.iter actionids (fun actionid -> 29 | let action = x.get_action actionid in 30 | U_action.(action.func) () 31 | ) 32 | 33 | (* 34 | Run one cycle at time t. 35 | 36 | 1. Activate controls (get their list). 37 | 2. Record active controls (acts) and keep them until they're older 38 | than some max. 39 | 3. Perform the actions triggered by the controls. 40 | 4. Collect feedback from the goal function. 41 | 5. Decompose feedback as a sum of contributions from all recent acts. 42 | *) 43 | let step (x : U_system.t) t = 44 | register_active_controls x t; 45 | run_actions x t; 46 | let goal = x.goal_function t in 47 | let info = U_learn.learn x goal in 48 | U_obs.update x.observables t info 49 | 50 | let loop 51 | ?inner_log_mode 52 | ?max_iter 53 | ?(before_step = fun t -> ()) 54 | ?(after_step = fun t -> true) 55 | system = 56 | let add_duration, get_mean_duration = U_stat.create_mean_acc () in 57 | U_loop.run ?inner_log_mode ?max_iter (fun t -> 58 | before_step t; 59 | let (), step_duration = 60 | U_perf.time (fun () -> 61 | step system t 62 | ) 63 | in 64 | add_duration step_duration; 65 | after_step t 66 | ); 67 | let step_duration = get_mean_duration () in 68 | logf "effective step duration: %.2g ms, %.2g KHz" 69 | (1e3 *. step_duration) (1. /. (1e3 *. step_duration)) 70 | -------------------------------------------------------------------------------- /u_permanent_id.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Generic ID generator suitable for permanent resource identifiers. 3 | 4 | It internally uses ints for efficiency, but maintains a table 5 | for translating those ints back to their original strings. 6 | 7 | This is essentially a global string <-> int mapping that never gets 8 | garbage-collected. 9 | 10 | Note that distinct systems in the same process may use the same 11 | identifiers to refer to distinct objects while sharing this mapping 12 | between string IDs and int IDs. 13 | This works of course as a long as each system maintains its own 14 | mapping between the IDs and the objects specific to the system. 15 | *) 16 | 17 | module type Param = sig 18 | val name : string 19 | end 20 | 21 | module type Id = sig 22 | type t = private int 23 | 24 | val of_string : string -> t 25 | (* Any string is valid. It is guaranteed that 26 | a = b => of_string a = of_string b 27 | *) 28 | 29 | val to_string : t -> string 30 | (* It is guaranteed that 31 | a = b => to_string a = to_string b 32 | *) 33 | end 34 | 35 | module Make (Param : Param) : Id = struct 36 | type t = int 37 | 38 | let tbl_str_to_int = Hashtbl.create 1000 39 | let tbl_int_to_str = Hashtbl.create 1000 40 | 41 | let counter = ref 0 42 | 43 | let of_string s = 44 | try Hashtbl.find tbl_str_to_int s 45 | with Not_found -> 46 | let last = !counter in 47 | let i = last + 1 in 48 | if i = 0 then 49 | failwith ("Cannot create new " ^ Param.name) 50 | else ( 51 | counter := i; 52 | Hashtbl.add tbl_int_to_str i s; 53 | Hashtbl.add tbl_str_to_int s i; 54 | i 55 | ) 56 | 57 | let to_string i = 58 | try Hashtbl.find tbl_int_to_str i 59 | with Not_found -> assert false 60 | end 61 | 62 | module Test = struct 63 | module Id = Make (struct let name = "test ID" end) 64 | 65 | let test () = 66 | let a = Id.of_string "a" in 67 | let b = Id.of_string "b" in 68 | let a2 = Id.of_string "a" in 69 | assert (a = a2); 70 | assert (a <> b); 71 | assert (Id.to_string a = Id.to_string a2); 72 | assert (Id.to_string a == Id.to_string a2); 73 | assert (Id.to_string a = "a"); 74 | assert (Id.to_string b = "b"); 75 | true 76 | end 77 | 78 | let tests = [ 79 | "main", Test.test; 80 | ] 81 | -------------------------------------------------------------------------------- /u_recent.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Data structure for keeping a window of the last w states in a sequence. 3 | *) 4 | 5 | (* 6 | For a window of length 5, we use an array structured as follows: 7 | 8 | [0 1 2 3 4] 9 | ^ ^ 10 | | current index 11 | previous index 12 | *) 13 | type 'a t = { 14 | array: 'a array; 15 | mutable index: int; 16 | (* position of the latest element in the array *) 17 | } 18 | 19 | (* 20 | Initialize the circular array, proceeding from oldest to latest. 21 | The argument of the user function is the age of the element, 22 | like for `iter`. 23 | *) 24 | let init w f = 25 | if w < 1 then 26 | invalid_arg "U_recent.init"; 27 | let a = 28 | Array.init w (fun pos -> 29 | let age = w - pos - 1 in 30 | f age 31 | ) 32 | in 33 | { 34 | array = a; 35 | index = w - 1; 36 | } 37 | 38 | (* 39 | Return the latest element, whose age is 0. 40 | *) 41 | let get_latest x = 42 | x.array.(x.index) 43 | 44 | (* 45 | Iterate from oldest to latest. 46 | The user function takes the age of the element and the element. 47 | *) 48 | let iter x f = 49 | let a = x.array in 50 | let w = Array.length a in 51 | let oldest_index = (x.index + 1) mod w in 52 | for i = 0 to w - 1 do 53 | (* position in the array *) 54 | let pos = (oldest_index + i) mod w in 55 | 56 | (* age of the element, using age(latest) = 0 *) 57 | let age = w - 1 - i in 58 | 59 | let elt = a.(pos) in 60 | f age elt 61 | done 62 | 63 | (* 64 | Accumulate over the elements from oldest to latest. 65 | *) 66 | let fold x acc0 f = 67 | let acc = ref acc0 in 68 | iter x (fun age elt -> 69 | acc := f age elt !acc 70 | ); 71 | !acc 72 | 73 | let to_list x = 74 | List.rev (fold x [] (fun age elt acc -> (age, elt) :: acc)) 75 | 76 | (* 77 | Move in time, adding 1 to the age of all past elements. 78 | The oldest element is passed to the user function to be recycled 79 | into the new latest element. 80 | *) 81 | let step x f = 82 | let a = x.array in 83 | let w = Array.length a in 84 | let new_index = (x.index + 1) mod w in 85 | let recycled_elt = a.(new_index) in 86 | a.(new_index) <- f recycled_elt; 87 | x.index <- new_index 88 | 89 | let test () = 90 | let elements x = List.map snd (to_list x) in 91 | let w = 3 in 92 | let acc1 = ref [] in 93 | let x = 94 | init w (fun age -> 95 | acc1 := age :: !acc1; 96 | -age 97 | ) 98 | in 99 | assert (!acc1 = List.rev [2; 1; 0]); 100 | step x (fun oldest -> 101 | assert (oldest = -2); 102 | 11 103 | ); 104 | assert (get_latest x = 11); 105 | step x (fun oldest -> 106 | assert (oldest = -1); 107 | 12 108 | ); 109 | assert (get_latest x = 12); 110 | step x (fun oldest -> 111 | assert (oldest = 0); 112 | 13 113 | ); 114 | assert (get_latest x = 13); 115 | let l = elements x in 116 | List.iter (fun i -> Printf.printf "%i\n%!" i) l; 117 | assert (elements x = [11; 12; 13]); 118 | true 119 | 120 | let tests = [ 121 | "main", test; 122 | ] 123 | -------------------------------------------------------------------------------- /u_stat.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Simple stats (mean, stdev, etc.) used primarily for testing. 3 | *) 4 | 5 | let sum l = 6 | List.fold_left (fun acc x -> acc +. x) 0. l 7 | 8 | (* 9 | Create an accumulator for computing an arithmetic mean, using O(1) memory. 10 | *) 11 | let create_mean_acc () = 12 | let n = ref 0 in 13 | let sum = ref 0. in 14 | let add_sample x = 15 | incr n; 16 | sum := !sum +. x 17 | in 18 | let get_mean () = 19 | !sum /. float !n 20 | in 21 | add_sample, get_mean 22 | 23 | (* 24 | Create an accumulator for computing a variance, using O(N) memory. 25 | *) 26 | let create_stdev_acc () = 27 | let values = ref [] in 28 | let add_sample x = 29 | values := x :: !values 30 | in 31 | let get_mean () = 32 | let l = !values in 33 | sum l /. float (List.length l) 34 | in 35 | let get_mean_and_stdev () = 36 | let mean = get_mean () in 37 | let variance = 38 | let l = !values in 39 | List.fold_left (fun acc x -> acc +. (x -. mean) ** 2.) 0. l 40 | /. float (List.length l - 1) 41 | in 42 | mean, sqrt variance 43 | in 44 | add_sample, get_mean, get_mean_and_stdev 45 | 46 | let test_mean () = 47 | let add, get = create_mean_acc () in 48 | add 1.; 49 | add 3.; 50 | assert (get () = 2.); 51 | true 52 | 53 | let test_stdev () = 54 | let add, get_mean, get_mean_and_stdev = create_stdev_acc () in 55 | add 1.; 56 | add 5.; 57 | let expected_mean = 3. in 58 | let expected_stdev = sqrt 8. in 59 | assert (get_mean () = expected_mean); 60 | let mean, stdev = get_mean_and_stdev () in 61 | assert (mean = expected_mean); 62 | assert (stdev = expected_stdev); 63 | true 64 | 65 | let get_mean_and_stdev l = 66 | let add, get_mean, get_mean_and_stdev = create_stdev_acc () in 67 | List.iter add l; 68 | get_mean_and_stdev () 69 | 70 | (* 71 | Compute a percentile from a list of floats. 72 | The percentile is specified as parameter p in the range [0,1]. 73 | *) 74 | let get_percentile l = 75 | let a = Array.of_list l in 76 | let n = Array.length a in 77 | if n = 0 then 78 | invalid_arg "get_percentile: no data"; 79 | Array.sort compare a; 80 | fun p -> 81 | if not (p >= 0. && p <= 1.) then 82 | invalid_arg "get_percentile: p must be within [0,1]"; 83 | let index = float (n - 1) *. p in 84 | let low_index = truncate index in 85 | let high_index = min (n - 1) (low_index + 1) in 86 | assert (low_index >= 0); 87 | let high_weight, _ = modf index in 88 | let low_weight = 1. -. high_weight in 89 | low_weight *. a.(low_index) +. high_weight *. a.(high_index) 90 | 91 | let test_percentiles () = 92 | let ( =~ ) a b = abs_float (a -. b) < 1e-6 in 93 | 94 | assert (get_percentile [0.] 0. =~ 0.); 95 | assert (get_percentile [1.23] 0. =~ 1.23); 96 | assert (get_percentile [1.23] 0.5 =~ 1.23); 97 | assert (get_percentile [1.23] 1. =~ 1.23); 98 | 99 | assert (get_percentile [0.; 1.] 0.6 =~ 0.6); 100 | assert (get_percentile [0.; 1.] 0. =~ 0.); 101 | assert (get_percentile [0.; 1.] 1. =~ 1.); 102 | 103 | assert (get_percentile [0.; 1.; 10.] 0.5 =~ 1.); 104 | assert (get_percentile [0.; 1.; 10.] 0.25 =~ 0.5); 105 | assert (get_percentile [0.; 1.; 10.] 0.75 =~ 5.5); 106 | 107 | assert (get_percentile [-1.; -0.5; 0.; 10.] 0. =~ -1.); 108 | assert (get_percentile [-1.; -0.5; 0.; 10.] 0.5 =~ -0.25); 109 | 110 | true 111 | 112 | 113 | let tests = [ 114 | "mean", test_mean; 115 | "stdev", test_stdev; 116 | "percentiles", test_percentiles; 117 | ] 118 | -------------------------------------------------------------------------------- /u_control.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Library for managing controls. 3 | A control is a named virtual button, permanently linked to one 4 | action. Several controls may be linked to the same action. 5 | When a control is activated, the linked action will run during the 6 | same time step. If several controls trigger the same action during the 7 | same time step, the action will run just once. 8 | *) 9 | 10 | open Printf 11 | 12 | type contribution = { 13 | mutable last: float; 14 | (* last contribution, as corrected for a perfect a result *) 15 | 16 | variance: Mv_var.state; 17 | (* object that includes exponential moving average 18 | and exponential moving variance *) 19 | } 20 | 21 | type t = { 22 | id: U_controlid.t; 23 | actionid: U_actionid.t; 24 | contributions: contribution array; 25 | (* one contribution per age, starting from age 0 *) 26 | } 27 | 28 | let create_contribution () = 29 | let alpha = 0.1 in 30 | { 31 | last = nan; 32 | variance = Mv_var.init ~alpha_avg:alpha ~alpha_var:alpha () 33 | } 34 | 35 | let create ~window_length ~id ~actionid = 36 | let contributions = 37 | Array.init window_length 38 | (fun i -> create_contribution ()) 39 | in 40 | { 41 | id; 42 | actionid; 43 | contributions; 44 | } 45 | 46 | let create_set () = 47 | U_set.create (fun x -> x.id) 48 | 49 | let add ~window_length ~id ~actionid set = 50 | let control = create ~window_length ~id ~actionid in 51 | U_set.add set control 52 | 53 | let get set id = 54 | match U_set.get set id with 55 | | None -> failwith ("Invalid control ID " ^ U_controlid.to_string id) 56 | | Some x -> x 57 | 58 | (* 59 | Get the weight used to correct the contribution based on the difference 60 | between observed and predicted value of the goal function. 61 | 62 | This weight is the recent standard deviation, as estimated by 63 | an exponential moving average. The wider a contribution fluctuates, 64 | the greater correction it will receive, 65 | relative to the other contributions. 66 | *) 67 | let get_weight (x : contribution) = 68 | let mv = x.variance in 69 | let n = Mv_var.get_count mv in 70 | if n <= 1 then 71 | (* Assign an infinite weight, which is usable, unlike a NaN. *) 72 | infinity 73 | else 74 | (* Initially, the estimate of the standard deviation is very coarse. 75 | Not sure if or how it should be tweaked for better results. *) 76 | Mv_var.get_stdev (x.variance) 77 | 78 | let update_contrib (x : contribution) v = 79 | assert (v = v); 80 | Mv_var.update x.variance v; 81 | x.last <- v 82 | 83 | let get_contribution x age = 84 | x.contributions.(age) 85 | 86 | let get_average contrib = 87 | Mv_var.get_average contrib.variance 88 | 89 | let get_contrib_value contrib = 90 | U_float.default ~if_nan:0. contrib.last 91 | 92 | let get_stdev contrib = 93 | Mv_var.get_stdev contrib.variance 94 | 95 | let get_contribution_average x age = 96 | get_average (get_contribution x age) 97 | 98 | let iter_contributions x f = 99 | Array.iteri (fun age x -> 100 | let v = x.variance in 101 | let contrib_value = get_contrib_value x in 102 | let average = Mv_var.get_average v in 103 | let stdev = Mv_var.get_stdev v in 104 | f ~age ~contrib:contrib_value ~average ~stdev 105 | ) x.contributions 106 | 107 | let map_contributions x f = 108 | Array.mapi (fun age x -> 109 | let v = x.variance in 110 | let contrib_value = get_contrib_value x in 111 | let average = Mv_var.get_average v in 112 | let stdev = Mv_var.get_stdev v in 113 | f ~age ~contrib_value ~average ~stdev 114 | ) x.contributions 115 | 116 | let info_of_contributions a = 117 | let strings = 118 | Array.mapi (fun age x -> 119 | let v = x.variance in 120 | let average = Mv_var.get_average v in 121 | let stdev = Mv_var.get_stdev v in 122 | sprintf "%i:(%.2g, %.2g)" 123 | age 124 | average stdev 125 | ) a 126 | in 127 | String.concat " " (Array.to_list strings) 128 | 129 | let to_info x = 130 | sprintf "control %s: [%s]" 131 | (U_controlid.to_string x.id) 132 | (info_of_contributions x.contributions) 133 | 134 | let open_csv window_length fname = 135 | let header = 136 | let a = 137 | Array.to_list ( 138 | Array.init window_length 139 | (fun age -> 140 | sprintf "contrib[%i]" age 141 | ) 142 | ) 143 | in 144 | let b = 145 | Array.to_list ( 146 | Array.init window_length 147 | (fun age -> 148 | sprintf "orig_weight[%i]" age 149 | ) 150 | ) 151 | in 152 | String.concat "," (a @ b) 153 | in 154 | let oc = open_out fname in 155 | fprintf oc "%s\n" header; 156 | oc 157 | 158 | let print_csv oc x = 159 | let a = 160 | Array.to_list ( 161 | Array.mapi (fun age x -> 162 | sprintf "%g" 163 | (get_contrib_value x) 164 | ) x.contributions 165 | ) 166 | in 167 | let b = 168 | Array.to_list ( 169 | Array.mapi (fun age x -> 170 | sprintf "%g" 171 | (get_weight x) 172 | ) x.contributions 173 | ) 174 | in 175 | fprintf oc "%s\n" 176 | (String.concat "," (a @ b)) 177 | -------------------------------------------------------------------------------- /u_obs.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Various values, updated at each cycle, meant to be used 3 | by as input by an IO module. 4 | See mli. 5 | *) 6 | 7 | open Printf 8 | 9 | type t = { 10 | normalized_goal : float; 11 | (* normalized goal *) 12 | recent_normalized_goal : float; 13 | (* recent normalized goal *) 14 | recent_delta : float; 15 | (* recent |delta| / average |delta| *) 16 | activity : float; 17 | (* activity = 18 | recent number of contributions / average number of contributions *) 19 | recent_pos_contrib : float; 20 | (* recent positive contribution / average positive contribution *) 21 | recent_neg_contrib : float; 22 | (* recent negative contribution / average negative contribution *) 23 | } 24 | 25 | type time = int 26 | 27 | type state = { 28 | mutable value : t option; 29 | mutable last_updated : time; 30 | update_value : time -> U_info.t -> t; 31 | } 32 | 33 | let get state t = 34 | if t < 0 then 35 | invalid_arg "U_obs.get: negative time"; 36 | if t <> state.last_updated then 37 | invalid_arg "U_obs.get: state is not up to date"; 38 | match state.value with 39 | | None -> assert false 40 | | Some x -> x 41 | 42 | let update state t info = 43 | if t <> state.last_updated + 1 then 44 | invalid_arg ( 45 | sprintf "U_obs.update: last updated at %i, current time is %i" 46 | state.last_updated t 47 | ); 48 | assert (U_float.is_finite info.U_info.goal); 49 | let new_value = state.update_value t info in 50 | state.value <- Some new_value; 51 | state.last_updated <- t 52 | 53 | let create_stat window get = 54 | let alpha = 1. /. float window in 55 | let state = Mv_var.init ~alpha_avg:alpha ~alpha_var:alpha () in 56 | let force_update t = 57 | let x = get t in 58 | if U_float.is_finite x then 59 | Mv_var.update state x 60 | in 61 | let update = U_lazy.get force_update in 62 | let get_average t = 63 | update t; 64 | Mv_var.get_average state 65 | in 66 | let get_stdev t = 67 | update t; 68 | Mv_var.get_stdev state 69 | in 70 | let get_normalized t = 71 | update t; 72 | Mv_var.get_normalized state 73 | in 74 | get_average, get_stdev, get_normalized 75 | 76 | let create () = 77 | let open U_info in 78 | let short_window = 20 in 79 | let long_window = 1000 in 80 | 81 | let create_long_stat get = 82 | create_stat long_window get 83 | in 84 | 85 | let create_short_stat get = 86 | create_stat short_window get 87 | in 88 | 89 | let input = ref U_info.dummy in 90 | 91 | let get_goal t = !input.goal in 92 | let get_pos_contrib t = !input.pos_contrib in 93 | let get_neg_contrib t = !input.neg_contrib in 94 | let get_pos_contrib_count t = !input.pos_contrib_count in 95 | let get_neg_contrib_count t = !input.neg_contrib_count in 96 | 97 | let _, _, get_norm_goal = create_long_stat get_goal in 98 | let get_recent_norm_goal, _, _ = create_short_stat get_norm_goal in 99 | 100 | let get_prediction t = 101 | get_pos_contrib t +. get_neg_contrib t 102 | in 103 | 104 | let get_delta t = 105 | abs_float (get_prediction t -. get_goal t) 106 | in 107 | 108 | let get_avg_delta, _, _ = create_long_stat get_delta in 109 | let get_recent_delta, _, _ = create_short_stat get_delta in 110 | let get_delta_ratio t = 111 | get_recent_delta t /. get_avg_delta t 112 | in 113 | 114 | let get_avg_pos_contrib, _, _ = create_long_stat get_pos_contrib in 115 | let get_recent_pos_contrib, _, _ = create_short_stat get_pos_contrib in 116 | let get_pos_contrib_ratio t = 117 | get_recent_pos_contrib t /. get_avg_pos_contrib t 118 | in 119 | 120 | let get_avg_neg_contrib, _, _ = create_long_stat get_neg_contrib in 121 | let get_recent_neg_contrib, _, _ = create_short_stat get_neg_contrib in 122 | let get_neg_contrib_ratio t = 123 | get_recent_neg_contrib t /. get_avg_neg_contrib t 124 | in 125 | 126 | let get_contrib_count t = 127 | float (get_pos_contrib_count t + get_neg_contrib_count t) 128 | in 129 | 130 | let get_avg_contrib_count, _, _ = create_long_stat get_contrib_count in 131 | let get_recent_contrib_count, _, _ = create_short_stat get_contrib_count in 132 | let get_activity t = 133 | get_recent_contrib_count t /. get_avg_contrib_count t 134 | in 135 | 136 | let update_value t new_input = 137 | input := new_input; 138 | let normalized_goal = get_norm_goal t in 139 | let recent_normalized_goal = get_recent_norm_goal t in 140 | let recent_delta = get_delta_ratio t in 141 | let activity = get_activity t in 142 | let recent_pos_contrib = get_pos_contrib_ratio t in 143 | let recent_neg_contrib = get_neg_contrib_ratio t in 144 | { 145 | normalized_goal; 146 | recent_normalized_goal; 147 | recent_delta; 148 | activity; 149 | recent_pos_contrib; 150 | recent_neg_contrib; 151 | } 152 | in 153 | { 154 | value = None; 155 | last_updated = -1; 156 | update_value; 157 | } 158 | 159 | let to_string x = 160 | sprintf "normalized_goal=%.3f \ 161 | recent_normalized_goal=%.3f \ 162 | recent_delta=%.3f \ 163 | activity=%.3f \ 164 | recent_pos_contrib=%.3f \ 165 | recent_neg_contrib=%.3f" 166 | x.normalized_goal 167 | x.recent_normalized_goal 168 | x.recent_delta 169 | x.activity 170 | x.recent_pos_contrib 171 | x.recent_neg_contrib 172 | -------------------------------------------------------------------------------- /u_exp.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Define an experiment, i.e. a set of parameters and 3 | and conditions that make the system stop once it reaches 4 | a satisfying state. 5 | *) 6 | 7 | open Printf 8 | open U_log 9 | 10 | let default_time_to_confirm_convergence = 1000 11 | 12 | type goal = { 13 | goal_name: string; 14 | goal_condition: U_system.t -> U_time.t -> bool; 15 | 16 | goal_time_to_confirm_convergence: U_time.t; 17 | (* Number of consecutive times the goal condition must be met 18 | in order to assume convergence toward the goal. *) 19 | 20 | mutable goal_reached_at: U_time.t option; 21 | (* Date of the cycle after which the goal was reached. 22 | For the number of steps, add 1. *) 23 | 24 | mutable goal_converged: bool; 25 | } 26 | 27 | type experiment = { 28 | exp_name: string; 29 | exp_goals: goal list; 30 | } 31 | 32 | let update_converged_flag x t = 33 | let converged = 34 | match x.goal_reached_at with 35 | | None -> false 36 | | Some t1 -> 37 | let consecutive_times_converged = t - t1 + 1 in 38 | assert (consecutive_times_converged >= 1); 39 | consecutive_times_converged >= x.goal_time_to_confirm_convergence 40 | in 41 | x.goal_converged <- converged 42 | 43 | (* 44 | If convergence hasn't been confirmed yet, 45 | update the fields `goal_reached_at` and `goal_converged`. 46 | *) 47 | let update_goal_state x system t = 48 | match x.goal_converged with 49 | | true -> () 50 | | false -> 51 | (if x.goal_condition system t then 52 | (match x.goal_reached_at with 53 | | Some _ -> 54 | () 55 | | None -> 56 | x.goal_reached_at <- Some t 57 | ) 58 | else 59 | x.goal_reached_at <- None 60 | ); 61 | update_converged_flag x t 62 | 63 | let stop_condition (x : experiment) system t = 64 | List.iter (fun goal -> update_goal_state goal system t) x.exp_goals; 65 | let all_converged = List.for_all (fun x -> x.goal_converged) x.exp_goals in 66 | (if all_converged then 67 | assert (List.for_all (fun x -> x.goal_reached_at <> None) x.exp_goals) 68 | else 69 | () 70 | ); 71 | all_converged 72 | 73 | let create_goal 74 | ?(time_to_confirm_convergence = default_time_to_confirm_convergence) 75 | name cond = { 76 | goal_name = name; 77 | goal_condition = cond; 78 | goal_time_to_confirm_convergence = time_to_confirm_convergence; 79 | goal_reached_at = None; 80 | goal_converged = false; 81 | } 82 | 83 | let create_experiment name goals = { 84 | exp_name = name; 85 | exp_goals = goals; 86 | } 87 | 88 | (* 89 | Inspect the results of the same experiment with the same goals, 90 | repeated several times, and associate each goal name with 91 | the number of steps it took to reach it. 92 | *) 93 | let make_report l = 94 | let first_exp = 95 | match l with 96 | | [] -> assert false 97 | | x :: _ -> x 98 | in 99 | 100 | let tbl = Hashtbl.create 10 in 101 | List.iter 102 | (fun x -> Hashtbl.add tbl x.goal_name (ref [])) 103 | first_exp.exp_goals; 104 | 105 | List.iter (fun x -> 106 | List.iter (fun goal -> 107 | assert goal.goal_converged; 108 | match goal.goal_reached_at with 109 | | None -> assert false 110 | | Some t -> 111 | let r = 112 | try Hashtbl.find tbl goal.goal_name 113 | with Not_found -> assert false 114 | in 115 | let number_of_steps = t + 1 in 116 | r := number_of_steps :: !r 117 | ) x.exp_goals 118 | ) l; 119 | 120 | let goals_reached = 121 | List.map (fun x -> 122 | let name = x.goal_name in 123 | let r = 124 | try Hashtbl.find tbl name 125 | with Not_found -> assert false 126 | in 127 | (name, !r) 128 | ) first_exp.exp_goals 129 | in 130 | (first_exp.exp_name, goals_reached) 131 | 132 | let latex_of_goal_name = function 133 | | "a0" -> "A" 134 | | "b0" -> "B" 135 | | s -> s 136 | 137 | let print_goal_report exp_name (goal_name, int_list) = 138 | let data = List.map float int_list in 139 | let mean, stdev = U_stat.get_mean_and_stdev data in 140 | let p = U_stat.get_percentile data in 141 | let p0 = p 0. in 142 | let p10 = p 0.10 in 143 | let p50 = p 0.50 in 144 | let p90 = p 0.90 in 145 | let p100 = p 1. in 146 | logf "experiment %s, number of steps used to reach goal %s:" 147 | exp_name goal_name; 148 | logf " n = %i" (List.length data); 149 | logf " mean, stdev = %.2f, %.2f" mean stdev; 150 | logf " p0 (min) = %g" p0; 151 | logf " p10 = %.1f" p10; 152 | logf " p50 (median) = %.1f" p50; 153 | logf " p90 = %.1f" p90; 154 | logf " p100 (max) = %g" p100; 155 | 156 | (* This line is meant to produce a summary with `grep '^>'` *) 157 | printf "> %-30s %.1f\n%!" 158 | (exp_name ^ "." ^ goal_name) p50; 159 | 160 | (* The following lines are for the paper *) 161 | printf {| 162 | Number of steps to converge to Condition$_{%s}$: 163 | 164 | $$ 165 | \begin{eqnarray} 166 | \mathrm{10^{th} \dots 90^{th}\ percentile} &=& %.1f \dots %.1f\\ 167 | \mathrm{median} &=& %.1f\\ 168 | \hat{\mu} &=& %.1f\\ 169 | \hat{\sigma} &=& %.1f\\ 170 | \end{eqnarray} 171 | $$ 172 | %!|} 173 | (latex_of_goal_name goal_name) 174 | p10 p90 175 | p50 176 | mean 177 | stdev 178 | 179 | let print_report (exp_name, goals_reached) = 180 | List.iter (print_goal_report exp_name) goals_reached 181 | -------------------------------------------------------------------------------- /u_learn.ml: -------------------------------------------------------------------------------- 1 | (* 2 | Reinforcement step. 3 | *) 4 | 5 | open U_log 6 | open U_system 7 | 8 | (* 9 | The original weights with which we correct contributions are based 10 | on the recent standard deviation of those contributions. 11 | It turns out that if we give more weight to the largest standard 12 | deviations, it helps contributions to non-noisy effects to converge faster 13 | when they're mixed with effects associated with noise. 14 | 15 | Strictly increasing, continuous function f from [0,1] to [0,1], 16 | with the following properties: 17 | f(0) = 0 18 | f(1) = 1 19 | f(x) = x' such that x' < x 20 | f(x) behaves linearly near 0 21 | 22 | For this, we use a function of the form ax^b + (1-a)x, 23 | determined empirically. 24 | *) 25 | let reduce_low_relative_weight x = 26 | 0.5 *. x ** 1.3 +. 0.5 *. x 27 | 28 | let adjust_relative_weight x = 29 | assert (x >= 0. && x <= 1.); 30 | let weight = reduce_low_relative_weight x in 31 | (* Hack intended to correct for any initial underestimation 32 | of the standard deviation and for sudden changes of a contribution 33 | previously estimated with high certainty. *) 34 | max 0.001 weight 35 | 36 | (* 37 | Express all values as a fraction of the maximum value. 38 | *) 39 | let normalize_max l = 40 | let m = U_float.maxf l snd in 41 | List.rev_map (fun (a, b) -> (a, b /. m)) l 42 | 43 | (* 44 | Express all values as a fraction of the total. 45 | *) 46 | let normalize_total l = 47 | let sum = U_float.sumf l snd in 48 | List.rev_map (fun (a, b) -> (a, b /. sum)) l 49 | 50 | (* 51 | Adjust the value of each contribution, in the case where 52 | the weight of each contribution is known with enough confidence. 53 | 54 | The weights are redistributed so as to increase higher weights. 55 | This is done in two passes: 56 | 57 | 1. Normalize the weights with respect to the maximum weight, 58 | modify them to favor higher weights 59 | 2. Normalize the new weights with respect to the total. 60 | *) 61 | let adjust_partial_contributions ~delta contributions_and_weights = 62 | let l = normalize_max contributions_and_weights in 63 | let l = List.rev_map (fun (a, b) -> (a, adjust_relative_weight b)) l in 64 | let l = normalize_total l in 65 | List.iter (fun (x, share) -> 66 | let old_contrib = U_control.get_contrib_value x in 67 | let new_contrib = old_contrib +. share *. delta in 68 | if debug then 69 | logf "contrib: %g -> %g" 70 | old_contrib new_contrib; 71 | U_control.update_contrib x new_contrib 72 | ) l 73 | 74 | let adjust_contributions_evenly ~delta contributions = 75 | let n = List.length contributions in 76 | if n > 0 then ( 77 | let contrib_delta = delta /. float n in 78 | List.iter 79 | (fun x -> 80 | let old_contrib = U_control.get_contrib_value x in 81 | let new_contrib = old_contrib +. contrib_delta in 82 | if debug then 83 | logf "contrib: %g -> %g" 84 | old_contrib new_contrib; 85 | U_control.update_contrib x new_contrib) 86 | contributions 87 | ) 88 | 89 | let adjust_contributions_with_infinite_weight ~delta contributions = 90 | let infinite_contributions = 91 | List.filter (fun x -> U_control.get_weight x = infinity) 92 | contributions 93 | in 94 | adjust_contributions_evenly ~delta infinite_contributions 95 | 96 | let adjust_contributions contributions feedback = 97 | let prediction = 98 | List.fold_left (fun acc x -> 99 | let contrib = U_control.get_contrib_value x in 100 | if debug then 101 | logf "contribution to prediction: %g" contrib; 102 | acc +. contrib 103 | ) 104 | 0. contributions 105 | in 106 | let contributions_and_weights = 107 | List.rev_map (fun x -> (x, U_control.get_weight x)) contributions 108 | in 109 | let total_weight = U_float.sumf contributions_and_weights snd in 110 | let delta = feedback -. prediction in 111 | logf "feedback: %g, prediction: %g, delta: %g, total_weight: %g" 112 | feedback prediction delta total_weight; 113 | if total_weight > 0. && total_weight < infinity then 114 | adjust_partial_contributions ~delta contributions_and_weights 115 | else if total_weight = 0. then 116 | adjust_contributions_evenly ~delta contributions 117 | else if total_weight = infinity then 118 | adjust_contributions_with_infinite_weight ~delta contributions 119 | else 120 | assert false 121 | 122 | let extract_info feedback contributions = 123 | let open U_info in 124 | let info = { 125 | goal = feedback; 126 | pos_contrib = 0.; 127 | neg_contrib = 0.; 128 | pos_contrib_count = 0; 129 | neg_contrib_count = 0; 130 | } in 131 | let info = 132 | List.fold_left (fun info contrib -> 133 | let x = U_control.get_average contrib in 134 | if x > 0. then 135 | { info with 136 | pos_contrib = info.pos_contrib +. x; 137 | pos_contrib_count = info.pos_contrib_count + 1 } 138 | else if x < 0. then 139 | { info with 140 | neg_contrib = info.neg_contrib +. x; 141 | neg_contrib_count = info.neg_contrib_count + 1 } 142 | else 143 | info 144 | ) info contributions 145 | in 146 | info 147 | 148 | let learn (x : U_system.t) (feedback : float) : U_info.t = 149 | let contributions = 150 | U_recent_acts.fold x.recent_acts [] (fun age controlid acc -> 151 | let control = x.get_control controlid in 152 | let contribution = U_control.(control.contributions).(age) in 153 | contribution :: acc 154 | ) 155 | in 156 | adjust_contributions contributions feedback; 157 | extract_info feedback contributions 158 | -------------------------------------------------------------------------------- /u_eval.ml: -------------------------------------------------------------------------------- 1 | (* 2 | A very simple setup to make sure we don't have a major bug 3 | and to evaluate the behavior of the system in different conditions. 4 | 5 | Not tested at this point: 6 | - any number of actions other than 2 7 | - multiple controls for the same action 8 | - delayed contributions 9 | *) 10 | 11 | open Printf 12 | open U_log 13 | 14 | let default_global_iter = 100 15 | let default_window_length = 5 16 | 17 | (* We give up with an error after this many steps *) 18 | let max_iter = 100_000 19 | 20 | let default_base_contrib_a0 = 1. 21 | let default_base_contrib_a1 = -0.5 22 | let default_base_contrib_a2 = 0.25 23 | 24 | let default_base_contrib_b0 = 0.1 25 | let default_base_contrib_b1 = 0.2 26 | let default_base_contrib_b2 = 0.05 27 | 28 | let default_epsilon_a0 = 0.05 29 | let default_epsilon_b0 = 0.005 30 | 31 | let default_determine_actions_ab t = (U_random.pick 0.5, U_random.pick 0.5) 32 | let determine_actions_always_b t = (U_random.pick 0.5, true) 33 | 34 | let print_control oc x = 35 | logf "%s" (U_control.to_info x); 36 | U_control.print_csv oc x 37 | 38 | let get_average_contributions window_length acc = 39 | let stat = 40 | Array.init window_length (fun age -> 41 | U_stat.create_stdev_acc (), U_stat.create_stdev_acc () 42 | ) 43 | in 44 | List.iter (fun control -> 45 | U_control.iter_contributions control (fun ~age ~contrib ~average ~stdev -> 46 | let (add1, _, _), (add2, _, _) = stat.(age) in 47 | add1 average; 48 | add2 stdev 49 | ) 50 | ) acc; 51 | 52 | Array.map (fun ((_, _, get_avg_stat), (_, _, get_stdev_stat)) -> 53 | get_avg_stat (), get_stdev_stat () 54 | ) stat 55 | 56 | let print_contrib_stats controlid contrib_stat_array = 57 | Array.iteri (fun age ((avg_mean, avg_stdev), (stdev_mean, stdev_stdev)) -> 58 | logf "contribution %s[%i]: avg:(%.2g, %.2g) stdev:(%.2g, %.2g)" 59 | (U_controlid.to_string controlid) age 60 | avg_mean avg_stdev stdev_mean stdev_stdev 61 | ) contrib_stat_array 62 | 63 | let print_observables system t = 64 | let x = U_obs.get U_system.(system.observables) t in 65 | logf "observables: %s" (U_obs.to_string x) 66 | 67 | let create_delayed_effect_manager () = 68 | let scheduled_contributions = ref [] in 69 | let add_action future_contributions = 70 | scheduled_contributions := 71 | future_contributions :: !scheduled_contributions 72 | in 73 | let pop_effects t = 74 | let current = 75 | List.fold_left (fun acc l -> 76 | match l with 77 | | current :: _ -> acc +. current 78 | | [] -> acc 79 | ) 0. !scheduled_contributions 80 | in 81 | let future = 82 | List.filter ((<>) []) 83 | (List.map (function [] -> [] | _ :: l -> l) !scheduled_contributions) 84 | in 85 | scheduled_contributions := future; 86 | current 87 | in 88 | add_action, U_lazy.get pop_effects 89 | 90 | let csv_dir = "out" 91 | 92 | let get_csv_dir () = 93 | if not (Sys.file_exists csv_dir) then 94 | Unix.mkdir csv_dir 0o777; 95 | csv_dir 96 | 97 | let get_csv_filename name = 98 | sprintf "%s/%s.csv" 99 | (get_csv_dir ()) name 100 | 101 | let test_system_once 102 | ?inner_log_mode 103 | ~name 104 | ~create_experiment 105 | ~window_length 106 | ~controlid_a 107 | ~controlid_b 108 | ~base_contrib_a0 109 | ~base_contrib_a1 110 | ~base_contrib_a2 111 | ~base_contrib_b0 112 | ~base_contrib_b1 113 | ~base_contrib_b2 114 | ?(noise_a = fun t -> 0.) 115 | ?(noise_b = fun t -> 0.) 116 | ?(noise = fun t -> 0.) 117 | ?(determine_actions_ab = default_determine_actions_ab) 118 | () = 119 | 120 | let controls = U_control.create_set () in 121 | let actions = U_action.create_set () in 122 | let add_control id actionid = 123 | U_control.add ~window_length ~id ~actionid controls 124 | in 125 | 126 | (* A has its own frequency and constant contribution. *) 127 | let a_was_active = ref false in 128 | let actionid_a = U_actionid.of_string "A" in 129 | U_action.add actionid_a (fun () -> a_was_active := true) actions; 130 | 131 | add_control controlid_a actionid_a; 132 | 133 | (* B has its own frequency and constant contribution, independent from A. *) 134 | let b_was_active = ref false in 135 | let actionid_b = U_actionid.of_string "B" in 136 | U_action.add actionid_b (fun () -> b_was_active := true) actions; 137 | 138 | add_control controlid_b actionid_b; 139 | 140 | let before_step t = 141 | logf "--------------------------------------------------------------"; 142 | a_was_active := false; 143 | b_was_active := false 144 | in 145 | 146 | let add_action, pop_effects = create_delayed_effect_manager () in 147 | 148 | let read_active_controls t add = 149 | let a, b = determine_actions_ab t in 150 | if a then ( 151 | logf "A*"; 152 | add controlid_a; 153 | let extra = noise_a t in 154 | add_action [base_contrib_a0 +. extra; 155 | base_contrib_a1 +. extra; 156 | base_contrib_a2 +. extra]; 157 | ); 158 | if b then ( 159 | logf "B*"; 160 | add controlid_b; 161 | let extra = noise_b t in 162 | add_action [base_contrib_b0 +. extra; 163 | base_contrib_b1 +. extra; 164 | base_contrib_b2 +. extra]; 165 | ) 166 | in 167 | 168 | let goal_function t = 169 | pop_effects t +. noise t 170 | in 171 | 172 | let get_control id = 173 | U_control.get controls id 174 | in 175 | let get_action id = 176 | U_action.get actions id 177 | in 178 | 179 | let get_controls () = 180 | get_control controlid_a, get_control controlid_b 181 | in 182 | 183 | let system = 184 | U_system.create 185 | ~window_length 186 | ~goal_function 187 | ~read_active_controls 188 | ~get_control 189 | ~get_action 190 | in 191 | let experiment = create_experiment get_controls in 192 | let oc_a = 193 | U_control.open_csv window_length (get_csv_filename (name ^ "-a")) in 194 | let oc_b = 195 | U_control.open_csv window_length (get_csv_filename (name ^ "-b")) in 196 | 197 | let after_step t = 198 | let control_a, control_b = get_controls () in 199 | print_control oc_a control_a; 200 | print_control oc_b control_b; 201 | print_observables system t; 202 | let stop = U_exp.stop_condition experiment system t in 203 | let continue = not stop in 204 | if continue && t >= max_iter then ( 205 | eprintf "> %s ERROR: too many iterations (%i)\n%!" name max_iter; 206 | failwith "Too many iterations" 207 | ); 208 | continue 209 | in 210 | U_cycle.loop 211 | ?inner_log_mode 212 | ~before_step 213 | ~after_step 214 | system; 215 | 216 | close_out oc_a; 217 | close_out oc_b; 218 | let control_a, control_b = get_controls () in 219 | (experiment, control_a, control_b) 220 | 221 | (* 222 | We run the system from t=0 several times in order to get a sense 223 | of the variability of the results. 224 | *) 225 | let test_system 226 | ~name 227 | ~create_experiment 228 | ~base_contrib_a0 229 | ~base_contrib_a1 230 | ~base_contrib_a2 231 | ~base_contrib_b0 232 | ~base_contrib_b1 233 | ~base_contrib_b2 234 | ?(global_iter = default_global_iter) 235 | ?(window_length = default_window_length) 236 | ?noise_a 237 | ?noise_b 238 | ?noise 239 | ?determine_actions_ab 240 | () = 241 | 242 | let controlid_a = U_controlid.of_string "A" in 243 | let controlid_b = U_controlid.of_string "B" in 244 | let acc_exp = ref [] in 245 | let acc_a = ref [] in 246 | let acc_b = ref [] in 247 | for i = 1 to global_iter do 248 | let inner_log_mode = 249 | if i = 1 then `Skip 250 | else `Off 251 | in 252 | logf "--- Run %i/%i ---" i global_iter; 253 | let experiment, control_a, control_b = 254 | test_system_once 255 | ~name: (sprintf "%s-%i" name i) 256 | ~inner_log_mode 257 | ~create_experiment 258 | ~window_length 259 | ~controlid_a 260 | ~controlid_b 261 | ~base_contrib_a0 262 | ~base_contrib_a1 263 | ~base_contrib_a2 264 | ~base_contrib_b0 265 | ~base_contrib_b1 266 | ~base_contrib_b2 267 | ?noise_a 268 | ?noise_b 269 | ?noise 270 | ?determine_actions_ab 271 | () 272 | in 273 | acc_exp := experiment :: !acc_exp; 274 | acc_a := control_a :: !acc_a; 275 | acc_b := control_b :: !acc_b; 276 | done; 277 | let contrib_stat_a = get_average_contributions window_length !acc_a in 278 | let contrib_stat_b = get_average_contributions window_length !acc_b in 279 | print_contrib_stats controlid_a contrib_stat_a; 280 | print_contrib_stats controlid_b contrib_stat_b; 281 | let exp_report = U_exp.make_report !acc_exp in 282 | U_exp.print_report exp_report; 283 | true 284 | 285 | let create_default_goals 286 | ?(epsilon_a0 = default_epsilon_a0) 287 | ?(epsilon_b0 = default_epsilon_b0) 288 | ~base_contrib_a0 289 | ~base_contrib_b0 290 | get_controls = 291 | assert (epsilon_a0 > 0.); 292 | assert (epsilon_b0 > 0.); 293 | let cond_a0 system t = 294 | let control_a, control_b = get_controls () in 295 | let contrib_a0 = U_control.get_contribution control_a 0 in 296 | let x = U_control.get_contrib_value contrib_a0 in 297 | abs_float (x -. base_contrib_a0) <= epsilon_a0 298 | in 299 | let cond_b0 system t = 300 | let control_a, control_b = get_controls () in 301 | let contrib_b0 = U_control.get_contribution control_b 0 in 302 | let x = U_control.get_contrib_value contrib_b0 in 303 | abs_float (x -. base_contrib_b0) <= epsilon_b0 304 | in 305 | let goal_a = U_exp.create_goal "a0" cond_a0 in 306 | let goal_b = U_exp.create_goal "b0" cond_b0 in 307 | [goal_a; goal_b] 308 | 309 | let make_create_experiment 310 | ~base_contrib_a0 311 | ~base_contrib_b0 312 | ?epsilon_a0 313 | ?epsilon_b0 314 | ?(create_extra_goals = fun get_controls -> []) 315 | name = 316 | fun get_controls -> 317 | let base_goals = 318 | create_default_goals 319 | ?epsilon_a0 320 | ?epsilon_b0 321 | ~base_contrib_a0 322 | ~base_contrib_b0 323 | get_controls 324 | in 325 | let extra_goals = create_extra_goals get_controls in 326 | U_exp.create_experiment name (base_goals @ extra_goals) 327 | 328 | let make_test 329 | ?window_length 330 | ?(base_contrib_a0 = default_base_contrib_a0) 331 | ?(base_contrib_a1 = default_base_contrib_a1) 332 | ?(base_contrib_a2 = default_base_contrib_a2) 333 | ?(base_contrib_b0 = default_base_contrib_b0) 334 | ?(base_contrib_b1 = default_base_contrib_b1) 335 | ?(base_contrib_b2 = default_base_contrib_b2) 336 | ?epsilon_a0 337 | ?epsilon_b0 338 | ?noise_a 339 | ?noise_b 340 | ?noise 341 | ?determine_actions_ab 342 | ~name 343 | () = 344 | let create_experiment = 345 | make_create_experiment 346 | ~base_contrib_a0 347 | ~base_contrib_b0 348 | ?epsilon_a0 349 | ?epsilon_b0 350 | name 351 | in 352 | test_system 353 | ~name 354 | ~create_experiment 355 | ~base_contrib_a0 356 | ~base_contrib_a1 357 | ~base_contrib_a2 358 | ~base_contrib_b0 359 | ~base_contrib_b1 360 | ~base_contrib_b2 361 | ?window_length 362 | ?noise_a 363 | ?noise_b 364 | ?noise 365 | ?determine_actions_ab 366 | () 367 | 368 | let test_default () = 369 | make_test 370 | ~name: "default" 371 | () 372 | 373 | let test_window1 () = 374 | make_test 375 | ~name: "window1" 376 | ~window_length: 1 377 | ~base_contrib_a1: 0. 378 | ~base_contrib_a2: 0. 379 | ~base_contrib_b1: 0. 380 | ~base_contrib_b2: 0. 381 | () 382 | 383 | let test_window3 () = 384 | make_test 385 | ~name: "window3" 386 | ~window_length: 3 387 | () 388 | 389 | let test_window10 () = 390 | make_test 391 | ~name: "window10" 392 | ~window_length: 10 393 | () 394 | 395 | let test_scaled_contributions () = 396 | let scale x = -1000. *. x in 397 | let abs_scale x = abs_float (scale x) in 398 | make_test 399 | ~base_contrib_a0: (scale default_base_contrib_a0) 400 | ~base_contrib_a1: (scale default_base_contrib_a1) 401 | ~base_contrib_a2: (scale default_base_contrib_a2) 402 | ~base_contrib_b0: (scale default_base_contrib_b0) 403 | ~base_contrib_b1: (scale default_base_contrib_b1) 404 | ~base_contrib_b2: (scale default_base_contrib_b2) 405 | ~epsilon_a0: (abs_scale default_epsilon_a0) 406 | ~epsilon_b0: (abs_scale default_epsilon_b0) 407 | ~name: "scaled_contributions" 408 | () 409 | 410 | let test_large_difference () = 411 | make_test 412 | ~name: "large_difference" 413 | ~base_contrib_a0: 100. 414 | () 415 | 416 | let test_large_difference_corrected () = 417 | (* hardcoded ratio of max gap between expected contributions, 418 | from default setup to this setup *) 419 | let r = (100. -. (-0.5)) /. (1. -. (-0.5)) in 420 | let epsilon_a0 = r *. default_epsilon_a0 in 421 | let epsilon_b0 = r *. default_epsilon_b0 in 422 | make_test 423 | ~name: "large_difference_corrected" 424 | ~base_contrib_a0: 100. 425 | ~epsilon_a0 426 | ~epsilon_b0 427 | () 428 | 429 | (* B active => A active *) 430 | let test_subaction () = 431 | let determine_actions_ab t = 432 | let a = U_random.pick 0.5 in 433 | let b = a && U_random.pick 0.5 in 434 | a, b 435 | in 436 | make_test 437 | ~name: "subaction" 438 | ~determine_actions_ab 439 | () 440 | 441 | (* 442 | Change the contributions of A and B suddenly, and see if we can adjust 443 | the predictions. 444 | *) 445 | let test_adaptation () = 446 | assert (default_base_contrib_a0 = 1.); 447 | assert (default_base_contrib_b0 = 0.1); 448 | let noise_a t = 449 | if t < 100 then 1. 450 | else 0. 451 | in 452 | let noise_b t = 453 | if t < 50 then (-0.1) 454 | else 0. 455 | in 456 | make_test 457 | ~name: "adaptation" 458 | ~noise_a 459 | ~noise_b 460 | ~determine_actions_ab: (fun t -> U_random.pick 0.5, U_random.pick 0.5) 461 | () 462 | 463 | (* same parameters as noisy_contribution below, without the noise. *) 464 | let test_nonnoisy_contribution () = 465 | assert (default_base_contrib_a0 = 1.); 466 | make_test 467 | ~name: "nonnoisy_contribution" 468 | ~determine_actions_ab:determine_actions_always_b 469 | ~epsilon_b0: 1000. 470 | () 471 | 472 | let test_noisy_other_contribution () = 473 | assert (default_base_contrib_a0 = 1.); 474 | make_test 475 | ~name: "noisy_other_contribution" 476 | ~determine_actions_ab:determine_actions_always_b 477 | ~noise_b:(fun _ -> 478 | U_random.normal ~stdev: 0.5 () 479 | ) 480 | ~epsilon_b0: 1000. 481 | () 482 | 483 | let test_noisy_contributions () = 484 | assert (default_base_contrib_a0 = 1.); 485 | assert (default_base_contrib_b0 = 0.1); 486 | let noise_a t = 487 | U_random.normal ~stdev:0.4 () 488 | in 489 | let noise_b t = 490 | U_random.normal ~stdev:0.04 () 491 | in 492 | make_test 493 | ~name: "noisy_contributions" 494 | ~epsilon_a0: 0.4 495 | ~epsilon_b0: 0.04 496 | ~noise_a 497 | ~noise_b 498 | () 499 | 500 | let test_global_noise suffix noise_stdev = 501 | assert (default_base_contrib_a0 = 1.); 502 | assert (default_base_contrib_b0 = 0.1); 503 | let noise t = 504 | U_random.normal ~stdev:noise_stdev () 505 | in 506 | make_test 507 | ~name: ("global_noise" ^ suffix) 508 | ~noise 509 | () 510 | 511 | let test_global_noise1 () = test_global_noise "1" 0.1 512 | let test_global_noise2 () = test_global_noise "2" 0.2 513 | 514 | let tests = [ 515 | "default", test_default; 516 | "window1", test_window1; 517 | "window3", test_window3; 518 | "window10", test_window10; 519 | "scaled_contributions", test_scaled_contributions; 520 | "large difference", test_large_difference; 521 | "large difference_corrected", test_large_difference_corrected; 522 | "subaction", test_subaction; 523 | "adaptation", test_adaptation; 524 | "non-noisy contribution", test_nonnoisy_contribution; 525 | "noisy other contribution", test_noisy_other_contribution; 526 | "noisy contributions", test_noisy_contributions; 527 | "global noise1", test_global_noise1; 528 | "global noise2", test_global_noise2; 529 | ] 530 | --------------------------------------------------------------------------------