├── client ├── dune-project ├── tlstunnel.opam ├── configuration.ml ├── dune └── client.ml ├── .cirrus.yml ├── unikernel ├── config.ml ├── configuration.ml ├── filesystem.ml └── unikernel.ml └── README.md /client/dune-project: -------------------------------------------------------------------------------- 1 | (lang dune 1.0) -------------------------------------------------------------------------------- /client/tlstunnel.opam: -------------------------------------------------------------------------------- 1 | opam-version: "2.0" 2 | -------------------------------------------------------------------------------- /client/configuration.ml: -------------------------------------------------------------------------------- 1 | ../unikernel/configuration.ml -------------------------------------------------------------------------------- /client/dune: -------------------------------------------------------------------------------- 1 | (executable 2 | (name client) 3 | (public_name tlstunnel-client) 4 | (libraries cmdliner logs.fmt fmt.cli logs.cli fmt.tty logs ipaddr asn1-combinators digestif.c)) 5 | -------------------------------------------------------------------------------- /.cirrus.yml: -------------------------------------------------------------------------------- 1 | freebsd_instance: 2 | image_family: freebsd-14-3 3 | 4 | freebsd_task: 5 | pkg_install_script: pkg install -y ocaml-opam gmake bash 6 | ocaml_script: opam init -a --comp=4.14.1 7 | mirage_script: eval `opam env` && opam install --confirm-level=unsafe-yes "mirage>=4.9.0" 8 | configure_script: eval `opam env` && cd unikernel && mirage configure -t hvt 9 | depend_script: eval `opam env` && cd unikernel && gmake depend 10 | build_script: eval `opam env` && cd unikernel && gmake build 11 | tlstunnel_artifacts: 12 | path: unikernel/dist/tlstunnel.hvt 13 | 14 | freebsd_monitoring_task: 15 | pkg_install_script: pkg install -y ocaml-opam gmake bash 16 | ocaml_script: opam init -a --comp=4.14.2 17 | mirage_script: eval `opam env` && opam install --confirm-level=unsafe-yes "mirage>=4.9.0" 18 | configure_script: eval `opam env` && cd unikernel && mirage configure -t hvt --enable-monitoring 19 | depend_script: eval `opam env` && cd unikernel && gmake depend 20 | build_script: eval `opam env` && cd unikernel && gmake build 21 | tlstunnel_artifacts: 22 | path: unikernel/dist/tlstunnel.hvt 23 | -------------------------------------------------------------------------------- /unikernel/config.ml: -------------------------------------------------------------------------------- 1 | (* mirage >= 4.10.0 & < 4.11.0 *) 2 | (* (c) 2019 Hannes Mehnert, all rights reserved *) 3 | 4 | open Mirage 5 | 6 | let main = 7 | main 8 | ~packages:[ 9 | package ~min:"0.14.0" "tls-mirage" ; 10 | package ~min:"10.0.0" ~sublibs:["mirage"] "dns-certify" ; 11 | package ~min:"6.0.0" "cstruct" ; 12 | package ~min:"7.0.0" "tcpip" ; 13 | package "metrics"; 14 | package ~min:"4.5.0" ~sublibs:["network"] "mirage-runtime"; 15 | ] 16 | "Unikernel.Main" 17 | (block @-> stackv4v6 @-> stackv4v6 @-> job) 18 | 19 | (* uTCP *) 20 | 21 | let tcpv4v6_direct_conf id = 22 | let packages_v = Key.pure [ package "utcp" ~sublibs:[ "mirage" ] ] in 23 | let connect _ modname = function 24 | | [ip] -> 25 | code ~pos:__POS__ "Lwt.return (%s.connect %S %s)" modname id ip 26 | | _ -> failwith "direct tcpv4v6" 27 | in 28 | impl ~packages_v ~connect "Utcp_mirage.Make" 29 | (ipv4v6 @-> (tcp: 'a tcp typ)) 30 | 31 | let direct_tcpv4v6 id ip = 32 | tcpv4v6_direct_conf id $ ip 33 | 34 | let net ?group name netif = 35 | let ethernet = ethif netif in 36 | let arp = arp ethernet in 37 | let i4 = create_ipv4 ?group ethernet arp in 38 | let i6 = create_ipv6 ?group netif ethernet in 39 | let i4i6 = create_ipv4v6 ?group i4 i6 in 40 | let tcpv4v6 = direct_tcpv4v6 name i4i6 in 41 | direct_stackv4v6 ?group ~tcp:tcpv4v6 netif ethernet arp i4 i6 42 | 43 | let use_utcp = 44 | let doc = Key.Arg.info ~doc:"Use uTCP" [ "use-utcp" ] in 45 | Key.(create "use-utcp" Arg.(flag doc)) 46 | 47 | let stack = 48 | if_impl 49 | (Key.value use_utcp) 50 | (net "service" default_network) 51 | (generic_stackv4v6 default_network) 52 | 53 | let private_stack = 54 | if_impl 55 | (Key.value use_utcp) 56 | (net ~group:"private" "private" (netif ~group:"private" "private")) 57 | (generic_stackv4v6 ~group:"private" (netif ~group:"private" "private")) 58 | 59 | let block = 60 | Key.(if_impl is_solo5 (block_of_file "storage") (block_of_file "disk.img")) 61 | 62 | let enable_monitoring = 63 | let doc = Key.Arg.info 64 | ~doc:"Enable monitoring (syslog, metrics to influx, log level, statmemprof tracing)" 65 | [ "enable-monitoring" ] 66 | in 67 | Key.(create "enable-monitoring" Arg.(flag doc)) 68 | 69 | let management_stack = 70 | if_impl 71 | (Key.value enable_monitoring) 72 | (if_impl 73 | (Key.value use_utcp) 74 | (net ~group:"management" "management" (netif ~group:"management" "management")) 75 | (generic_stackv4v6 ~group:"management" (netif ~group:"management" "management"))) 76 | stack 77 | 78 | let monitoring = 79 | let monitor = Runtime_arg.(v (monitor None)) in 80 | let connect _ modname = function 81 | | [ stack ; monitor ] -> 82 | code ~pos:__POS__ 83 | "Lwt.return (match %s with\ 84 | | None -> Logs.warn (fun m -> m \"no monitor specified, not outputting statistics\")\ 85 | | Some ip -> %s.create ip ~hostname:(Mirage_runtime.name ()) %s)" 86 | monitor modname stack 87 | | _ -> assert false 88 | in 89 | impl 90 | ~packages:[ package ~min:"0.0.6" "mirage-monitoring" ] 91 | ~runtime_args:[ monitor ] 92 | ~connect "Mirage_monitoring.Make" 93 | (stackv4v6 @-> job) 94 | 95 | let syslog = 96 | let syslog = Runtime_arg.(v (syslog None)) in 97 | let connect _ modname = function 98 | | [ stack ; syslog ] -> 99 | code ~pos:__POS__ 100 | "Lwt.return (match %s with\ 101 | | None -> Logs.warn (fun m -> m \"no syslog specified, dumping on stdout\")\ 102 | | Some ip -> Logs.set_reporter (%s.create %s ip ~hostname:(Mirage_runtime.name ()) ()))" 103 | syslog modname stack 104 | | _ -> assert false 105 | in 106 | impl 107 | ~packages:[ package ~sublibs:["mirage"] ~min:"0.5.0" "logs-syslog" ] 108 | ~runtime_args:[ syslog ] 109 | ~connect "Logs_syslog_mirage.Udp" 110 | (stackv4v6 @-> job) 111 | 112 | let optional_monitoring stack = 113 | if_impl (Key.value enable_monitoring) 114 | (monitoring $ stack) 115 | noop 116 | 117 | let optional_syslog stack = 118 | if_impl (Key.value enable_monitoring) 119 | (syslog $ stack) 120 | noop 121 | 122 | let () = 123 | register "tlstunnel" 124 | [ 125 | optional_syslog management_stack ; 126 | optional_monitoring management_stack ; 127 | main $ block $ stack $ private_stack 128 | ] 129 | -------------------------------------------------------------------------------- /unikernel/configuration.ml: -------------------------------------------------------------------------------- 1 | 2 | let ip = 3 | let f = function 4 | | `C1 i -> Ipaddr.(V4 (V4.of_int32 (Int32.of_int i))) 5 | | `C2 (a, b, c, d) -> 6 | let v6 = 7 | Int32.of_int a, Int32.of_int b, Int32.of_int c, Int32.of_int d 8 | in 9 | Ipaddr.(V6 (V6.of_int32 v6)) 10 | and g = function 11 | | Ipaddr.V4 ip -> `C1 (Int32.to_int (Ipaddr.V4.to_int32 ip)) 12 | | Ipaddr.V6 ip -> 13 | let a, b, c, d = Ipaddr.V6.to_int32 ip in 14 | `C2 (Int32.to_int a, Int32.to_int b, Int32.to_int c, Int32.to_int d) 15 | in 16 | Asn.S.(map f g (choice2 17 | (explicit 0 int) 18 | (explicit 1 (sequence4 (required int) (required int) 19 | (required int) (required int))))) 20 | 21 | let sni = 22 | let f (sni, host, port) = 23 | Domain_name.(host_exn (of_string_exn sni)), host, port 24 | and g (sni, host, port) = 25 | Domain_name.to_string sni, host, port 26 | in 27 | Asn.S.(map f g 28 | (sequence3 29 | (required ~label:"sni" utf8_string) 30 | (required ~label:"host" ip) 31 | (required ~label:"port" int))) 32 | 33 | let data = 34 | let f = function 35 | | `C1 s -> s 36 | | `C2 () -> assert false 37 | and g s = `C1 s 38 | in 39 | Asn.S.(map f g (choice2 (sequence_of sni) (explicit 1 null))) 40 | 41 | let decode_strict codec cs = 42 | match Asn.decode codec cs with 43 | | Ok (a, rest) -> 44 | if String.length rest = 0 then 45 | Ok a 46 | else 47 | Error (`Msg "trailing bytes") 48 | | Error (`Parse msg) -> Error (`Msg msg) 49 | 50 | let projections_of asn = 51 | let c = Asn.codec Asn.der asn in 52 | (decode_strict c, Asn.encode c) 53 | 54 | let data_of_cs, data_to_cs = projections_of data 55 | 56 | let decode_data data = 57 | match data_of_cs data with 58 | | Ok snis -> 59 | List.fold_left 60 | (fun acc (sni, host, port) -> 61 | Domain_name.Host_map.add sni (host, port) acc) 62 | Domain_name.Host_map.empty snis 63 | | Error `Msg msg -> 64 | Logs.err (fun m -> m "error %s decoding data" msg); 65 | assert false 66 | 67 | let encode_data sni = 68 | let snis = 69 | Domain_name.Host_map.fold 70 | (fun sni (host, port) acc -> (sni, host, port) :: acc) 71 | sni [] 72 | in 73 | data_to_cs snis 74 | 75 | let add_sni snis (sni, host, port) = 76 | (match Domain_name.Host_map.find_opt sni snis with 77 | | None -> () 78 | | Some (ohost, oport) -> 79 | Logs.warn (fun m -> m "overwriting %a -> %a:%d with %a:%d" 80 | Domain_name.pp sni Ipaddr.pp ohost oport Ipaddr.pp host port)); 81 | Logs.info (fun m -> m "%a is now redirected to %a:%d" 82 | Domain_name.pp sni Ipaddr.pp host port); 83 | Domain_name.Host_map.add sni (host, port) snis 84 | 85 | let remove_sni snis sni = 86 | Logs.info (fun m -> m "%a is no longer redirected" Domain_name.pp sni); 87 | Domain_name.Host_map.remove sni snis 88 | 89 | type cmd = 90 | | Add of [`host] Domain_name.t * Ipaddr.t * int 91 | | Remove of [`host] Domain_name.t 92 | | List 93 | | Snis of ([`host] Domain_name.t * Ipaddr.t * int) list 94 | | Result of int * string 95 | 96 | let pp_one ppf (sni, host, port) = 97 | Fmt.pf ppf "%a -> %a:%u" Domain_name.pp sni Ipaddr.pp host port 98 | 99 | let pp_cmd ppf = function 100 | | Add (s, h, p) -> Fmt.pf ppf "adding %a" pp_one (s, h, p) 101 | | Remove sni -> Fmt.pf ppf "removing %a" Domain_name.pp sni 102 | | List -> Fmt.string ppf "list" 103 | | Snis xs -> Fmt.(list ~sep:(any ";@ ") pp_one) ppf xs 104 | | Result (c, msg) -> Fmt.pf ppf "exited %d: %s" c msg 105 | 106 | let cmd = 107 | let f = function 108 | | `C1 (s, h, p) -> Add (s, h, p) 109 | | `C2 s -> Remove Domain_name.(host_exn (of_string_exn s)) 110 | | `C3 () -> List 111 | | `C4 xs -> Snis xs 112 | | `C5 (c, s) -> Result (c, s) 113 | and g = function 114 | | Add (s, h, p) -> `C1 (s, h, p) 115 | | Remove s -> `C2 (Domain_name.to_string s) 116 | | List -> `C3 () 117 | | Snis xs -> `C4 xs 118 | | Result (c, s) -> `C5 (c, s) 119 | in 120 | Asn.S.(map f g 121 | (choice5 122 | (explicit 0 sni) 123 | (explicit 1 utf8_string) 124 | (explicit 2 null) 125 | (explicit 3 (sequence_of sni)) 126 | (explicit 4 (sequence2 127 | (required ~label:"exit" int) 128 | (required ~label:"message" utf8_string))))) 129 | 130 | let cmd_of_str, cmd_to_str = projections_of cmd 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## TLStunnel 2 | 3 | This is a MirageOS unikernel accepting TLS connections via the public (service) 4 | network interface on frontend-port, and proxying them using TCP via the private 5 | network interface to backend-ip and backend-port. A client connecting to 6 | TLStunnel has to establish a TLS connection, which payload is forwarded to the 7 | backend service via TCP. 8 | 9 | TLStunnel can be used for load-balancing - using multiple TLStunnel on the 10 | frontend doing expensive crypto operations (asymmetrics TLS handshakes and 11 | symmetric cryptography) with a single (or multiple) backend-services which 12 | communicate via plain TCP. 13 | 14 | Security-wise only the TLStunnel needs access to the private key of the X.509 15 | certificate(s). When TLStunnel is configured to do client authentication, only 16 | valid clients can access the backend service, limiting the attack surface 17 | drastically. 18 | 19 | ## Usage 20 | 21 | Executing TLStunnel requires two IP addresses: one is the public facing one, the 22 | other is on the private network (where TCP connections are forwarded to). 23 | Configuration can be done via a command-line utility on the private network. The 24 | X.509 certificate should be available via DNS (see 25 | [dns-primary-git](https://github.com/robur-coop/dns-primary-git) and 26 | [dns-letsencrypt-secondary](https://github.com/robur-coop/dns-letsencrypt-secondary/)). 27 | 28 | Let's consider your public IP address being 1.2.3.4/24 (with default gateway 29 | 1.2.3.1). You use 192.168.0.4/24 as your private network. Your DNS server is 30 | 1.2.3.5 with the key tlstunnel._update.example.org. 31 | 32 | Starting TLStunnel: 33 | 34 | ```bash 35 | $ truncate -s 1m /var/db/tlstunnel 36 | $ solo5-hvt --net:service=tap0 --net:private=tap10 --block:storage=/var/db/tlstunnel -- \ 37 | tlstunnel/unikernel/dist/tlstunnel.hvt --ipv4=1.2.3.4/24 --ipv4-gateway=1.2.3.1 \ 38 | --private-ipv4=192.168.0.4/24 --domains=example.org \ 39 | --dns-server=1.2.3.5 --dns-key=tlstunnel._update.example.org:SHA256:m2gls0y3ZMN4DVKx37x/VoKEdll4J2A9qNIl6JIz2z4= \ 40 | --key-seed=ROkD8o/Xrc4ScDdxM8cV1+4eQiWUEul+3I1twW+I15E= \ 41 | --key=9Fe92fogykIAPBJZU4FUsmpRsAy6YDajIkdSRs650zM= 42 | ``` 43 | 44 | Now, once tlstunnel managed to get a certificate via DNS, you can already 45 | connect to https://1.2.3.4 and should see the certificate: 46 | 47 | ```bash 48 | $ openssl s_client -connect 1.2.3.4:443 49 | $ curl https://1.2.3.4 50 | ``` 51 | 52 | To configure TLStunnel's forwarding, where a specified hostname will be 53 | forwarded to an IP address and port pair, you have to use the binary 54 | `tlstunnel-client` from the `client` subfolder. The communication is 55 | authenticated using the shared secret passed to TLStunnel (`--key=secret`). 56 | 57 | The configuration is kept in the block device (in a robust way, i.e. on change 58 | first the new data is written and afterwards the superblock is updates). 59 | 60 | ```bash 61 | $ cd tlstunnel/client 62 | $ dune build 63 | 64 | # Listing all configured hostnames: 65 | $ _build/install/default/bin/tlstunnel-client list --key=9Fe92fogykIAPBJZU4FUsmpRsAy6YDajIkdSRs650zM= -r 192.168.0.4:1234 66 | 67 | # Adding a new forward: 68 | $ _build/install/default/bin/tlstunnel-client add --key=9Fe92fogykIAPBJZU4FUsmpRsAy6YDajIkdSRs650zM= -r 192.168.0.4:1234 test.example.org 192.168.0.42 80 69 | 70 | # Removing a foward: 71 | $ _build/install/default/bin/tlstunnel-client remove --key=9Fe92fogykIAPBJZU4FUsmpRsAy6YDajIkdSRs650zM= -r 192.168.0.4:1234 test.example.org 72 | ``` 73 | 74 | ## Installation from source 75 | 76 | To install this unikernel from source, you need to have 77 | [opam](https://opam.ocaml.org) (>= 2.1.0) and 78 | [ocaml](https://ocaml.org) (>= 4.13.0) installed. Also, 79 | [mirage](https://mirageos.org) is required (>= 4.10.0). Please follow the 80 | [installation instructions](https://mirageos.org/wiki/install). 81 | 82 | The following steps will clone this git repository and compile the unikernel: 83 | 84 | ```bash 85 | $ git clone https://github.com/robur-coop/tlstunnel.git 86 | $ cd tlstunnel/unikernel && mirage configure -t 87 | $ make depend 88 | $ make build 89 | ``` 90 | 91 | ## Installing as binary 92 | 93 | Binaries are available at [Reproducible OPAM 94 | builds](https://builds.robur.coop/job/tlstunnel/), see [Deploying binary MirageOS 95 | unikernels](https://hannes.robur.coop/Posts/Deploy) and [Reproducible MirageOS 96 | unikernel builds](https://hannes.robur.coop/Posts/ReproducibleOPAM) for details. 97 | 98 | ## Questions? 99 | 100 | Please open an issue if you have questions, feature requests, or comments. 101 | -------------------------------------------------------------------------------- /client/client.ml: -------------------------------------------------------------------------------- 1 | let ( let* ) = Result.bind 2 | 3 | let rec ign_intr f v = 4 | try f v with Unix.Unix_error (Unix.EINTR, _, _) -> ign_intr f v 5 | 6 | let connect (host, port) = 7 | let connect () = 8 | try 9 | let sockaddr = Unix.ADDR_INET (host, port) in 10 | let s = Unix.(socket PF_INET SOCK_STREAM 0) in 11 | Unix.(connect s sockaddr); 12 | Ok s 13 | with 14 | | Unix.Unix_error (err, f, _) -> 15 | Logs.err (fun m -> m "unix error in %s: %s" f (Unix.error_message err)); 16 | Error (`Msg "connect failure") 17 | in 18 | connect () 19 | 20 | let read fd = 21 | try 22 | let rec r b ?(off = 0) l = 23 | if l = 0 then 24 | Ok () 25 | else 26 | let read = ign_intr (Unix.read fd b off) l in 27 | if read = 0 then 28 | Error (`Msg "end of file") 29 | else 30 | r b ~off:(read + off) (l - read) 31 | in 32 | let bl = Bytes.create 8 in 33 | let* () = r bl 8 in 34 | let l = Bytes.get_int64_be bl 0 in 35 | let l_int = Int64.to_int l in (* TODO *) 36 | let b = Bytes.create l_int in 37 | let* () = r b l_int in 38 | Ok (Bytes.unsafe_to_string b) 39 | with 40 | Unix.Unix_error (err, f, _) -> 41 | Logs.err (fun m -> m "Unix error in %s: %s" f (Unix.error_message err)); 42 | Error (`Msg "unix error in read") 43 | 44 | let read_cmd fd = 45 | let* data = read fd in 46 | Configuration.cmd_of_str data 47 | 48 | let write fd data = 49 | try 50 | let rec w b ?(off = 0) l = 51 | if l = 0 then 52 | () 53 | else 54 | let written = ign_intr (Unix.write fd b off) l in 55 | w b ~off:(written + off) (l - written) 56 | in 57 | let len = 8 + String.length data in 58 | let csl = Bytes.create len in 59 | Bytes.set_int64_be csl 0 (Int64.of_int (String.length data)); 60 | Bytes.blit_string data 0 csl 8 (String.length data); 61 | w csl len; 62 | Ok () 63 | with 64 | Unix.Unix_error (err, f, _) -> 65 | Logs.err (fun m -> m "Unix error in %s: %s" f (Unix.error_message err)); 66 | Error (`Msg "unix error in write") 67 | 68 | module H = Digestif.SHA256 69 | 70 | let write_cmd fd key cmd = 71 | let data = Configuration.cmd_to_str cmd in 72 | let auth = H.(to_raw_string (hmac_string ~key data)) in 73 | write fd (auth ^ data) 74 | 75 | let write_read_print key remote cmd = 76 | let* s = connect remote in 77 | let* () = write_cmd s key cmd in 78 | let* cmd = read_cmd s in 79 | Unix.close s; 80 | Logs.app (fun m -> m "result: %a" Configuration.pp_cmd cmd); 81 | Ok () 82 | 83 | let list () key remote = 84 | write_read_print key remote Configuration.List 85 | 86 | let add () key remote sni host port = 87 | write_read_print key remote (Configuration.Add (sni, host, port)) 88 | 89 | let remove () key remote sni = 90 | write_read_print key remote (Configuration.Remove sni) 91 | 92 | let help () man_format cmds = function 93 | | None -> `Help (`Pager, None) 94 | | Some t when List.mem t cmds -> `Help (man_format, Some t) 95 | | Some x -> 96 | print_endline ("unknown command '" ^ x ^ "', available commands:"); 97 | List.iter print_endline cmds; 98 | `Ok () 99 | 100 | let setup_log style_renderer level = 101 | Fmt_tty.setup_std_outputs ?style_renderer (); 102 | Logs.set_level level; 103 | Logs.set_reporter (Logs_fmt.reporter ~dst:Format.std_formatter ()) 104 | 105 | open Cmdliner 106 | 107 | let host_port : (Unix.inet_addr * int) Arg.conv = 108 | let parse s = 109 | match String.split_on_char ':' s with 110 | | [ hostname ; port ] -> 111 | begin try 112 | `Ok (Unix.inet_addr_of_string hostname, int_of_string port) 113 | with 114 | Not_found -> `Error "failed to parse IP:port" 115 | end 116 | | _ -> `Error "broken: no port specified" 117 | in 118 | parse, fun ppf (h, p) -> Format.fprintf ppf "%s:%d" 119 | (Unix.string_of_inet_addr h) p 120 | 121 | let remote = 122 | let doc = "The remote host:port to connect to" in 123 | Arg.(value & opt host_port (Unix.inet_addr_loopback, 1234) & 124 | info [ "r" ; "remote" ] ~doc ~docv:"IP:PORT") 125 | 126 | let key = 127 | let doc = "The shared secret" in 128 | Arg.(value & opt string "" & info [ "key" ] ~doc ~docv:"KEY") 129 | 130 | let hn : [`host] Domain_name.t Arg.conv = 131 | let parse s = 132 | match Domain_name.of_string s with 133 | | Error `Msg m -> `Error m 134 | | Ok d -> match Domain_name.host d with 135 | | Error `Msg m -> `Error m 136 | | Ok h -> `Ok h 137 | in 138 | parse, Domain_name.pp 139 | 140 | let sni = 141 | let doc = "The SNI." in 142 | Arg.(required & pos 0 (some hn) None & info [ ] ~doc ~docv:"SNI") 143 | 144 | let setup_log = 145 | Term.(const setup_log 146 | $ Fmt_cli.style_renderer () 147 | $ Logs_cli.level ()) 148 | 149 | let list_cmd = 150 | let term = Term.(term_result (const list $ setup_log $ key $ remote)) in 151 | Cmd.(v (info "list") term) 152 | 153 | let ip_conv : Ipaddr.t Arg.conv = 154 | let parse s = 155 | match Ipaddr.of_string s with 156 | | Ok ip -> `Ok ip 157 | | Error `Msg msg -> `Error msg 158 | in 159 | parse, Ipaddr.pp 160 | 161 | let ip = 162 | let doc = "The IP address." in 163 | Arg.(required & pos 1 (some ip_conv) None & info [] ~doc ~docv:"IP") 164 | 165 | let port = 166 | let doc = "The port." in 167 | Arg.(required & pos 2 (some int) None & info [] ~doc ~docv:"PORT") 168 | 169 | let add_cmd = 170 | let term = 171 | Term.(term_result (const add $ setup_log $ key $ remote $ sni $ ip $ port)) 172 | in 173 | Cmd.(v (info "add") term) 174 | 175 | let remove_cmd = 176 | let term = 177 | Term.(term_result (const remove $ setup_log $ key $ remote $ sni)) 178 | in 179 | Cmd.(v (info "remove") term) 180 | 181 | let cmds = [ list_cmd ; add_cmd ; remove_cmd ] 182 | 183 | let () = 184 | let info = 185 | let doc = "Tlstunnel configuration client" in 186 | Cmd.info "tlstunnel-client" ~doc 187 | and help = 188 | Term.(ret (const help $ setup_log $ Arg.man_format $ choice_names $ const None)) 189 | in 190 | let group = Cmd.group ~default:help info cmds in 191 | exit (Cmd.eval group) 192 | 193 | -------------------------------------------------------------------------------- /unikernel/filesystem.ml: -------------------------------------------------------------------------------- 1 | module Make (Block : Mirage_block.S) = struct 2 | module H = Digestif.SHA256 3 | 4 | let s_version = 1 5 | 6 | module IS = Set.Make(Int64) 7 | 8 | type superblock = { 9 | super_version : int ; (* 2 byte *) 10 | (* padding - 6 byte *) 11 | super_counter : int ; (* 8 byte *) 12 | timestamp : Ptime.t ; (* 8 byte *) 13 | active_sector : int64 ; (* 8 byte *) 14 | data_length : int ; (* 8 byte *) 15 | data_checksum : string ; 16 | (* padding until length - 32 *) 17 | (* superblock_checksum : string ; *) 18 | used_sectors : IS.t ; (* not persistent *) 19 | } 20 | 21 | let superblock_size = 40 + 2 * H.digest_size 22 | 23 | let empty_superblock () = { 24 | super_version = s_version ; 25 | super_counter = 0 ; 26 | timestamp = Mirage_ptime.now () ; 27 | active_sector = 0L ; 28 | data_length = 0 ; 29 | data_checksum = "" ; 30 | used_sectors = IS.empty ; 31 | } 32 | 33 | let ns_per_day = Int64.mul 86_000L 1_000_000_000L 34 | let ps_per_ns = 1_000L 35 | 36 | let decode_timestamp data off = 37 | let ns = String.get_int64_be data off in 38 | let d = Int64.unsigned_div ns ns_per_day 39 | and ps = Int64.(mul (unsigned_rem ns ns_per_day) ps_per_ns) 40 | in 41 | Ptime.v (Int64.to_int d, ps) 42 | 43 | let encode_timestamp data off v = 44 | let d, ps = Ptime.Span.to_d_ps (Ptime.to_span v) in 45 | let ns = Int64.(add (mul (Int64.of_int d) ns_per_day) (div ps ps_per_ns)) in 46 | Bytes.set_int64_be data off ns 47 | 48 | let safe_int ~msg d = 49 | if d > Int64.of_int max_int then 50 | Error (`Overflow msg) 51 | else 52 | Ok (Int64.to_int d) 53 | 54 | type decode_err = 55 | [ `Overflow of string |`Bad_checksum | `Bad_superblock_version of int ] 56 | 57 | let pp_decode_err ppf = function 58 | | `Overflow msg -> Fmt.pf ppf "integer overflow %s" msg 59 | | `Bad_checksum -> Fmt.string ppf "bad superblock checksum" 60 | | `Bad_superblock_version v -> 61 | Fmt.pf ppf "superblock version %d is not supported (supported is %d)" 62 | v s_version 63 | 64 | let decode_superblock buf : (superblock, [> decode_err ]) result = 65 | let payload, checksum = 66 | let mid = String.length buf - H.digest_size in 67 | String.sub buf 0 mid, String.sub buf mid H.digest_size 68 | in 69 | if String.equal checksum H.(to_raw_string (digest_string payload)) then 70 | let super_version = String.get_uint16_be payload 0 71 | and super_counter = String.get_int64_be payload 8 72 | and timestamp = decode_timestamp payload 16 73 | and active_sector = String.get_int64_be payload 24 74 | and data_length = String.get_int64_be payload 32 75 | and data_checksum = String.sub payload 40 H.digest_size 76 | in 77 | let ( let* ) = Result.bind in 78 | if super_version = s_version then 79 | let* super_counter = safe_int ~msg:"superblock counter" super_counter in 80 | let* data_length = safe_int ~msg:"data length" data_length in 81 | Ok { super_version ; super_counter ; timestamp ; active_sector ; 82 | data_length ; data_checksum ; used_sectors = IS.empty } 83 | else 84 | Error (`Bad_superblock_version super_version) 85 | else 86 | Error `Bad_checksum 87 | 88 | let encode_superblock t buf = 89 | Bytes.set_uint16_be buf 0 t.super_version; 90 | Bytes.set_int64_be buf 8 (Int64.of_int t.super_counter); 91 | encode_timestamp buf 16 t.timestamp; 92 | Bytes.set_int64_be buf 24 t.active_sector; 93 | Bytes.set_int64_be buf 32 (Int64.of_int t.data_length); 94 | Bytes.blit_string t.data_checksum 0 buf 40 H.digest_size; 95 | let eop = Bytes.length buf - H.digest_size in 96 | let payload = Bytes.sub_string buf 0 eop in 97 | let checksum = H.(to_raw_string (digest_string payload)) in 98 | Bytes.blit_string checksum 0 buf eop H.digest_size 99 | 100 | let lwt_err_to_msg ~pp_error f = 101 | let open Lwt.Infix in 102 | f >|= Result.map_error (fun e -> `Msg (Fmt.to_to_string pp_error e)) 103 | 104 | let read_data block = 105 | let open Lwt.Infix in 106 | Block.get_info block >>= fun info -> 107 | let open Lwt_result.Infix in 108 | let ss = info.Mirage_block.sector_size in 109 | assert (ss >= superblock_size); 110 | let data_per_sector = ss - 8 in (* each sector is prefixed by a next pointer *) 111 | let super_data_first, super_data_last = Cstruct.create ss, Cstruct.create ss in 112 | let first_super, last_super = 0L, Int64.pred info.Mirage_block.size_sectors in 113 | lwt_err_to_msg ~pp_error:Block.pp_error 114 | (Block.read block first_super [ super_data_first ]) >>= fun () -> 115 | lwt_err_to_msg ~pp_error:Block.pp_error 116 | (Block.read block last_super [ super_data_last ]) >>= fun () -> 117 | Lwt_result.lift 118 | (match decode_superblock (Cstruct.to_string super_data_first), 119 | decode_superblock (Cstruct.to_string super_data_last) 120 | with 121 | | Ok a, Ok b -> 122 | (match compare a.super_counter b.super_counter with 123 | | 0 -> Ok (a, None) 124 | | 1 -> Ok (a, Some (last_super, super_data_first)) 125 | | -1 -> Ok (b, Some (first_super, super_data_last)) 126 | | _ -> assert false) 127 | | Error `Bad_checksum, Ok b -> Ok (b, Some (first_super, super_data_last)) 128 | | Ok a, Error `Bad_checksum -> Ok (a, Some (last_super, super_data_first)) 129 | | Error a, _ -> Error a 130 | | _, Error b -> Error b) >>= fun (superblock, to_write) -> 131 | let scratch = Cstruct.create ss in 132 | let rec read_one sectors data sector = 133 | match sector = 0L, Cstruct.length data = 0 with 134 | | true, true -> Lwt.return (Ok sectors) 135 | | false, false -> 136 | lwt_err_to_msg ~pp_error:Block.pp_error 137 | (Block.read block sector [ scratch ]) >>= fun () -> 138 | let next = Cstruct.BE.get_uint64 scratch 0 in 139 | let len = min (Cstruct.length data) data_per_sector in 140 | Cstruct.blit scratch 8 data 0 len; 141 | read_one (IS.add sector sectors) (Cstruct.shift data len) next 142 | | true, false -> Lwt.return (Error (`Msg "early end of data")) 143 | | false, true -> Lwt.return (Error (`Msg "sector chain exceeds data")) 144 | in 145 | let data = Cstruct.create superblock.data_length in 146 | read_one IS.empty data superblock.active_sector >>= fun used_sectors -> 147 | let data = Cstruct.to_string data in 148 | if String.equal superblock.data_checksum H.(to_raw_string (digest_string data)) then 149 | (match to_write with 150 | | None -> Lwt.return (Ok ()) 151 | | Some (idx, d) -> 152 | lwt_err_to_msg ~pp_error:Block.pp_write_error 153 | (Block.write block idx [ d ])) >|= fun () -> 154 | { superblock with used_sectors }, data 155 | else 156 | Lwt.return (Error (`Msg "bad data checksum")) 157 | 158 | let write_data block old_superblock data = 159 | let open Lwt.Infix in 160 | (* first check that we could write data on block without overwriting the old data *) 161 | Block.get_info block >>= fun info -> 162 | let ss = info.Mirage_block.sector_size 163 | and sectors = info.Mirage_block.size_sectors 164 | in 165 | assert (ss >= superblock_size); 166 | let data = Cstruct.of_string data in 167 | let data_per_sector = ss - 8 in (* each sector is prefixed by a next pointer *) 168 | let sectors_needed = (Cstruct.length data + (pred ss)) / data_per_sector in 169 | if 2 + sectors_needed + IS.cardinal old_superblock.used_sectors > Int64.to_int sectors then 170 | Lwt.return (Error (`Msg "not enough blocks")) 171 | else 172 | (* write data *) 173 | let open Lwt_result.Infix in 174 | let data_sector = Cstruct.create ss in 175 | let rec is_free i = 176 | if Int64.succ i >= sectors then 177 | Error (`Msg "no more sectors") (* according to the test above this should not happen *) 178 | else if IS.mem i old_superblock.used_sectors then 179 | is_free (Int64.succ i) 180 | else 181 | Ok i 182 | in 183 | let rec write_one sector data acc = 184 | (if Cstruct.length data <= data_per_sector then 185 | Lwt.return (Ok 0L) 186 | else 187 | Lwt_result.lift (is_free (Int64.succ sector))) >>= fun next -> 188 | Cstruct.BE.set_uint64 data_sector 0 next; 189 | let len = min (Cstruct.length data) data_per_sector in 190 | Cstruct.blit data 0 data_sector 8 len; 191 | lwt_err_to_msg ~pp_error:Block.pp_write_error 192 | (Block.write block sector [ data_sector ]) >>= fun () -> 193 | let acc' = IS.add sector acc in 194 | if next = 0L then 195 | Lwt.return (Ok acc') 196 | else 197 | write_one next (Cstruct.shift data data_per_sector) acc' 198 | in 199 | Lwt_result.lift (is_free 1L) >>= fun first_sector -> 200 | write_one first_sector data IS.empty >>= fun used_sectors -> 201 | let superblock = 202 | let empty = empty_superblock () in 203 | { 204 | empty with 205 | super_counter = succ old_superblock.super_counter ; 206 | active_sector = first_sector ; 207 | data_length = Cstruct.length data ; 208 | data_checksum = H.(to_raw_string (digest_string (Cstruct.to_string data))) ; 209 | used_sectors ; 210 | } 211 | in 212 | let s = Bytes.create ss in 213 | encode_superblock superblock s; 214 | let s = Cstruct.of_bytes s in 215 | lwt_err_to_msg ~pp_error:Block.pp_write_error 216 | (Block.write block (Int64.pred sectors) [ s ]) >>= fun () -> 217 | lwt_err_to_msg ~pp_error:Block.pp_write_error 218 | (Block.write block 0L [ s ]) >|= fun () -> 219 | superblock 220 | 221 | let init block = 222 | let open Lwt.Infix in 223 | let superblock = 224 | let empty = empty_superblock () in 225 | { empty with data_checksum = H.(to_raw_string (digest_string "")) } 226 | in 227 | Block.get_info block >>= fun info -> 228 | let ss = info.Mirage_block.sector_size in 229 | assert (ss >= superblock_size); 230 | let s = Bytes.create ss in 231 | encode_superblock superblock s; 232 | let last_sector = Int64.pred info.Mirage_block.size_sectors in 233 | let open Lwt_result.Infix in 234 | let s = Cstruct.of_bytes s in 235 | lwt_err_to_msg ~pp_error:Block.pp_write_error 236 | (Block.write block last_sector [ s ]) >>= fun () -> 237 | lwt_err_to_msg ~pp_error:Block.pp_write_error 238 | (Block.write block 0L [ s ]) >|= fun () -> 239 | superblock 240 | end 241 | -------------------------------------------------------------------------------- /unikernel/unikernel.ml: -------------------------------------------------------------------------------- 1 | (* (c) 2020 Hannes Mehnert, all rights reserved *) 2 | 3 | (* left to do: 4 | - haproxy1 support (PROXY TCP4|6 SOURCEIP DESTIP SRCPORT DESTPORT\r\n) at the beginning of the TCP connection to the backend 5 | - NG: apart from SNI allow other ports to be redirected (no proxy) 6 | *) 7 | 8 | module K = struct 9 | open Cmdliner 10 | 11 | let host = 12 | Arg.conv ~docv:"HOSTNAME" 13 | ((fun s -> Result.bind (Domain_name.of_string s) Domain_name.host), 14 | Domain_name.pp) 15 | 16 | let key_v = 17 | Arg.conv ~docv:"HOST:HASH:DATA" 18 | Dns.Dnskey.(name_key_of_string, 19 | (fun ppf v -> Fmt.string ppf (name_key_to_string v))) 20 | 21 | let frontend_port = 22 | let doc = Arg.info ~doc:"The TCP port of the frontend." ["frontend-port"] in 23 | Mirage_runtime.register_arg Arg.(value & opt int 443 doc) 24 | 25 | let key = 26 | let doc = Arg.info ~doc:"The shared secret" ["key"] in 27 | Mirage_runtime.register_arg Arg.(required & opt (some string) None doc) 28 | 29 | let configuration_port = 30 | let doc = Arg.info ~doc:"The TCP port for configuration." ["configuration-port"] in 31 | Mirage_runtime.register_arg Arg.(value & opt int 1234 doc) 32 | 33 | let dns_key = 34 | let doc = Arg.info ~doc:"nsupdate key" ["dns-key"] in 35 | Mirage_runtime.register_arg Arg.(required & opt (some key_v) None doc) 36 | 37 | let dns_server = 38 | let doc = Arg.info ~doc:"dns server IP" ["dns-server"] in 39 | Mirage_runtime.register_arg 40 | Arg.(required & opt (some Mirage_runtime_network.Arg.ip_address) None doc) 41 | 42 | let domains = 43 | let doc = Arg.info ~doc:"domains" ["domains"] in 44 | Mirage_runtime.register_arg Arg.(value & opt_all host [] doc) 45 | 46 | let key_seed = 47 | let doc = Arg.info ~doc:"certificate key seed" ["key-seed"] in 48 | Mirage_runtime.register_arg Arg.(required & opt (some string) None doc) 49 | end 50 | 51 | open Lwt.Infix 52 | 53 | module Main (Block : Mirage_block.S) (Public : Tcpip.Stack.V4V6) (Private : Tcpip.Stack.V4V6) = struct 54 | let snis = 55 | let create ~f = 56 | let data : (string, int) Hashtbl.t = Hashtbl.create 7 in 57 | (fun x -> 58 | let key = f x in 59 | let cur = match Hashtbl.find_opt data key with 60 | | None -> 0 61 | | Some x -> x 62 | in 63 | Hashtbl.replace data key (succ cur)), 64 | (fun () -> 65 | let data, total = 66 | Hashtbl.fold (fun key value (acc, total) -> 67 | (Metrics.uint key value :: acc), value + total) 68 | data ([], 0) 69 | in 70 | Metrics.uint "total" total :: data) 71 | in 72 | let src = 73 | let open Metrics in 74 | let doc = "Counter metrics" in 75 | let incr, get = create ~f:Fun.id in 76 | let data thing = incr thing; Data.v (get ()) in 77 | Src.v ~doc ~tags:Metrics.Tags.[] ~data "tlstunnel" 78 | in 79 | (fun r -> Metrics.add src (fun x -> x) (fun d -> d r)) 80 | 81 | let access kind = 82 | let s = ref (0, 0) in 83 | let open Metrics in 84 | let doc = "connection statistics" in 85 | let data () = 86 | Data.v [ 87 | int "active" (fst !s) ; 88 | int "total" (snd !s) ; 89 | ] in 90 | let tags = Tags.string "kind" in 91 | let src = Src.v ~doc ~tags:Tags.[ tags ] ~data "connections" in 92 | (fun action -> 93 | (match action with 94 | | `Open -> s := (succ (fst !s), succ (snd !s)) 95 | | `Close -> s := (pred (fst !s), snd !s)); 96 | Metrics.add src (fun x -> x kind) (fun d -> d ())) 97 | 98 | let frontend_access = access "frontend" 99 | let tls_access = access "tls" 100 | let config_access = access "config" 101 | let http_access = access "http" 102 | let backend_access = access "backend" 103 | 104 | module FS = Filesystem.Make(Block) 105 | 106 | type config = { 107 | mutable superblock : FS.superblock ; 108 | mutable sni : (Ipaddr.t * int) Domain_name.Host_map.t ; 109 | } 110 | 111 | let read_configuration block = 112 | FS.read_data block >>= function 113 | | Error `Bad_checksum -> 114 | (FS.init block >>= function 115 | | Ok superblock -> 116 | Lwt.return { superblock ; sni = Domain_name.Host_map.empty } 117 | | Error `Msg e -> 118 | Logs.err (fun m -> m "error initializing the block device %s" e); 119 | Lwt.fail_with "initializing block device") 120 | | Error `Msg msg -> 121 | Logs.err (fun m -> m "error reading block device %s" msg); 122 | Lwt.fail_with "reading block device" 123 | | Error (#FS.decode_err as e) -> 124 | Logs.err (fun m -> m "error reading block device %a" FS.pp_decode_err e); 125 | Lwt.fail_with "reading block device" 126 | | Ok (superblock, data) -> 127 | Logs.info (fun m -> m "read from %a (counter %u) %u bytes data" 128 | (Ptime.pp_rfc3339 ()) superblock.FS.timestamp 129 | superblock.FS.super_counter 130 | superblock.FS.data_length); 131 | let config = { superblock ; sni = Domain_name.Host_map.empty } in 132 | if String.length data > 0 then begin 133 | let sni = Configuration.decode_data data in 134 | config.sni <- sni; 135 | end; 136 | Logs.info (fun m -> m "SNI map has %d entries" 137 | (Domain_name.Host_map.cardinal config.sni)); 138 | Lwt.return config 139 | 140 | let write_configuration block config = 141 | let open Lwt_result.Infix in 142 | let data = Configuration.encode_data config.sni in 143 | FS.write_data block config.superblock data >|= fun superblock -> 144 | config.superblock <- superblock 145 | 146 | let handle_config block config cmd = 147 | match cmd with 148 | | Configuration.Add (sni, host, port) -> 149 | begin 150 | let snis = Configuration.add_sni config.sni (sni, host, port) in 151 | config.sni <- snis; 152 | write_configuration block config >|= function 153 | | Ok () -> 154 | let msg = 155 | Format.asprintf "%a was successfully added" Domain_name.pp sni 156 | in 157 | Configuration.Result (0, msg) 158 | | Error `Msg m -> 159 | let msg = Format.asprintf "error %s adding %a" m Domain_name.pp sni in 160 | Configuration.Result (1, msg) 161 | end 162 | | Configuration.Remove sni -> 163 | begin 164 | let snis = Configuration.remove_sni config.sni sni in 165 | config.sni <- snis; 166 | write_configuration block config >|= function 167 | | Ok () -> 168 | let msg = 169 | Format.asprintf "%a was successfylly removed" Domain_name.pp sni 170 | in 171 | Configuration.Result (0, msg) 172 | | Error `Msg m -> 173 | let msg = 174 | Format.asprintf "error %s removing %a" m Domain_name.pp sni 175 | in 176 | Configuration.Result (1, msg) 177 | end 178 | | Configuration.List -> 179 | let snis = 180 | Domain_name.Host_map.fold 181 | (fun sni (host, port) acc -> (sni, host, port) :: acc) 182 | config.sni [] 183 | in 184 | Lwt.return (Configuration.Snis snis) 185 | | _ -> 186 | Lwt.return (Configuration.Result (1, "unexpected")) 187 | 188 | let handle_command block config data = 189 | (match Configuration.cmd_of_str data with 190 | | Ok cmd -> handle_config block config cmd 191 | | Error `Msg err -> Lwt.return (Configuration.Result (2, err))) >|= fun reply -> 192 | Configuration.cmd_to_str reply 193 | 194 | module H = Digestif.SHA256 195 | 196 | let auth key data = 197 | if String.length data > H.digest_size then 198 | let auth, data = 199 | String.sub data 0 H.digest_size, 200 | String.sub data H.digest_size (String.length data - H.digest_size) 201 | in 202 | if String.equal H.(to_raw_string (hmac_string ~key data)) auth then 203 | Some data 204 | else 205 | None 206 | else 207 | None 208 | 209 | let config_cmd block config key data = 210 | match auth key data with 211 | | None -> Lwt.return (Configuration.cmd_to_str (Configuration.Result (3, "authentication failure"))) 212 | | Some data -> handle_command block config data 213 | 214 | let config_change block config key tcp = 215 | config_access `Open; 216 | (Private.TCP.read tcp >>= function 217 | | Error e -> 218 | Logs.err (fun m -> m "config TCP read error %a" Private.TCP.pp_error e); 219 | Lwt.return_unit 220 | | Ok `Eof -> 221 | Logs.warn (fun m -> m "config TCP read eof"); 222 | Lwt.return_unit 223 | | Ok `Data buf -> 224 | let buf' = Cstruct.to_string ~off:8 buf in 225 | let l = Cstruct.BE.get_uint64 buf 0 in 226 | if String.length buf' = Int64.to_int l then 227 | config_cmd block config key buf' >>= fun res -> 228 | let buf = Cstruct.create (8 + String.length res) in 229 | Cstruct.BE.set_uint64 buf 0 (Int64.of_int (String.length res)); 230 | Cstruct.blit_from_string res 0 buf 8 (String.length res); 231 | Private.TCP.write tcp buf >|= function 232 | | Ok () -> () 233 | | Error e -> 234 | Logs.warn (fun m -> m "config TCP write error %a" Private.TCP.pp_write_error e) 235 | else begin 236 | Logs.warn (fun m -> m "truncated config message"); 237 | Lwt.return_unit 238 | end) 239 | >>= fun () -> 240 | config_access `Close; 241 | Private.TCP.close tcp 242 | 243 | module TLS = Tls_mirage.Make(Public.TCP) 244 | 245 | let extract_location content = 246 | (* we assume a HTTP request in here, and want to reply with a moved 247 | permanently (301) carrying a location header of the form 248 | Location: https:/// 249 | So we decode the incoming read data for 250 | (a) "HTTP method" "URL" (anything else) 251 | (b) "Host:" header *) 252 | match List.map String.trim (String.split_on_char '\n' content) with 253 | | request :: headers -> 254 | begin 255 | match 256 | String.split_on_char ' ' request, 257 | List.find_opt (fun x -> 258 | String.length x >= 5 && 259 | String.sub (String.lowercase_ascii x) 0 5 = "host:") 260 | headers 261 | with 262 | | _method :: url :: _, Some host -> 263 | begin match String.split_on_char ':' host with 264 | | _hdr :: host_els -> 265 | let host = String.concat ":" host_els in 266 | let loc = ["https://" ; String.trim host ; url ] in 267 | Some (String.concat "" loc) 268 | | _ -> 269 | Logs.warn (fun m -> m "no name in host header %S" host); 270 | None 271 | end 272 | | _ -> 273 | Logs.warn (fun m -> m "no url or host header found in %S" content); 274 | None 275 | end 276 | | [] -> 277 | Logs.warn (fun m -> m "no http header found in %S" content); 278 | None 279 | 280 | let http_reply ?(body = "") ?(headers = []) ~status_code status = 281 | let status = Printf.sprintf "HTTP/1.1 %u %s" status_code status 282 | and headers = 283 | "Server: OCaml TLStunnel" :: 284 | Printf.sprintf "Content-Length: %u" (String.length body) :: 285 | (if body = "" then [] else [ "Content-Type: text/plain; charset=utf-8" ]) @ 286 | headers 287 | in 288 | String.concat "\r\n" (status :: headers @ [ "" ; body ]) 289 | 290 | let redirect tcp = 291 | http_access `Open; 292 | Public.TCP.read tcp >>= fun data -> 293 | let reply = match data with 294 | | Error e -> 295 | Logs.err (fun m -> m "TCP error %a" Public.TCP.pp_error e); 296 | None 297 | | Ok `Eof -> 298 | Logs.err (fun m -> m "TCP eof"); 299 | None 300 | | Ok `Data data -> 301 | (* this is slighly brittle since it only uses the first bytes read() *) 302 | extract_location (Cstruct.to_string data) 303 | in 304 | (match reply with 305 | | None -> Lwt.return_unit 306 | | Some data -> 307 | let reply = 308 | http_reply ~headers:[ "Location: " ^ data ] ~status_code:301 309 | "Moved permanently" 310 | in 311 | Public.TCP.write tcp (Cstruct.of_string reply) >|= function 312 | | Ok () -> () 313 | | Error e -> 314 | Logs.err (fun m -> m "error %a sending redirect" Public.TCP.pp_write_error e)) 315 | >>= fun () -> 316 | http_access `Close; 317 | Public.TCP.close tcp 318 | 319 | let close tls tcp = 320 | tls_access `Close; 321 | frontend_access `Close; 322 | backend_access `Close; 323 | Private.TCP.close tcp >>= fun () -> 324 | TLS.close tls 325 | 326 | let rec read_tls_write_tcp tls tcp = 327 | TLS.read tls >>= function 328 | | Error e -> 329 | Logs.err (fun m -> m "TLS read error %a" TLS.pp_error e); 330 | close tls tcp 331 | | Ok `Eof -> close tls tcp 332 | | Ok `Data buf -> 333 | Private.TCP.write tcp buf >>= function 334 | | Error e -> 335 | Logs.err (fun m -> m "TCP write error %a" Private.TCP.pp_write_error e); 336 | close tls tcp 337 | | Ok () -> 338 | read_tls_write_tcp tls tcp 339 | 340 | let rec read_tcp_write_tls tcp tls = 341 | Private.TCP.read tcp >>= function 342 | | Error e -> 343 | Logs.err (fun m -> m "TCP read error %a" Private.TCP.pp_error e); 344 | close tls tcp 345 | | Ok `Eof -> close tls tcp 346 | | Ok `Data buf -> 347 | TLS.write tls buf >>= function 348 | | Error e -> 349 | Logs.err (fun m -> m "TLS write error %a" TLS.pp_write_error e); 350 | close tls tcp 351 | | Ok () -> 352 | read_tcp_write_tls tcp tls 353 | 354 | let default_host = Domain_name.(host_exn (of_string_exn "default")) 355 | 356 | let tls_accept priv config tls_config tcp_flow = 357 | frontend_access `Open; 358 | (* TODO this should timeout the TLS handshake with a reasonable timer *) 359 | TLS.server_of_flow tls_config tcp_flow >>= function 360 | | Error e -> 361 | Logs.warn (fun m -> m "TLS error %a" TLS.pp_write_error e); 362 | frontend_access `Close; 363 | Public.TCP.close tcp_flow 364 | | Ok tls_flow -> 365 | tls_access `Open; 366 | let close () = 367 | tls_access `Close; 368 | frontend_access `Close; 369 | TLS.close tls_flow >>= fun () -> 370 | Public.TCP.close tcp_flow 371 | in 372 | match TLS.epoch tls_flow with 373 | | Ok epoch -> 374 | begin 375 | let sni, sni_text = 376 | let default () = 377 | Domain_name.Host_map.find_opt default_host config.sni 378 | in 379 | match epoch.Tls.Core.own_name with 380 | | None -> 381 | snis "no sni"; 382 | default (), "no sni" 383 | | Some sni -> 384 | let r = 385 | match Domain_name.Host_map.find_opt sni config.sni with 386 | | None -> 387 | Logs.warn (fun m -> m "server name %a not configured" 388 | Domain_name.pp sni); 389 | default () 390 | | Some (host, port) -> 391 | snis (Domain_name.to_string sni); 392 | Some (host, port) 393 | in 394 | r, Domain_name.to_string sni 395 | in 396 | match sni with 397 | | None -> 398 | let reply = 399 | http_reply 400 | ~body:("Couldn't figure which service you want ('" ^ sni_text ^ "'), and no default is configured") 401 | ~status_code:404 "Not Found" 402 | in 403 | TLS.write tls_flow (Cstruct.of_string reply) >>= fun _ -> 404 | close () 405 | | Some (host, port) -> 406 | Private.TCP.create_connection priv (host, port) >>= function 407 | | Error e -> 408 | Logs.err (fun m -> m "error %a connecting to backend" 409 | Private.TCP.pp_error e); 410 | let reply = 411 | http_reply 412 | ~body:("Couldn't connect to backend service for '" ^ sni_text ^ "', please come back later") 413 | ~status_code:502 "Bad gateway" 414 | in 415 | TLS.write tls_flow (Cstruct.of_string reply) >>= fun _ -> 416 | close () 417 | | Ok tcp_flow -> 418 | backend_access `Open; 419 | Lwt.pick [ 420 | read_tls_write_tcp tls_flow tcp_flow ; 421 | read_tcp_write_tls tcp_flow tls_flow 422 | ] >>= fun () -> 423 | close () >>= fun () -> 424 | Private.TCP.close tcp_flow 425 | end 426 | | Error () -> 427 | Logs.warn (fun m -> m "unexpected error retrieving the TLS session"); 428 | close () 429 | 430 | module D = Dns_certify_mirage.Make(Public) 431 | 432 | let start block pub priv = 433 | read_configuration block >>= fun config -> 434 | Private.TCP.listen (Private.tcp priv) ~port:(K.configuration_port ()) 435 | (config_change block config (K.key ())); 436 | Public.TCP.listen (Public.tcp pub) ~port:80 redirect; 437 | let rec retrieve_certs () = 438 | Lwt_list.fold_left_s (fun acc domain -> 439 | let key_seed = Domain_name.to_string domain ^ ":" ^ (K.key_seed ()) in 440 | D.retrieve_certificate pub (K.dns_key ()) 441 | ~hostname:domain 442 | ~additional_hostnames:[ Domain_name.(append_exn (of_string_exn "*") domain) ] 443 | ~key_seed (K.dns_server ()) 53 >>= function 444 | | Error `Msg err -> Lwt.fail_with err 445 | | Ok certificates -> Lwt.return (certificates :: acc)) 446 | [] (K.domains ()) >>= fun cert_chains -> 447 | (match List.rev cert_chains with 448 | | [] -> failwith "empty certificate chains" 449 | | a :: _ -> Lwt.return a) >>= fun first -> 450 | let certificates = `Multiple_default (first, cert_chains) in 451 | match Tls.Config.server ~certificates () with 452 | | Error `Msg msg -> failwith msg 453 | | Ok tls_config -> 454 | let priv_tcp = Private.tcp priv in 455 | Public.TCP.listen (Public.tcp pub) ~port:(K.frontend_port ()) (tls_accept priv_tcp config tls_config); 456 | let now = Mirage_ptime.now () in 457 | let seven_days_before_expire = 458 | let next_expire = 459 | let expiring = 460 | List.map snd 461 | (List.map X509.Certificate.validity 462 | (List.map (function (s::_, _) -> s | _ -> assert false) 463 | cert_chains)) 464 | in 465 | let diffs = List.map (fun exp -> Ptime.diff exp now) expiring in 466 | let closest_span = List.hd (List.sort Ptime.Span.compare diffs) in 467 | fst (Ptime.Span.to_d_ps closest_span) 468 | in 469 | max (Duration.of_hour 1) (Duration.of_day (max 0 (next_expire - 7))) 470 | in 471 | Mirage_sleep.ns seven_days_before_expire >>= fun () -> 472 | retrieve_certs () 473 | in 474 | retrieve_certs () 475 | end 476 | --------------------------------------------------------------------------------