├── .gitignore ├── LICENSE ├── README.md ├── data ├── canonical_mcmc.py └── target_systems.py ├── experiments ├── LJ3D.yaml ├── config.yaml └── main.py ├── install.sh ├── requirements.txt ├── run_exp.sh ├── src ├── E_model.py ├── dataloader.py ├── diffusion_model.py ├── distance_on_torus.py └── evaluation.py └── toy_example.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .pip_cache/* 2 | .venv/* 3 | **pycache** 4 | data/LJ3D_N=* 5 | experiments/outputs/* 6 | experiments/wandb/* 7 | experiments/LJ3D.ckpt 8 | experiments/wandb.key 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Balint Mate 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural TI 2 | 3 | This repository contains the implementation of the paper 4 | > [Neural Thermodynamic Integration: Free Energies from Energy-based Diffusion Models](https://arxiv.org/abs/2406.02313) by Bálint Máté, François Fleuret and Tristan Bereau. 5 | 6 | ## Environment 7 | The ```install.sh``` script will create a virtualenv necessary to run the experiments. The only requirement for this is python>=3.9. 8 | 9 | ## Toy experiment 10 | The notebook ```toy_example.ipynb``` contains a simple experiment that demonstrates the idea on a 1D Gaussian mixture. 11 | 12 | ## 3D Lennard-Jones experiment 13 | 14 | The ```run_exp.sh``` activates the virtualenv created by ```install.sh``` and then executes ```experiments/main.py``` using the configs ```experiments/config.yaml``` and ```experiments/LJ3D.yaml```. When executing for the first time, it begins with generating the training data using MCMC. The samples are then dumped to files and loaded in later runs. 15 | 16 | 17 | ## Logging 18 | All the plots and metrics are also logged to the ```experiments/wandb```directory by default. If you create a file at ```experiments/wandb.key``` containing your weights and biases key, then all the logs will be pushed to your wandb account. 19 | 20 | ## Citation 21 | If you find our paper or this repository useful, consider citing us at 22 | 23 | ``` 24 | @article{mate2024neural, 25 | author = {Mát{\'e}, Bálint and Fleuret, Fran{\c{c}}ois and Bereau, Tristan}, 26 | title = {Neural Thermodynamic Integration: Free Energies from Energy-Based Diffusion Models}, 27 | journal = {The Journal of Physical Chemistry Letters}, 28 | volume={15}, 29 | number = {45}, 30 | pages = {11395-11404}, 31 | year = {2024}, 32 | doi = {10.1021/acs.jpclett.4c01958}, 33 | note ={PMID: 39503734}, 34 | URL = {https://doi.org/10.1021/acs.jpclett.4c01958}, 35 | eprint = {2406.02313}, 36 | pages={11395--11404}, 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /data/canonical_mcmc.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from dataclasses import dataclass 4 | from data.target_systems import TargetSystemAbs 5 | from typing import Callable 6 | import os, sys, pickle 7 | from functools import partial 8 | import time 9 | from distance_on_torus import dist2_on_torus 10 | import wandb 11 | import optax 12 | 13 | 14 | @dataclass 15 | class Canonical_Sampler: 16 | target_system: TargetSystemAbs 17 | N: int 18 | 19 | @partial(jax.jit, static_argnames=["self"]) 20 | def interaction(self, x): 21 | r2 = dist2_on_torus(x) 22 | r2 = r2 + jnp.eye(len(x)) 23 | return self.target_system.U_ij(r2) 24 | 25 | @partial(jax.jit, static_argnames=["self"]) 26 | def U(self, x): 27 | potential = self.target_system.U_x(x).sum() 28 | mask = 1 - jnp.eye(len(x)) 29 | interaction = (self.interaction(x) * mask).sum() 30 | return potential + interaction 31 | 32 | @partial(jax.jit, static_argnames=["self"]) 33 | def propose(self, x, key): 34 | z = jax.random.normal(key, x.shape) * jnp.sqrt(2 * self.dx) 35 | x = x + z 36 | x = x % 1 37 | return x 38 | 39 | # Monte Carlo sampling 40 | def sample(self, key, dx, return_samples=False): 41 | self.dx = dx 42 | data_path = self.target_system.data_path + f"_N={self.N}" 43 | print(50 * "-") 44 | print(f"N = {self.N}") 45 | if os.path.isfile(data_path) and not return_samples: 46 | print("Data already generated") 47 | return 48 | 49 | ## positon MC sampling 50 | 51 | x0 = jax.random.uniform(key, shape=(self.N, self.target_system.num_dim)) 52 | 53 | # grad descent to spread out the points 54 | def loss(x): 55 | D = dist2_on_torus(x) 56 | mask = 1 - jnp.eye(len(x)) 57 | return ((1 / (D + 1e-4)) * mask).sum() 58 | 59 | optim = optax.adam(learning_rate=1e-4) 60 | opt_state = optim.init(x0) 61 | 62 | @jax.jit 63 | def move(x0, opt_state): 64 | grad = jax.grad(loss)(x0) 65 | updates, opt_state = optim.update(grad, opt_state, x0) 66 | x0 = optax.apply_updates(x0, updates) % 1 67 | mask = jnp.eye(len(x0)) 68 | D = dist2_on_torus(x0) 69 | D = D / self.target_system.sigma**2 70 | D_min = (D + mask).min() 71 | return D_min, x0, opt_state 72 | 73 | D_min = 0 74 | while D_min < 0.7: 75 | D_min, x0, opt_state = move(x0, opt_state) 76 | print(f"D2_min: {D_min:.4f} ", end="\r") 77 | 78 | x_curr = x0 79 | samples_list = [] 80 | i = 0 81 | 82 | @jax.jit 83 | def body_fn(i, carry): 84 | x_traj, U_traj, acc_prob_traj, key = carry 85 | key1, key2 = jax.random.split(key) 86 | x_curr = x_traj[i] 87 | U_curr = U_traj[i] 88 | x_prop = self.propose(x_curr, key1) 89 | U_prop = self.U(x_prop) 90 | U_diff = U_prop - U_curr 91 | acc_prob = jnp.exp(-U_diff) 92 | acc_prob_traj = acc_prob_traj.at[i].set(jnp.clip(acc_prob, max=1)) 93 | take_new = jax.random.uniform(key2, (1,))[0] < acc_prob 94 | x_new = take_new * x_prop + (1 - take_new) * x_curr 95 | x_traj = x_traj.at[i + 1].set(x_new) 96 | U_new = take_new * U_prop + (1 - take_new) * U_curr 97 | U_traj = U_traj.at[i + 1].set(U_new) 98 | key = jax.random.split(key)[0] 99 | return x_traj, U_traj, acc_prob_traj, key 100 | 101 | NUM_TO_SAMPLE = self.target_system.num_samples + self.target_system.burn_in 102 | while i < NUM_TO_SAMPLE: 103 | N = 1000 104 | ####### 105 | x_traj = jnp.zeros((N,) + x_curr.shape) 106 | x_traj = x_traj.at[0].set(x_curr) 107 | U_traj = jnp.zeros((N,)) 108 | U_traj = U_traj.at[0].set(self.U(x_curr)) 109 | acc_prob_traj = jnp.zeros(N) 110 | carry = (x_traj, U_traj, acc_prob_traj, key) 111 | x_traj, U_traj, acc_prob_traj, key = jax.lax.fori_loop( 112 | 0, N - 1, body_fn, carry 113 | ) 114 | x_curr = x_traj[-1] 115 | 116 | i += N 117 | if i > self.target_system.burn_in: 118 | samples_list.append(x_traj) 119 | 120 | i = 0 121 | 122 | print("") 123 | samples = jnp.concatenate(samples_list) 124 | samples = jax.random.permutation(jax.random.key(0), samples, axis=0) 125 | if not return_samples: 126 | with open(data_path, "wb") as file: 127 | pickle.dump(samples, file) 128 | print(f"Generated data has size {os.path.getsize(data_path)/2**20:.1f} MB") 129 | print(f"and shape {samples.shape}") 130 | if wandb.run is not None: 131 | wandb.log({f"acceptance rate/{self.N}": acc_prob_traj.mean()}) 132 | print(50 * "-") 133 | else: 134 | return samples 135 | 136 | def __hash__(self): 137 | return 0 138 | 139 | def __eq__(self, other): 140 | return True 141 | -------------------------------------------------------------------------------- /data/target_systems.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import jax.numpy as jnp 3 | from typing import Sequence 4 | import jax 5 | 6 | 7 | ## eps of LJ is actually 2*eps, because i'm not dividing by 2 when summing LJ(D^2) 8 | @dataclass 9 | class TargetSystemAbs: 10 | 11 | def U_ij(self, r2): 12 | raise NotImplementedError 13 | 14 | def U_x(self, x): 15 | raise NotImplementedError 16 | 17 | 18 | @dataclass 19 | class LJ(TargetSystemAbs): 20 | beta: float = 1 21 | 22 | def U_ij_soft(self, a, r2): 23 | sr6 = (self.sigma**2 / (a * self.sigma**2 + r2)) ** 3 24 | U_ij = 4 * self.eps * (sr6**2 - sr6) 25 | return U_ij 26 | 27 | def U_ij(self, r2): 28 | return self.U_ij_soft(0, r2) 29 | 30 | def U_x(self, x): 31 | return jnp.array(0.0) 32 | 33 | 34 | @dataclass 35 | class LJ3D(LJ): 36 | num_dim: int = 3 37 | sigma: float = 1 / 6 38 | eps: float = 0.4 39 | data_path: str = "../data/LJ3D" 40 | num_samples: int = 30 * 1000 41 | burn_in: int = 10 * 1000 42 | -------------------------------------------------------------------------------- /experiments/LJ3D.yaml: -------------------------------------------------------------------------------- 1 | target_system: 2 | _target_: data.target_systems.LJ3D 3 | 4 | E_model: 5 | _target_: E_model.E_model 6 | size_to_pad: {40: 350, 80: 1200, 120: 2300, 160: 4000, 200: 6500} 7 | target_system: ${target_system} 8 | data_to_generate: [40, 60, 80, 100, 120, 140, 160, 180, 200] 9 | sampling_dx: [3e-5, 1e-5,7e-6, 4e-6,2e-6, 9e-7,7e-7,5e-7, 4e-7] 10 | 11 | train_N_list: [40,80,120,160,200] 12 | eval_N_list: [60,100,140,180] 13 | eval_num_batches: [10,10,10,10] 14 | eval_batch_size: [5,5,5,5] 15 | num_train_steps: 100000 16 | eval_every_n_steps: 5000 17 | 18 | 19 | 20 | batch_size: 32 21 | 22 | 23 | optim: 24 | _target_: optax.adamw 25 | learning_rate: 26 | _target_: optax.exponential_decay 27 | init_value: 1e-3 28 | decay_rate: 1e-1 29 | transition_steps: 50000 30 | end_value: 1e-5 31 | 32 | 33 | 34 | model: 35 | _target_: diffusion_model.diffusion_model 36 | target_system: ${target_system} 37 | num_integration_steps: 1000 38 | sigma_min: 1e-3 39 | sigma_max: 5e-1 40 | 41 | logZ_estimate: 42 | num_batches: 5 43 | batch_size: 10 44 | N_range: [2,201,1] 45 | 46 | -------------------------------------------------------------------------------- /experiments/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: 3 | - _self_ 4 | - LJ3D 5 | 6 | jax_config: 7 | jax_enable_x64: False 8 | jax_numpy_rank_promotion: 'raise' 9 | jax_debug_nans: True 10 | 11 | 12 | wandb_project_name: neural-TI 13 | PRNGKey: 0 14 | cuda_devices: '0' 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /experiments/main.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import wandb 3 | import hydra 4 | import omegaconf 5 | import jax.numpy as jnp 6 | from evaluation import eval_model 7 | from data.canonical_mcmc import Canonical_Sampler 8 | import os 9 | import sys 10 | import optax 11 | import time 12 | import pickle 13 | from dataloader import DataLoader 14 | 15 | 16 | @hydra.main(version_base=None, config_path=".", config_name="config.yaml") 17 | def main(cfg): 18 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.cuda_devices 19 | #### jax flags ### 20 | for cfg_name, cfg_value in cfg.jax_config.items(): 21 | jax.config.update(cfg_name, cfg_value) 22 | try: 23 | wandb_key = open("./wandb.key", "r").read() 24 | wandb.login(key=wandb_key) 25 | run = wandb.init(project=cfg.wandb_project_name) 26 | except: 27 | print("Weights and biases key not found or not valid. Will be logging locally.") 28 | run = wandb.init(project=cfg.wandb_project_name, mode="offline") 29 | wandb.run.log_code("..") 30 | wandb.config.update( 31 | omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) 32 | ) 33 | 34 | print("devices: ", *jax.devices()) 35 | 36 | target_system = hydra.utils.instantiate(cfg.target_system) 37 | 38 | run.tags = run.tags + (f"{target_system.num_dim}D",) 39 | for N in cfg.eval_N_list + cfg.train_N_list: 40 | assert N in cfg.data_to_generate 41 | for i, N in enumerate(cfg.data_to_generate): 42 | sampler = Canonical_Sampler(target_system, N=N) 43 | sampler.sample(key=jax.random.PRNGKey(cfg.PRNGKey), dx=cfg.sampling_dx[i]) 44 | 45 | print(80 * "-") 46 | print("Preparing data, this might take a few minutes...") 47 | eval_dataloaders = [] 48 | for i, N in enumerate(cfg.eval_N_list): 49 | with open(target_system.data_path + f"_N={N}", "rb") as pickle_file: 50 | x = pickle.load(pickle_file) 51 | eval_dataloaders.append(DataLoader(x, batch_size=cfg.eval_batch_size[i])) 52 | train_x = [] 53 | train_n = [] 54 | for i, N in enumerate(cfg.train_N_list): 55 | with open(target_system.data_path + f"_N={N}", "rb") as pickle_file: 56 | x = pickle.load(pickle_file) 57 | padding_shape = (len(x), max(cfg.train_N_list) - N, x.shape[-1]) 58 | x = jnp.concatenate((x, jnp.zeros(padding_shape)), 1) 59 | train_x.append(x) 60 | train_n.append(jnp.full((len(x), 1), N)) 61 | 62 | train_x = jnp.concatenate(train_x) 63 | train_n = jnp.concatenate(train_n) 64 | train_loader = DataLoader(train_x, train_n, batch_size=cfg.batch_size) 65 | 66 | print("Done.") 67 | print(80 * "-") 68 | 69 | ddpm = hydra.utils.instantiate(cfg.model) 70 | ddpm.num_features = target_system.num_dim 71 | ddpm.E_model = hydra.utils.instantiate(cfg.E_model) 72 | ddpm.init_params( 73 | key=jax.random.PRNGKey(cfg.PRNGKey + 1), maxN=max(cfg.train_N_list) 74 | ) 75 | 76 | num_params = sum(x.size for x in jax.tree_util.tree_leaves(ddpm.params)) 77 | print(f"num params: {num_params/1000:.1f}K") 78 | 79 | ## train 80 | optim = hydra.utils.instantiate(cfg.optim) 81 | opt_state = optim.init(ddpm.params) 82 | key = jax.random.PRNGKey(cfg.PRNGKey + 2) 83 | params = ddpm.params 84 | 85 | target_name = cfg.target_system._target_.split(".")[-1] 86 | if "model_name" in cfg.keys(): 87 | target_name += f'.{cfg["model_name"]}' 88 | ckpt_path = f"{target_name}.ckpt" 89 | 90 | @jax.jit 91 | def update_step(key, params, batch, opt_state): 92 | loss_and_grad_fn = jax.value_and_grad(ddpm.loss_fn) 93 | loss, grad = loss_and_grad_fn(params, batch, key) 94 | updates, opt_state = optim.update(grad, opt_state, params) 95 | params = optax.apply_updates(params, updates) 96 | return loss, params, opt_state 97 | 98 | def eval_step(): 99 | 100 | ## save params 101 | file = open(f"{ckpt_path}", "wb") 102 | pickle.dump({"params": params, "opt_state": opt_state, "cfg": cfg}, file) 103 | file.close() 104 | wandb.save(f"{ckpt_path}", policy="now") 105 | 106 | ## logs 107 | ddpm.params = params 108 | print("\nSampling... ", end=" ") 109 | for i, loader in enumerate(eval_dataloaders): 110 | print(f"{cfg.eval_N_list[i]}", end=" ") 111 | sys.stdout.flush() 112 | eval_model(loader, ddpm, target_system, cfg.eval_num_batches[i]) 113 | print("Done.") 114 | 115 | start = time.time() 116 | for s in range(1, cfg.num_train_steps + 1): 117 | 118 | x_batch, n_batch = train_loader.next() 119 | key = jax.random.split(key, 2)[0] 120 | loss, params, opt_state = update_step( 121 | key, params, (x_batch, n_batch), opt_state 122 | ) 123 | wandb.log({"loss": loss}) 124 | print(f"training progress: {s/cfg.num_train_steps:.3f}", end=" ") 125 | print(f"time: {time.time()-start:.2f}s", end=" ") 126 | print(f"loss: {loss:.4f}", end=" \r") 127 | if s % cfg.eval_every_n_steps == 0: 128 | eval_step() 129 | 130 | print(80 * "=") 131 | file = open(f"{ckpt_path}", "wb") 132 | pickle.dump({"params": params, "opt_state": opt_state, "cfg": cfg}, file) 133 | file.close() 134 | wandb.save(f"{ckpt_path}", policy="now") 135 | wandb.save( 136 | f"eval_logs_{target_name}", 137 | policy="now", 138 | ) 139 | run.finish() 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm -rf .venv 3 | python3 -m virtualenv --system-site-packages .venv 4 | source .venv/bin/activate 5 | export PIP_CACHE_DIR=.pip_cache/ 6 | pip3 install --upgrade pip 7 | pip3 install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 8 | pip3 install -r requirements.txt 9 | deactivate -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb 2 | flax 3 | optax 4 | hydra-core 5 | matplotlib 6 | pytest 7 | -------------------------------------------------------------------------------- /run_exp.sh: -------------------------------------------------------------------------------- 1 | source .venv/bin/activate 2 | ps aux|grep wandb|grep -v grep | awk '{print $2}'|xargs kill -9 &> /dev/null 3 | 4 | 5 | export XLA_PYTHON_CLIENT_MEM_FRACTION=.99 6 | export PYTHONPATH=${PWD}:${PYTHONPATH} 7 | 8 | cd src 9 | export PYTHONPATH=${PWD}:${PYTHONPATH} 10 | 11 | cd ../experiments 12 | python3 main.py "$@" #1>out 2>error -------------------------------------------------------------------------------- /src/E_model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | from typing import Sequence 5 | from distance_on_torus import dist2_on_torus, dR_on_torus 6 | from data.target_systems import TargetSystemAbs 7 | from functools import partial 8 | 9 | 10 | class E_model(nn.Module): 11 | target_system: TargetSystemAbs 12 | NN: Sequence[int] = (64, 64) 13 | cutoff: float = 2 / 6 14 | size_to_pad: int = 1800 15 | ## for 2d 16 | # cutoff: float = 3 / 10 17 | # size_to_pad: int = 4000 18 | # cutoff: float = 1 19 | # size_to_pad: int = 110**2 20 | num_features: int = 32 21 | num_vec_features: int = 8 22 | agg_norm: int = 10 23 | num_layers: int = 3 24 | 25 | @nn.compact 26 | def __call__(self, t, x, n): 27 | 28 | num_dim = x.shape[-1] 29 | mask_node = jnp.array(jnp.arange(len(x)) < n, dtype="int32") 30 | mask2 = jnp.einsum("i,j->ij", mask_node, mask_node) 31 | dR = dR_on_torus(x) / self.cutoff 32 | D2 = (dR**2).sum(-1) + (1 - mask2) * 10 # filling particles are far 33 | 34 | def sizetopad(size): 35 | if type(self.size_to_pad) == int: 36 | return self.size_to_pad 37 | else: 38 | keys = filter(lambda k: k >= size, list(self.size_to_pad.keys())) 39 | key = min(keys) 40 | return self.size_to_pad[key] 41 | 42 | edges = jnp.stack( 43 | jnp.where( 44 | (D2 < 1) * (D2 > 0), 45 | size=sizetopad(len(x)), 46 | fill_value=-42, 47 | ) 48 | ) 49 | senders, receivers = edges[0], edges[1] 50 | 51 | edge_dist2 = D2.reshape(-1)[senders * len(D2) + receivers] 52 | mask_edge = senders != -42 53 | edge_dR = dR.reshape(-1, num_dim)[senders * len(D2) + receivers] 54 | edge_dR = jnp.expand_dims(edge_dR, 1) 55 | 56 | ## particle type could be added here later 57 | h = jax.nn.one_hot(jnp.zeros((len(x),)), 2) 58 | h = jax.vmap(nn.Dense(self.num_features, use_bias=False))(h) 59 | ### 60 | 61 | h_vec = jnp.zeros((len(x), self.num_vec_features, num_dim)) 62 | 63 | edge_embedder = nn.Dense(self.num_vec_features, use_bias=False) 64 | edge_embedder = jax.vmap(edge_embedder, in_axes=-1, out_axes=-1) 65 | h_edge_vec = jax.vmap(edge_embedder)(edge_dR) 66 | 67 | ## particle type should be addded here as well (src,target) 68 | h_edge = MLP(self.NN + (h.shape[1],))(edge_dist2.reshape(-1, 1)) 69 | 70 | ###### 71 | for _ in range(self.num_layers): 72 | dh, dh_vec, dh_edge, dh_edge_vec = Layer(self.NN, self.agg_norm)( 73 | t, 74 | h, 75 | h_vec, 76 | h_edge, 77 | h_edge_vec, 78 | edge_dist2, 79 | edge_dR, 80 | mask_edge, 81 | senders, 82 | receivers, 83 | ) 84 | h += dh 85 | h_vec += dh_vec 86 | h_edge += dh_edge 87 | h_edge_vec += dh_edge_vec 88 | return jnp.einsum("nf,n->", h, mask_node) / self.agg_norm 89 | 90 | 91 | class Layer(nn.Module): 92 | NN: Sequence[int] 93 | agg_norm: float 94 | 95 | @nn.compact 96 | def __call__( 97 | self, 98 | t, 99 | h, 100 | h_vec, 101 | h_edge, 102 | h_edge_vec, 103 | edge_dist2, 104 | edge_dR, 105 | mask_edge, 106 | senders, 107 | receivers, 108 | ): 109 | 110 | inp = jnp.concatenate( 111 | [ 112 | jnp.einsum("nfx,nfx->nf", h_vec[receivers], h_edge_vec), 113 | jnp.einsum("nfx,nfx->nf", h_vec[senders], h_edge_vec), 114 | jnp.einsum("nfx,nfx->nf", h_vec[senders], h_vec[receivers]), 115 | jnp.einsum("nfx,nfx->nf", h_vec[senders], h_vec[senders]), 116 | jnp.einsum("nfx,nfx->nf", h_vec[receivers], h_vec[receivers]), 117 | jnp.einsum("nfx,nfx->nf", h_edge_vec, h_edge_vec), 118 | h[senders], 119 | h[receivers], 120 | h_edge, 121 | ], 122 | -1, 123 | ) 124 | # print( 125 | # h.shape, 126 | # h_vec.shape, 127 | # h_edge.shape, 128 | # h_edge_vec.shape, 129 | # edge_dist2.shape, 130 | # inp.shape, 131 | # ) 132 | 133 | ## Message passing 134 | 135 | message_w_model = MessageWeight(self.NN + (h.shape[1] + h_vec.shape[1],)) 136 | message_w_model = partial(message_w_model, t) 137 | message_w_model = jax.vmap(message_w_model) 138 | 139 | mw, mw_vec = jnp.split(message_w_model(inp), [h.shape[1]], axis=-1) 140 | 141 | ## smooth cutoff 142 | cutoff = 0.5 * (jnp.cos(edge_dist2 * jnp.pi) + 1) 143 | mw = jnp.einsum("nf,n->nf", mw, cutoff) 144 | mw_vec = jnp.einsum("nf,n->nf", mw_vec, cutoff) 145 | 146 | m = jnp.einsum("efx,ef,e->efx", h_edge_vec, mw_vec, mask_edge) 147 | h_vec = jnp.zeros(h_vec.shape).at[receivers].add(m) / self.agg_norm 148 | 149 | # h_vec_scale = jax.vmap(MLP(self.NN + (h_vec.shape[1],)))(h) 150 | # h_vec = jnp.einsum("nfx,nf->nfx", h_vec, h_vec_scale) 151 | 152 | m = jnp.einsum("ef,ef,e->ef", mw, h[senders], mask_edge) 153 | h = jnp.zeros(h.shape).at[receivers].add(m) / self.agg_norm 154 | 155 | ######################## 156 | ### per_atom update 157 | h = jax.vmap(MLP(self.NN + (h.shape[1],)))(h) 158 | hvec_update = nn.Dense(h_vec.shape[1], use_bias=False) 159 | hvec_update = jax.vmap(hvec_update, in_axes=-1, out_axes=-1) 160 | h_vec = jax.vmap(hvec_update)(h_vec) 161 | ## per edge update 162 | edge_update = nn.Dense(h_edge_vec.shape[1], use_bias=False) 163 | edge_update = jax.vmap(edge_update, in_axes=-1, out_axes=-1) 164 | h_edge_vec = jnp.concatenate((h_edge_vec, h_vec[senders], h_vec[receivers]), 1) 165 | h_edge_vec = jax.vmap(edge_update)(h_edge_vec) 166 | 167 | h_edge = jnp.concatenate((h_edge, h[senders], h[receivers]), 1) 168 | h_edge = MLP(self.NN + (h.shape[1],))(h_edge) 169 | 170 | return h, h_vec, h_edge, h_edge_vec 171 | 172 | 173 | class MLP(nn.Module): 174 | features: Sequence[int] 175 | 176 | @nn.compact 177 | def __call__(self, x): 178 | for i, feat in enumerate(self.features): 179 | x = nn.Dense(feat)(x) 180 | if i != len(self.features) - 1: 181 | x = nn.swish(x) 182 | return x 183 | 184 | 185 | class MessageWeight(nn.Module): 186 | NN: Sequence[int] 187 | 188 | @nn.compact 189 | def __call__(self, t, x): 190 | x = jnp.concatenate((x.reshape(-1), t.reshape(-1))) 191 | return MLP(self.NN)(x) 192 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax 3 | 4 | 5 | class DataLoader: 6 | def __init__(self, x, n=None, batch_size=128): 7 | num_batches = len(x) // batch_size 8 | x = jax.random.permutation(jax.random.key(0), x, axis=0) 9 | if n is None: 10 | self.N = x.shape[1] 11 | n = jnp.full((len(x), 1), self.N) 12 | assert len(x) == len(n) 13 | self.x_all = jnp.stack(jnp.split(x[: num_batches * batch_size], num_batches)) 14 | n = jax.random.permutation(jax.random.key(0), n, axis=0) 15 | self.n_all = jnp.stack(jnp.split(n[: num_batches * batch_size], num_batches)) 16 | self.i = -1 17 | 18 | def next(self): 19 | self.i += 1 20 | x_batch = self.x_all[self.i % len(self.x_all)] 21 | n_batch = self.n_all[self.i % len(self.x_all)] 22 | return x_batch, n_batch 23 | -------------------------------------------------------------------------------- /src/diffusion_model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | from jax.scipy.stats.norm import pdf as normpdf 6 | from functools import partial 7 | from src.distance_on_torus import dist2_on_torus 8 | 9 | 10 | class diffusion_model: 11 | def __init__( 12 | self, 13 | num_integration_steps, 14 | sigma_min, 15 | sigma_max, 16 | target_system, 17 | ): 18 | self.sigma_min, self.sigma_max = sigma_min, sigma_max 19 | self.num_integration_steps = num_integration_steps 20 | self.target_system = target_system 21 | 22 | def init_params(self, key, maxN): 23 | self.params = self.E_model.init( 24 | rngs=key, 25 | t=jnp.ones( 26 | 1, 27 | ), 28 | x=jax.random.uniform(jax.random.PRNGKey(6), (maxN, self.num_features)), 29 | n=jnp.array([maxN]), 30 | ) 31 | 32 | ## dx = f dt + g dW 33 | ## f = 0; g**2 = beta 34 | ## sigma = sqrt(int_beta**2) 35 | def sigma(self, t): 36 | return (self.sigma_min ** (1 - t)) * (self.sigma_max**t) 37 | 38 | def beta(self, t): # d/dt sigma^2(t) = 2 sigma * sigma' 39 | # min * (max/min)**t 40 | return 2 * self.sigma(t) ** 2 * jnp.log(self.sigma_max / self.sigma_min) 41 | 42 | def energy(self, params, x, t, n): 43 | E_NN = self.E_model.apply(params, x=x, t=t, n=n) 44 | # return E_NN * sigma_min 45 | ## fixing boundary coniditons 46 | R2 = dist2_on_torus(x) 47 | mask = jnp.array(jnp.arange(len(x)) < n, dtype="int32") 48 | mask = mask.reshape(-1, 1) 49 | mask = jnp.einsum("ie,jt->ij", mask, mask) * (1 - jnp.eye(len(x))) 50 | R2 = R2 + (1 - mask) ## avoid small distances 51 | E_softLJ = (self.target_system.U_ij_soft(t[0], R2) * mask).sum() 52 | 53 | NN_w = (1 - t[0]) * t[0] 54 | LJ_w = (1 - t[0]) * self.sigma_min 55 | 56 | return NN_w * E_NN + LJ_w * E_softLJ 57 | 58 | def force(self, params, x, t, n): 59 | return -jax.grad(lambda x: self.energy(params, x=x, t=t, n=n))(x) 60 | 61 | def loss_fn(self, params, batch, key): 62 | x0, n = batch 63 | key1, key2 = jax.random.split(key, 2) 64 | 65 | t = jax.random.uniform(key1, (len(x0), 1)) 66 | 67 | z = jax.random.normal(key2, x0.shape) 68 | 69 | def loss_one(x0, n, z, t): 70 | xt = x0 + self.sigma(t[0]) * z # 71 | xt = xt % 1 72 | mask = (jnp.arange(len(x0)) < n[0]).reshape(-1, 1) 73 | ## score = force = - grad (logp) -z/sigma 74 | force_times_sigma = self.force(params, xt, t, n) 75 | z_pred = -force_times_sigma 76 | error2 = (z_pred - z) ** 2 77 | return (error2 * mask).mean() 78 | 79 | return jax.vmap(loss_one)(x0, n, z, t).mean() 80 | 81 | @partial(jax.jit, static_argnames=["self", "num_samples", "n"]) 82 | def sample(self, params, key, num_samples, n): 83 | x = jax.random.uniform(key, (num_samples, n, self.num_features)) 84 | n = jnp.array([n]) 85 | logZ = jnp.array(0.0) 86 | dt = 1 / self.num_integration_steps 87 | t = jnp.array([1.0]) 88 | init_value = (x, logZ, t, key) 89 | 90 | def body_fun(i, carry): 91 | x, logZ, t, key = carry 92 | z = jax.random.normal(key, x.shape) 93 | 94 | def rescaled_E(x, t): # absorb the division by sigma 95 | return self.energy(params, x=x, t=t, n=n) / self.sigma(t)[0] 96 | 97 | def force_dlogZ(x, t): 98 | dUdx, dUdt = jax.grad(rescaled_E, argnums=(0, 1))(x, t) 99 | return -dUdx, dUdt 100 | 101 | score, dlogZ = jax.vmap(force_dlogZ, in_axes=(0, None))(x, t) 102 | 103 | drift = self.beta(t[0]) * score 104 | x += drift * dt + z * (self.beta(t[0]) * dt) ** 0.5 105 | return x % 1, logZ + dt * dlogZ.mean(), t - dt, jax.random.split(key)[0] 106 | 107 | x, logZ, t, key = jax.lax.fori_loop( 108 | 0, self.num_integration_steps, body_fun, init_value 109 | ) 110 | 111 | return x, logZ 112 | -------------------------------------------------------------------------------- /src/distance_on_torus.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax 3 | 4 | 5 | def dR_on_torus(x): 6 | 7 | ## (num_dim, x,y,num_dim) 8 | def diff_fn(a, b): 9 | return a - b 10 | 11 | diff_fn = jax.vmap(jax.vmap(diff_fn, in_axes=(None, 0)), in_axes=(0, None)) 12 | dR1 = diff_fn(x, x) 13 | dR2 = diff_fn(x, x + 1) 14 | dR3 = diff_fn(1 + x, x) 15 | 16 | ### get absolute min 17 | dR = jnp.where(jnp.abs(dR1) < jnp.abs(dR2), dR1, dR2) 18 | dR = jnp.where(jnp.abs(dR) < jnp.abs(dR3), dR, dR3) 19 | return dR 20 | 21 | 22 | def dist2_on_torus(x): 23 | return (dR_on_torus(x) ** 2).sum(-1) 24 | -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import jax 3 | import jax.numpy as jnp 4 | import wandb 5 | from distance_on_torus import dist2_on_torus 6 | from functools import partial 7 | 8 | 9 | def g_hist_one(x, bins): 10 | num_dims = x.shape[-1] 11 | dR = dist2_on_torus(x) ** 0.5 12 | mask = 1 - jnp.eye(len(x)) 13 | dR = dR * mask - (1 - mask) 14 | R_hist = jnp.histogram(dR.reshape(-1), bins=bins)[0] 15 | g_hist = R_hist / (bins[1:] ** (num_dims - 1)) 16 | return g_hist 17 | 18 | 19 | @partial(jax.jit, static_argnames=["ddpm"]) 20 | def eval_one_batch(ddpm, params, bins_g, x_train, key): 21 | x_samples, logZ = ddpm.sample( 22 | params, key=key, num_samples=len(x_train), n=x_train.shape[1] 23 | ) 24 | 25 | g_hist = jax.vmap(g_hist_one, in_axes=(0, None)) 26 | g_train = g_hist(x_train, bins_g).sum(0) 27 | g_samples = g_hist(x_samples, bins_g).sum(0) 28 | 29 | return g_train, g_samples, logZ 30 | 31 | 32 | def eval_model(dataloader, ddpm, target, num_batches): 33 | logdict = {} 34 | g_train, g_samples = 0, 0 35 | logZ = 0 36 | 37 | bins_g = jnp.linspace(0, 4 * target.sigma, 300) 38 | key = jax.random.PRNGKey(5) 39 | 40 | for i in range(num_batches): 41 | key = jax.random.split(key)[0] 42 | x_train, _ = dataloader.next() 43 | g1, g2, logZ_ = eval_one_batch(ddpm, ddpm.params, bins_g, x_train, key) 44 | logZ += logZ_ 45 | g_train += g1 46 | g_samples += g2 47 | logZ /= num_batches 48 | n = dataloader.N 49 | fig = plt.figure(figsize=(5, 5)) 50 | plt.plot(bins_g[1:], g_train, label="MC data", linewidth=3) 51 | plt.plot(bins_g[1:], g_samples, label="diffusion samples", linewidth=1.5) 52 | plt.legend() 53 | plt.yticks([]) 54 | plt.ylim(bottom=0) 55 | plt.xlim(0, 3 * target.sigma) 56 | plt.xlabel("r$r/\sigma$", fontsize=18) 57 | plt.xticks(jnp.arange(4) * target.sigma, jnp.arange(4)) 58 | plt.ylabel(r"$g(r)$", fontsize=18) 59 | logdict = { 60 | **logdict, 61 | f"g(r)/{n}": wandb.Image(fig), 62 | f"logZ/{n}": logZ, 63 | } 64 | plt.close() 65 | 66 | wandb.log(logdict) 67 | -------------------------------------------------------------------------------- /toy_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\n", 10 | "import jax.numpy as jnp\n", 11 | "import jax\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "from typing import Sequence\n", 14 | "import flax.linen as nn\n", 15 | "import optax\n", 16 | "jax.config.update(\"jax_numpy_rank_promotion\", \"raise\")" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "### Data, two Gaussians, Z =3" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 4, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "data": { 33 | "text/plain": [ 34 | "[]" 35 | ] 36 | }, 37 | "execution_count": 4, 38 | "metadata": {}, 39 | "output_type": "execute_result" 40 | }, 41 | { 42 | "data": { 43 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkcAAAGdCAYAAAAYDtcjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABtkklEQVR4nO3deXhTVfoH8O9N2nSB7oVuQMsmUEDKUKjFfehYUVEUHVBGmA6DA0NdqOPC/FgUmamjDDIqY9URwQVFZkZcB4QqoFKBFlChLILQQiEtW/clbe75/XGbtGnT0rRJbpJ+P8+TJ8nNXd4Qmr495z3nSEIIASIiIiICAGjUDoCIiIjIlTA5IiIiImqGyRERERFRM0yOiIiIiJphckRERETUDJMjIiIiomaYHBERERE1w+SIiIiIqBkvtQPojIaGBuzbtw8RERHQaJjfERERuQNZllFcXIzRo0fDy8t1UxDXjawd+/btw7hx49QOg4iIiDph9+7dGDt2rNphtMktk6OIiAgAyj9uVFSUytEQERFRR5w9exbjxo0z/x53VW6ZHJm60qKiotCnTx+VoyEiIiJbuHpJjGtHR0RERORkTI6IiIiImmFyRERERNSMzcnRjh07MGnSJERHR0OSJGzcuLHVPocOHcLtt9+OoKAg9OjRA2PHjkVhYaH59draWsybNw9hYWHo2bMnpkyZguLi4i69ESIiIiJ7sDk5qqqqwqhRo7Bq1Sqrrx8/fhzXXHMNhg4dim3btuGHH37AokWL4Ovra95n/vz5+OSTT7BhwwZs374dZ86cwV133dX5d0FERERkJ5IQQnT6YEnChx9+iMmTJ5u3TZs2Dd7e3nj77betHlNWVoZevXph3bp1uPvuuwEAhw8fxrBhw5CTk4Orrrrqstc9ffo0+vbti1OnTnG0GhERkZtwl9/fdq05kmUZn332Ga644gqkpqaid+/eSEpKsuh6y8vLQ319PVJSUszbhg4din79+iEnJ8fqeevq6lBeXm6+VVRU2DNsIiIiIjO7JkclJSWorKzEs88+i5tvvhlffPEF7rzzTtx1113Yvn07AECv10On0yE4ONji2IiICOj1eqvnzczMRFBQkPkWHx9vz7CJiIiIzOzecgQAd9xxB+bPn4+EhAQ8+eSTuO2225CVldXp8y5YsABlZWXmW35+vr1CJiIiIhe3atUqxMXFwdfXF0lJSdi9e3e7+69cuRJDhgyBn58f+vbti/nz56O2trbD17NrchQeHg4vL69WLTvDhg0zj1aLjIyEwWBAaWmpxT7FxcWIjIy0el4fHx8EBgaabwEBAfYMm4iIiFzU+vXrkZGRgSVLlmDv3r0YNWoUUlNTUVJSYnX/devW4cknn8SSJUtw6NAhvPHGG1i/fj3+/Oc/d/iadk2OdDodxo4diyNHjlhsP3r0KGJjYwEAY8aMgbe3N7Kzs82vHzlyBIWFhUhOTrZnOEREROTmVqxYgdmzZyMtLQ3x8fHIysqCv78/Vq9ebXX/nTt34uqrr8Z9992HuLg43HTTTbj33nsv29rUnM1rq1VWVuLYsWPm5ydOnMD+/fsRGhqKfv364bHHHsPUqVNx3XXX4cYbb8SmTZvwySefYNu2bQCAoKAgzJo1CxkZGQgNDUVgYCAefPBBJCcnd2ikGhHZRgiBmnojAMDPWwtJklSOiIioYwwGA/Ly8rBgwQLzNo1Gg5SUlDYHcY0fPx7vvPMOdu/ejXHjxuHnn3/G559/jvvvv7/D17U5OcrNzcWNN95ofp6RkQEAmDlzJtasWYM777wTWVlZyMzMxEMPPYQhQ4bgP//5D6655hrzMS+88AI0Gg2mTJmCuro6pKam4p///KetoRBRe6ouQOx9C//NOYTVl0bhoIhDYmwINsxJZoJERKqqqKhAeXm5+bmPjw98fHxa7Xf+/HkYjUZERERYbI+IiMDhw4etnvu+++7D+fPncc0110AIgYaGBsyZM8embrUuzXOkFneZJ4FINSWHgHemAOVFAACjkJBRPxcfydcgf2kq/HU2/11ERNRlpt/fLS1ZsgRPPfVUq+1nzpxBTEwMdu7caVF68/jjj2P79u3YtWtXq2O2bduGadOmYdmyZUhKSsKxY8fw8MMPY/bs2Vi0aFGH4uQ3JJGnqasE1v8GKC+CHDIA284H4pfa/Vjh/QpOGKLUjo6ICPn5+YiJiTE/t9ZqBCgDvbRabaslxtobxLVo0SLcf//9+P3vfw8AGDlyJKqqqvDAAw/g//7v/6DRXL7cmgvPEnmab1YAF44BgTGonbkZs+r/hM+M46CVBDK9/wUIWe0IiaibCwgIsBiF3lZypNPpMGbMGItBXLIsIzs7u81BXNXV1a0SIK1WC0CpwewIJkdEnqTqPPBd45xiE58D/EMhoMHi+jSUC38M1xRAe/RzdWMkIrJBRkYGXn/9daxduxaHDh3C3LlzUVVVhbS0NADAjBkzLAq2J02ahFdeeQXvv/8+Tpw4gS1btmDRokWYNGmSOUm6HHarEXmSvDVAfRUQNQoYeivQOErtAoLwlvFXSPf6CF7fvQSMnKxqmEREHTV16lScO3cOixcvhl6vR0JCAjZt2mQu0i4sLLRoKVq4cCEkScLChQtRVFSEXr16YdKkSfjLX/7S4WuyIJvIU8gy8GICUFoATH4FSLgP1YYGxC/eDAAIRxlyfNLhLRmBP34H9B6mbrxE1O24y+9vdqsReYqCb5XEyCcIiJ/c6uXzCMI2OUF58v37Tg2NiMidMDki8hSHPlbuh00CdP5Wd/mP8VrlwYH/AO7XaExE5BRMjog8gSwDhz5RHsff3uZuX8kJEF5+QNkpoPiAk4IjInIvTI6IPIH+e6DiLKALAAbc0OZuddDB2P965cmR/zknNiIiN8PkiMgT/LxNue9/LeBlfb4QE+PgicqDo5sdGxMRkZtickTkCUzJUTutRiZy/8Z9zuwD6iocFBARkftickTk7uprgILG1ak7kByJoD5ASBwgjE3HERGRGZMjInd3ahdgrAN6RgLhV3TsmLjGUWsnv3ZcXEREborJEZG7O9GY4Ay4HpCkjh3D5IiIqE1Mjojc3ek9yn0/64swWtW/MTk6+z1QW2b/mIiI3BiTIyJ3JhuBor3K4z5jO35cYDQQHAsIuel4IiICwOSIyL2dOwIYKgDvHravldYnUbkvyrN/XEREbozJEZE7M3WpxfwC0GhtOzZmjHLPliMiIgtMjojcmSk56jvO9mPNyVEu11kjImqGyRGROzN1icUk2n5s5JWApAUqi4HyM/aNi4jIjTE5InJX9bVKzREARI+2/XidPxARrzxm3RERkRmTIyJ3de6QMsu1fxgQENm5c5i61s7ss19cRERujskRkbvS/6jcR4zo+OSPLUWOVO6LD9gnJiIiD8DkiMhd6RsTGlOC0xkRIy3PRURETI6I3FaxPZKjxpqjijNA1YWux0RE5AGYHBG5IyGaWnsiRnT+PD4BQEh/5XHxj12Pi4jIAzA5InJHpYVAXRmg8QbCr+jauSIbkyt2rRERAWByROSeTF1qvYYCXjoIIVBtaDDfhC2TOkawKJuIqDkvtQMgok4wjVSLHAEhBO7OykFewSXzy4mxIdgwJ7lj5zK1HDE5IiICwJYjIvdUfFC5jxiBmnqjRWIEALkFl1BTb+zYuUwL1p7/CZA7eAwRkQdjckTkjkwzY/cearH568dvtP1cwbGA1gdoqFVqmYiIujkmR0TuxlgPXPxZedyiGNtPp7X9fBotED5YeWxKuoiIujGbk6MdO3Zg0qRJiI6OhiRJ2LhxY5v7zpkzB5IkYeXKlRbbL168iOnTpyMwMBDBwcGYNWsWKisrbQ2FqHu6dBKQ6wFvfyCwj33O2WuIcn+eyRERkc3JUVVVFUaNGoVVq1a1u9+HH36I7777DtHR0a1emz59Og4ePIgtW7bg008/xY4dO/DAAw/YGgpR92Rq3QkfDGjs1Pgb3pgcnTtqn/MREbkxm0erTZw4ERMnTmx3n6KiIjz44IPYvHkzbr31VovXDh06hE2bNmHPnj1ITEwEALz00ku45ZZbsHz5cqvJFBE1c74xgTElNPbQq7F77txh+52TiMhN2b3mSJZl3H///XjssccwfPjwVq/n5OQgODjYnBgBQEpKCjQaDXbt2mX1nHV1dSgvLzffKioq7B02kfswJ0ddnPyxuV5Dm85tyxxJREQeyO7J0d/+9jd4eXnhoYcesvq6Xq9H7969LbZ5eXkhNDQUer3e6jGZmZkICgoy3+Lj4+0dNpH7MHWr9bJjchQ6EJC0QF05UGH955CIqLuwa3KUl5eHf/zjH1izZg0kSbLbeRcsWICysjLzLT8/327nJnIrQijzEQH27Vbz0gGhjWussWuNiFzMqlWrEBcXB19fXyQlJWH37t1t7nvDDTdAkqRWt5ZlPu2xa3L09ddfo6SkBP369YOXlxe8vLxQUFCARx99FHFxcQCAyMhIlJSUWBzX0NCAixcvIjIy0up5fXx8EBgYaL4FBATYM2wi91FxFjBUKK08oQPse25TsnWeRdlE5DrWr1+PjIwMLFmyBHv37sWoUaOQmpraKpcw+e9//4uzZ8+abwcOHIBWq8U999zT4WvaNTm6//778cMPP2D//v3mW3R0NB577DFs3rwZAJCcnIzS0lLk5eWZj/vyyy8hyzKSkpLsGQ6R5zF1qYX2V1p77Mk0nJ9zHRGRC1mxYgVmz56NtLQ0xMfHIysrC/7+/li9erXV/UNDQxEZGWm+bdmyBf7+/jYlRzaPVqusrMSxY8fMz0+cOIH9+/cjNDQU/fr1Q1hYmMX+3t7eiIyMxJAhyhfvsGHDcPPNN2P27NnIyspCfX090tPTMW3aNI5UI7ocR3SpmZgmgrxwrP39iIi6qKKiAuXl5ebnPj4+8PHxabWfwWBAXl4eFixYYN6m0WiQkpKCnJycDl3rjTfewLRp09CjR48Ox2dzy1Fubi5Gjx6N0aNHAwAyMjIwevRoLF68uMPnePfddzF06FBMmDABt9xyC6655hq89tprtoZC1P1cPK7chw20/7lDG89pmn2biMhB4uPjLQZaZWZmWt3v/PnzMBqNiIiIsNgeERHR5iCu5nbv3o0DBw7g97//vU3x2dxydMMNN0DYMNT35MmTrbaFhoZi3bp1tl6aiEyJi73rjZqfs+w0UF8LePva/xpERADy8/MRExNjfm6t1cge3njjDYwcORLjxo2z6TiurUbkThyZHPUIB3wCAQhliRIiIgcJCAiwGGjVVnIUHh4OrVaL4uJii+3FxcVtDuIyqaqqwvvvv49Zs2bZHB+TIyJ3IRuBSwXKY0ckR5LUdF5T9x0RkYp0Oh3GjBmD7Oxs8zZZlpGdnY3k5OR2j92wYQPq6urwm9/8xubrMjkichdlp5UFZ7U+QGDM5ffvDFMt0wUmR0TkGjIyMvD6669j7dq1OHToEObOnYuqqiqkpaUBAGbMmGFRsG3yxhtvYPLkya0GinWEzTVHRKQSU5daSJz9FpxtyVyUzeSIiFzD1KlTce7cOSxevBh6vR4JCQnYtGmTuUi7sLAQmhbfiUeOHME333yDL774olPXZHJE5C7M9Ub9HXcNU7caW46IyIWkp6cjPT3d6mvbtm1rtW3IkCE2DR5rid1qRO7CkcXYJqZutYsnHHcNIiIXx+SIyF2YEhZHJkembrXy00B9jeOuQ0TkwpgcEbmLS6bkyIHdav6hgG+Q8pitR0TUTTE5InIHsuycliNJYlE2EXV7TI6I3EGlHmioASQtENTXsdcK4zIiRNS9MTkicgemRCW4H6D1duy1gmOVe9OEk0RE3QyTIyJ34IyRaiYhcco9lxAhom6KyRGRO3BqctTYclTKliMi6p6YHBG5A2cmR6ZutdJCpRCciKibYXJE5A5M9T+mVp1GQghUG4z2vVZgDKDxAowGoOKsfc9NROQGuHwIkTsoLVTug5uSIyEE7s7KQV7BJfteS+sFBPVRao5KC4AgBy1yS0TkothyROTq6iqAmovK4+CmYfw19UaLxCgxNgR+3lr7XJNF2UTUjbHliMjVlZ5S7n2Dm2avbiF3YQrCeuhQU2+nLjYO5yeibowtR0Suztyl1q/NXfx1WkiSZL9rcsQaEXVjTI6IXF0HkiO7Y7caEXVjTI6IXJ2p9caUsDhDcOO12K1GRN0QkyMiV9fJliMh0Plh/qZErOIMUF/buXMQEbkpFmQTubpOJkd3Z+Xg0Nnyzl3TPxTQ9QQMlUDZKSB8cOfOQ0TkhthyROTqOpkcNU+M4qMCre6jTCLZgGpDA4QQTS9IEkesEVG3xZYjIlfWfI6joL7t79uG3IUp8PPWYviSzRbbW04imRgbgg1zkptGvYXEASUHgUsnOhs9EZFbYssRkSszzXHkFwL4Wm/9uRxlmH/r7S0nkcwtuGQ5TxKH8xNRN8XkiMiVqTGM38TcrXbS+dcmIlIRkyMiV9ZGcuSQBWdbMl3T1HpFRNRNsOaIyJWZurScseBsS6Z13MpOO/Y6REQuhi1HRK7MnBw1tRw5dMHZ5oL6KPfV5wFDtf3PT0TkothyROTKLlNzZFpw1q7rqpn4BgO6AMBQobQe9brC/tcgInJBNrcc7dixA5MmTUJ0dDQkScLGjRvNr9XX1+OJJ57AyJEj0aNHD0RHR2PGjBk4c+aMxTkuXryI6dOnIzAwEMHBwZg1axYqKyu7/GaIPM5lkiO7LzjbnCQ161ordMw1iIhckM3JUVVVFUaNGoVVq1a1eq26uhp79+7FokWLsHfvXvz3v//FkSNHcPvtt1vsN336dBw8eBBbtmzBp59+ih07duCBBx7o/Lsg8kS15UBNY/eZGqPVgKa5lVh3RETdiM3dahMnTsTEiROtvhYUFIQtW7ZYbHv55Zcxbtw4FBYWol+/fjh06BA2bdqEPXv2IDExEQDw0ksv4ZZbbsHy5csRHR3dibdB5IHKTHMchQI+AerEYKo74og1IupGHF6QXVZWBkmSEBwcDADIyclBcHCwOTECgJSUFGg0GuzatcvR4RC5DzXnODIxd6sxOSKi7sOhBdm1tbV44okncO+99yIwUJndV6/Xo3fv3pZBeHkhNDQUer3e6nnq6upQV1dnfl5RUeG4oIlchTk56tyyIXbBbjUi6oYc1nJUX1+PX//61xBC4JVXXunSuTIzMxEUFGS+xcfH2ylKIhdmSkhsWFPNz1uLxNgQ8/O2hvl3eBJJ07XZrUZEKlq1ahXi4uLg6+uLpKQk7N69u939S0tLMW/ePERFRcHHxwdXXHEFPv/88w5fzyEtR6bEqKCgAF9++aW51QgAIiMjUVJSYrF/Q0MDLl68iMjISKvnW7BgATIyMszPi4qKmCCR5zMnR306fIgkSdgwJ9m8Rpqfd+vRbEKg45NImlqtyosAYwOg5ewfRORc69evR0ZGBrKyspCUlISVK1ciNTUVR44cadUTBQAGgwG/+tWv0Lt3b/z73/9GTEwMCgoKzOU9HWH3liNTYvTTTz9h69atCAsLs3g9OTkZpaWlyMvLM2/78ssvIcsykpKSrJ7Tx8cHgYGB5ltAgErFqUTOVF6k3AfGADC19jRctsVHkiT467zgr/OyOsy/5SSS8VHtLGjbMxLQeAPCCFSctf09EBF10YoVKzB79mykpaUhPj4eWVlZ8Pf3x+rVq63uv3r1aly8eBEbN27E1Vdfjbi4OFx//fUYNWpUh69p85+BlZWVOHbsmPn5iRMnsH//foSGhiIqKgp333039u7di08//RRGo9FcRxQaGgqdTodhw4bh5ptvxuzZs5GVlYX6+nqkp6dj2rRpHKlG1FyzbjVHLRmSuzAFft5aDF+y2foOGg0QGK3M1F12Wt36JyLyGBUVFSgvLzc/9/HxgY+PT6v9DAYD8vLysGDBAvM2jUaDlJQU5OTkWD33xx9/jOTkZMybNw8fffQRevXqhfvuuw9PPPEEtNqOrSZgc8tRbm4uRo8ejdGjRwMAMjIyMHr0aCxevBhFRUX4+OOPcfr0aSQkJCAqKsp827lzp/kc7777LoYOHYoJEybglltuwTXXXIPXXnvN1lCIPJexoamlJiimVWsPYJ9lQ5RJJC+zk2m0HEesEZGdxMfHW9QSZ2ZmWt3v/PnzMBqNiIiIsNgeERHR5iCun3/+Gf/+979hNBrx+eefY9GiRfj73/+OZcuWdTg+m1uObrjhBggh2ny9vddMQkNDsW7dOlsvTdR9VOoBIStdWj16Aw2y+aXchSnw12mt1hM5hLkom7NkE5F95OfnIyYmxvzcWqtRZ8myjN69e+O1116DVqvFmDFjUFRUhOeffx5Llizp0DlYXUnkikxdaoHRStcWmpIjf50W/jon/uiaCsI5nJ+I7CQgIMBisFZbwsPDodVqUVxcbLG9uLi4zUFcUVFR8Pb2tuhCGzZsGPR6PQwGA3Q63WWv6/BJIImoEzoxUs1hOBEkEalEp9NhzJgxyM7ONm+TZRnZ2dlITk62eszVV1+NY8eOQZab/qg8evQooqKiOpQYAUyOiFxTi5Fq9lTTkfmNmuNcR0SkooyMDLz++utYu3YtDh06hLlz56KqqgppaWkAgBkzZlgUbM+dOxcXL17Eww8/jKNHj+Kzzz7DX//6V8ybN6/D12S3GpErcmDL0bXPfWXbAc0LsoXA5Su4iYjsZ+rUqTh37hwWL14MvV6PhIQEbNq0yVykXVhYCI2mqa2nb9++2Lx5M+bPn48rr7wSMTExePjhh/HEE090+JpMjohcUVljy1GQfVqOTDNn5zYb8WYa7WaaMLJNptar+mqg5hLgH2qXmIiIOio9PR3p6elWX9u2bVurbcnJyfjuu+86fT0mR0SuyFTfE2iflqOWM2cD1mfPtsrbVxkxV1WijFhjckREHo41R0SuyFRzZMduteYzZ7c1e3abWJRNRN0IkyMiV1NfA1RfUB7bqVutyzicn4i6ESZHRK7GVG/k3QPwDVY1FDOOWCOiboTJEZGrKW82Uk2SGhectXH4vb2ZR6xxlmwi8nwsyCZyNc1GqjlqwVmbmbrV2HJERN0AW46IXE2zOY5aLjhrj8Vm21NtMFpfH9HUrcaaIyLqBthyRORqTN1qLYbx5y5MQVgPnUMXm01cthWJsSHYMCfZ8jqmlqPq80rBuLefw2IgIlIbW46IXI255chypJq/roPzEtnINEGkSW7BpdYTQ/qFKAXiAFB+xu4xEBG5EiZHRK6mzP5zHLXHNEFk7sKU9nZqNpyfdUdE5NmYHBG5EiGaLTrrnOQIME0Q2VTLVG0wotrQYFl/xLmOiKibYM0RkSupLQUMlcrjwGjVwkhctlW5b15/xOSIiLoJthwRuRJTl5p/GKDzd+qlW9YeAUr9kakVydCzMVljtxoReTi2HBG5EnOXmvOXDWm+OG21wWhuPbo7KweHzpbjLs0lrNABouw0HDdejohIfUyOiFyJqVXGScXYLZkWp23u0NlyAMAZhAMAROkpJkdE5NHYrUbkSpw8Us0WZ0QYAEAqL1IKx4mIPBSTIyJXYip2VqFb7XL0IhQAIDXUAtUXVI6GiMhxmBwRuZLyppYjl1hwthkDvFEigpUnLMomIg/GmiMiV9LYciQCY1xjwdkWzogw9JZKlTijR6sdDhGRQ7DliMhVyLJ5aY5a/yinLjjbUUWNdUec64iIPBlbjohcRVUJINcDkgYiIBLAQQDOWXC2o84IZcQakyMi8mRsOSJyFaaRagFRgKbp7xZHLTjbGWfMLUesOSIiz8XkiMhVmBIOFxypZtLUclSkbiBERA7E5IjIVZS77hxHJqw5IqLugMkRkaswtcaouODs5Zhbjir1QEOdusEQETkIkyMiV1He2Brjwi1HFxEA4eWrPGkcWUdE5GmYHBG5ClOy4cI1R4AEYYqPXWtE5KFsTo527NiBSZMmITo6GpIkYePGjRavCyGwePFiREVFwc/PDykpKfjpp58s9rl48SKmT5+OwMBABAcHY9asWaisrOzSGyFye+ZuNVdOjgAR2NiyxeSIiDyUzclRVVUVRo0ahVWrVll9/bnnnsOLL76IrKws7Nq1Cz169EBqaipqa2vN+0yfPh0HDx7Eli1b8Omnn2LHjh144IEHOv8uiNydsUGp4wGAIBdPjoKYHBGRZ7N5EsiJEydi4sSJVl8TQmDlypVYuHAh7rjjDgDAW2+9hYiICGzcuBHTpk3DoUOHsGnTJuzZsweJiYkAgJdeegm33HILli9fjuho1y1GJXKYSj0gZEDjBdGjF6qrG9SOqE0iwNStxrmOiMgz2bXm6MSJE9Dr9UhJSTFvCwoKQlJSEnJycgAAOTk5CA4ONidGAJCSkgKNRoNdu3ZZPW9dXR3Ky8vNt4qKCnuGTaS+xnojERCFu1/dhcRlW1UOqG0yW46IyMPZNTnS65VugYiICIvtERER5tf0ej169+5t8bqXlxdCQ0PN+7SUmZmJoKAg8y0+Pt6eYROprzHRkANiXHJNteZYkE1EzrZq1SrExcXB19cXSUlJ2L17d5v7rlmzBpIkWdx8fX1tup5bjFZbsGABysrKzLf8/Hy1QyKyL1PLUbM5jnIXpmDDnGSXWTrExKIgWwh1gyEij7d+/XpkZGRgyZIl2Lt3L0aNGoXU1FSUlJS0eUxgYCDOnj1rvhUUFNh0TbsmR5GRkQCA4uJii+3FxcXm1yIjI1u9oYaGBly8eNG8T0s+Pj4IDAw03wICAuwZNpH6GmfHFs1GqrnSmmrNmWOsrwJqLrW/MxFRF61YsQKzZ89GWloa4uPjkZWVBX9/f6xevbrNYyRJQmRkpPnWskfrcuyaHPXv3x+RkZHIzs42bysvL8euXbuQnJwMAEhOTkZpaSny8vLM+3z55ZeQZRlJSUn2DIfIfZiSo4AolQPpAG8/wN+0xhq71ojIdhUVFRa1xHV11mfcNxgMyMvLs6hl1mg0SElJMdcyW1NZWYnY2Fj07dsXd9xxBw4ePGhTfDYnR5WVldi/fz/2798PQCnC3r9/PwoLCyFJEh555BEsW7YMH3/8MX788UfMmDED0dHRmDx5MgBg2LBhuPnmmzF79mzs3r0b3377LdLT0zFt2jSOVKPuq6x1y5FLY1E2EXVBfHy8RS1xZmam1f3Onz8Po9HYbi1zS0OGDMHq1avx0Ucf4Z133oEsyxg/fjxOn+7495XNQ/lzc3Nx4403mp9nZGQAAGbOnIk1a9bg8ccfR1VVFR544AGUlpbimmuuwaZNmyyKod59912kp6djwoQJ0Gg0mDJlCl588UVbQyHyGKL8DCQANb6RAC6oHc7lBfUBzu5nckREnZKfn4+YmKY/Bn18fOx27uTkZHNvFQCMHz8ew4YNw6uvvopnnnmmQ+ewOTm64YYbINopwpQkCUuXLsXSpUvb3Cc0NBTr1q2z9dJEHkkY6yFX6KEF8Ks3jgMIVjmiDgjqq9yXMzkiItsFBAQgMDDwsvuFh4dDq9W2W8t8Od7e3hg9ejSOHTvW4fjcYrQakServXQGWsgwCC3OQ/mycMUh/BbYrUZETqDT6TBmzBiLWmZZlpGdnW3ROtQeo9GIH3/8EVFRHa/ptLnliIjsS2osxi4Wodiz8Cb467Tw83bNkWpmTI6IyEkyMjIwc+ZMJCYmYty4cVi5ciWqqqqQlpYGAJgxYwZiYmLMdUtLly7FVVddhUGDBqG0tBTPP/88CgoK8Pvf/77D12RyRKQyqXGOozMIw0idFv46N/ixNHWrMTkiIgebOnUqzp07h8WLF0Ov1yMhIQGbNm0yF2kXFhZCo2nqCLt06RJmz54NvV6PkJAQjBkzBjt37rRpAmk3+BYm8mxShdJypBehGKlyLB1majmqOAsY6wGtt7rxEJFHS09PR3p6utXXtm3bZvH8hRdewAsvvNCl67HmiEhlpm61syJM5Uhs0KMXoNUpi+VWnFU7GiIiu2JyRKQyU7faWRGqciQ20GgArrFGRB6KyRGRykzdam6VHAEsyiYij8XkiEhlTS1HbtStBjRLjk6pGwcRkZ0xOSJSk7EeUqUyuZn7JkdsOSIiz8LkiEhNFXpIEDAILS4gQO1obMPkiIg8FJMjIjWVNw3jF+7248jkiIg8lJt9GxN5FlHWWIwNN+tSAzgRJBF5LCZHRCoRQuCtzd8CcMORakDTUP66cqC2TN1YiIjsiMkRkUpq6o0wljZNAOnyi8225NMT8AtRHrP1iIg8CJMjIhVFSRcAAPffNB4b5iS79mKz1pjrjorUjYOIyI6YHBGpKEq6CADwDu3jfokR0KzuiHMdEZHnYHJEpCJTy5EcEK1yJJ3EEWtE5IGYHBGpxViPXlAKmYWpuNndMDkiIg/E5IhIJVKlHhpJoE54Af5uOJQfYHJERB6JyRGRSkxrqulFKCC56Y8i5zoiIg/kpt/IRO5PMs2ODTec48jE1B1YXgTIRnVjISKyEyZHRCqRKpTk6Iy7LTjbXEAkIGkBYQQq9GpHQ0RkF0yOiFRi0a3mrjTaptYjdq0RkYdgckSkElO3mlu3HAHNirI51xEReQYmR0QqkSqUlqOzHpMcseWIiDwDkyMilWjKTcmRG3erAUyOiMjjMDkiUkODAagqAcCWIyIiV8PkiEgNFWchQaBOeOMiAtSOpms41xEReRgmR0RqsOhSc8MFZ5tjQTYReRgmR0Rq8IQJIE1MyVFtKVBXqWooRET2wOSISA2eMowfAHwDAZ8g5XHj+yIicmdMjojUUNbYcuTuI9VM2LVGRB7E7smR0WjEokWL0L9/f/j5+WHgwIF45plnIIQw7yOEwOLFixEVFQU/Pz+kpKTgp59+sncoRK7LxVuO/Ly1SIwNMT+Pjwps/wCOWCMiD+Jl7xP+7W9/wyuvvIK1a9di+PDhyM3NRVpaGoKCgvDQQw8BAJ577jm8+OKLWLt2Lfr3749FixYhNTUV+fn58PX1tXdIRK6n3LVbjiRJwoY5yaipVxaTFQIYvmRz2wcwOSIiD2L35Gjnzp244447cOuttwIA4uLi8N5772H37t0AlFajlStXYuHChbjjjjsAAG+99RYiIiKwceNGTJs2zd4hEbmectefHVuSJPjrlK+IakND+zsHdd/11YQQ5iQSUFrdJMnNRyASdXN271YbP348srOzcfToUQDA999/j2+++QYTJ04EAJw4cQJ6vR4pKSnmY4KCgpCUlIScnByr56yrq0N5ebn5VlFRYe+wiZynwQBUKhNAnnHRliObddO5joQQuDsrB/GLN5tv92TlWJQREFHXrVq1CnFxcfD19UVSUpK5weVy3n//fUiShMmTJ9t0PbsnR08++SSmTZuGoUOHwtvbG6NHj8YjjzyC6dOnAwD0ej0AICIiwuK4iIgI82stZWZmIigoyHyLj4+3d9hEzlNxFoCA8PLFJXefANKkmxZk19QbkVdwyWJbbsEli5YkIuqa9evXIyMjA0uWLMHevXsxatQopKamoqSkpN3jTp48iT/96U+49tprbb6m3ZOjDz74AO+++y7WrVuHvXv3Yu3atVi+fDnWrl3b6XMuWLAAZWVl5lt+fr4dIyZyssZ6IxEQDbefANLEnBwVAbKsbiwq+frxG9UOgcgjrVixArNnz0ZaWhri4+ORlZUFf39/rF69us1jjEYjpk+fjqeffhoDBgyw+Zp2T44ee+wxc+vRyJEjcf/992P+/PnIzMwEAERGRgIAiouLLY4rLi42v9aSj48PAgMDzbeAAA/5a5u6p7LmyZGHCIgCJA0g15vXjOsukqRDWOj1Nnp9mYHfajchEJwIk+hyKioqLMpl6urqrO5nMBiQl5dnUYqj0WiQkpLSZikOACxduhS9e/fGrFmzOhWf3ZOj6upqaDSWp9VqtZAb/5rs378/IiMjkZ2dbX69vLwcu3btQnJysr3DIXI95UpdjjAVMXsCrbeSIAHdp+6othy6f8/Aep9n8Huv/8H3wHt4yvst7PCZD+3R/6kdHZFLi4+PtyiXMTWgtHT+/HkYjUabSnG++eYbvPHGG3j99dc7HZ/dR6tNmjQJf/nLX9CvXz8MHz4c+/btw4oVK/C73/0OgDIC5pFHHsGyZcswePBg81D+6OhomwumiNxSY/IgAvuoHIidBfVRugzLTgF9EtWOxrEMVcBbd8DrzF7UCy3+a7wWt103Fqe+eR9DNacg/n0/gDeAEVPUjpTIJeXn5yMmpukPRB8fH7uct6KiAvfffz9ef/11hIeHd/o8dk+OXnrpJSxatAh//OMfUVJSgujoaPzhD3/A4sWLzfs8/vjjqKqqwgMPPIDS0lJcc8012LRpE+c4ou6hMTmS3TQ5qjYYrQ9XD+oDnNrl+S1HQgCfPAyc2QvhF4YppY/gBzEQE8an4LYvR+IvXqsx1WsbxIdzUBsQC99+Yzi0n6iFgIAABAZeZnJZAOHh4dBqtR0uxTl+/DhOnjyJSZMmmbeZeq68vLxw5MgRDBw48LLXtXu3WkBAAFauXImCggLU1NTg+PHjWLZsGXQ6nXkfSZKwdOlS6PV61NbWYuvWrbjiiivsHQqRazK1HAW5Z3KUuGyr9eHq3WUiyH3vAD9uADReqJuyFj+Ipi/aBnhhQcPvscU4BpLRgDNv/Ab3vbKdQ/uJOkmn02HMmDEWpTiyLCM7O9tqKc7QoUPx448/Yv/+/ebb7bffjhtvvBH79+9H3759O3Rdrq1G5ERCCIjG4e5ygPvUHLVcTsTqcPXuMNdR9UVgi9IKLm5ciMrIceaXTP9GMjR4rP4BFItgDNScRULRexzaT9QFGRkZeP3117F27VocOnQIc+fORVVVFdLS0gAAM2bMwIIFCwAAvr6+GDFihMUtODgYAQEBGDFihEVDTXvs3q1GRNYJIXD/P7findoyAMBvNrjPCvam5UQuVBmQuGyr9Z26Q8vRtkyg5iJEr2GY+uMY7P6s6d9CkmCx5IpxXz3wv3T80esjoHIxEOo+yTCRK5k6dSrOnTuHxYsXQ6/XIyEhAZs2bTIXaRcWFrYaCNZVTI6InKSm3oji0z8DPsAl0RN79fVqh2QTZTkRbds7eHpyVFYE5K0BANSl/AW732yaqT8xNsRch2VeciVhKvZ/tgIJmp9Rn/MP4Nbn1IiayCOkp6cjPT3d6mvbtm1r99g1a9bYfD12qxE5UYx0HgBwxoXXVOs0U3JUfR6or1E3FkfY+RJgNACxV0Puf715c+7CFGyYk9y66FrS4O8NvwYAeO1/W+mSIyK3wOSIyImipQsAgDOi80NMXZZvMKDrqTwuc58uww6pumBuNcJ1f7J4yV/X9kKzX8sjcVCOhVRfDez5l4ODJCJ7YXJE5ETRjS1HRZ7YciRJQGBjXY2nrbH2/TqgoQaIvBIYYMsyIRJebWgcUrz7dcDoXl2pRN0VkyMiJ2pqOfLA5AjwzLojIZpajRJ/pySBNvhcHgfRo7eyrMoRzpxN5A6YHBE5UVPNkQd2qwGemRyd/Bq4cEzpMhx5t82HN8ALDVfeqzzZ2/kFuInIeZgcETlRNDy95cgD5zrKfVO5H3kP4NO5Ra8bEu5XHhzLBko9rMuRyAMxOSJyFtmISEkZsVTk8S1HHpIA1JYBhz9VHo/5badPI0L6A3HXAhDAwQ/tEhoROQ6TIyInkaqK4S0ZUS+0OIdgtcNxDE/rVjv8mTJ8v9dQIGpU1841/E7l/sB/uh4XETkUkyMiJ5Eah7cXQ1liwt1VG4yoNjRYrhvWPDnyhPXEDvxXuR9+l82F2K3E3wFIWuDsfuDC8S6HRkSO4/7f0ERuQipXWlM8pUstcdlWxC/ebLkIbWA0AAkw1gFV51WNr8uqLwI/f6U8HnFX18/XIxzof53ymF1rRC6NyRGRk5iSI3cuxm65AC3QYhFaLx+gp7LeEcoKnRxd1wkhUG1oUFrE8j8G5AZlbqPwwfa5gCnJYnJE5NK4thqRk0hl7p8cmRagrak3otpgtL4IbXA/oFIPlBYCMWOcH2QnCSFwd1YO8gouAQA+ClyDUYB9Wo1Mht4GfPIIUHwAuHgCCO1vv3MTkd2w5YjISZpajty7W820uGqbi9CGxCn3l046KyS7qKk3mhOjAFQjvu4H5YVht9vvIv6hQOx45fHRzfY7LxHZFZMjIifRlCsF2R65dEhzbpocNXet5gd4S0bIoYOAsIH2PfkVNyv3RzlbNpGrYnJE5CTNW47iowLN2xNjQ+Dn3UYrjDvygORognYvAMA4+Gb7n3zIROX+5LdAbbn9z09EXcaaIyJnqKuEVKN02ZwRYdg1J9k8MtzPu+1V3d2SmydHGsi4UbMfAGAcdBO87X2BsIFA2CBlSZLjXwLDJ9v7CkTURWw5InKGxi61cuGPSvhDktBYt+PlWYkR0FRkXHoKMDaoG0sn/EI6ilCpEqWiB+S+SY65iLlrbZNjzk9EXcLkiMgZGpfT8Ph6IwDoGQlofQBhdMtlRCZo9wEAtsmjAI2DGtevSFXuj2V7xmSZRB6GyRGRM5R5xki1DtFogJBY5bEbdq1N0Cj1RtnGX0AImOc9ajUbeFf0TQK8/YGqEqAk3z7nJCK7Yc0RkTM0JkdnRajKgThJSBxw/qjbJUcRuIgrNEUwCgnb5Stxd1YODp1tKppOjA3BhjnJXb+Ql48ypP/YVuDnbUDE8K6fk4jshi1HRA4mhEDDxZMAgFOit7rBOIubFWUrM2MbkaxRWnEOiP4oR0+LxAhoMRt4Vw24Ubk//pV9zkdEdsPkiMiBTLMu7/1BmVDwtOilckRO4kbJkekzSly21Zwc5cjxFvt8/fiN9r/wgBuU+4JvgQaD/c9PRJ3G5IjIgUyzLveRzgEATotwz5vXyJqQxhFrbpAcNZ8Z+6rG5Oi7FsmRX1uzgXdFxHCgRy+gvho4vcf+5yeiTmNyRORg3mhAJJRfvu89/mtsmJPsecP3W3KjliOTaJxHrKYEQtJijzzE8ReUpKbWo5/ZtUbkSpgcETlYlHQBGklAePnBPzjS8xMjoGm0Wm0p0Dj5paszdanJUQmogp9zLmqqO/p5m3OuR0QdwuSIyEFMRb6mLjUR1BfoDokRAOh6AD0ai8/dpPXI1KUmx17T7n7VBqP9piYytRwV5QG1ZXY6KRF1FZMjIgdoXuTb15QcBfdTOSonc7OutWStkhwZY69td7/EZVtxT1aOfS4aFAOEDgCEDJzabZ9zElGXMTkicoDmRb6mliONqaupu3Cj5KiPVII+0nkIjRfkPmNbve7nrUVibIj5ef5ZOy4YGzteuT/5jf3OSURdwuSIyMFmj1RGOkndreUo1H1GrJnrjaJ/Aeh6tnpdkoANc5KRuzDF/hc3deMV7LT/uYmoUxySHBUVFeE3v/kNwsLC4Ofnh5EjRyI3N9f8uhACixcvRlRUFPz8/JCSkoKffvrJEaEQOZ2p1sjEu0KZHRtsOXJZ5nqjfm3XG0mSBH9HDOk3tRyd2QsYqu1/fiKymd2To0uXLuHqq6+Gt7c3/ve//yE/Px9///vfERLS1CT93HPP4cUXX0RWVhZ27dqFHj16IDU1FbW1tfYOh8ipmtcamUhlhcqD7tZy5C7JkRDmliPjZYqxHSK4HxDYB5AbON8RkYuwe3L0t7/9DX379sWbb76JcePGoX///rjpppswcOBAAMovj5UrV2LhwoW44447cOWVV+Ktt97CmTNnsHHjRnuHQ+RUzWuNAOCqfj0gVeiVJ8HdreWosVut9JRLzwAtXTqBaOkiDEJrtd7I8QFITa1HBd86//pEbmDVqlWIi4uDr68vkpKSsHt32wMY/vvf/yIxMRHBwcHo0aMHEhIS8Pbbb9t0PbsnRx9//DESExNxzz33oHfv3hg9ejRef/118+snTpyAXq9HSkpT331QUBCSkpKQk2N9BEhdXR3Ky8vNt4qKCnuHTWR3uQtT8N7UvpAglBXY/cPUDsm5AiIB7x6AMAKlBWpH0yZNgVIIvU8MVj4nNZiTI9YdEbW0fv16ZGRkYMmSJdi7dy9GjRqF1NRUlJSUWN0/NDQU//d//4ecnBz88MMPSEtLQ1paGjZv3tzha9o9Ofr555/xyiuvYPDgwdi8eTPmzp2Lhx56CGvXrgUA6PXKX9EREREWx0VERJhfaykzMxNBQUHmW3x8vNX9iFyJv04LyZQUBPfrPnMcmUgSEKa0GOPCMXVjaYe2UGmtablkiFPFNXbnnd4DNNSpFweRC1qxYgVmz56NtLQ0xMfHIysrC/7+/li9erXV/W+44QbceeedGDZsGAYOHIiHH34YV155Jb75puMjQu2eHMmyjF/84hf461//itGjR+OBBx7A7NmzkZWV1elzLliwAGVlZeZbfn6+HSMmcqDSblpvZBI2SLl31eRICGgKvgbQerFZpwobpKyz1lALFO1VLw4iJ6moqLDoEaqrs/5HgcFgQF5enkVvk0ajQUpKSpu9Tc0JIZCdnY0jR47guuuu63B8dk+OoqKiWrXsDBs2DIWFyi+JyMhIAEBxcbHFPsXFxebXWvLx8UFgYKD5FhAQYO+wiRyDyZFy76rJ0YVj0FQWo054Y588yGGXUWbVVqbVVkYzNqDa0GDexroj6m7i4+MteoQyMzOt7nf+/HkYjUabepsAoKysDD179oROp8Ott96Kl156Cb/61a86HJ9Xh/fsoKuvvhpHjhyx2Hb06FHExirFqP3790dkZCSys7ORkJAAACgvL8euXbswd+5ce4dDpC4mR8r9hePqxtGWEzsAAHvlwaiDzmGXSVy2FYmxIfjgD8m459Ucc9F+YmxI00LEsVcD+R81Jkd/clgsRK4gPz8fMTEx5uc+Pj52PX9AQAD279+PyspKZGdnIyMjAwMGDMANN9zQoePtnhzNnz8f48ePx1//+lf8+te/xu7du/Haa6/htddeA6DMFfLII49g2bJlGDx4MPr3749FixYhOjoakydPtnc4ROrqJslRtcEIP29t60V1Xb3lqHFWakd0qZlm1c5tTIRyCy7hYrXBYjRjbsEl1NQb4a/zAvolKxtP5wKyEdA4YE4lIhcREBCAwMDAy+4XHh4OrVZrU28ToHS9DRqkfP8kJCTg0KFDyMzM7HByZPdutbFjx+LDDz/Ee++9hxEjRuCZZ57BypUrMX36dPM+jz/+OB588EE88MADGDt2LCorK7Fp0yb4+vraOxwidZmTI88exm9ab0y0XJE1bIByX3EWqKt0fmDtEcKhyZEkSbbNqt07Xpmdu64cOHfY7vEQuSOdTocxY8YgOzvbvE2WZWRnZyM5ObnD55Fluc26Jmvs3nIEALfddhtuu+22Nl+XJAlLly7F0qVLHXF5ItfQUKckBYBHJkfWWkbMrSDmnUIA/3Cg+jxw8TgQNUqlaK04dwSoKoHw8sX3tQMdcgmbZtXWegExY4AT24FTu4CI4Q6JicjdZGRkYObMmUhMTMS4ceOwcuVKVFVVIS0tDQAwY8YMxMTEmOuWMjMzkZiYiIEDB6Kurg6ff/453n77bbzyyisdvqZDkiMiAqSyU8oD7x6Af6i6wTiAqWXkQpXBPCO41e61sEFKcnThmGslRyeVUWpyn3EwHPZWOZhGfZMak6PdQOLv1I6GyCVMnToV586dw+LFi6HX65GQkIBNmzaZi7QLCwuh0TR1hFVVVeGPf/wjTp8+DT8/PwwdOhTvvPMOpk6d2uFrMjkichCLZUM8dI6jli0jpsJjc5ExoCRHp75zvaLsxuTIGHsN4Cq9WH2TlPtTu9SNg8jFpKenIz093epr27Zts3i+bNkyLFu2rEvXc8jCs0QEaC6dUB6YVqf3UKbuNRNT95qZeSJIF0qOZNlcbyTHXqtyMM30SVTuL/4MVJ5TNxaibozJEZGDSKYFV0M8Ozm6bOGxK45YO3cIqL4AePtDjkpQO5omfsFAr2HK49Ntrx1FRI7F5IjIQaTSxpYj0+r0HqzdwmNzcvSTMkLMFTS2GqHfVYDWcfMbdUo/dq0RqY3JEZGDaC41rqvm4d1qlxXaH4AE1JYB1RfVjkbROPkj4lyoS83EXHfEliMitTA5InII0bTorId3q12Wtx8Q1Fd5fP6ourEASr2RaYmO/h1fa8lpTMlR0V6gwaBuLETdFJMjIgfohTJI9VWApPH42bE7pNcQ5d4VJjcsPgDUXFImXHSlqQVMQgcA/mGAsQ44+73a0RB1S0yOiBygn9Q41X1gH8DLxWpa1OBKyZG53igZ0LrI/EbNSRKH9BOpjMkRkQPEmpKjEM+bGbtTejeOwHKJ5EiZ3wj9XbDeyKTvOOWeyRGRKjgJJJED9NOUKA+6aTG2EEC1ocH83C98CCRAWbJDTbIRONlYbxR3jbqxtKd5y5EQHjuJKJGrYnJE5AD9pMbkqJsWY9+dlYNDZ8vNz6/tq8PbgLLWXE2pMp+PGvQ/AHVlgE8gEOmC9UYm0aMBjRdQWawsXswWSCKnYrcakQOYu9W6actR88QIAL4+ZYAcEK08Uan1SAgBw3FlCL+ITVYWerWi5YzfibEh8PPu4OKx9uLt11Qszq41IqdjckTkAOaC7G7acmTy9eM3mh+LXkOVB+cOOT0OIQTuzsrB1198CABYe6YfRBsTUppm/M5fmor8pamW68Q5E+c7IlINkyMiO+uBGvSSGltOusHs2O3xazZrthxuGrHm/Jajmnoj9hecxziNUhC+4UKc5fpvLSgzfnvBX+elTmIENBVlcxkRIqdjckRkZ30lZcFQ4ReiXm2NCzInRyXObzkCgBHSCQRINSgT/jgkYs1F49WGtpMke6ux5Vp9GpMj/QHAUOWYgIjIKhZkE9mZqd5IDo6DkytVXJpQseUIAK7SKEnZLnkYZGhaFY07w7XPfdXxnYNilHmyyk8rs2W78tQDRB6GLUdEdmaqNxLdvN6oJWNYY81RxRkIFdZYS9bkAwBy5HgArYvGO1J43Zli7ZbHAEB8VGDHgu47VrlnUTaRU7HliMjO+kt6AIAIGaByJK7l7jUH8boIRx/pPJ56/QM89dAfnFfPY6xHokZpsfquMTkyyV2YAn+dFn7e2svGYyrWNtUrdeYYQJm6aPiSzZePu28ScPBDFmUTORlbjojszJQcyWEDVY7EtRw6W458WZmvR1tyoN2CaHvTnNmLnlItLogAHBZ9LV7z12ltKrzuTLF282OU45peqzYY2xw5Z647Or1byaiIyCmYHBHZWX/NWQCACGVy1FK+UJKjYVKBU6+rKVCWDMmR4yFc7GsvcdlW3JOVYz1BihwJePkqC+VeOOb84Ii6Kdf6liByd4ZKREqXAAByCJOjlkwtR/Ea5yZH2pPK5I875RFOvW5bWtYh5RZcst6S5qUDon+hPGbXGpHTMDkisiPp4s8AgPMikMP4rTC1HA2WTgNGg3MuaqiGpmgPAGBni3ojtZjqkHIXplx+ZxZlEzkdkyMiO9JcPA4AOCEiVY7ENZ0WvVAu/KCTjJAu/GT38wshGucuamjqpjr1HSSjAUUiDCdd6HNR6pA6MNmDaabs03scGxARmXG0GpEdSY11ISfkKAxXORbXJOGQiEWSdBia4h+BPvZb/NW0REhegdKtmRgboiz9cULpUsuRhwNww9XtTUXZJYfUXbSXqBthyxGRHWkuNiZHIkrlSFyXqe5IU3zAruetqTeaEyOgWR3Pz9sBADuNrtGlZrOevRrX6BNAUa7a0RB1C0yOiOxIauxW+9mFum9cjanuSFP8o+MvVlsGnN0PANgpu3FbnmmdtVPsWiNyBiZHRHYiZNmcHLHlqG0H5TgAgEb/PSDLDr2WtvBbQMiQQwdCjzCHXsuhzMkRi7KJnIHJEZEdCCHwu1c2QVNXDllIKBARaofkso6IvqgWPpDqKgAHFGU3p2kcwm+Mu86h13E4U1F2UR4gO2/yTKLuiskRkR3U1BtRcVpZ2LRIhGNkbMRl19zyNB1dd8wILX4UjevOnXZsDY325y8BAHLc9Q69jsP1jgd0PYG6cuDcYbWjIfJ4HK1G1AlCCIv1tQCgv0ZZNiRiwHBsmJnsvHXDXIS1dcfaWiJkvzwQSZrDSkvI6OkOiaevVKxMraDxgrH/9QByHHIdp9BogZgxwIntymSQEW5cP0XkBpgcEdnI2pDxtb8bhwGSsmyIFDao2yVGJqY1xC7ne7lx9nAHjr66QfO98qDvVYBPoMOu4zR9xzUlR4lpakdD5NEc3q327LPPQpIkPPLII+ZttbW1mDdvHsLCwtCzZ09MmTIFxcXFjg6FyC7aGjJuWnCWa6pd3n55kPKg+CBQX9Pl8ymTP1q2Ul1vSo4GTejy+Z3F6iSWJubJILmMCJGjOTQ52rNnD1599VVceeWVFtvnz5+PTz75BBs2bMD27dtx5swZ3HXXXY4MhcjhBklFAAA5bJDKkbi+MwiD6BEByA3A2e+7dC5TS17isq3mbT4wYLwmX3ky+FddOr+zCAHcnZWD+MWbEb94c+vFaPskKvcXjgFVF9QJkkglq1atQlxcHHx9fZGUlITdu9v+I+H111/Htddei5CQEISEhCAlJaXd/a1xWHJUWVmJ6dOn4/XXX0dISFORZllZGd544w2sWLECv/zlLzFmzBi8+eab2LlzJ7777jtHhUPkUDU1NYgztRyFD1E5GncgwRjTuKBqF4uyW7bkxUcFIlFzBP5SHeSekUCEayw2ezkXqwytWiSrDcamliTfYMD0f4utR9SNrF+/HhkZGViyZAn27t2LUaNGITU1FSUlJVb337ZtG+6991589dVXyMnJQd++fXHTTTehqKiow9d0WHI0b9483HrrrUhJsVxYMS8vD/X19Rbbhw4din79+iEnx3rBZF1dHcrLy823iooKR4VN1Cm/W7EeXpKMCuEHERCtdjhuQY5ubAmx4y/63IUp2DAn2VxvJA+cALhJ/de1z33ValurlqQ+pkVomRxR97FixQrMnj0baWlpiI+PR1ZWFvz9/bF69Wqr+7/77rv44x//iISEBAwdOhT/+te/IMsysrOzO3xNhyRH77//Pvbu3YvMzMxWr+n1euh0OgQHB1tsj4iIgF6vt3q+zMxMBAUFmW/x8W66DAB5nPgopdB3sHQaAHBW1w9+HShIJkDum6w8OPmt0qdkB/46LSSpqRjbOMC1641aTn8ANP2fAoBDZ8vNj3MLLsEQzeSIPENFRYVFo0ddXZ3V/QwGA/Ly8iwaVDQaDVJSUtpsUGmpuroa9fX1CA0N7XB8dv8WP3XqFB5++GFs2bIFvr6+djnnggULkJGRYX5eVFTEBIlcwoY5yZAkwOvr74GvgcHDE7vtSDVbydGjAS8/oPo8cO4I0HuoXc4rXfgJgzVFMAgtjP1vtMs5HaXl9AeAkicOX7LZ6v6yqeXozF7AWA9ovZ0RJpHdtfwdvmTJEjz11FOt9jt//jyMRiMiIiwn1o2IiMDhwx2b8+uJJ55AdHR0q56s9tg9OcrLy0NJSQl+8YtfmLcZjUbs2LEDL7/8MjZv3gyDwYDS0lKL1qPi4mJERlpfj8rHxwc+Pj7m5+Xl5Vb3I3I2SYIydP3iUeV5L/v8gu8WtLqm4ekF37SZHLWcU+pyyaf2yGcAgBx5OMb6th7CHx8ViPzGFpm2Jqp0ppbTH1QbGtrcV4QNBnyDlDXjig8A0aOdESKR3eXn5yMmJsb8vPnveHt69tln8f7772Pbtm02NdjYPTmaMGECfvzRckHJtLQ0DB06FE888QT69u0Lb29vZGdnY8qUKQCAI0eOoLCwEMnJyfYOh8g5zh1R7nsPUzcOdxN3jZIcnfwGGPv7Vi9bm1NKaa1rO0HSHlWSo83yWIy18rqptQ/oWLLlUiQN0GcccGyLsggtkyNyUwEBAQgMvPz8Y+Hh4dBqta2m+2mvQcVk+fLlePbZZ7F169ZWo+Yvx+41RwEBARgxYoTFrUePHggLC8OIESMQFBSEWbNmISMjA1999RXy8vKQlpaG5ORkXHXVVfYOh8jxjPXA+cY1wnpxpJpN4q5R7k9+Y3UR2rbmlGqLVHEG2jN7IQsJW4xjrO/T2Nrnr/Nyr8TIhIvQUjei0+kwZswYi2JqU3F1ew0qzz33HJ555hls2rQJiYmJNl9XlcrRF154ARqNBlOmTEFdXR1SU1Pxz3/+U41QiLru4glArge8ewCBfdSOxr3EjFH+3arOAcU/AlGjunQ67ZHPAQB7xWCcQ7AdAnRBpuSIw/mpm8jIyMDMmTORmJiIcePGYeXKlaiqqkJamjJT/IwZMxATE2MeBPa3v/0Nixcvxrp16xAXF2ce7NWzZ0/07NmzQ9d0SnK0bds2i+e+vr5YtWoVVq1a5YzLEzmWaSHQXlcAGq7lbBMvH2DA9cCRz4GftnQ9OTr8MQBgs9H2vxTdRswYpXuttBAoPwsERqkdEZFDTZ06FefOncPixYuh1+uRkJCATZs2mYu0CwsLoWn23fvKK6/AYDDg7rvvtjhPW0Xf1nDMMVFXmeqNWIzdOYNSlOTo2Fbguj91+jTROA9t4bcAgE+NHly/6BMA9B6utLSd3g3E36F2REQOl56ejvT0dKuvtWyAOXnyZJevxz9zibrK3HLEeqNOMS3vcWoXUHOp/X3bMVmrJEbGftfgLMLsEZnrMtcdsWuNyBGYHBF1FVuOuia4n/JvJ2Tgp62X398qgTu13wAAGkb+2n6xuap+jYNXCnaqGweRh2JyRNQVcgNwXpnjCOFXqBuLOxt6m3J/8MNOHT5COoHBmiIIL18Yh0yyY2AuKna8cn/2e6COyykR2RuTI6IukC4eB4x1gLc/ENJf7XDc14i7lPtjW4CaUpsP/7V2OwDAOPhmwMrEjx4nqA8QHAsII4f0EzkAkyOiLtCU5CsPesdzpFpX9I5XutaMBqU42xZ15bhL+zUAoCFhhgOCc1GxVyv37Fojsjt+mxN1gabkgPIgcoS6gbg7SQKGN7Yeff9ehw4RQqDa0AB53/voKdXimBwNOe46BwbpYkxdaye/VTcOIg/E5IioCzTFB5UHEUyOuizhXmX+nhM7morc22BaVmTk4s9RsvUfAIC1xpsAd5zxurPiGluOivKA+hp1YyHyMEyOiLpAKmFyZDfB/SCuuBkAUP/daxBCWN2t2mBEtUFZVuR2zU4M0OhxUfTEiejbWy0iK4SykGu1oe0lR9xWSH8gIEqZnf10rtrREHkUJkdEnRSESmgqzihPIoarG4wHEEJgabHSGlKX+w5+989NsJYfJS7binuycqCFEQ96KaPbet44H2//8Zet1kq7OysH8Ys3I3FZZ6cIcGGS1KzuiF1rRPbE5Iiok+I1BcqD4NjuMULKwWrqjXhTH4eDcix6SrW45uxai0Vm46Oa/o3zz5ZjpvYLDNDoIfxCoUueY3UR2UNnyy2eJ8aGtGpdcldCCBj6KPMdCSZHRHbF5UOIOmmoVKg8YJeaHUl4tuFevK17Fvdrv0Bl6QnzKxvmJKOm3ojEZVsRjfN41OsDAIDhhkXw8Wl/McnchSnw12nh5621mkS5G1PNVVmhjK0+gOHELuga6iB5+agdGpFHYMsRUScNMyVHHKnWJj9vLRJjQ8zPm7f+tOVr+UrsMI6ETjIi4PN0aKG0HkkS4K/TwhsN+IfuZfSQ6rBbHgJjwm8ue05/nRb+Oi+PSIwApZUtr+ASjokYXBAB8EEd6grz1A6LyGMwOSLqJHO3GluO2iRJEjbMSUb+0lTkL03FhjkdWxD2zw2zUC784H1mD17w/id8YFBeqKvAP73/gbGaoygX/ni8/gFlhFu3JWG3rCxboynkfEdE9tKdv1WIOk2HelwhnVKeRI1SNxgXJ0kS/HVejS03TduVUWcNVkelnRa98XB9OoTGC7drc7BF9xh0G2fD99Uk/Eqbh1rhjXn1D+GkiHLiO3EOW1vbdsnDAADaQtYdEdkLa46IOmGIdAo6yQjhFwopuJ/a4bgl0wiyxNgQqy1KX8mjUX7nOzD8+w/opzkH5P8XAFAo90JG/VzkCs9c6NfU2mYqRhcCGL5kc5v7fyfHAwA0p74DGgyAl84pcRJ5MiZHRJ1wpeZnAIAcOQpaD6ljcQZTq0huwSXzttyCSxaj0pqr7/9LXF/3AiZo9uLvN4dDhAzAr96VUQfPTgBMrW2AMk9Te46IPjgnAtGrvhw4vadpckgi6jR2qxF1wkipMTmKSlA3EDfTvAYpd2FKh46phi8+kcej4aoHYRxyq8cnRpejLJvSlEwKaLBTbqx7+3mbOkEReRgmR0SdMFKjDDGXI1lvZKumGiTPmG/ImYRQJrZsOanlN0yOiOyKyRGRreprcIV0GgAgR41WORjqTkxD+E1MxdrfGhuTo6I8oLZMjdCIPAqTIyIbaUoOwlsy4rwIhAiMUTsc6qZyF6aYC9nPIBxy6EBAGIGTHLVG1FVMjohspNF/DwA4IPfvXqvAO5CyQKwHLg7rQP46rcV/P2PcdcqDn79SJyAiD8LkiMhGmrP7AAA/iv4qR+I5rNXRkG3kuOuVB6w7IuoyJkdENtKc3gMA2CsPVjkSz9F8gdjmkx7WdKI1qeUkip602Gx7jLHXApCA80eBsiK1wyFya5zniMgWVReguXgMALBPHqRyMJ4nd2EK/Ly15kkPr33O9i6ilpMoespis5flFwxEjwbO7FVaj0ZPVzsiIrfFliMiWzS2Gh2Xo1CKAJWD8TzKArGWLT+A7a0/lkuWeE5idNmWtIG/VO6PsYuSqCvYckRki9O7AQB58hUqB+K5Wrb8AN2o9ecyLtuSNvgm4OvlwPFswNgAaPkVT9QZbDkisoE4tQsAkCeYHDlS85af5q0/3bGeqOV7Btp5330SAb8QZa6jxkSeiGzHPyuIOkgY61F3cg98wWJstXTHeiKbWtI0WmBQCvDjBuCnL4DY8U6MlMhzsOWIqINqi36AL+pQLvxxTER3i1YLV+Sp9UTtaaslzarBqcr90S+cExyRB2LLEVEHaRu71PbJg7Bn4U0I66HrNr+cyY0MmgBIGqDkIFB2Ggjqo3ZERG7H7i1HmZmZGDt2LAICAtC7d29MnjwZR44csdintrYW8+bNQ1hYGHr27IkpU6aguLjY3qEQ2ZWm4BsAwC55WOPsxEyMyAX5hwJ9xiqPf2LrEVFn2D052r59O+bNm4fvvvsOW7ZsQX19PW666SZUVVWZ95k/fz4++eQTbNiwAdu3b8eZM2dw11132TsUIrsQQqC6zgBN4U4AQI4cr3JERJcx+FfK/dHN6sZBZCerVq1CXFwcfH19kZSUhN272x5wcPDgQUyZMgVxcXGQJAkrV660+Xp271bbtGmTxfM1a9agd+/eyMvLw3XXXYeysjK88cYbWLduHX75S2VOjjfffBPDhg3Dd999h6uuusreIRF1mhACd2floKZwHz73uYQq4cNlQ8j1DbkF+HIZcPwroK4C8OGcXOS+1q9fj4yMDGRlZSEpKQkrV65Eamoqjhw5gt69e7fav7q6GgMGDMA999yD+fPnd+qaDi/ILisrAwCEhoYCAPLy8lBfX4+UlBTzPkOHDkW/fv2Qk5Pj6HCIbFJTb0RewSVcpTkEANgjD0VCbC8WYpNr6x0PhA4EjHWoO7QJ1YYGCCHUjoqoU1asWIHZs2cjLS0N8fHxyMrKgr+/P1avXm11/7Fjx+L555/HtGnT4OPj06lrOjQ5kmUZjzzyCK6++mqMGDECAKDX66HT6RAcHGyxb0REBPR6vdXz1NXVoby83HyrqKhwZNhErSRr8gEA41MmY8OcZNYbkWuTJIhhkwAAW/7zL8Qv3ox7snKYIJHLqKiosPi9XldXZ3U/g8GAvLw8iwYVjUaDlJQUhzaoODQ5mjdvHg4cOID333+/S+fJzMxEUFCQ+RYfz5oPch4NZCQ1thzpBl3PxMhOuuOEjo5UbTBaJD91g28FANyo2QcfGJBbcMliriQiNcXHx1v8Xs/MzLS63/nz52E0GhEREWGxvb0GFXtw2FD+9PR0fPrpp9ixYwf69GkaShoZGQmDwYDS0lKL1qPi4mJERkZaPdeCBQuQkZFhfl5UVMQEiZxmuHQSgVI1hE8ApMhRaofjMbrjhI6OlLhsKxJjQ8wtm3LUaBSJMMRIF3Ct5kdslceoHSKRWX5+PmJiYszPO9v95Sh2bzkSQiA9PR0ffvghvvzyS/Tvb1m8OmbMGHh7eyM7O9u87ciRIygsLERycrLVc/r4+CAwMNB8CwhgcSE5z/Wa7wEAxthruVaVnXXHCR3tqWXrm6l1SAiB6noZm43KkP6JWi4lQq4lICDA4vd6W8lReHg4tFptq+l+2mtQsQe7J0fz5s3DO++8g3Xr1iEgIAB6vR56vR41NTUAgKCgIMyaNQsZGRn46quvkJeXh7S0NCQnJ3OkGrmkG7X7AQDGgSnt70jkZKbWt9yFTf83hQDuzspB4rKt2NSYHKVo8qBDvVphEnWaTqfDmDFjLBpUZFlGdnZ2mw0q9mD3P4NfeeUVAMANN9xgsf3NN9/Eb3/7WwDACy+8AI1GgylTpqCurg6pqan45z//ae9QiLpECIGasnNIkI4BAGQmR+SClNa3plqti1UG5BVcAgDkiiE4L4UiHBdxg2Y/gNvUCZKoCzIyMjBz5kwkJiZi3LhxWLlyJaqqqpCWlgYAmDFjBmJiYsx1SwaDAfn5+ebHRUVF2L9/P3r27IlBgwZ16Jp2T446MhrC19cXq1atwqpVq+x9eSK7MM1vFHPqM7yoEzgs90W/wJjLH0iksmuf+8r8ePfCmxD49X3ArpcxWfutilERdd7UqVNx7tw5LF68GHq9HgkJCdi0aZO5SLuwsBAaTVNH2JkzZzB69Gjz8+XLl2P58uW4/vrrsW3btg5dkwUURFaY5je6z3s/AOBwzyQM4UgqclGm2qPcxhYjQBn9F9ZDh9oR9wC7XsYEzV4Ya8sAXZiKkRJ1Tnp6OtLT062+1jLhiYuL6/K0FUyOiNqggYzrNT8AAO64ZyYLhslltRz5BzSN/hO9h+Ow3BdDNadQd+gjIOl3KkZK5B4cPkM2kbsRQqDaYMQY6SjCpXII3yBI/RxX+EdkD81H/lmM/pMkbDReDQDwOrBBxQiJ3AeTI6JmTLVGicu24hbtLgCAcfBEQOutcmREnfeR8WrIQoL21E7gUoHa4RC5PCZHRM2Yao0kyLhZuwcAoB1+h8pREXXNWYRhp9w4ce7et9QNhsgNMDkismK0dAxR0kUIXQCkgb9UOxyiLltnnKA82PsWYOScR0TtYXJEZIVpRmFpyM2At6/K0RB13RdyIkSPCKCqBDj8qdrhELk0JkdELWgg4zbtd8qTYberGwyRnTTAC/WjfgMAMO7+F6oNDV0e7kzkqZgcEbVwteaA0qXmGwwMvkntcIjs5vcHh8MoJGgLvsGkJatxT1YOEyQiK5gcEbVwj3Y7AKAh/i52qZFH2VHsiy9lZebgGdovzAvVEpElJkdEzdWWIVWTCwAwjrpP5WCI7G+1cSIA4Nfa7QhFucrRELkmJkdEjYQQkH/4N3ykehyW+0KOTFA7JCK7y5HjUR8xCn6SATO9vlA7HCKXxOSICI2TP76yE6e+eBkAsMF4HcDlQsgjSagZ9yAApWsNhiqV4yFyPUyOiKBM/qg7/S2GaQpRLXxwLPoO+HGhWfIApkVpTRJjQ6CNvx0n5QiESJXwyntDxeiIXBMXniVq9DvtJgCANPo+rLnjJi40Sx6h5aK0ft5a1NQb8VLDnfi7Lgve370IJM0CfINUjpTIdbDliAiAdPFnTNDsVZ4k/YGJEXmU5ovSmv5vfyhfg2NyNKSaS0DOP1WOkMi1MDkiAuCd8w9oJIEvjQkQYYPVDofI4WRo8PeGe5QnOS8DFcXqBkTkQpgcEV04Du0P7wEAXm6YrG4sRE60SR6L+sjRgKESyH5a7XCIXAaTI6Ltz0ESRnxlHIW94gq1oyFyGgEN7im4U3my/13g1G51AyJyEUyOqHsrzgd+/AAAsMLUxUDk4ZqPYNsvBmF9ww0AAPFZBmCsVzEyItfA5Ii6LyGAz/8ECBkNQ27Dj2KA2hEROYVpBFvuwhQAwHMNU1EqekDS/wjx9d9Vjo5IfUyOqPv64QOg4FvAyw/1KcvUjobIqSRJQlgPHRJjQ3ABQVhUn6a8sON54Mx+VWMjUhuTI+qWROU5iC8WAgAM1zyKKr9olSMicr7mLUifyMn41JgESW6A/O/fQdSUqh0ekWqYHFG3I2QZuS9Oh1RVgp/kGIzYNBiJy7aqHRaRKpQ5kLQAJCyqT0ORCIPm4nHsXjkNQpbVDo9IFUyOqFsQQqDa0IBqQwMMu/6FsYZdqBNeeKg+HQZ4m/dLjA3hsiHU7ZgKtC8hEHMNj6BOeCGpLgcNX/5V7dCIVMHlQ8jjCSFwd1YO8gouIUk6hHd8MgEAzzVMw9v/N6vxr2aFn7eWs2NTt9N8iZFqgxELM0/hee/X4P3N80BgBDButtohEjkVW47I49XUG5FXcAn9pbN4VbcC3mjAp8YkrDbeDH+d1rysQvOlFYi6m6YlRrTYYLwBL9RPUV74/DFg71ut9m/eGiuEcHK0RI7FliPqFgZIZ7BO9xcES1XYKw/Co/VzIfi3AVGb/mG8C+lXhcA771/Axw8CteXA+HQAlq2xgNIdvWFOMv+4II/B3w7kcVr+RSvpv8d63TOIlC7hiNwHsw2Pog46tcMkcnES6m96Fhj/oPL0i/9Dw8Z0iPpac2usSW7BJdTUG1WKk8j+2HJEHqXlX7Tze+/DQ9UvwU+qxUE5Fr8xLMAlBKocJZGbkCSIlKV4+4dq/KZiNbz2v43jB3IQNXO12pERORRbjsgtNG8NsnYz1TyY/qINRTn+4f0yHi5/HlJDLb4yjsK9hoWIiupjPidHphG1r9pgRHW9jMXnJ2Bm/RO4JHpiYMMx+L35S8z32gB/1KodIpFDsOWIXF7L1iBrTDUPMFRhrvZj/MHrEwRLVTAKCS8bJ+MfDVMgQ9NYF6Ecw5FpRO1LXLYV8VFKS+vX8pW4qe5veMZ7DW7GHjzs9SHu02bjlYY78L7xRpUjJbIvVVuOVq1ahbi4OPj6+iIpKQm7d3NF6O7ocqNeWtY3WFNSeBg1/1sM31W/wBPe7yNYqsIJbRwmG57BCw33QIYGibEhFqPTmBgRtdZ8UVoAyD9bbn7cKyoWc+ofwVzDwzgpR6CXVI7F3m/jO595kD5/DOJ0LtBs4siWP9sc4UadZWu+sGHDBgwdOhS+vr4YOXIkPv/8c5uup1rL0fr165GRkYGsrCwkJSVh5cqVSE1NxZEjR9C7d2+1wiIns3XUS+7CFGVeovpqaPQ/wHh8G4598x+M0vwMNP6sFMi9sbJhCpYtehrrtZzDiMgWpjmPLlQZWs0c39TyejNgXIC6H97Dmc+eRX9NMbB/NbB/NUTPSEhDJkLEXo25O7yw6bQ3AAljGhMujnAjW9maL+zcuRP33nsvMjMzcdttt2HdunWYPHky9u7dixEjRnTompJQKX1PSkrC2LFj8fLLLwMAZFlG37598eCDD+LJJ59s99jTp0+jb9++OHXqFPr06dPuvrYQQnDEhZNVG4ytvoBzF6bA31sCai5Bqr6AuvJzeOqtTeiv0WPeKA28LhwFSvIBucF8jCwk7JCvxHvGG5Et/wIJsb34xUvUBUII3JOVg9x2khkhBH79yrfwOf0Npmq/wg2a7xEg1Vic57wIxHERjWNyDE6ISBSLEJSIEJQgGP/+02T49QwEJJa/uhJH/iHZmd/ftuYLU6dORVVVFT799FPztquuugoJCQnIysrq0DVVaTkyGAzIy8vDggULzNs0Gg1SUlKQk5PTav+6ujrU1dWZn1dUVDgkrtqzh7D+n09ZfU2CsPpYed61/RSiQ/t1+FqS9f1gQ0xSh2Oyvl/ra1nu54N6+MKA/+gM5se+kgHS83XwQSW0krK/H4CVppH3B5udsGcE0DcJYvBNqI2bgHE9e2Nc40tsJSLqmuazZgPWf6YkScIHc6/GhaqxSFw2EjrUI1mTj+s13+MXmp8wXDqJcKkc4VI5kjSHW1/kpUcBAJXCF5XwQ1XjfR280SC0aEDTrR5aNMALDdDCKJRkSjR+q4jGbx7lG0OyeN70uOW+/H5oy2+uioW3VgOEXwGMneWQa1RUVKC8vKnL1sfHBz4+Pq32szVfAICcnBxkZGRYbEtNTcXGjRs7HJ8qydH58+dhNBoRERFhsT0iIgKHD7f+AcrMzMTTTz/t8Lik8iKkeW12+HWo48qEPy6KAJwVYajq2Q8pVydDChsIRCUAQX0ASYIEwF/tQIk8kGnW7MvtE9ZDh8TYEOQWXMJ2eRS2y6MAAL6ow6ToCsQ0nIL24k+IlYoxwLcSIfJFBDVcQE9JGe3WU6pFT9S2/OuK1JLbeD9wgsOSo/j4eIvnS5YswVNPPdVqP1vzBQDQ6/VW99fr9R2Ozy1Gqy1YsMAiCywqKmr1D2sPPuH9UT/eMtu0aIJp/pPbqlWiWVNzW8dYPa5r12rzGLucr+U5JeubOxmT8PIFGm86X3/Ayxd1kg7Cyx/CPxTwC4W31hsRACLA1iAiV9Wylak503QZzVugTM+r66uBukpIhkrAUAmprkK5N9YBxgZArgfkBkhG5V651QNCBoTSDmS+Bxofw2Kb1OZ+LAhvi5dGo3xVhw502DXy8/MRExNjfm6t1UhNqiRH4eHh0Gq1KC4uttheXFyMyMjIVvu3bG5r3hRnT1L4IHjftMQh56aO8VU7ACLqlMu1MrV8zV/nBegCgR6clLU7CggIQGDg5T97W/MFAIiMjLRpf2tUqYLT6XQYM2YMsrOzzdtkWUZ2djaSk5PVCImIiIhcTGfyheTkZIv9AWDLli025ReqdatlZGRg5syZSExMxLhx47By5UpUVVUhLS1NrZCIiIjIxVwuX5gxYwZiYmKQmZkJAHj44Ydx/fXX4+9//ztuvfVWvP/++8jNzcVrr73W4WuqlhxNnToV586dw+LFi6HX65GQkIBNmza1KqIiIiKi7uty+UJhYSE0mqaOsPHjx2PdunVYuHAh/vznP2Pw4MHYuHFjh+c4AlSc56grHDXPERERETmOu/z+5sxbRERERM0wOSIiIiJqhskRERERUTNMjoiIiIiaYXJERERE1AyTIyIiIqJmmBwRERERNcPkiIiIiKgZJkdEREREzai2fEhXyLIMADh79qzKkRAREVFHmX5vm36Puyq3TI6Ki4sBAOPGjVM5EiIiIrJVcXEx+vXrp3YYbXLLtdUaGhqwb98+REREWCw2Zw8VFRWIj49Hfn4+AgIC7HpuV8D35/48/T16+vsDPP898v25P0e9R1mWUVxcjNGjR8PLy3XbZ9wyOXKk8vJyBAUFoaysDIGBgWqHY3d8f+7P09+jp78/wPPfI9+f++sO77E9LMgmIiIiaobJEREREVEzTI5a8PHxwZIlS+Dj46N2KA7B9+f+PP09evr7Azz/PfL9ub/u8B7bw5ojIiIiombYckRERETUDJMjIiIiomaYHBERERE1w+SIiIiIqJlulxz95S9/wfjx4+Hv74/g4GCr+xQWFuLWW2+Fv78/evfujcceewwNDQ3tnvfixYuYPn06AgMDERwcjFmzZqGystIB78A227ZtgyRJVm979uxp87gbbrih1f5z5sxxYuQdFxcX1yrWZ599tt1jamtrMW/ePISFhaFnz56YMmWKeVkaV3Py5EnMmjUL/fv3h5+fHwYOHIglS5bAYDC0e5wrf4arVq1CXFwcfH19kZSUhN27d7e7/4YNGzB06FD4+vpi5MiR+Pzzz50Uqe0yMzMxduxYBAQEoHfv3pg8eTKOHDnS7jFr1qxp9Vn5+vo6KWLbPPXUU61iHTp0aLvHuNPnB1j/TpEkCfPmzbO6v6t/fjt27MCkSZMQHR0NSZKwceNGi9eFEFi8eDGioqLg5+eHlJQU/PTTT5c9r60/x+6k2yVHBoMB99xzD+bOnWv1daPRiFtvvRUGgwE7d+7E2rVrsWbNGixevLjd806fPh0HDx7Eli1b8Omnn2LHjh144IEHHPEWbDJ+/HicPXvW4vb73/8e/fv3R2JiYrvHzp492+K45557zklR227p0qUWsT744IPt7j9//nx88skn2LBhA7Zv344zZ87grrvuclK0tjl8+DBkWcarr76KgwcP4oUXXkBWVhb+/Oc/X/ZYV/wM169fj4yMDCxZsgR79+7FqFGjkJqaipKSEqv779y5E/feey9mzZqFffv2YfLkyZg8eTIOHDjg5Mg7Zvv27Zg3bx6+++47bNmyBfX19bjppptQVVXV7nGBgYEWn1VBQYGTIrbd8OHDLWL95ptv2tzX3T4/ANizZ4/F+9uyZQsA4J577mnzGFf+/KqqqjBq1CisWrXK6uvPPfccXnzxRWRlZWHXrl3o0aMHUlNTUVtb2+Y5bf05djuim3rzzTdFUFBQq+2ff/650Gg0Qq/Xm7e98sorIjAwUNTV1Vk9V35+vgAg9uzZY972v//9T0iSJIqKiuwee1cYDAbRq1cvsXTp0nb3u/7668XDDz/snKC6KDY2Vrzwwgsd3r+0tFR4e3uLDRs2mLcdOnRIABA5OTkOiND+nnvuOdG/f/9293HVz3DcuHFi3rx55udGo1FER0eLzMxMq/v/+te/FrfeeqvFtqSkJPGHP/zBoXHaS0lJiQAgtm/f3uY+bX0fuaIlS5aIUaNGdXh/d//8hBDi4YcfFgMHDhSyLFt93Z0+PwDiww8/ND+XZVlERkaK559/3ryttLRU+Pj4iPfee6/N89j6c+xuul3L0eXk5ORg5MiRiIiIMG9LTU1FeXk5Dh482OYxwcHBFi0xKSkp0Gg02LVrl8NjtsXHH3+MCxcuIC0t7bL7vvvuuwgPD8eIESOwYMECVFdXOyHCznn22WcRFhaG0aNH4/nnn2+3GzQvLw/19fVISUkxbxs6dCj69euHnJwcZ4TbZWVlZQgNDb3sfq72GRoMBuTl5Vn822s0GqSkpLT5b5+Tk2OxP6D8TLrTZwXgsp9XZWUlYmNj0bdvX9xxxx1tft+4gp9++gnR0dEYMGAApk+fjsLCwjb3dffPz2Aw4J133sHvfvc7SJLU5n7u9Pk1d+LECej1eovPKCgoCElJSW1+Rp35OXY3rrskrkr0er1FYgTA/Fyv17d5TO/evS22eXl5ITQ0tM1j1PLGG28gNTUVffr0aXe/++67D7GxsYiOjsYPP/yAJ554AkeOHMF///tfJ0XacQ899BB+8YtfIDQ0FDt37sSCBQtw9uxZrFixwur+er0eOp2uVc1ZRESEy31e1hw7dgwvvfQSli9f3u5+rvgZnj9/Hkaj0erP2OHDh60e09bPpDt8VrIs45FHHsHVV1+NESNGtLnfkCFDsHr1alx55ZUoKyvD8uXLMX78eBw8ePCyP6vOlpSUhDVr1mDIkCE4e/Ysnn76aVx77bU4cOCA1dXb3fnzA4CNGzeitLQUv/3tb9vcx50+v5ZMn4Mtn1Fnfo7djUckR08++ST+9re/tbvPoUOHLls06E46855Pnz6NzZs344MPPrjs+ZvXS40cORJRUVGYMGECjh8/joEDB3Y+8A6y5f1lZGSYt1155ZXQ6XT4wx/+gMzMTJee+r4zn2FRURFuvvlm3HPPPZg9e3a7x6r9GRIwb948HDhwoN2aHABITk5GcnKy+fn48eMxbNgwvPrqq3jmmWccHaZNJk6caH585ZVXIikpCbGxsfjggw8wa9YsFSNzjDfeeAMTJ05EdHR0m/u40+dHHeMRydGjjz7ablYPAAMGDOjQuSIjI1tV3JtGMUVGRrZ5TMsitIaGBly8eLHNY7qqM+/5zTffRFhYGG6//Xabr5eUlARAabVwxi/WrnymSUlJaGhowMmTJzFkyJBWr0dGRsJgMKC0tNSi9ai4uNhhn5c1tr7HM2fO4MYbb8T48ePx2muv2Xw9Z3+G1oSHh0Or1bYaGdjev31kZKRN+7uK9PR08+AMW1sPvL29MXr0aBw7dsxB0dlPcHAwrrjiijZjddfPDwAKCgqwdetWm1tb3enzM30OxcXFiIqKMm8vLi5GQkKC1WM683PsdtQuelLL5Qqyi4uLzdteffVVERgYKGpra62ey1SQnZuba962efNmlyrIlmVZ9O/fXzz66KOdOv6bb74RAMT3339v58js75133hEajUZcvHjR6uumgux///vf5m2HDx926YLs06dPi8GDB4tp06aJhoaGTp3DVT7DcePGifT0dPNzo9EoYmJi2i3Ivu222yy2JScnu2xBryzLYt68eSI6OlocPXq0U+doaGgQQ4YMEfPnz7dzdPZXUVEhQkJCxD/+8Q+rr7vb59fckiVLRGRkpKivr7fpOFf+/NBGQfby5cvN28rKyjpUkG3Lz7G76XbJUUFBgdi3b594+umnRc+ePcW+ffvEvn37REVFhRBC+U89YsQIcdNNN4n9+/eLTZs2iV69eokFCxaYz7Fr1y4xZMgQcfr0afO2m2++WYwePVrs2rVLfPPNN2Lw4MHi3nvvdfr7a8vWrVsFAHHo0KFWr50+fVoMGTJE7Nq1SwghxLFjx8TSpUtFbm6uOHHihPjoo4/EgAEDxHXXXefssC9r586d4oUXXhD79+8Xx48fF++8847o1auXmDFjhnmflu9PCCHmzJkj+vXrJ7788kuRm5srkpOTRXJyshpv4bJOnz4tBg0aJCZMmCBOnz4tzp49a74138ddPsP3339f+Pj4iDVr1oj8/HzxwAMPiODgYPMI0fvvv188+eST5v2//fZb4eXlJZYvXy4OHToklixZIry9vcWPP/6o1lto19y5c0VQUJDYtm2bxWdVXV1t3qfle3z66afF5s2bxfHjx0VeXp6YNm2a8PX1FQcPHlTjLbTr0UcfFdu2bRMnTpwQ3377rUhJSRHh4eGipKRECOH+n5+J0WgU/fr1E0888USr19zt86uoqDD/rgMgVqxYIfbt2ycKCgqEEEI8++yzIjg4WHz00Ufihx9+EHfccYfo37+/qKmpMZ/jl7/8pXjppZfMzy/3c+zuul1yNHPmTAGg1e2rr74y73Py5EkxceJE4efnJ8LDw8Wjjz5q8ZfDV199JQCIEydOmLdduHBB3HvvvaJnz54iMDBQpKWlmRMuV3DvvfeK8ePHW33txIkTFv8GhYWF4rrrrhOhoaHCx8dHDBo0SDz22GOirKzMiRF3TF5enkhKShJBQUHC19dXDBs2TPz1r3+1aOVr+f6EEKKmpkb88Y9/FCEhIcLf31/ceeedFsmGK3nzzTet/p9t3vDrbp/hSy+9JPr16yd0Op0YN26c+O6778yvXX/99WLmzJkW+3/wwQfiiiuuEDqdTgwfPlx89tlnTo6449r6rN58803zPi3f4yOPPGL+94iIiBC33HKL2Lt3r/OD74CpU6eKqKgoodPpRExMjJg6dao4duyY+XV3//xMNm/eLACII0eOtHrN3T4/0++sljfTe5BlWSxatEhEREQIHx8fMWHChFbvOzY2VixZssRiW3s/x+5OEkIIR3fdEREREbkLznNERERE1AyTIyIiIqJmmBwRERERNcPkiIiIiKgZJkdEREREzTA5IiIiImqGyRERERFRM0yOiIiIiJphckRERETUDJMjIiIiomaYHBERERE1w+SIiIiIqJn/B1ZtnPMUBE9gAAAAAElFTkSuQmCC", 44 | "text/plain": [ 45 | "
" 46 | ] 47 | }, 48 | "metadata": {}, 49 | "output_type": "display_data" 50 | } 51 | ], 52 | "source": [ 53 | "p = lambda x: 2*jax.scipy.stats.norm.pdf(x, loc=-2, scale=1)+jax.scipy.stats.norm.pdf(x, loc=2, scale=1)\n", 54 | "E0 = lambda x: -jnp.log(p(x))\n", 55 | "## sum of two gaussian densities with weights (2,1) -> Z =3\n", 56 | "\n", 57 | "key = jax.random.PRNGKey(1)\n", 58 | "N = 5000\n", 59 | "X = jax.random.normal(key,(N,1))\n", 60 | "offset = (((jax.random.randint(key,(N,1),0,3)-.5)>1)*2-1) * 2\n", 61 | "X = X+offset\n", 62 | "fig, ax = plt.subplots()\n", 63 | "bins = jnp.linspace(-10,10,200)\n", 64 | "ax.step(bins[:-1],jnp.histogram(X,bins)[0])\n", 65 | "x = jnp.linspace(-10,10,1000)\n", 66 | "ax2=ax.twinx()\n", 67 | "ax2.plot(x,p(x),c='C1')" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 5, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "class E_fn(nn.Module):\n", 77 | " features: Sequence[int]\n", 78 | "\n", 79 | " @nn.compact\n", 80 | " def __call__(self, x, t):\n", 81 | " x = jnp.concatenate((x,t)) \n", 82 | " for i, feat in enumerate(self.features):\n", 83 | " x = nn.Dense(feat)(x)\n", 84 | " if i != len(self.features) - 1:\n", 85 | " x = nn.swish(x)\n", 86 | " return x.sum()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 6, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "class DPM():\n", 96 | " def __init__(self):\n", 97 | " sigma_min, sigma_max = 0.001,0.999\n", 98 | " self.lambda_a,self.lambda_b = jnp.arccos(sigma_max), jnp.arccos(sigma_min)\n", 99 | " self.E_fn = E_fn([64,64,64])\n", 100 | " key = jax.random.PRNGKey(1)\n", 101 | " self.params = self.E_fn.init(key,t=jnp.ones(1,),x=jnp.ones(1,))\n", 102 | " \n", 103 | " \n", 104 | " def beta(self,t):\n", 105 | " return 2 * (self.lambda_b - self.lambda_a) * jnp.tan(self.lambda_a + t * (self.lambda_b-self.lambda_a))\n", 106 | " \n", 107 | " def gamma(self,t): # exp(.5 * int 0 to t beta(s) ds)\n", 108 | " return jnp.cos(self.lambda_a + t * (self.lambda_b-self.lambda_a))\n", 109 | "\n", 110 | " def sigma(self,t): #sigma^2 = 1- exp(int 0 to t beta(s) ds)\n", 111 | " return jnp.sin(self.lambda_a + t * (self.lambda_b-self.lambda_a))\n", 112 | "\n", 113 | "\n", 114 | " def E(self,params,x,t):\n", 115 | " E1 = lambda x:-jax.scipy.stats.norm.logpdf(x, loc=0, scale=1)\n", 116 | " return t[0]*(1-t[0])*self.E_fn.apply(params,x,t) + (1-t[0]) * E0(x).sum() + t[0] * E1(x).sum()\n", 117 | "\n", 118 | "\n", 119 | "\n", 120 | " def train(self,X,key,steps,lr,batch_size=32):\n", 121 | " optim = optax.adam(learning_rate=lr)\n", 122 | " opt_state = optim.init(self.params)\n", 123 | "\n", 124 | " def loss_fn(params,key,x0):\n", 125 | " key1,key2 = jax.random.split(key)\n", 126 | " t = jax.random.uniform(key1,(len(x0),1))\n", 127 | " z = jax.random.normal(key2,x0.shape)\n", 128 | " xt = self.gamma(t)*x0 + z*self.sigma(t)\n", 129 | " \n", 130 | " dEdx,dEdt = jax.vmap(jax.grad(self.E,(1,2)),in_axes=(None,0,0))(params,xt,t)\n", 131 | "\n", 132 | " score_pred = -dEdx*self.sigma(t)\n", 133 | " score_target = -z\n", 134 | "\n", 135 | " L = ((score_pred -score_target)**2).sum(-1,keepdims=True)\n", 136 | " return L.mean()\n", 137 | "\n", 138 | " @jax.jit\n", 139 | " def update_step(key,params,opt_state,x_batch):\n", 140 | " grad = jax.grad(lambda p: loss_fn(params = p, key = key,x0 = x_batch))(params)\n", 141 | " updates, opt_state = optim.update(grad, opt_state, params)\n", 142 | " params = optax.apply_updates(params, updates)\n", 143 | " return params,opt_state\n", 144 | "\n", 145 | "\n", 146 | "\n", 147 | " for i in range(steps):\n", 148 | " \n", 149 | " key1,key2,key = jax.random.split(key,3)\n", 150 | " x_batch =jax.random.choice(key1,X,(batch_size,))\n", 151 | " self.params,opt_state =update_step(key2,self.params,opt_state,x_batch)\n", 152 | " print(f'step {i+1}/{steps}',end='\\r')\n", 153 | "\n", 154 | "\n", 155 | "\n", 156 | " def sample(self,key,N,num_steps):\n", 157 | " dt = 1/num_steps\n", 158 | " @jax.jit\n", 159 | " def step(key,t,x,logZ):\n", 160 | " key = jax.random.split(key,2)[0]\n", 161 | " z = jax.random.normal(key,x.shape)\n", 162 | " dEdx,dEdt = jax.vmap(jax.grad(self.E,(1,2)),in_axes=(None,0,None))(self.params,x,t.reshape(1,))\n", 163 | " score = -dEdx\n", 164 | " logZ += dt*dEdt.reshape(-1)\n", 165 | " x += .5*self.beta(t) * (x +2*score)* dt + jnp.sqrt(self.beta(t))*z *(dt)**(1/2)\n", 166 | " return key,t-dt,x,logZ\n", 167 | "\n", 168 | " x = jax.random.normal(jax.random.PRNGKey(4),(N,1))\n", 169 | " logZ = jnp.zeros(N)\n", 170 | " t = 1-dt/2.\n", 171 | " \n", 172 | " \n", 173 | " for _ in range(num_steps):\n", 174 | " key,t,x,logZ = step(key,t,x,logZ)\n", 175 | " return x,logZ\n" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "### Train" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 7, 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "step 15000/15000\r" 195 | ] 196 | } 197 | ], 198 | "source": [ 199 | "model = DPM()\n", 200 | "model.train(X,key,steps=15000,lr=1e-3)\n" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "### Sample, estimate logZ" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 8, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "3.0421011\n" 220 | ] 221 | }, 222 | { 223 | "data": { 224 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA6GklEQVR4nO3deXxU5aH/8e9kmwTIwpYFEiBKWGVXICCiBUSkFlpvi14U2ootLdxC9Vev0brU1obWKrVeRa0LrRapG9CymoIsyiZLJCxG9rAkAYSsJCHJPL8/QgYmG0nI5CScz/v1mtdr5pznzDxzRObLszqMMUYAAAAW8bG6AgAAwN4IIwAAwFKEEQAAYCnCCAAAsBRhBAAAWIowAgAALEUYAQAAliKMAAAAS/lZXYHacLlcOnnypIKDg+VwOKyuDgAAqAVjjHJzc9WhQwf5+FTf/tEswsjJkycVExNjdTUAAEA9HDt2TNHR0dWebxZhJDg4WFLZlwkJCbG4NgAAoDZycnIUExPj/h2vTrMII+VdMyEhIYQRAACamSsNsWAAKwAAsBRhBAAAWIowAgAALEUYAQAAliKMAAAASxFGAACApQgjAADAUoQRAABgKcIIAACwFGEEAABYijACAAAsRRgBAACWIoxc5sCpPL2+/qAKi0utrgoAALbRLHbtbSyjX1gnScouKNavxvawuDYAANgDLSNV2HLorNVVAADANggjF5WUutzPC0vopgEAoLEQRi56Z/NR9/PCYlcNJQEAQEMijFy0Yf8Z9/OCC7SMAADQWAgjFzkue15U4p2WkQslLm0+9I0ueOn9AQBojggjVSjy0tTeuf/5Wve8vllP/WuPV94fAIDmiDBShaJS77RczFt7UJL03tY0r7w/AADNEWHkIlPti4bj6+O4ciEAAGyGMFIF46U04u9LGAEAoCLCSBWMl1pGAny53QAAVGT7X8dv8opUXGGMiJeyiAL8fL30zgAANF+23pvm8Jl83fantYoLb6WYNi3cx11eaho5k1fklfcFAKA5s3XLyPKUdEnS/lN5Hse91U0DAAAqq1MYmTdvnvr27auQkBCFhIQoPj5eK1asqLb8/Pnz5XA4PB6BgYFXXenGkF9U0qDvd+h0Xo3nNx48o7c/PyxzMQnlFZUoPbugQesAAEBTVKdumujoaM2ZM0dxcXEyxuhvf/ubJkyYoJ07d6p3795VXhMSEqLU1FT3a4ejecwo6f3UKv3x7r76wU0xDfJ+j36U4vH6VG6hwoMvBbP//usWSVLrFgHaeuSsFmwpW4vk80e/pY5hQQ1SBwAAmqI6tYzcdddduvPOOxUXF6du3brp2WefVatWrbR58+Zqr3E4HIqMjHQ/IiIirrrS3pB1/kKlY498tKvB3v/L41ker9emnnY/z8gudD8/eDrPHUQkaV3qaZW6jLvFBACAa029x4yUlpZq4cKFys/PV3x8fLXl8vLy1LlzZ8XExGjChAnas+fKS6EXFRUpJyfH4+FtO9KyvPr+ndu28Hj97mW7BG89ctb9/KU1BzzKPbYoRdc/tlyxCcu9Wj8AAKxS5zCSkpKiVq1ayel0avr06Vq0aJF69epVZdnu3bvrrbfe0pIlS/Tuu+/K5XJp2LBhOn78eI2fkZiYqNDQUPcjJqZhukoqasweo4pdLbuOZ7uf+9ayIlcadwIAQHNU5zDSvXt3JScna8uWLfrZz36mqVOnau/evVWWjY+P15QpU9S/f3+NHDlSH3/8sdq3b6/XXnutxs9ISEhQdna2+3Hs2LG6VrPJWff16UrHyruGCmq5Md+3nl+n6x9bTpcNAOCaUud1RgICAtS1a1dJ0qBBg/TFF1/oxRdfvGLAkCR/f38NGDBABw4cqLGc0+mU0+msa9WarPyiErku5odAfx8VFpctstb/mST94MZolZTWPlyUuoxW7M7QqZxC+fo4dH98Fy/UGACAxnPV64y4XC4VFdVuMa/S0lKlpKQoKirqaj+2WckqKHY//8kt13uce3/bcX2880Slax4e063a9/vNv/fo6X/v1RNL9ujY2fMNV1EAACxQpzCSkJCg9evX68iRI0pJSVFCQoLWrl2ryZMnS5KmTJmihIQEd/lnnnlGn3zyiQ4dOqQdO3bovvvu09GjRzVt2rSG/Rb15FDjDBopvKwbZvKQTlcsH9bCXwM7t672/OVjTPame39wLwAA3lSnbppTp05pypQpSk9PV2hoqPr27atVq1ZpzJgxkqS0tDT5+FzKN+fOndODDz6ojIwMtW7dWoMGDdLGjRurHfB6rSq4UBZG/HwciggJVLeIVvo6s/rBqFnni9UioPp9bE5eNhV42a50je0d2XCVBQCgkdUpjLz55ps1nl+7dq3H67lz52ru3Ll1rlRzsD8zVxGhgQoJ9L9i2c2HvpEkRYaWLXJ2z02d9MzSqgf9luvY+tLsm4U/Gap7Xq96LZcgfzbfAwA0b7bem6a+dh3P0pi56/WtP62tVfnfLdsnSTp+rmx591E9w2ssHxLop/DgQC2YNkRLZgyXr0/13UknWTIeANDM2TqM1HedkdX7TkmSzuRVXrW1NloEVN0g9ckvb9GYXhFa8OBQSdKwru3ULyZMse1aVvteG/af0ROLdzPdFwDQbNV5au+1pL6/3/6+VzfwNaiK8SB/vLuvukUE669Tbqx0rl0rpz755S0K8vdVUUmpRr+w3uP8O5uPatqIWHVuW31oAQCgqbJ1GKmNjOxC91iPcn6+V9eg1DLAVzd3bafC4lK9/9N4ORxX3kCwW0Sw+/mROeOVdf6C+j+T5D62Lz1HxaUujX5hvR6/s6cevOW6q6ojAACNhW6aK8i/UOJ+viT5hKb97QsVXVy0TJK+PJZVj8916J0HBuuD6fHy8XHUayfjsBYBHq8zc4rcLSbPLt9X5/cDAMAqtIzUwayFyZKk1V+dch/bfTJb/WLC6vxe9QkgFbUPdup0btmCc+mXTfeVpJJS11W34AAA0Bj4tbqCqsaVXH6sfA2R6pS6LhUeVMNCZvWx9H9uVu8OIZKko9/ke5zr+viKBv0sAAC8hTByBRkVWhwqqthdUlFRyaWw8srkgQ1Sp3IRIYGaNiJWknT4TH6l88tT0hv08wAA8AZbh5HadJTc9+aWGs//vw++rPL4hv2nNfmNzer15Cr3sXatGn7zv4iQssG1X2XkVjr383/saPDPAwCgoTFmpJa+yavdZoDl7n9za6VjNS1eVl/hwYE1nv8mr0htvRCCAABoKLZuGamtnMJivbh6f63LZ+bU3LXTkLq0bVHp2N9/PNj9vLqWGwAAmgrCSC3874e7lFtYcuWCFw35/Wov1saTn6+P4q9r6379++/20S3d2rtfZ+TUrUUHAIDGRhiphRW7M2o1vqQmoUFX3lCvvv4xbYj7+fCuZcHkvYtLyu9Lz9HLnx7w2mcDAHC1CCO1VNO6IK7Lpu+WlLqqLLPlsVENXqdyPj4ObXlslJb/YoR7SfhuEa3c559bleq1zwYA4GoRRmqppjXK3tl81P08v4p1R0Z2a69A/8r70TSkiJBA9bq45ogktWnpOeWYjfQAAE2VrcNIXRZBvbz1o6J/fnHM/byqlpF1X5+uU70agsPh0FN39XK/Pne+uFIZAgoAoCmwdRipi493nqj2XNb5C+7nFZdlt9KPhscqIqRsWu/nB85o1sKd2n0iW5KUmpGr2ITl6vLoMmUXVA4qAAA0FsJIBSGBtVt6JazFpQGpeUVlM22yzl/Qt1/6rFLZR+7o3jCVq4fo1mVTf//nvZ1aknxS337pM5WUujT2z+vdZW7+wxqrqgcAAIueVVTbzeX8fC6VKw8j24+e8yiz95mx2nU8W0Ni2zRcBesoPLjygmeXj3GRpNzCEu06nqW+0WGNVCsAAC6xdcuIo4oJuwNquQOvv++la6sbTtIiwE9Dr2vbIDv01ldVS9D/5t97Kx37zv99rqf/tacxqgQAgAdbh5Gq/GJUnH41trs+nB5fYzk/38oBo4Yxrpa5fIrvlczfeEQb9jf+YFsAgL0RRioICvDVjNu6qkdUSI3l/H0q37rzF2q/SmtjmXRTJ/3gxmj9eVJ/90Jo5QZ2CqtU/v43t6rUy6nqZFaBTuU2nYG+AABrEUYqKG/vCLrCuiC+Pg61a1W2lseIuHaSpIIq1hixWoCfj/74X/00cUBH3dDRM2DtSMvS43f2rHTN9Y8t90ogycguVJdHl2nYnDUa/OxqvbbuYIN/BgCg+SGMVFA+vsPXx6Gtj49Sh9Cqd8UNCvDVw7eXzZIpX9Ds/GVhZNkvbvZyTesuONBzSfpgp5+mjYjVWz+8UasfHulx7jf/bvjxI29sOOTxOnHFVyquZsVaAIB9EEYquHysaXhwoGLaVN4VVyprOXH6ld2+wuKyEHImr2xTuon9O6h3h1DvVrSe7h4Y7X7+zrQhcjgc+laPCF3fvpV+OKyL+9yxs+c9rntp9X7NWLCjXi0mxhhlFxTrjc8OVzr3yZ5MLdiSpqISz1al5Snp+uLI2Tp/FgCg+WFqbwUVh6X6+lQ9E+bL41maevHHu6ik7F/3r6wt63ZYnHxSf75ngLeqeFWe/0E/fbtflPKLStS/wsyhp7/TW306hurhD77Up6mnNXzOGi37xc3y9/XR80lfS5Ju7tpO9w7uVOvPS88uUHxi9euYzFiwQ5L02KIU/XvmzeoTHaqlu05q5oKd7jJ39I7UH7/fVyGB3ttsEABgHVuHkapm3PpUOFhdGCksdrlbRsrDSHNxW/fwas9dH35p9s2JrAL1fybJ4/xvl+6tUxipKoik/u4OzfjHDv1n3ymP43f9X+UF4yRp5Z4MhQT56Y//1a/WnwsAaD7opqmgYkCpLozc1KW1nH5lY0WKij27GKaPvN4rdWsMceE1TwWu7vynqae0JPnSkvmzFu7U9Y8tr1TunQcGy+l3abxNbSXtzaxTeQBA80EYqaDiQmh+1YSR0KAAOf09W0bKZ6sMuc66FVevVkunn9ZUGMx6uS+PZyujwv47T/9rj3709heatTBZ24+e0zd5RVqSfLLS+BKHQxoR116S1DMqRCtnj9DmhFF6/6eV13S5e2C0Av0v/fE8d76Y6cAAcI2ydTdNVSq2jFTstinX0umrwAotI2fzyjbMaxnQvG/rde1b6cic8XK5jDYe/Eb3vbnF4/xHO45rxm1dVeoylVo/9pzMVtrZ/Crf940pN3q87hFZFt4iQpyKaROkY2cLdHuvCP3h7r5q3TJAid/rI5cxuvPFDTp0Jl+Dn12t7w+K1lPf6a1WzuZ9jwEAl/A3egUVs0dVK61KUvfIYI+WEZfL6OTFFoPqunaaGx8fh26Oa+feY+e3S/dqz8kcPbcqVc+tSq20bokkPbmk8pTgEXHt9PYPb6p23x+Hw6ENj3yr0vGAi2NyhlzXRofOlAWcD7YfV6tAPz11V++r+WoAgCaEbpoKKu4jU7Fl5LcTeutHw7vogZtjPab2Xj6ItbY7/zYX5XvsPDjiOo/ju0/k1HhdRIhTf/vxYP3fvQNrvQFhVSqOL3n78yMqYX0SALhmEEYqqNioUXHMyIi49nrqrt5y+vm6B7DmXyjVXy9b0Ku6tUmau5pm4Xz0s8rjPhb9fLhGdmuv0BZXNyW3XSunfj3ec6XYMXPXq+BCKaEEAK4BhJEKKg5g9fGpfqrv5QMsX7i4DodU/aDX5i60hb8eHHGpRajcm1NvVJ+OYe7XPaNC9NVv71CHsKAG++xpI67TkTnjFXYx2Bw+k6+eT65U18dX6LP9Z6q8xpgmuHMhAKCSa6s/oY4W7TxR6VjFMSMV95sJvqwLprxlpKJrZcxIVR4f30uPj+8lSUr75rwC/HwUeXHJ/PJBrxUDXEOadnOs/vTJ1x7H7ntzi/b8ZqxaXhzUuvdkjuZvPKy96TkK8PXR+z+Nv6puIgCAd9k6jOw5WXnMQ8Wf0RW7Mzxeh7UIcD8P8Kt+QKYddGpbuTvKm0FEkn5yy/UyRu4VYcv9Z1+mJvTvqFO5hbrzLxs8zr2y9qBHy9UrkweqQ1iQXl9/UP6+PvrewGiVlLo0qmeEV+sOAKiarcNIlerwW3ott4A0VQF+PvqfUXEacl1b/eC1TfL3dai41Gh5SroiQgJ1z+ubK13zQoXg8vN/7PB4vST5pCTp9fsH6Vcf7lJ2QbG+eHy02gc7ZYzxWrjMKyrRb/61Rx9sP66Nj36rQbu1GtPekznaeeycvjcgWkEBNe92DQBVIYxUUHHMyOUq7uUC6wyObaMjc8ZrX3qOxr24Qav2ZGrVnqtbpfUn72x3P7/p2f8o2Omn3KISbXlslCJCqt69uT6MMbrvzS36/MA37mPD5qxR6u/uqLbr72qUb+RYvrt0Q3ht3UG9uu6gnrqrt2b/M1mStOXQWf3l3qa5JxOApo2O9Aoq/iO4d4dLa2l8b2DHRq4NrqRHZLDHQOJyT367l96cemmRtQdHxGpzwijd3LVdrd87t6hEkjTk96vrPRj2yJl83frcp3r/i2PKKyrRjrRzik1Y7hFEyn335Y3KLyrRkTP5Kil16cCpvKsahHvkTL66PLpMPZ5YqR5PrNTB03n1fq/LncopVOKKr3TufLE7iEjSv748qS6PLlOXR5fpl/9M1uncogb5PADXPlpGKqjYLtLismbn4tIr/zBcqzNpmiqHw6F7buqk+RuPuI8dmTO+yueS9O60ISp1Gfk4Lo3tMcao15OrVHCxBcHPx6GSCkvZxyYs1z+mDdHwi2Hm0Ok8tXL6KbxCi8m6r09r6ltbJUntg53uH+RHPtqlRz7aVeV3+MGN0Xp/23HtTc9R76dWVTr/4j39NTi2jf73oxT9eHgX9YwKUYsAXwVfYRfjW/+01uP175bu1ds/GlzjNVeScjy72g0NL7do5wn3APHff7eP+kaHKrug2H3/AOByhJEKahofUNW/Ut/+4U360fwv3K9fvIdm6sY2e3SczuZf0Cd7M/TMhBuuWL7iWB+Hw6F9v73D/drlMhqauFqnKvzLfvIbW7T1sVF6ZuleLd2VLkmacdv1+unI67XxwDfaezJbf1lzwF3+Si0DK2ePUI/IELlcRu9vO15tuVkLk93P13992v18SGwb/e3Hgz26X55dtld/3XC4yvf5NPW0ujy6zP065enb5e/ro8LiUoW1CFBuYbGKS41aBPhW2aWz92SORxAZ3ydK+0/lqk/HMD0+vqcmvbZJ+09Vbn15bFGK+3mAr49WzB6h69vXvCEjAHtxmGawGENOTo5CQ0OVnZ2tkJDKS5DX1+V/MZdLfnKMx4yZu+dt1Paj5yRJvx7fU9MqrEIqSduPntPd8zZKkjY8cts1u+iZnWRkF+pkdoFaOf10+9z1Dfre3xvYUS/8oL/Hsa2Hz+oHr21yv+4W0UpfZzZMt8rB39+pRz/apQ+2Vx94Kpo1Kk4/Hh4rfz+HPv3qtN747JB2pmW5z9/QMURLZtxc5SBuY4zWpp72COlVmRLfuVbhEUDzVdvfb1pGKqg4gLXlZRuyVfyXcrl+0aFq2zJAgf6+BJFrRGRooHv9lPceHKqH3k9Wenbtdg1+eEw3TR3eRX2f/kSS9P5P4xUZEqh2wQFqUc0miuUDcit687PD+u3SvfX8FtLWx0fJ18ehqcO61CmMvLh6v15cvb/Kc98d0FFzJ/Wv9lqHw6HbeoTryJzxyi8qUe+nVikixKnvDYzWvLUH3eX+vumo/r7pqA48O67KdWB2Hc9SZGigwoMbbvAwgKaJlpEKvnzqdoUGXeqLP/pNvkY+t1aS1C8mTEtmDK/yvbw5BRTWKyopVfdfr3S/njd5oJbvztC/vzzpUe6+oZ30u4l9JJXNYikudV1xbMeVbDtyVodO52vCgA46cua8olsHqaXTT3e+uEF70yuvlfPQmG76xai4Ssf3nMxWREig2rYM0Afbj+uRDyuPYeka3koHquhquVxVoam29mfmakwVLU1bHx+lklKjsBb++tm7O7Tusu6oy3UMC9KJrAJJ0l39OmjWqDjd+eIGXbi4LcADN8fq0XE95M8id0CTUNvfb8JIBbuevl0hFX48+jy1SrlFJdrxxBi1aRlQ6RrYx56T2Wof7FR4cKDSswsUn7hGkvT178ZVuwiet5zIKtCh03nuGUJ/3XBI7YOd+u6A6Dq9T8UBvZk5hRry+9UeZf4982Z9eTxL378xukGmH395LEsTXv78qt+nKnHhrZT00Mhqz+9MO6fvvrLR49i3eoRr/dentf6R2+QyRoXFLkW3DtLMBTuUciJbr943SAM6tfZKfYFrGWGkFqoKIylP317lv2Rp+UBVzuVfkJ+v46pbP+yo4EKpej658soFG8DDY7rpO/076H/e26ldx7PrfL2fj0MbE75VZZcRfzcA1SOM1EJVYWT3b8aqlZOhNEBj2bD/tO5/s2w6dHCgn6bEd9asUd0qtTQZY5R/oVTn8i8ounWQ0rMLtSPtnLq0bakbOoZKksb/ZUOV2zw0pMfv7KkHb7lOJaUubTr0jbvuFfWNDtVf7hmgLu1aehzfeOCMMnIKNb5vlFcWuQOaEsJILVQVRi7fcA1A83L+Qol6PVl5rZaK/vPQLeoaHiypbDzQy58e1B29I/VCUqq+yb/gnjnk7+vQxz8bXqu1VWrSqU0LpZ09X+35dq0C9Mf/6qtb4tqzqWMtFBaXKu3seW09XDae6p7BMeoWEexRpqTUpQ0HziimdZD7v3VFxhhdKHUpaW+mburSxmOl5bP5F7Rh/2n1iAzR2D9fGue08CdD1SMyWKFB/sorKlFxqVHrFv4yxvt7czVHhJFaqCqM7H1mbLUzHgA0fcYYHTydp85tW7oHsh47e16n84rUITTIPUvqStKzCxQZEiiHw6HcwmI9+lGKlqWke7PqbvcOjtHonhH6Vo9wW3YBlZS6KoWyU7mFGv+Xz+q1sm94sFOncos0Iq6d7h4YrZg2Qdp86KyeW5XaUFWWVDaz8q0f3iRfH4f2n8pTz6gQ27e0E0Zqoaowsu+ZO9jsC0C1fvDaJm09fFaSNOnGGMVf31YTB1TeKuL4ufO6+Q+fVjr+8n8P1PCubXX/m1uVcqLm8SttWwbojak36quMXMW2a6m/bTyiFbsz9O2+Ufq//x5Y7XX5RSXadTxb+9JztHJPhh4e001Drmurj7Yf18MffKkekcG6P76z7h4Y3aB7FtXF+Qsl8vPxcXfH5RWV6N3NRzVnxVeW1Mebfjisi1bsTlf/mDD1iAzR3QOjq9z1/FpEGKmFqsLIV7+9w7L/OQE0fcYY7TyWpT4dQ684hbi2g1tLXUZJezM1/d3tVyzrDb8YFaeu4a10V9+oBmmJMcaouNRo1/EsxbRp4dH9sT8zV798P1m7T5SN7XE4pLr+Cg29ro0eu7OnIkIC1b6VU7tPZuvP/9mv3MJijezWXjvTstQ+2Kl7B3eSv6+PNh48o98t21fle80aFadv9QjX80lfe6xwPL5PlJ7+Tm/tPpmtm7u2c/+3PplVoK8ycvSHFamaMqyzBnVurb0nc3Tkm/P6SzVr81TnpyOv0z03dVKXti282gJWWFxq2e8aYaQWCCMAmqLsgmIt2nFcT/+7/gveXY0J/TtoSfLJKs85HFLSL8vG3Bhj9N7WY+4l/8ODnZoS31l/+uTrq/r8Ph1D5TJGfj4O3R/fRZ8fOKOBnVtr0o0xVzWF/vJwWFRS2uADiAuLS/XJ3ky1DPBV/PVtNW/tQb102RYRNQnw89GFEpdGxLXTq/cNqnLsYmFxqU7nFikqNFBHz57XkTP5ysgplJ+PQ88u26ecwrLNPcNa+Gti/45yOKTj5wqUtLdsR/PhXdsqIiRQXcNbqWdUiEpKjQ6dzlPbVk6FBPrp9t6RDXczLiKM1EJVYcSK9SIAoCpn8op0MqtAgf6++sV7OzXuhijNGh2npL2ZevDv26q9rm3LAH3882HqGBYkXx+H3tt6TB/vOK7o1kF64Qf95ePj0PajZ/Xssn3acdky/41p7qR+CvTzVXZBsXIKixUXEazYti3V2cutBE3BruNZWp6Sob9vOqLzF0ornQ/w89Gt3dorM6dQ584X67r2LfV1Rq5O1nIV6Pra+cQYtW7gtbQII7VQVRjZ/+w4Vm8E0OQVlZTKGF11S252QbG+OHxWrVsGaPY/d+rY2YIGqd8t3drrifE99Y8tae5dtWPaBOm27uF6fHxPpjVfdCqnUMtS0nUyq6DaTS5r67bu7fVp6mmFBvkru6DYffy+oZ307uY092uHQ2oV4KfcohKP6/8982b1iQ69qjpU5JUwMm/ePM2bN09HjhyRJPXu3VtPPvmkxo0bV+01H3zwgZ544gkdOXJEcXFx+sMf/qA777yz9t9EjRtGqtsnAwDswuUyMqq8w7VUFl6mvLVVXx7LUoCvj5b94mbFRVQ9dRZ1d6HEpQVbjirlRI6MjFwuo2/yL2jD/jOKaROkd348RMnHsnT4TL6mxHdWS6efHA5VCneFxWUtLkUlLo8tTqr7TG/1CHhlo7zo6GjNmTNHcXFxMsbob3/7myZMmKCdO3eqd+/elcpv3LhR9957rxITE/Xtb39bCxYs0MSJE7Vjxw7dcEPT3K3zWm8eBIArqWm9jNAg/2r36MLVC/Dz0Q+Hx9ZYpuJCelUpbzGrTctZUxiacNXdNG3atNFzzz2nBx54oNK5SZMmKT8/X0uXLnUfGzp0qPr3769XX3211p/hzZYR1wXPPrg9vxnLwjUAAFtp2fLKAac+vNIycrnS0lJ98MEHys/PV3x8fJVlNm3apIceesjj2NixY7V48eIa37uoqEhFRZcWtsnJ8d7yzsfm/pfH65C5XvsoAACaJKuHj9a5bSYlJUWtWrWS0+nU9OnTtWjRIvXq1avKshkZGYqIiPA4FhERoYyMjBo/IzExUaGhoe5HTExMXasJAACaiTq3jHTv3l3JycnKzs7Whx9+qKlTp2rdunXVBpL6SEhI8GhRycnJ8Vogifnlhx6v9/32Dq98DgAAqFqdw0hAQIC6du0qSRo0aJC++OILvfjii3rttdcqlY2MjFRmZqbHsczMTEVG1rywitPplNPprGvV6sUnwHOfCm/1mwEAgKpd9RBal8vlMb7jcvHx8Vq9erXHsaSkpGrHmAAAAPupU8tIQkKCxo0bp06dOik3N1cLFizQ2rVrtWpV2ZbdU6ZMUceOHZWYmChJmjVrlkaOHKnnn39e48eP18KFC7Vt2za9/vrrDf9NAABAs1SnMHLq1ClNmTJF6enpCg0NVd++fbVq1SqNGTNGkpSWliYfn0uNLcOGDdOCBQv061//Wo899pji4uK0ePHiJrvGCAAAaHwsB1/BkTnjG+z9AQCws9r+flu/7BoAALA1wggAALAUYQQAAFiKMAIAACxFGAEAAJaydRiJaRNkdRUAALA9W4eRUT0irlwIAAB4la3DCAAAsB5hBAAAWIowAgAALEUYAQAAliKMAAAASxFGAACApQgjAADAUoQRAABgKcIIAACwFGEEAABYijACAAAsRRgBAACWIowAAABLEUYAAIClCCMAAMBShBEAAGApwggAALAUYQQAAFjK1mHEz8dhdRUAALA9W4eRdsFOq6sAAIDt2TqMGGN1DQAAgK3DCAAAsB5hBAAAWIowAgAALEUYAQAAlvKzugJWMiobwTqmV4S+ysjRnX2iLK4RAAD2Y+swUi4syF/rf3WbHA7WHQEAoLHRTXMRQQQAAGsQRiSRQwAAsI6twwiLngEAYD1bhxEAAGA9wggAALAUYUSSQwwaAQDAKoQRAABgKcIIAACwFGEEAABYijACAAAsRRgRi54BAGAlwggAALCUrcOIYQlWAAAsZ+swAgAArEcYEWNGAACwEmEEAABYijACAAAsZeswwvhVAACsZ+swAgAArEcYkSR27QUAwDJ1CiOJiYm66aabFBwcrPDwcE2cOFGpqak1XjN//nw5HA6PR2Bg4FVVGgAAXDvqFEbWrVunGTNmaPPmzUpKSlJxcbFuv/125efn13hdSEiI0tPT3Y+jR49eVaUBAMC1w68uhVeuXOnxev78+QoPD9f27dt1yy23VHudw+FQZGRk/WroRYxfBQDAelc1ZiQ7O1uS1KZNmxrL5eXlqXPnzoqJidGECRO0Z8+eGssXFRUpJyfH4wEAAK5N9Q4jLpdLs2fP1vDhw3XDDTdUW6579+566623tGTJEr377rtyuVwaNmyYjh8/Xu01iYmJCg0NdT9iYmLqW81aYQVWAACsU+8wMmPGDO3evVsLFy6ssVx8fLymTJmi/v37a+TIkfr444/Vvn17vfbaa9Vek5CQoOzsbPfj2LFj9a0mAABo4uo0ZqTczJkztXTpUq1fv17R0dF1utbf318DBgzQgQMHqi3jdDrldDrrU7U6YdEzAACsV6eWEWOMZs6cqUWLFmnNmjWKjY2t8weWlpYqJSVFUVFRdb4WAABce+rUMjJjxgwtWLBAS5YsUXBwsDIyMiRJoaGhCgoKkiRNmTJFHTt2VGJioiTpmWee0dChQ9W1a1dlZWXpueee09GjRzVt2rQG/ir1x5ARAACsU6cwMm/ePEnSrbfe6nH87bff1g9/+ENJUlpamnx8LjW4nDt3Tg8++KAyMjLUunVrDRo0SBs3blSvXr2uruYAAOCaUKcwYmoxyGLt2rUer+fOnau5c+fWqVIAAMA+bL03jWHZMwAALGfrMAIAAKxHGBGLngEAYCXCCAAAsBRhBAAAWMrWYYQVWAEAsJ6tw0g5B8ueAQBgGcIIAACwFGEEAABYijACAAAsZeswwvhVAACsZ+swUo5FzwAAsA5hBAAAWIowAgAALEUYAQAAlrJ3GGEJVgAALGfrMOK6mEV8GMEKAIBlbB5GytIIWQQAAOvYOoyUd9LQMgIAgHVsHUbKW0Z8yCIAAFjG1mGkfPyqg5YRAAAsY/MwwpgRAACsZuswwmwaAACsZ/MwcrFlxOJ6AABgZ7YOI4aWEQAALGfzMMJsGgAArGbrMFI+ZoQRrAAAWMfWYcSIlhEAAKxm6zBS3jLiYAgrAACWsXUYuTSA1dp6AABgZ7YOIwAAwHqEEQAAYCmbhxFz5SIAAMCrbB5GyjCzFwAA6xBGAACApWwdRgy9NAAAWM7WYaScg34aAAAsQxgBAACWsnUYoZsGAADr2TqMAAAA6xFGAACApQgjAADAUrYOI4YVWAEAsJytw0g5ZvYCAGAdwggAALCUrcMIU3sBALCercNIOYfopwEAwCqEEQAAYClbhxF6aQAAsJ6tw0g5ZtMAAGAdwggAALAUYQQAAFjK1mGEqb0AAFjP1mGkHENGAACwTp3CSGJiom666SYFBwcrPDxcEydOVGpq6hWv++CDD9SjRw8FBgaqT58+Wr58eb0rDAAAri11CiPr1q3TjBkztHnzZiUlJam4uFi333678vPzq71m48aNuvfee/XAAw9o586dmjhxoiZOnKjdu3dfdeWvFhvlAQBgPYcx9R85cfr0aYWHh2vdunW65ZZbqiwzadIk5efna+nSpe5jQ4cOVf/+/fXqq6/W6nNycnIUGhqq7OxshYSE1Le6lTz0frI+3nFCj93ZQz+55foGe18AAFD73++rGjOSnZ0tSWrTpk21ZTZt2qTRo0d7HBs7dqw2bdpU7TVFRUXKycnxeAAAgGtTvcOIy+XS7NmzNXz4cN1www3VlsvIyFBERITHsYiICGVkZFR7TWJiokJDQ92PmJiY+lazZvTSAABguXqHkRkzZmj37t1auHBhQ9ZHkpSQkKDs7Gz349ixYw3+GZdjozwAAKzjV5+LZs6cqaVLl2r9+vWKjo6usWxkZKQyMzM9jmVmZioyMrLaa5xOp5xOZ32qBgAAmpk6tYwYYzRz5kwtWrRIa9asUWxs7BWviY+P1+rVqz2OJSUlKT4+vm41BQAA16Q6tYzMmDFDCxYs0JIlSxQcHOwe9xEaGqqgoCBJ0pQpU9SxY0clJiZKkmbNmqWRI0fq+eef1/jx47Vw4UJt27ZNr7/+egN/lbpjyAgAANarU8vIvHnzlJ2drVtvvVVRUVHuxz//+U93mbS0NKWnp7tfDxs2TAsWLNDrr7+ufv366cMPP9TixYtrHPTa2Ni1FwAA69SpZaQ2S5KsXbu20rHvf//7+v73v1+XjwIAADZh671prmK9NwAA0EBsHUYAAID1CCMAAMBStg4jdNIAAGA9W4eRcg6m0wAAYBnCCAAAsBRhBAAAWMrWYYSZvQAAWM/WYaQcI0YAALAOYQQAAFjK1mGEXhoAAKxn6zBSjpm9AABYhzACAAAsZeswwkZ5AABYz9ZhpBy9NAAAWIcwAgAALEUYAQAAlrJ1GGHECAAA1rN1GCnHrr0AAFiHMAIAACxl7zBCPw0AAJazdxi5iF4aAACsQxgBAACWsnUYMfTTAABgOVuHkXL00gAAYB3CCAAAsBRhBAAAWMrWYYRNewEAsJ6tw4gbc3sBALAMYQQAAFjK1mGEbhoAAKxn6zBSjk4aAACsQxgBAACWsnUYYQVWAACsZ+swUo7JNAAAWIcwAgAALEUYAQAAlrJ1GGFqLwAA1rN1GCnnYHIvAACWIYwAAABL2TqM0EsDAID1bB1GyjG1FwAA6xBGAACApQgjAADAUrYOI0ztBQDAerYOI+UYMgIAgHUIIwAAwFI2DyP00wAAYDWbh5EyTO0FAMA6hBEAAGApW4cRZtMAAGA9W4eRcmyUBwCAdQgjAADAUoQRAABgqTqHkfXr1+uuu+5Shw4d5HA4tHjx4hrLr127Vg6Ho9IjIyOjvnVuMAwZAQDAenUOI/n5+erXr59efvnlOl2Xmpqq9PR09yM8PLyuH+09DBkBAMAyfnW9YNy4cRo3blydPyg8PFxhYWF1vg4AAFzbGm3MSP/+/RUVFaUxY8bo888/r7FsUVGRcnJyPB7eYJjbCwCA5bweRqKiovTqq6/qo48+0kcffaSYmBjdeuut2rFjR7XXJCYmKjQ01P2IiYnxah3ppQEAwDp17qapq+7du6t79+7u18OGDdPBgwc1d+5cvfPOO1Vek5CQoIceesj9Oicnx+uBBAAAWMPrYaQqgwcP1meffVbteafTKafT6fV60EkDAID1LFlnJDk5WVFRUVZ8dJUc7JQHAIBl6twykpeXpwMHDrhfHz58WMnJyWrTpo06deqkhIQEnThxQn//+98lSX/+858VGxur3r17q7CwUG+88YbWrFmjTz75pOG+BQAAaLbqHEa2bdum2267zf26fGzH1KlTNX/+fKWnpystLc19/sKFC3r44Yd14sQJtWjRQn379tV//vMfj/cAAAD25TDNYH5rTk6OQkNDlZ2drZCQkAZ736lvbdW6r0/rT9/vp/8aFN1g7wsAAGr/+83eNGJqLwAAViKMAAAAS9k6jDT5/ikAAGzA1mGkHDN7AQCwDmEEAABYytZhpBlMJAIA4Jpn6zBSjm4aAACsQxgBAACWIowAAABLEUYAAIClCCOSHKzBCgCAZQgjAADAUrYOI8zsBQDAerYOI+WY2gsAgHUIIwAAwFK2DiOGrfIAALCcrcMIAACwHmEEAABYijACAAAsZeswwtReAACsZ+swUs7B3F4AACxDGAEAAJaydRihmwYAAOvZOoyUo5MGAADrEEYAAIClbB1GWIEVAADr2TqMlGMyDQAA1iGMAAAASxFGAACApWwdRpjaCwCA9WwdRso5mNwLAIBlCCMAAMBStg4j9NIAAGA9W4eRckztBQDAOoQRAABgKXuHEfppAACwnL3DyEX00gAAYB3CCAAAsBRhBAAAWMrWYYRdewEAsJ6tw0g5pvYCAGAdwggAALCUrcMIG+UBAGA9W4eRS+inAQDAKoQRAABgKVuHEXppAACwnq3DSDlm0wAAYB3CCAAAsBRhBAAAWMrWYcQwtxcAAMvZOoyUY8gIAADWIYwAAABL+VldASvdPSha8de3VWy7llZXBQAA27J1GJk8pLPVVQAAwPbq3E2zfv163XXXXerQoYMcDocWL158xWvWrl2rgQMHyul0qmvXrpo/f349qgoAAK5FdQ4j+fn56tevn15++eValT98+LDGjx+v2267TcnJyZo9e7amTZumVatW1bmyAADg2lPnbppx48Zp3LhxtS7/6quvKjY2Vs8//7wkqWfPnvrss880d+5cjR07tq4fDwAArjFen02zadMmjR492uPY2LFjtWnTpmqvKSoqUk5OjscDAABcm7weRjIyMhQREeFxLCIiQjk5OSooKKjymsTERIWGhrofMTEx3q4mAACwSJNcZyQhIUHZ2dnux7Fjx6yuEgAA8BKvT+2NjIxUZmamx7HMzEyFhIQoKCioymucTqecTqe3qwYAAJoAr7eMxMfHa/Xq1R7HkpKSFB8f7+2PBgAAzUCdw0heXp6Sk5OVnJwsqWzqbnJystLS0iSVdbFMmTLFXX769Ok6dOiQHnnkEX311Vd65ZVX9P777+uXv/xlw3wDAADQrNU5jGzbtk0DBgzQgAEDJEkPPfSQBgwYoCeffFKSlJ6e7g4mkhQbG6tly5YpKSlJ/fr10/PPP6833niDab0AAECS5DDGGKsrcSU5OTkKDQ1Vdna2QkJCrK4OAACohdr+fjfJ2TQAAMA+CCMAAMBSzWLX3vKeJFZiBQCg+Sj/3b7SiJBmEUZyc3MliZVYAQBohnJzcxUaGlrt+WYxgNXlcunkyZMKDg6Ww+FosPfNyclRTEyMjh07xsBYL+I+Nx7udePgPjcO7nPj8OZ9NsYoNzdXHTp0kI9P9SNDmkXLiI+Pj6Kjo732/iEhIfxBbwTc58bDvW4c3OfGwX1uHN66zzW1iJRjACsAALAUYQQAAFjK1mHE6XTqqaeeYlM+L+M+Nx7udePgPjcO7nPjaAr3uVkMYAUAANcuW7eMAAAA6xFGAACApQgjAADAUoQRAABgKVuHkZdfflldunRRYGCghgwZoq1bt1pdpSZr/fr1uuuuu9ShQwc5HA4tXrzY47wxRk8++aSioqIUFBSk0aNHa//+/R5lzp49q8mTJyskJERhYWF64IEHlJeX51Fm165dGjFihAIDAxUTE6M//vGP3v5qTUpiYqJuuukmBQcHKzw8XBMnTlRqaqpHmcLCQs2YMUNt27ZVq1atdPfddyszM9OjTFpamsaPH68WLVooPDxcv/rVr1RSUuJRZu3atRo4cKCcTqe6du2q+fPne/vrNRnz5s1T37593Ys8xcfHa8WKFe7z3GPvmDNnjhwOh2bPnu0+xr1uGE8//bQcDofHo0ePHu7zTf4+G5tauHChCQgIMG+99ZbZs2ePefDBB01YWJjJzMy0umpN0vLly83jjz9uPv74YyPJLFq0yOP8nDlzTGhoqFm8eLH58ssvzXe+8x0TGxtrCgoK3GXuuOMO069fP7N582azYcMG07VrV3Pvvfe6z2dnZ5uIiAgzefJks3v3bvPee++ZoKAg89prrzXW17Tc2LFjzdtvv212795tkpOTzZ133mk6depk8vLy3GWmT59uYmJizOrVq822bdvM0KFDzbBhw9znS0pKzA033GBGjx5tdu7caZYvX27atWtnEhIS3GUOHTpkWrRoYR566CGzd+9e89JLLxlfX1+zcuXKRv2+VvnXv/5lli1bZr7++muTmppqHnvsMePv7292795tjOEee8PWrVtNly5dTN++fc2sWbPcx7nXDeOpp54yvXv3Nunp6e7H6dOn3eeb+n22bRgZPHiwmTFjhvt1aWmp6dChg0lMTLSwVs1DxTDicrlMZGSkee6559zHsrKyjNPpNO+9954xxpi9e/caSeaLL75wl1mxYoVxOBzmxIkTxhhjXnnlFdO6dWtTVFTkLvO///u/pnv37l7+Rk3XqVOnjCSzbt06Y0zZffX39zcffPCBu8y+ffuMJLNp0yZjTFlw9PHxMRkZGe4y8+bNMyEhIe57+8gjj5jevXt7fNakSZPM2LFjvf2VmqzWrVubN954g3vsBbm5uSYuLs4kJSWZkSNHusMI97rhPPXUU6Zfv35VnmsO99mW3TQXLlzQ9u3bNXr0aPcxHx8fjR49Wps2bbKwZs3T4cOHlZGR4XE/Q0NDNWTIEPf93LRpk8LCwnTjjTe6y4wePVo+Pj7asmWLu8wtt9yigIAAd5mxY8cqNTVV586da6Rv07RkZ2dLktq0aSNJ2r59u4qLiz3udY8ePdSpUyePe92nTx9FRES4y4wdO1Y5OTnas2ePu8zl71Fexo5//ktLS7Vw4ULl5+crPj6ee+wFM2bM0Pjx4yvdD+51w9q/f786dOig6667TpMnT1ZaWpqk5nGfbRlGzpw5o9LSUo+bLkkRERHKyMiwqFbNV/k9q+l+ZmRkKDw83OO8n5+f2rRp41Gmqve4/DPsxOVyafbs2Ro+fLhuuOEGSWX3ISAgQGFhYR5lK97rK93H6srk5OSooKDAG1+nyUlJSVGrVq3kdDo1ffp0LVq0SL169eIeN7CFCxdqx44dSkxMrHSOe91whgwZovnz52vlypWaN2+eDh8+rBEjRig3N7dZ3OdmsWsvYEczZszQ7t279dlnn1ldlWtS9+7dlZycrOzsbH344YeaOnWq1q1bZ3W1rinHjh3TrFmzlJSUpMDAQKurc00bN26c+3nfvn01ZMgQde7cWe+//76CgoIsrFnt2LJlpF27dvL19a00kjgzM1ORkZEW1ar5Kr9nNd3PyMhInTp1yuN8SUmJzp4961Gmqve4/DPsYubMmVq6dKk+/fRTRUdHu49HRkbqwoULysrK8ihf8V5f6T5WVyYkJKRZ/MXVEAICAtS1a1cNGjRIiYmJ6tevn1588UXucQPavn27Tp06pYEDB8rPz09+fn5at26d/vKXv8jPz08RERHcay8JCwtTt27ddODAgWbxZ9qWYSQgIECDBg3S6tWr3cdcLpdWr16t+Ph4C2vWPMXGxioyMtLjfubk5GjLli3u+xkfH6+srCxt377dXWbNmjVyuVwaMmSIu8z69etVXFzsLpOUlKTu3burdevWjfRtrGWM0cyZM7Vo0SKtWbNGsbGxHucHDRokf39/j3udmpqqtLQ0j3udkpLiEf6SkpIUEhKiXr16uctc/h7lZez859/lcqmoqIh73IBGjRqllJQUJScnux833nijJk+e7H7OvfaOvLw8HTx4UFFRUc3jz/RVD4FtphYuXGicTqeZP3++2bt3r/nJT35iwsLCPEYS45Lc3Fyzc+dOs3PnTiPJvPDCC2bnzp3m6NGjxpiyqb1hYWFmyZIlZteuXWbChAlVTu0dMGCA2bJli/nss89MXFycx9TerKwsExERYe6//36ze/dus3DhQtOiRQtbTe392c9+ZkJDQ83atWs9puidP3/eXWb69OmmU6dOZs2aNWbbtm0mPj7exMfHu8+XT9G7/fbbTXJyslm5cqVp3759lVP0fvWrX5l9+/aZl19+2VZTIR999FGzbt06c/jwYbNr1y7z6KOPGofDYT755BNjDPfYmy6fTWMM97qhPPzww2bt2rXm8OHD5vPPPzejR4827dq1M6dOnTLGNP37bNswYowxL730kunUqZMJCAgwgwcPNps3b7a6Sk3Wp59+aiRVekydOtUYUza994knnjARERHG6XSaUaNGmdTUVI/3+Oabb8y9995rWrVqZUJCQsyPfvQjk5ub61Hmyy+/NDfffLNxOp2mY8eOZs6cOY31FZuEqu6xJPP222+7yxQUFJif//znpnXr1qZFixbmu9/9rklPT/d4nyNHjphx48aZoKAg065dO/Pwww+b4uJijzKffvqp6d+/vwkICDDXXXedx2dc63784x+bzp07m4CAANO+fXszatQodxAxhnvsTRXDCPe6YUyaNMlERUWZgIAA07FjRzNp0iRz4MAB9/mmfp8dxhhz9e0rAAAA9WPLMSMAAKDpIIwAAABLEUYAAIClCCMAAMBShBEAAGApwggAALAUYQQAAFiKMAIAACxFGAEAAJYijAAAAEsRRgAAgKUIIwAAwFL/H9BryDAhCxj8AAAAAElFTkSuQmCC", 225 | "text/plain": [ 226 | "
" 227 | ] 228 | }, 229 | "metadata": {}, 230 | "output_type": "display_data" 231 | } 232 | ], 233 | "source": [ 234 | "\n", 235 | "samples,logZ = model.sample(key,N=len(X),num_steps=1500)\n", 236 | "\n", 237 | "plt.plot(jnp.exp(logZ.cumsum()/jnp.arange(1,len(logZ)+1)))\n", 238 | "plt.hlines([3],0,len(logZ),colors=['black'])\n", 239 | "print(jnp.exp(logZ.mean())) # should be close to 3\n" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 9, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "data": { 256 | "text/plain": [ 257 | "[]" 258 | ] 259 | }, 260 | "execution_count": 9, 261 | "metadata": {}, 262 | "output_type": "execute_result" 263 | }, 264 | { 265 | "data": { 266 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAFfCAYAAABuhCaHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA0SElEQVR4nO3df3xT9b0/8NdJm6RJaVIK0rSzrdWJBYaILZQq+oDZ7woiysRO9qjIkK8MLqBQFel3QAcXrSIqwpDOXQX8wTbZLlzFrV4sCnIplZaxywD54bDtHaTcCW1o0yZpc75/hByaNkDTnvTkJK/n45GHzTmnJ+/G8uonn/M5n48giqIIIiJSDY3SBRARUWAY3EREKsPgJiJSGQY3EZHKMLiJiFSGwU1EpDIMbiIilYlWuoCecLvdOHv2LOLi4iAIgtLlEBH1miiKuHTpEpKTk6HRXLtNrcrgPnv2LFJSUpQug4hIdnV1dbjxxhuveYwqgzsuLg6A5wc0mUwKV0NE1Hs2mw0pKSlSvl2LKoPb2z1iMpkY3EQUVrrT/cuLk0REKsPgJiJSmYCDe+/evZg8eTKSk5MhCAJ27NjR5Zjjx4/jwQcfhNlsRmxsLEaNGoXa2lppf2trK+bNm4cBAwagX79+mDp1Kurr63v1gxARRYqAg7u5uRkjRozAhg0b/O7/5ptvMHbsWGRkZOCLL77Af//3f2PZsmWIiYmRjlm0aBE+/vhjbNu2DXv27MHZs2fx8MMP9/ynICKKIEJv5uMWBAHbt2/HlClTpG3Tpk2DVqvFe++95/d7GhsbccMNN2Dr1q145JFHAABff/01hgwZgoqKCowZM+a6r2uz2WA2m9HY2MiLk0QUFgLJNVn7uN1uNz755BMMHjwYeXl5GDRoELKzs326U6qrq+FyuZCbmytty8jIQGpqKioqKvye1+FwwGaz+TyIiCKVrMF9/vx5NDU14aWXXsKECRPwn//5n/jxj3+Mhx9+GHv27AEAWK1W6HQ6xMfH+3xvYmIirFar3/OWlJTAbDZLD958Q0SRTPYWNwA89NBDWLRoEe644w4sWbIEDzzwAEpLS3t83qKiIjQ2NkqPuro6uUomIlIdWW/AGThwIKKjozF06FCf7UOGDMG+ffsAABaLBU6nEw0NDT6t7vr6elgsFr/n1ev10Ov1cpZKRKRasra4dTodRo0ahRMnTvhsP3nyJNLS0gAAmZmZ0Gq1KC8vl/afOHECtbW1yMnJkbMcClOiKMLubIPd2QaudU2RKOAWd1NTE06fPi09P3PmDA4fPoyEhASkpqbiueeew6OPPop7770X48ePR1lZGT7++GN88cUXAACz2YxZs2ahsLAQCQkJMJlMWLBgAXJycro1ooQimyiKeKS0AtU1FwEAWWn9sW1ODmeJpIgScHBXVVVh/Pjx0vPCwkIAwIwZM7B582b8+Mc/RmlpKUpKSvDUU0/htttuwx//+EeMHTtW+p7XX38dGo0GU6dOhcPhQF5eHt58800ZfhwKdy2udim0AaCq5iJaXO0w6lQ57Q5Rj/RqHLdSOI47ctmdbRi6/FOfbcdW5jG4SfUUG8dNRETBx+AmIlIZBjeFPlEEnM2e/xIRg5tCnCgC7+QBLyYD70xgeBOBwU2hzmUH6io9X9cd8DwninAMbiIilWFwExGpDIObiEhlGNxERCrD4CYiUhkGNxGRyjC4iYhUhsFNRKQyDG4iIpVhcBMRqQyDm4hIZTj7PKmYCAMcnHiKIg5b3KRSIv6gW4HjMU9A/94khjdFFAY3qZIBDmRpTgIAov6nkrMGUkRhcBMRqQyDm4hIZRjcREQqw+CmkCZ2uujIa5BEDG4KcS2u9ms+J4pEDG4iIpUJOLj37t2LyZMnIzk5GYIgYMeOHVc9ds6cORAEAWvXrvXZfuHCBRQUFMBkMiE+Ph6zZs1CU1NToKUQEUWkgIO7ubkZI0aMwIYNG6553Pbt23HgwAEkJyd32VdQUICjR49i165d2LlzJ/bu3YvZs2cHWgoRUUQK+Jb3iRMnYuLEidc85h//+AcWLFiATz/9FJMmTfLZd/z4cZSVleHgwYPIysoCAKxfvx73338/1qxZ4zfoiYjoCtn7uN1uN6ZPn47nnnsOw4YN67K/oqIC8fHxUmgDQG5uLjQaDSorK/2e0+FwwGaz+TyIiCKV7MH98ssvIzo6Gk899ZTf/VarFYMGDfLZFh0djYSEBFitVr/fU1JSArPZLD1SUlLkLpuISDVkDe7q6mq88cYb2Lx5MwRBkO28RUVFaGxslB51dXWynZuISG1kDe4vv/wS58+fR2pqKqKjoxEdHY2amho888wzuOmmmwAAFosF58+f9/m+trY2XLhwARaLxe959Xo9TCaTz4OIKFLJOh/39OnTkZub67MtLy8P06dPx8yZMwEAOTk5aGhoQHV1NTIzMwEAu3fvhtvtRnZ2tpzlkFqJome2P61R6UqIQlLAwd3U1ITTp09Lz8+cOYPDhw8jISEBqampGDBggM/xWq0WFosFt912GwBgyJAhmDBhAp588kmUlpbC5XJh/vz5mDZtGkeUkCe038kD6iqBlDHAo79XuiKikBNwV0lVVRVGjhyJkSNHAgAKCwsxcuRILF++vNvn+OCDD5CRkYH77rsP999/P8aOHYu33nor0FIoHLnsntAGgLoDgKtF2XqIQlDALe5x48Z1mfjnWr799tsu2xISErB169ZAX5qIiMC5SoiIVIfBTUSkMgxuIiKVYXATEakMg5uISGUY3BTauFYZURcMbgppMe8/oHQJRCGHwU0hTXPhG5/ngssOgK1wimwMblKVARuHYZtuBfzOPSmKgLOZ3SsU9hjcpDqjNCdhgMN3o3eOkxeTgXcmMLwprDG4STXGOtZefaerpdMcJ/Y+qYlICQxuUo0WUa90CUQhgcFNRKQyDG4iIpVhcFNo8DMi5Kg7Tfr6oHswWsCuEiJA5qXLiHqk46o3luHS5nxnMUQAXy4ej/zV+7uOJCGKUGxxk/I6rnpjPSJtFgG0IAbQxQKdRm4bBIY4RS4GN6nSPv1Cn+d2V5syhRApgMFNYeGelz9XugSiPsPgJtW65g05RGGMwU2qxRtyKFIxuImIVIbBTUSkMgxuCmlDk0wwaKOULoMopAQc3Hv37sXkyZORnJwMQRCwY8cOaZ/L5cLzzz+P4cOHIzY2FsnJyXj88cdx9uxZn3NcuHABBQUFMJlMiI+Px6xZs9DU1NTrH4bCz7Y5ORD8Tr5NFLkCDu7m5maMGDECGzZs6LLPbrfj0KFDWLZsGQ4dOoR///d/x4kTJ/Dggw/6HFdQUICjR49i165d2LlzJ/bu3YvZs2f3/KegsNUxtFugx0H3YACA63ujeQs8RayAb3mfOHEiJk6c6Hef2WzGrl27fLb96le/wujRo1FbW4vU1FQcP34cZWVlOHjwILKysgAA69evx/333481a9YgOTm5Bz8GRQYB+c5iGODAl9MmQXzhE6ULIlJE0Pu4GxsbIQgC4uPjAQAVFRWIj4+XQhsAcnNzodFoUFlZ6fccDocDNpvN50GRSvDcBs/+E4pgQQ3u1tZWPP/88/jpT38Kk8kEALBarRg0aJDPcdHR0UhISIDVavV7npKSEpjNZumRkpISzLIpxBi0UchK6y89z0rrzwuWFNGCNjugy+XCT37yE4iiiI0bN/bqXEVFRSgsLJSe22w2hncEEQQB2+bkoMXVDsAT5N6viSJRUILbG9o1NTXYvXu31NoGAIvFgvPnz/sc39bWhgsXLsBisfg9n16vh17PC1GRTBAEGHVX/3VNTTACXGaSIoTsXSXe0D516hQ+++wzDBgwwGd/Tk4OGhoaUF1dLW3bvXs33G43srOz5S6HIsRvZmRd/yCiMBFwi7upqQmnT5+Wnp85cwaHDx9GQkICkpKS8Mgjj+DQoUPYuXMn2tvbpX7rhIQE6HQ6DBkyBBMmTMCTTz6J0tJSuFwuzJ8/H9OmTeOIEvKRYen+zTe8VEmRJODgrqqqwvjx46Xn3r7nGTNm4Je//CU++ugjAMAdd9zh832ff/45xo0bBwD44IMPMH/+fNx3333QaDSYOnUq1q1b18MfgcLV+7NGQ+jm6BHBxX4SihwBB/e4ceMgdlgXsLNr7fNKSEjA1q1bA31pijCBjPhL+LdRwSuEKMRwrhJSligCTraWiQLB4CbleBcJXvN9WU7HhRUoUjC4STkdFwkOkEEbhTtT+/ts48IKFCkY3BQanv5rQIcLgoD3Z40OUjFEoY3BTaFBGxvwt3R3xAlRuGFwExGpDIObiEhlGNwUEkRcf/w/EXkwuCkk+JvtrzdTt3bnRjAitWJwU8gK9OLj4MQ46WtO+0rhjMFNYePfOEMgRQgGN6mX1gikjPF8nTIGgs4g7bI729ldQmEraCvgEAWdIABPlHnuwNQageYra5Hes/pzDEtLwrY5ORzvTWGHLW5SN0EAdLGAIPhczDTCgaqaC/iu2Qm7s42tbworbHGTcmQO044t6+qYuTjoHoysVQAgICutP1vfFDbY4iZliCKwacKV51oDDroHX3meMsbT/RGIjn3eAEZpTsIABwCgquYiR5pQ2GCLm5ThsgPWI56vLcMhRsci31kMAxyoXpoLY6wpsJUUAKnP295ghfGNDPlrJgoRbHGT4sSZf0b+rw8AENCCGKnPukcEIfCWOpHKMLhJcS0uN46d84wIGZrU/QWCu4M92hSOGNwUUuS+gLhNtwLgPCgUZhjcFFJkyWytEUfdaQCAYZoa6QIlUbhgcFP4EQTkO4uVroIoaBjcFJY6do4Y4QC7SyiccDgghT3vzTgQ85QuhUgWbHFTWGqB3ueGnlGak56x40RhIODg3rt3LyZPnozk5GQIgoAdO3b47BdFEcuXL0dSUhIMBgNyc3Nx6tQpn2MuXLiAgoICmEwmxMfHY9asWWhqaurVD0Lky9PPndm6UelCiGQXcHA3NzdjxIgR2LBhg9/9q1evxrp161BaWorKykrExsYiLy8Pra2t0jEFBQU4evQodu3ahZ07d2Lv3r2YPXt2z38KIr8E2KFXuggi2QXcxz1x4kRMnDjR7z5RFLF27VosXboUDz30EADg3XffRWJiInbs2IFp06bh+PHjKCsrw8GDB5GV5Zn4fv369bj//vuxZs0aJCcn9+LHITWyOzmHCFEgZO3jPnPmDKxWK3Jzc6VtZrMZ2dnZqKioAABUVFQgPj5eCm0AyM3NhUajQWVlpd/zOhwO2Gw2nweFj3tWf650CUSqImtwW61WAEBiYqLP9sTERGmf1WrFoEGDfPZHR0cjISFBOqazkpISmM1m6ZGSkiJn2RQistL6y3q7O1G4UsWokqKiIjQ2NkqPuro6pUsimVUtzeV82UTdJGtwWywWAEB9fb3P9vr6emmfxWLB+fPnffa3tbXhwoUL0jGd6fV6mEwmnweFF6MuiqFN1E2yBnd6ejosFgvKy8ulbTabDZWVlcjJyQEA5OTkoKGhAdXV1dIxu3fvhtvtRnZ2tpzlEBGFpYBHlTQ1NeH06dPS8zNnzuDw4cNISEhAamoqFi5ciFWrVuHWW29Feno6li1bhuTkZEyZMgUAMGTIEEyYMAFPPvkkSktL4XK5MH/+fEybNo0jSiKIKIqccpWohwIO7qqqKowfP156XlhYCACYMWMGNm/ejMWLF6O5uRmzZ89GQ0MDxo4di7KyMsTExEjf88EHH2D+/Pm47777oNFoMHXqVKxbt06GH4fUosXVDu9yBxkWeefgJgp3gqjC5a9tNhvMZjMaGxvZ361S9qZGGNekAgCan6lFbJxZvnM72zB0+acAAANacTzmCc/2Z2th7Cff6xDJKZBcU8WoEgpvvCZJFBgGNxGRyjC4iYhUhsFNRKQyDG4iIpVhcBMRqQyDm4hIZRjcpAj13T1AFDoY3NTnRFHEY29/pXQZRKrF4KY+1+Jqx9fWK4th8HZ3osAwuElxnM6VKDAMbiIilWFwExGpDIObiEhlGNxERCrD4CYiUhkGNxGRyjC4iYhUhsFNRKQyDG4iIpVhcBMRqQyDmxTBm9yJeo7BTX1PFLFNt0LpKohUi8FNfc9lxzBNDQDAnTgc0BoVLohIXRjcpKjW6TsBzg5IFBDZg7u9vR3Lli1Deno6DAYDbrnlFvzrv/4rxA5LnoiiiOXLlyMpKQkGgwG5ubk4deqU3KWQKsgf2gZtFLLS+gMAMiwm2c9PpLRouU/48ssvY+PGjdiyZQuGDRuGqqoqzJw5E2azGU899RQAYPXq1Vi3bh22bNmC9PR0LFu2DHl5eTh27BhiYmLkLokijCAI2DYnBy2udoiOZuBVpSsikpfswb1//3489NBDmDRpEgDgpptuwm9/+1t89ZVnqSpRFLF27VosXboUDz30EADg3XffRWJiInbs2IFp06bJXRJFIEEQYNRFw+5UupLAiaKIFlc7DNooLjJBfsneVXLXXXehvLwcJ0+eBAD89a9/xb59+zBx4kQAwJkzZ2C1WpGbmyt9j9lsRnZ2NioqKvye0+FwwGaz+TyIwpEoiniktAJDl3+K/NIKny5GIi/ZW9xLliyBzWZDRkYGoqKi0N7ejhdeeAEFBQUAAKvVCgBITEz0+b7ExERpX2clJSVYsYLDxyj8tbjaUV1zEQBQVXMRLa52GHWy/zMllZO9xf3hhx/igw8+wNatW3Ho0CFs2bIFa9aswZYtW3p8zqKiIjQ2NkqPuro6GSsmIlIX2f+UP/fcc1iyZInUVz18+HDU1NSgpKQEM2bMgMViAQDU19cjKSlJ+r76+nrccccdfs+p1+uh1+vlLpWISJVkb3Hb7XZoNL6njYqKgtvtBgCkp6fDYrGgvLxc2m+z2VBZWYmcnBy5yyGSqLG/2O5sh93ZpsraKXhkb3FPnjwZL7zwAlJTUzFs2DD85S9/wWuvvYYnnngCgOdq/8KFC7Fq1Srceuut0nDA5ORkTJkyRe5yiCTWteNx8y+qIWjUc99Z1qrPPP9N649tc3I4yoQABCG4169fj2XLluFf/uVfcP78eSQnJ+PnP/85li9fLh2zePFiNDc3Y/bs2WhoaMDYsWNRVlbGMdxhzjvMrS8bjwZjHL6Juhm3tP8dt7T/HXb7JRj7mfuugACIogi7s93vPl6opI4EUYWfwWw2G8xmMxobG2Ey8c44NfAOc6uuuYiRFh22NzwCALA/Wxv0IG2+1IDYV9P67PV6ouP741W11DNk1tvqPrYyj8EdxgLJNfV8ZiRV6zjM7Wtr347DV0P3Qsf3B/B0jQyI1cGoi1KwKgpV/PNNFGKqluZiQKxOFX9wSBlscROFGKOOt7rTtTG4qY+JMMKh6OsTqR2Dm/qQiD/oVqA6Zq5iFcS89wD6dFgLURAwuKnPGOFAluak9Pyge3DfrH6jNeKo2zOqRFN/BHDZg/+agRJFGNAKfiKg7uDFSeobndaZzGzdiO9gwrG+6MsVBOQ7i3Es5ongv1ZPiCL0792P4zFfef6YiXlKV0Qhji1u6hsd1pk86k7DdzChL9d692nHOu2h1V3isiPqfzzz1Y/SnAzNTwQUUhjc1OfyncXoy9DuYs33gXcmKBLenrsj2zj/CPUKu0qoz/V1XBm0UUi33ICDFwZ7WrQAUHfA07LVxfZZHZ3vjpTmH+mzCihcsMVNYU8QBGybexfyncXIbN2oWB2d7470zj9CFCgGN0UEzzVQAXZwXndSP3aVkGKy0vrDoI3kuThEwNkM+HsPRNHTleNs8xzHDhXqgMFNfe7LxeNh7GeK8FXMPTcjGdecBCzDu+zDO3lAXSWMALbpBl++oEvkweCmPmfURUX89KSGjjcjWY/47nS1AHWV0tNRmpMwKDpNAIUa9nETEakMg5so1HB8N10Hg5uC7lpLckWS7uZxzPsPBLcQUr3I7mikoPPedHKs5hyOR/CSoqIoIr+0olvHai584/ki4RbA+zVRB2xxU1B1vukEQEQOAWxxtePYOc+SbekDu3m35hNlQayI1IzBTX0ucocAenz485xuHhnZ7xNdHYObgorX2bqK8L9bJAMGNwVNIP26RNR9DG4Kmo79uhkWk8LVqEzKGEBrULoKClFBCe5//OMfeOyxxzBgwAAYDAYMHz4cVVVV0n5RFLF8+XIkJSXBYDAgNzcXp06dCkYppJDOQwDfnzVawWrUI7N1I+zP1nouTHboUzHCwX4nksge3BcvXsTdd98NrVaLP//5zzh27BheffVV9O/fXzpm9erVWLduHUpLS1FZWYnY2Fjk5eWhtbVV7nJIAd4hgFmrPpO2sV+3e+zQe+YI7/SGVcfMhf69SQxvAhCEcdwvv/wyUlJSsGnTJmlbenq69LUoili7di2WLl2Khx56CADw7rvvIjExETt27MC0adPkLon6WOchgJwF0Ev0zDnSw/CN+p/KPl/8gUKT7C3ujz76CFlZWcjPz8egQYMwcuRI/OY3v5H2nzlzBlarFbm5udI2s9mM7OxsVFT4v5DlcDhgs9l8HqQOVUtzPau8RHqTW/TMBng85gmYfzeZA/2oV2QP7r///e/YuHEjbr31Vnz66aeYO3cunnrqKWzZsgUAYLVaAQCJiYk+35eYmCjt66ykpARms1l6pKSkyF02BYlRF8lTt3bgskuzAWr/8RVn+6NekT243W437rzzTrz44osYOXIkZs+ejSeffBKlpaU9PmdRUREaGxulR11dnYwVExGpi+zBnZSUhKFDh/psGzJkCGprawEAFosFAFBfX+9zTH19vbSvM71eD5PJ5POgUCfCgFbPCi+RfEFNvLzKTU+WSNYa0X5jdqfTiVwhnuQP7rvvvhsnTpzw2Xby5EmkpaUB8FyotFgsKC8vl/bbbDZUVlYiJ6e7twJTSOvQn2tckwq8MyEyw1u8vJLNi8mIea8HM/4JAhzTP/FZ4Ljg7a8wdPmnyC+tYHhHMNmDe9GiRThw4ABefPFFnD59Glu3bsVbb72FefPmAfDMU7Fw4UKsWrUKH330EY4cOYLHH38cycnJmDJlitzlkBI69OcCAOoOeEZDhCJvizgYIeiySyvZaOqPXOdg/wy6aKRabpCe/6XWM1qHK8RHNtmHA44aNQrbt29HUVERVq5cifT0dKxduxYFBQXSMYsXL0ZzczNmz56NhoYGjB07FmVlZYiJieB5P8PdOxOUrqALt9sNzeW1HZEypstNL6FAEATPzUuvep4bL1/UbOFq9REtKPNxP/DAA3jggat/NBQEAStXrsTKlSuD8fIUQtwJt3jml1Z4XmmDNgpZaf1xtOactK1g4xf4bePltR29nwpCcIx0x78l1TFzAQAH3YMBMU+hikhpnKuEgqr1sZ1KlwDA01jYNicHVb+4cv/AyfpLitVjEAIYDqg1eoK6g1Gak6Hb/URBxxVwKLhCqOtBEATE6kPjV36ffmH3DxYE5DuLYYADRjikVjdFrtD4LSai6xDQAl4DIg92lRD1sQv/96DSJZDKMbiJ+pioNfrdftA9mKNFqFvYVUKksO/mHsXY1ysvh3boXBOg0MXgJlKYqDWy/5oCwq4SIiKVYXATXca5P0gtGNxEl8k594d3Fj9/Yoz9cFzrmUHzuHYoDMY42V6XIgP7uIlk5l1z81jNORz303UtCAIyiv4LdvslZBjj0NLm7vsiSdXY4iaSWec1N/0RNBoY+5khaHr5TzCYsxtSyGJwE6nWlfm+I3bO8wjF4Ca6TBTRJ6vLdF7x3jtzIQBkpfW/5n4frhZpvu+QnvOcZMc+bopYnWfoK3j7Kxy2OpGV1j+oK9N3Pq935sIWVzsM2q6LK3fcb2+yAes821uc7fB/DyaFO7a4KWJ1nqHvhNUGQJnVZQRBgFEXfdU/Ft79Rt2V1viPXt/bV+VRiGFwE6lIx26ULnN6O+3s544QDG6KaGMda5UuISAdW+Rd5vRe831epIwQDG6KaC2iPLPxeW+48VzclOWU/mmNnvUxr4YXKSMCL04S9ZL3hhvv2O2hSabgvZggAE+Uwd5ghfGNjCvbn/4r8MaI4L0uhRS2uCm4tAbfFmLKGE+rUbF6jGi/MRuAfPNfd77h5tg529UPluPnF4Su59CG3iLHFDxscVOQeVqI0sd3rVHZdSgFAY7pnyCz+CO0QA8DAli0txcyWzfiy6WTYIw1hdQ6nKRObHGT7Lr08QoCoIv1PEIhtATv+o3BqkWEAa0wdvijYIde3p+/w8rv7Tdmez7ZUMRgi5tkIYoiWlztEEXPjSw7lC6oB4xwyLAKjYg/6FYgS3NSrrL867Dye/X0B2EUnMF9PQopDG7qtc4X5wxohXdBl863b4cC7y3kVZ0mgqqOmYuD7sHIdxb3/NxwdAnt4K0lefmTQyh8iqE+FfSukpdeegmCIGDhwoXSttbWVsybNw8DBgxAv379MHXqVNTX1we7FAqSa82GF6zbxnvDewv5sZV5+PIXk6QuBwAYpTkZcL/31Yb/ZbZuxJDWdy7/IQje+2B3tvvOr+K0c8bAMBfU4D548CB+/etf4/bbb/fZvmjRInz88cfYtm0b9uzZg7Nnz+Lhhx8OZilEPqRbyPXRyHcWI7N1Y4/OI4oi8ksr/O6zQx/kvnSPrFWf4bG3v7qyYc33OWNgmAtacDc1NaGgoAC/+c1v0L//lZnNGhsb8fbbb+O1117DD3/4Q2RmZmLTpk3Yv38/Dhw4EKxyqA+FXhv7egTPxcMeaHG1S8P/0gcqNyTvUK2fTzy8GSdsBS24582bh0mTJiE3N9dne3V1NVwul8/2jIwMpKamoqLCf8vF4XDAZrP5PChUidimW6F0EYr48Oc5SpdAESIowf273/0Ohw4dQklJSZd9VqsVOp0O8fHxPtsTExNhtVr9nq+kpARms1l6pKSkBKNsksFIix7DNDUAANEyXNmbbfpYX3bnX3WObooIsgd3XV0dnn76aXzwwQeIifGz4F4PFBUVobGxUXrU1dXJcl6S3/uzRktfCzPLOOIhSLwXWKuW5l7/YAo7sgd3dXU1zp8/jzvvvBPR0dGIjo7Gnj17sG7dOkRHRyMxMRFOpxMNDQ0+31dfXw+LxeL3nHq9HiaTyedBocknp1Ua2ler2juRVCAr5HjnLfG3sk1veS6wht5wSwo+2cdx33fffThy5IjPtpkzZyIjIwPPP/88UlJSoNVqUV5ejqlTpwIATpw4gdraWuTksI+QlLdNtwIQH/TZ1nGseiAr5HiOg9+VbYh6SvbgjouLww9+8AOfbbGxsRgwYIC0fdasWSgsLERCQgJMJhMWLFiAnJwcjBlzjekqiYKoBXocdadhmKYGwzQ1sLvsgN58ZX+HsereFXKMuuv/8xEEdOs4okAo8hv1+uuvQ6PRYOrUqXA4HMjLy8Obb76pRClEl3luIT8W80T3v0UUAWczDGgN0p2RRP71SXB/8cUXPs9jYmKwYcMGbNiwoS9enqhbArpVRRSBd/JgrKvE8RjPbe0QeaGQ+gZnByTqCZcdqKuUno7SnOTNLtRnGNxERCrD4CYiUhkGN1E44yRTYYnBTdQDoZaHLdBL09Medadd2bGJMwSGIw4wJboG78o+dme7z/aLdhc6zwVo0EYpODPilRVxWqHDN2kvQ1N/BLAe8Vw01XEx4XDC4Ca6is4r+3T0f17fg+OdpuIRBKVnRvSupQm0Tt8J45q06xxPasWuEiJ/XHa0NNtQXXMhgO9pkWZGPOpOQ7rlBgWXbrvS9g90fhUKfQxuIj+Mb2TAuCb1cgvaN/S+XDz+ut+f7yzGtrl3KTY/Sceczlz1GfJLKxjeYYTBTRHtevNa+1uD0tCNGflEKDs5YovLt0/eO78KhQf2cVNE885r3eJqh73JBqxTuiKi62OLm2QiwoBWVd72LS0crMK5rbkSTmRii5t6TxTxB90KZGlOAm8oXYxyBAX+aPl8YnC2I2vVZwACnDCLVIctbuo9l90T2h2ljFHdepO9HQFiePNOmSoJjL9PDI+//ZUitVDfYIubZGV/+msYY02e0FbZii9yjQA56B6s+Pzc337XDMiz5CuFIAY3yUtrjOi79DJbN+I7mHD1lSuJeo9dJUTX4S+CrxbLduivsZdIHgxuouvochOOqPSt7RTpGNzUc5fXXFTjEMDu+LvbAgAYpqnxvQnHZZdubT/mTkP7jaMBAGLKGAxL9XxPVlp/BW9392WEAxxnEl7Yx00943YDb90LWI9AXWNHuu8nzmJUxcy95jGPOItRNf1BGAUnBK0R2+C5a9GgjVLsdvfOqmPmXl4TM0/pUkgmbHFT4ERRCu2ODroHq24I4LV0bKPemeq5yaVzS1oEPKNndLGAIEhD85QO7Y7zcwNcEzPcsMVNgXPZr4R2wi1onrkbWS+UowV6HAuRVqbc3n/8B2jRXg5t+z+l7Xemhk6XiPcuyqqaiwAEvGxZiy3TbkbsugylSyOZMbipV8Sf70F+6WFpHuhwJay5FcaUbM+TDqu7vz9rtOKta6+Od1ECniBvabYpXBUFA7tKqFdaXG4cO+cJh6FJppBpffaI1oj2Gz3hfNA9GBcQ59PdgLpKn9BGyhgIITZm/cpdlMp311DwsMVNstk2J0fdYSEIcEz/BJnFH12+89GzHNix5zNhfKNTd8Ozp4HYgaq7O5TCg+wt7pKSEowaNQpxcXEYNGgQpkyZghMnTvgc09rainnz5mHAgAHo168fpk6divr6erlLoT7QcS3GsMgwwbv8l/eHEfxfcNWp75Z+Ch+yB/eePXswb948HDhwALt27YLL5cKPfvQjNDc3S8csWrQIH3/8MbZt24Y9e/bg7NmzePjhh+UuhfrAPas/V7oEoogje1dJWVmZz/PNmzdj0KBBqK6uxr333ovGxka8/fbb2Lp1K374wx8CADZt2oQhQ4bgwIEDGDNmjNwlUR8IpRtOgs4yXNXDHr0r14fSWHMKTND7uBsbGwEACQkJAIDq6mq4XC7k5uZKx2RkZCA1NRUVFRV+g9vhcMDhuHLnms3GK+WhpGppLgbE6iInBGaWqbKbRBR9V67PSuuv/usSESqoo0rcbjcWLlyIu+++Gz/4wQ8AAFarFTqdDvHx8T7HJiYmwmq1+j1PSUkJzGaz9EhJSQlm2RQgoy58Wm6dV5Tx+0lCpT/rY29/BbuzHdU1FwFwHUo1C2qLe968efjb3/6Gffv29eo8RUVFKCwslJ7bbDaGNwWFv7HQSqxsI5eOf3ROWG240OxUsBqSS9CCe/78+di5cyf27t2LG2+8UdpusVjgdDrR0NDg0+qur6+HxWLxey69Xg+9XtmJ6SlyeMdCh4OOn4S26VbgntWcdjYcyN5VIooi5s+fj+3bt2P37t1IT0/32Z+ZmQmtVovy8nJp24kTJ1BbW4ucnBy5yyHqPa3RsxQboL4l2bRGuBOHA/DMcjgANs+izpwtUNVkb1bMmzcPW7duxX/8x38gLi5O6rc2m80wGAwwm82YNWsWCgsLkZCQAJPJhAULFiAnJ4cjSlRCFMXIarMJAvBEmWeOFrUtySYIaJ2+E8Y1aQA8MwUCnjtD853FSlZGvSB7cG/cuBEAMG7cOJ/tmzZtws9+9jMAwOuvvw6NRoOpU6fC4XAgLy8Pb775ptylUJC0uNqlqVwzLCq/zb27vDMAqpDBaMJx7VAMcR2Tto3SnLw8TzepkSCKouo+M9lsNpjNZjQ2NsJkMildTsSxNzXCuCYVAND8TC1i48wKV0TXI7rdaLFfggEOCGtuBQAcdach/RfVMOq1CldHQGC5xkmmKCCiKIbfbe4RQNBoYOxnhhB7g0+fN+foVicGN3WbKIp4ZON+5K0uu/7BFJou93mTuoXHmCfqEy3ONhRZFyIr5qS0LSL6t8MOPyapHVvc1H2uZmRproS2GILzUVNgRKcddocLKrzUFdEY3NQ9ooiY9x6Qntqf/hrCE+qcs4OuiF2XgaMv3IX8jfsZ3irC4KbucdmhqfesM3nUnQYYuYiAammNXRYSPlpr5bwlKsLgpoDlO4sZ2momeFb2yWzdqHQl1EMMbgoYP1CHAwF2cP4ftWJwU7eIolvpEojoMgY3XZ8oQnxnovQ0Ym5zD1Od5xwn9eE4brq+Thcm358zLmwWTohE0pzjzTZgjdLVUE+wxU0ByXcWQ9AwtNUunOYcj0QMbro2UQScV+az4IXJ8GSEA/YmG0Q3r2WoAf/k0tWJIvBOHlBXqXQlFGTVMXOBdcBx7VBkFP0XBA3bdKGM/3fo6lx2n9A+6B6MFg4hC2tDXMdgb76kdBl0HQxu6hb7019fXjGF/dvh7rG3v+Lt7yGOwU3dYhe5yGyk+Npq85lznUIPg5t8iSJERxPsDheaHW3S5ntWf65gUdSXBAD5pRVsdYcwXpykK0QR4jt5EOoqcdQ9GDOcS3AsxveQrLT+vPkmzG3TrcCkcy961hblkMGQxP8rkU4Ur6xe7rJDuHwxcpTmJBIEm3TYl4vHw9jPc8ckb74JUwm3ABe+wTBNDQx+FhIWRVGaQZC/B8picEeyjsP9UsYAj/3RZ/c+/ULpa6Muiq2vcPdEGXB5IeHORFHEI6UVqK65CMDzyWvbnByGt0LYxx3JOg73qzsAu53DwCJb1xD2LA7dhu+anVJoA0BVzUXO360gNqFIcs/Ln6P6cp/2WMdanxY3RR5RhE8rm0IHW9yRosNoEbuzTXp01LG9dUE04WjUUACeu+kMxrg+LJb6hNbo6SIDPP/VGqRdRjhwsaEBx2rOwYBWGNAKTngQOtjiDlPeC0kGbRQEUYT41r0QrEdw1D0YP3EuRwycMMIhtbAB4EPdiivfD+Cm5/bC3mZHhjGOt0CHI0Hw9Gt3uDjtVR0zF9gIHO/w+3HUnYZ8Z7EU36JbvPxfN1rslzx/3AXhyu8d+7+DRtHg3rBhA1555RVYrVaMGDEC69evx+jRo5UsKSx0vJCUlRqPbZrnIVg907KO0pzEJ7pfYKimpsv33ayxSl/fmdofRn00hBhzn9VNChAEQBfbrUOHaWpwLOYJ6fnRV4ZiyP/bh5MvjUWG6xiOa4diaf81qK5t4MXLIFMsuH//+9+jsLAQpaWlyM7Oxtq1a5GXl4cTJ05g0KBBsr+et1UQCezO9ssfcYEztTUQYo747O8c2sfcaV22vT9rNP/RRRpv10ndAbQnDke+YzmOWy8hMyUe70UVS3/8vYa1H8PZ2lPIcB0D4Jnn5NvaGhigx9Gac/ju4kUYdZE75t8QxE+qgqjQ7VHZ2dkYNWoUfvWrXwEA3G43UlJSsGDBAixZssTnWIfDAYfjyrhSm82GlJQUNDY2wmQydev17E2NMK5Jle8HCAfPnoaoNaDF1QbjmjTfff/vbLdbYhRGOozrF4Er3R6A1JXS3GRD7LoMJatUBfuztTD26/4nVpvNBrPZ3K1cU6Tj0ul0orq6Grm5uVcK0WiQm5uLioqKLseXlJTAbDZLj5SUlL4sNyxUuQejyj34yoaUMUDsQAj6fjDGmq9cpPLu0xr7vkhSnrfrRBCkxRYEQbiyXRcLY3wijmuHKl1pRFOkxX327Fl873vfw/79+5GTkyNtX7x4Mfbs2YPKSt/5n+VocUdSV4lXTHQUWtsuj7W9HMQGODz/ELVGzz9GL29Ly3ssu0noGjr/e4ox9ENrS5Pn646/dxEs0K6SQFrcqhhVotfrodf3bh5oQaMJ6GNLuOjabtb6PzCAi1RE/v49dXzOz2vBpUhXycCBAxEVFYX6+nqf7fX19bBYLEqURESkGooEt06nQ2ZmJsrLy6Vtbrcb5eXlPl0nRETUlWJdJYWFhZgxYwaysrIwevRorF27Fs3NzZg5c6ZSJRERqYJiwf3oo4/if//3f7F8+XJYrVbccccdKCsrQ2JiolIlERGpgmLjuHsjkKuvRERqEPLjuImIqOcY3EREKsPgJiJSGQY3EZHKMLiJiFRGFbe8d+YdCGOz2a5zJBGROnjzrDsD/VQZ3JcueSa34SyBRBRuLl26BLP52vMqqXIct9vtxtmzZxEXFxfQZP/eWQXr6uo4/rsTvjdXx/fm6vjeXF2g740oirh06RKSk5Ohuc6sgqpscWs0Gtx44409/n6TycRfsqvge3N1fG+uju/N1QXy3lyvpe3Fi5NERCrD4CYiUpmICm69Xo/i4uJeL8oQjvjeXB3fm6vje3N1wXxvVHlxkogokkVUi5uIKBwwuImIVIbBTUSkMgxuIiKVYXATEalMxAT3Cy+8gLvuugtGoxHx8fF+j6mtrcWkSZNgNBoxaNAgPPfcc2hra+vbQkPATTfdBEEQfB4vvfSS0mUpYsOGDbjpppsQExOD7OxsfPXVV0qXpLhf/vKXXX4/MjIylC5LEXv37sXkyZORnJwMQRCwY8cOn/2iKGL58uVISkqCwWBAbm4uTp061evXjZjgdjqdyM/Px9y5c/3ub29vx6RJk+B0OrF//35s2bIFmzdvxvLly/u40tCwcuVKnDt3TnosWLBA6ZL63O9//3sUFhaiuLgYhw4dwogRI5CXl4fz588rXZrihg0b5vP7sW/fPqVLUkRzczNGjBiBDRs2+N2/evVqrFu3DqWlpaisrERsbCzy8vLQ2trauxcWI8ymTZtEs9ncZfuf/vQnUaPRiFarVdq2ceNG0WQyiQ6How8rVF5aWpr4+uuvK12G4kaPHi3OmzdPet7e3i4mJyeLJSUlClalvOLiYnHEiBFKlxFyAIjbt2+XnrvdbtFisYivvPKKtK2hoUHU6/Xib3/72169VsS0uK+noqICw4cPR2JiorQtLy8PNpsNR48eVbAyZbz00ksYMGAARo4ciVdeeSXiuoycTieqq6uRm5srbdNoNMjNzUVFRYWClYWGU6dOITk5GTfffDMKCgpQW1urdEkh58yZM7BarT6/Q2azGdnZ2b3+HVLl7IDBYLVafUIbgPTcarUqUZJinnrqKdx5551ISEjA/v37UVRUhHPnzuG1115TurQ+889//hPt7e1+fye+/vprhaoKDdnZ2di8eTNuu+02nDt3DitWrMA999yDv/3tb4iLi1O6vJDhzQ1/v0O9zRRVt7iXLFnS5SJJ50ek/yPzCuS9KiwsxLhx43D77bdjzpw5ePXVV7F+/Xo4HA6FfwoKBRMnTkR+fj5uv/125OXl4U9/+hMaGhrw4YcfKl1axFB1i/uZZ57Bz372s2sec/PNN3frXBaLpcuIgfr6emmf2vXmvcrOzkZbWxu+/fZb3HbbbUGoLvQMHDgQUVFR0u+AV319fVj8PsgpPj4egwcPxunTp5UuJaR4f0/q6+uRlJQkba+vr8cdd9zRq3OrOrhvuOEG3HDDDbKcKycnBy+88ALOnz+PQYMGAQB27doFk8mEoUOHyvIaSurNe3X48GFoNBrpfYkEOp0OmZmZKC8vx5QpUwB4Vl4qLy/H/PnzlS0uxDQ1NeGbb77B9OnTlS4lpKSnp8NisaC8vFwKapvNhsrKyquObusuVQd3IGpra3HhwgXU1taivb0dhw8fBgB8//vfR79+/fCjH/0IQ4cOxfTp07F69WpYrVYsXboU8+bNi6gpKysqKlBZWYnx48cjLi4OFRUVWLRoER577DH0799f6fL6VGFhIWbMmIGsrCyMHj0aa9euRXNzM2bOnKl0aYp69tlnMXnyZKSlpeHs2bMoLi5GVFQUfvrTnypdWp9ramry+aRx5swZHD58GAkJCUhNTcXChQuxatUq3HrrrUhPT8eyZcuQnJwsNQZ6rFdjUlRkxowZIoAuj88//1w65ttvvxUnTpwoGgwGceDAgeIzzzwjulwu5YpWQHV1tZidnS2azWYxJiZGHDJkiPjiiy+Kra2tSpemiPXr14upqamiTqcTR48eLR44cEDpkhT36KOPiklJSaJOpxO/973viY8++qh4+vRppctSxOeff+43V2bMmCGKomdI4LJly8TExERRr9eL9913n3jixIlevy7n4yYiUhlVjyohIopEDG4iIpVhcBMRqQyDm4hIZRjcREQqw+AmIlIZBjcRkcowuImIVIbBTUSkMgxuIiKVYXATEanM/wdGKiIdaJaysAAAAABJRU5ErkJggg==", 267 | "text/plain": [ 268 | "
" 269 | ] 270 | }, 271 | "metadata": {}, 272 | "output_type": "display_data" 273 | } 274 | ], 275 | "source": [ 276 | "plt.figure(figsize=(4,4))\n", 277 | "plt.step(bins[:-1],jnp.histogram(X,bins)[0])\n", 278 | "plt.step(bins[:-1],jnp.histogram(samples,bins)[0])" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "### Learnt Energy interpolation" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 10, 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "data": { 295 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABQsAAADFCAYAAADzGaSgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB14klEQVR4nO3deXxU9b0//teZmcxM9oXsIRAIS9iDICkqim2Uay3WLja1XrH5We+tSutt2vut1ArVqsFKrb2Wyi1KsVYLra3WqxS1KahIBAXCGgIEsmeykGWyzmRmzu+PmTNJyDYzOTNnltfz8cijZXJmzmfAz+ec8/68P++PIIqiCCIiIiIiIiIiIgp5KqUbQERERERERERERP6BwUIiIiIiIiIiIiICwGAhEREREREREREROTBYSERERERERERERAAYLCQiIiIiIiIiIiIHBguJiIiIiIiIiIgIAIOFRERERERERERE5KBRugGusNlsaGhoQHR0NARBULo5REFBFEV0dXUhPT0dKpX/zhuw/xPJL1D6P8AxgMgbAmUMYP8nkl+g9H+AYwCRN7g6BgREsLChoQGZmZlKN4MoKNXW1mLq1KlKN2NM7P9E3uPv/R/gGEDkTf4+BrD/E3mPv/d/gGMAkTdNNAYERLAwOjoagP3LxMTEKNwaouBgNBqRmZnp7F/+iv2fSH6B0v8BjgFE3hAoYwD7P5H8AqX/AxwDiLzB1TEgIIKFUspxTEwMBwkimfl7Sj/7P5H3+Hv/BzgGEHmTv48B7P9E3uPv/R/gGEDkTRONAR4VKdi6dSuysrKg1+uRl5eHw4cPu/S+Xbt2QRAE3H777Z6cloiIiIiIiIgU4G4c4LnnnsPcuXMRHh6OzMxM/OAHP0B/f7+PWktEk+F2sHD37t0oKirCpk2bcPToUSxZsgRr1qxBc3PzuO+rqqrCj370I6xatcrjxhIRERERERGRb7kbB3jttdfw8MMPY9OmTSgvL8dLL72E3bt34yc/+YmPW05EnnA7WPjss8/ivvvuQ2FhIebPn49t27YhIiICO3bsGPM9VqsVd911Fx577DHMnDlzUg0mIiIiIiIiIt9xNw5w8OBBXHvttfjWt76FrKws3HzzzbjzzjtdXpVIRMpyK1hoNptx5MgR5OfnD36ASoX8/HyUlpaO+b7HH38cycnJuPfee106j8lkgtFoHPZD/u/d0wbcs+Mwni85D6tNVLo55EUsRRDazBYbnnn3LNbtOIw3jtUp3Rwiktnfy+qxbsdhbN13gddzoiDWbOzHj/5yHN/70zFUtfYo3RzyY57EAa655hocOXLE+Zxw8eJF7NmzB1/84hfHPA/jAL5hs4n43w8qcc+Ow/i/4w1KN4f8lFsbnLS2tsJqtSIlJWXY6ykpKTh79uyo7zlw4ABeeukllJWVuXye4uJiPPbYY+40jRT28YVW/OcrRwAAH5xrgVUU8V/5cxRuFXmDtARh27ZtyMvLw3PPPYc1a9agoqICycnJY76PpQiCx8N/PYG/HasHAHx4rgWiCHz1qqkKt4qI5PD2iQY8tKsMgL1/mwasKLp5rrKNIiLZWW0i7vvDZzhe1wkAKKttx/s/uAH6MLXCLSN/5Ekc4Fvf+hZaW1tx3XXXQRRFWCwWfPe73x13GTLjAL7xwgeVeObdCgD2Z/fY8DBcPydJ4VaRv/FogxNXdXV14e6778b27duRmJjo8vs2bNiAzs5O509tba0XW0mTJYoinninHAAwNT4cAPDC/kq0dpuUbBZ5CUsRhLZDFy/jb8fqoRKAz+fYg8OPv30GnX0DCreMiCZrwGrDE2/br+cLM+y7Tv52fyUaOvqUbBYRecHbJxpwvK4T+jAVYvQa1Lb14Q+lVUo3i4LI/v378dRTT+G3v/0tjh49ir/97W9455138POf/3zM9zAO4H1tPWb85l8XAAw+uz/5TjlEkSsJaDi3goWJiYlQq9Voamoa9npTUxNSU1NHHF9ZWYmqqiqsXbsWGo0GGo0Gf/jDH/DWW29Bo9GgsrJy1PPodDrn9ujcJt3/nao3orzRCJ1Ghbe/dx0WT42FyWLDX49weWKw8UUpAi4/8G+//7gKAFBw9TT87u5lyE6KREfvAF5nfycKeP862wyDsR+JUTq8/t1r8LmZCbDYRLzMAAJR0PnzZ/YgzH9en42ffHEeAOBPh2sZMKBRuRsHAIBHH30Ud999N77zne9g0aJF+MpXvoKnnnoKxcXFsNlso76HcQDv+7/jDegbsGJ+Wgze+f4qhIepUdHUhU+r2pVuGvkZt4KFWq0Wy5YtQ0lJifM1m82GkpISrFy5csTxOTk5OHnyJMrKypw/t912G2688UaUlZUhMzNz8t+AFPc3R82ymxekIi5Ci4Kr7f+ub59oVLJZ5AXjLUEwGAyjvkcqRbB9+3aXzlFcXIzY2FjnD8cJ/9HSZcL75fabxG9fkwWNWoXCa2cAAF79pJoPGEQBTprk++pVGdCHqXHPyiwAwNvHG9m/iYJIk7EfH1+4DEEA7lg+FWuXpEMfpsKl1h6UN3Yp3TzyQ+7GAQCgt7cXKtXwcINabV/mzmuKcqRSQncsn4rY8DDcujgNALDnJJ/daTi3lyEXFRVh+/btePnll1FeXo77778fPT09KCwsBACsW7cOGzZsAADo9XosXLhw2E9cXByio6OxcOFCaLVaeb8NKWJ/RQsA4EuOgeamefZA0sn6TrR0cSlyKPOkFAGXH/ivfWebYbWJWDw1FnNTowEAty/NgE6jwkU+YBAFNLPFhgMXWgEAty1JBwDcmJOMSK0a9R19OFrToWDriEhOH5239/XFGbGYGh+BSJ0G12Tb79P2n2tWsmnkx9yJAwDA2rVr8cILL2DXrl24dOkS3n//fTz66KNYu3atM2hIvtXabcLx2g4AcAYJb5pvf3b/19lmBnFpGLc2OAGAgoICtLS0YOPGjTAYDMjNzcXevXudmUY1NTUjZhAoeNW19+JSaw/UKgErs6cAAJJj9FiQHoPTDUZ8eK4FX1vGjQ+CxWRKEUikZQcajQYVFRXIzs4e9h6dTgedTueF1tNk/eus/QFCqlUIAFE6Da6fk4T3zzTh3dMGzE/nchGiQHS0ph29ZisSo3SYn2bvx/owNW7MScbbJxrx4bkWLJser3AriUgOB87bJ/qvmz04kXvj3CT862wz9le04IHVs5RqGvkxd+MAP/3pTyEIAn7605+ivr4eSUlJWLt2LZ588kmlvkLIK628DADISY1GcrQeAHDdrERo1SrUtNmf62cmRSnZRPIjbgcLAWD9+vVYv379qL/bv3//uO/duXOnJ6ckP3XQMeAsmRqLGH2Y8/VVs5NwusGIT6vaGCwMIkOXINx+++0ABpcgjDYmSKUIhvrpT3+Krq4u/PrXv+YS4wAyNOtoaLAQANYsSHUGC39wE3dBJwpEBxyZRqtmJ0KlEpyvXzsrEW+faERp5WX84CalWkdEciq9aL9/v3bWYLBQ+v9ltR0wW2zQapj8QSO5EwfQaDTYtGkTNm3a5IOWkSsOVtqv9dcN6fuROg2WTovDoUtt+LSqjcFCcvIoWEgkOeZYlnT1jIRhr181LQ6APVOBgktRURHuueceLF++HCtWrMBzzz03YglCRkYGiouLnaUIhoqLiwOAEa+TfztZ34lukwUJkVosTI8d9rsv5CRDEICzhi40d/U7ZyqJKHAcq7Vfr1dccT2/xrFq4FhtO3rNFkRoeetIFMgMnf1oMpqgEoDczDjn6zMSIxEXEYaO3gGUNxqxZMjviCg4HK3uADDyWr9sejwOXWrDkep2FFw9TYGWkT/ilBFNysn6DgDAkqlxw15fOs2+VOl8czeM/QM+bhV5U0FBAbZs2YKNGzciNzcXZWVlI5YgNDayQG6wOeYI/F81LX5Y1hEAxEdqscCx/Fha3kBEgcNmE3GithPA8OABAExLiEBGXDgGrCI+406JBGDr1q3IysqCXq9HXl4eDh8+PO7xHR0dePDBB5GWlgadToc5c+Zgz549PmotXelEXQcAYE5K9LDgvyAIWOro/5zsJwo+vWYLzjfb64tfORlwlePZ/Ug1+z4NYrCQPNY/YMVZx4YGi6cOzzRKitYhMyEcogjnAwgFj/Xr16O6uhomkwmHDh1CXl6e83f79+8ft9zAzp078eabb3q/kSQr6cHhqulxo/5eKox+8AKDhUSB5mJrD7pMFujDVJidPHz5kSAIWJ5lf4gocxRFp9C1e/duFBUVYdOmTTh69CiWLFmCNWvWoLl59E0xzGYzbrrpJlRVVeH1119HRUUFtm/fjoyMDB+3nCQn6uz35VfeuwODk/3H2deJgs6ZBiNsIpAcrUNKzPBVQEsdqwIrW3qY6ENODBaSx8objbDYRCREapERFz7i99JSxbMGo6+bRkQyk5YtSDOPV5I2ODp4sdVXTSIimUiBgUUZsdCoR94aLnasHpAykih0Pfvss7jvvvtQWFiI+fPnY9u2bYiIiMCOHTtGPX7Hjh1oa2vDm2++iWuvvRZZWVm44YYbsGTJEh+3nCTHHf148RWrggA4VwmcNXT5sEVE5AvHnRMFcSN+NyVKh1RHALGC/Z8cGCwkj52qtw84izJiIQjCiN/npNpvOMobOeAQBbKGjj4YjP1Qq4RRMxEAYEVWAjQqAbVtfahr7/VxC4loMk7Wj/0AAdg3MQPsDxqiKPqqWeRnzGYzjhw5gvz8fOdrKpUK+fn5KC0tHfU9b731FlauXIkHH3wQKSkpWLhwIZ566ilYrdZRjzeZTDAajcN+SF7ljfa/00UZI6/nOY6d0C80d8Nssfm0XUTkXafqx84qBoB5adEAgLONHHfJjsFC8tj55m4AQI5jYLmS9Ho5BxyigCZlHeWkRo+5uUGkTuPMSGC9k+Dmbr0yya5duyAIgnMndfIf55rsk3pzU0e/ni9Ij4VaJaCly4Qmo8mXTSM/0traCqvV6qxRLElJSYHBYBj1PRcvXsTrr78Oq9WKPXv24NFHH8Uvf/lLPPHEE6MeX1xcjNjYWOdPZmam7N8jlLX1mNHabQYAzE4ZueNpeqwe0XoNLDYRlS3dvm4eEXmRdK2f55gUuJL0+hkm+pADg4XkMWnAmZM8+sPF/CGzkwNWzk4SBapyx3IEKRg4lqum25coS7ukU/Bxt16ZpKqqCj/60Y+watUqH7WU3HGuyR4UmJMy+vU8XKt21jI8zqXI5AabzYbk5GT87ne/w7Jly1BQUIBHHnkE27ZtG/X4DRs2oLOz0/lTW1vr4xYHtwuOif6p8eGjTv4JgoB5qdJSZE72EwULq0109v8raxNLpGAhE31IwmAhecw54IwyMwkAGXHhiNJpYLbacLGlx5dNIyIZVTgeGOamThAs5E5qQc/demUAYLVacdddd+Gxxx7DzJkzfdhackV7jxmt3fZswbEeIABgiWOJsrSMiUJPYmIi1Go1mpqahr3e1NSE1NTUUd+TlpaGOXPmQK1WO1+bN28eDAYDzGbziON1Oh1iYmKG/ZB8pJ1Qx+vr85wrg5hdRBQs6tp7YbLYoNWokJkQMeoxUt+vMHTBZmPJEWKwkDw0dBlDdtLoNxwqlYBZjpsRLmUgClxSofN5YyxRlCxzZBaeaTSi12zxervItzypVwYAjz/+OJKTk3Hvvfe6dB7WLPMtaZXA1PhwROpGLzMADJYW4cYHoUur1WLZsmUoKSlxvmaz2VBSUoKVK1eO+p5rr70WFy5cgM02uMLk3LlzSEtLg1ar9XqbabjzTdJE/9jX81mO31U2896dKFhIfT87KQpq1ci9BgBg+pRIaFQC+gasaDT2+7J55KcYLCSPDF3GMN7DxcykSADARQYLiQJSj8mC6sv2DUvGqmcmSY8LR2qMHlabiOO1zD4KNp7UKztw4ABeeuklbN++3eXzsGaZb51rHn8JsmSu4/dScJFCU1FREbZv346XX34Z5eXluP/++9HT04PCwkIAwLp167Bhwwbn8ffffz/a2trw0EMP4dy5c3jnnXfw1FNP4cEHH1TqK4Q06f591jiZhdmJjnv3Vq4KIgoW5ydYggwAYWoVpk2xZx3y2Z0ABgvJQ9LDwngDDjCYdcgbDqLAJPX1pGgdpkTpJjxeyi48WsOlyKGuq6sLd999N7Zv347ExESX38eaZb51Xrqej1FSRCJNFtS09TJzOIQVFBRgy5Yt2LhxI3Jzc1FWVoa9e/c6JxFqamrQ2NjoPD4zMxPvvvsuPv30UyxevBjf//738dBDD+Hhhx9W6iuENFeWIc903LvXtPVyR2SiIOFK3weGPLuzhBgBGDsljGgcg/UKx89EmCHNTnLAIQpI0pLDnAmyCiVLp8XhnZONOMq6hUHH3XpllZWVqKqqwtq1a52vSUsRNRoNKioqkJ2dPeJ9Op0OOt3EgWmSx2DB8/H7+JQoHRKjdGjtNuFcUzdyM+N80DryR+vXr8f69etH/d3+/ftHvLZy5Up88sknXm4VTaTbZHHuZp49TsAgJUaHCK0avWYratt7xyw3RESBw5WsYoCrAmk4ZhaSR6ov24N/WVMixz1u6IAjiiyUShRozjp2RHM1WOjcEbm2g30+yLhbrywnJwcnT55EWVmZ8+e2227DjTfeiLKyMi4v9hNSmYEZiaMXPB9qbqr9IeMc6xYSBRzp3j0hUosYfdiYxwmCwMl+oiAiiiIuOVb5zZwg+J+dyFWBNIiZheSR6jb7w0XWlPEfLrKmREIQAGO/BW09ZpeWMRKR/7jQ4loWsWR+Wgw0KgFtPWbUtfeNueMaBaaioiLcc889WL58OVasWIHnnntuRL2yjIwMFBcXQ6/XY+HChcPeHxcXBwAjXidlmCxWNHT2AbAXNp/I3JQYfHzhMjc5IQpANY6JgWkuXJdnJkXhdIPRkV2UMuHxROS/OnoH0NVvLx8yUf8fTPRhsJAYLCQPWG0iah3BwmkTBAv1YWqkx4ajvqMPF1t7GCwkCjCXHDcL2UkTBxIAe5/PSYvGqXojTtR1MlgYZAoKCtDS0oKNGzfCYDAgNzd3RL0ylYqLFgJFbVsfRBGI0mkwJXLinWmlDOOKJu5QTRRopIn+6RPcuwPATGYWEgUNqe8nR+sQrlWPe6yUeVjf0Yc+s3XC4ym4MVhIbmvs7MOAVUSYWkBabPiEx89MikR9Rx8utfTg6qwEH7SQiOTQZ7aiobMfADAj0fWaRYunxjmChR24dXGat5pHCnG3XtlQO3fulL9B5DFpWeL0KREQBGHC4+ekSjsis5YRUaCRSg5Mdymz0B4svMSliEQBz9XyYYC9TEFcRBg6egdQdbkH89JivN088mOc/ie3ScsYMuMjoFZN/HAhzWDWOGY1iCgwVDluLuIiwpDgQtaRZMnUWADAibpOr7SLiORRddn1TCNgMIDQ0mVCV/+A19pFRPKrabNf06e5EDCQyhLw3p0o8DlLELh4rZeWKtey/4c8BgvJbdUuLkGWZMY7Bpx2DjhEgUTKKJAKnbtqUUYcAOBUfSdsNm5yQuSvBjMLXevjMfowJEXby4lweSJRYKl2Y3IgM96+cqipqx/9A1avtouIvMtZgsDF0kBSCSFOFhCDheQ2d5YxABxwiAKVp8HCOSlR0Iep0GWycDc1Ij8mZRZOtFnZUFL90soWLkUmChRmiw0NHY7NjFy4f0+I1CJSq4Yo2muXEVHgYmYheYrBQnKbO8sYgCGZhW282SAKJFLm0Ew3g4UatQoL0qWlyB1yN4uIZOJuZiEwWPycmYVEgaO+ow82EdCHqZzZweMRBIGT/URBorrNvWv94KpAPruHOgYLyW3uZxbalzK0dpvQZ+ZSBqJAcanVnjnkzuYmksWsW0jk1wasNtQ7HgRcKXouyXYEC5lZSBQ4pImBaQmubWYEDK4MqmOwkChg9ZmtaDKaALj+7D6NEwXkwGAhuUUURbdqngBAbHgYovX2jbfrWLeQKGB4ugwZAJZMjQPAzEIif9XQ0QeLTYQ+TIVkFzKNJNImJ8wsJAoc0kP/tATXr+cMGBAFPqn/Rus1iIsIc+k9Q5chs/Z4aGOwkNzS1mNGt8kCQRiccZyIIAjOdGbecBAFhvYeM9p77budZiW6Xs9MImUWnm4wYsBqk7VtRDR50sRfZnwEVCrXMo0AYJYjs/DS5R5Y+RBBFBBq3JzoBwY3OWEZIaLANVhuxPWs4rQ4PdQqASaLDS3dJm82j/wcg4XkljrHkqXkaB30YWqX3yctRWahVKLAcMlxc5EWq0eEVuP2+7OmRCJar4HJYsO5pi65m0dEkyRtWjDVERBwVXpcOLQaFcyWwWXMROTfPOnv0mYInOgnClzSs/s0F5N8ACBMrUJarB4An91DnUfBwq1btyIrKwt6vR55eXk4fPjwmMf+7W9/w/LlyxEXF4fIyEjk5ubilVde8bjBpCxpJ7WMOPceLlgolSiwXGrxfAkyAKhUAusWEvkxKdCX4WawUK0SnJsesW4hUWCo9+D+fXCDwl6IIrOIyc6dOAAAdHR04MEHH0RaWhp0Oh3mzJmDPXv2+Ki15EnfB1iGgOzcDhbu3r0bRUVF2LRpE44ePYolS5ZgzZo1aG5uHvX4hIQEPPLIIygtLcWJEydQWFiIwsJCvPvuu5NuPPmeNOCkuxss5BbsRAGluk1asuRZsBAAFmXEAWCwkMgfDT5AuF9mQKpbyGAhUWBo8OD+faojWNhlsqCzb8Ar7aLA4m4cwGw246abbkJVVRVef/11VFRUYPv27cjIyPBxy0OXp4k+DBYS4EGw8Nlnn8V9992HwsJCzJ8/H9u2bUNERAR27Ngx6vGrV6/GV77yFcybNw/Z2dl46KGHsHjxYhw4cGDSjSffcz5cuJmJwAEnuDC7OPhJux9KJQQ8sTAjBgBwptEoS5uISD6Dk396t987M1HaEZmbnBD5u/4BK1q7zQDcCxiEa9VIcmx+xPt3AtyPA+zYsQNtbW148803ce211yIrKws33HADlixZ4uOWh67JJvqw74c2t4KFZrMZR44cQX5+/uAHqFTIz89HaWnphO8XRRElJSWoqKjA9ddfP+ZxJpMJRqNx2A/5B+eyJbcHHPvxde19XMoQ4JhdHBpq2wc3P/DU/DR7sPBsoxEWbnJC5Fek67m7NQsBIDtZ2hGZmYVE/q6xsx8AEB6mdnk3VMngrqgsIxTqPIkDvPXWW1i5ciUefPBBpKSkYOHChXjqqadgtVrHPA/jAPLyJKsYGLw3qGPfD2luBQtbW1thtVqRkpIy7PWUlBQYDIYx39fZ2YmoqChotVrceuuteP7553HTTTeNeXxxcTFiY2OdP5mZme40k7yoodOzYKE0QHWbLDD2W2RvF/kOs4tDg/Rg4Oqu56OZPiUSEVo1TBYbqi4zA4nIX1isNhiM9gCCR8uQmVlIFDAahmQRu7obqkTaEZnZReRJHODixYt4/fXXYbVasWfPHjz66KP45S9/iSeeeGLM8zAOIJ+hWcXuTgxKx0uZiRSafLIbcnR0NMrKyvDpp5/iySefRFFREfbv3z/m8Rs2bEBnZ6fzp7a21hfNJBdImQjuzk5EaDWId8xmNnDQCVi+yC7mjKLyTBYrmrrsgYRMD7KOJGqVgJzUaADA6Qb+OxL5i+YuE6w2ERqV4Fxm6A6pZmFrt4m1zIj83OBmRu5PDDgzC9sZLCT32Ww2JCcn43e/+x2WLVuGgoICPPLII9i2bduY72EcQD7SM3eEVo3YcPeyiqWJRIOxH1YbVwWGKo07BycmJkKtVqOpqWnY601NTUhNTR3zfSqVCrNmzQIA5Obmory8HMXFxVi9evWox+t0Ouh07t+8knf1mi1o77U/FLhbsxCwBxjbewfQ2NmHeY7liRRYxptVPHv27Jjv6+zsREZGBkwmE9RqNX7729+OmV1cXFyMxx57TNZ2k3vq2/sgivabi4RI7aQ+a356DI7WdOBMoxFfzmVBayJ/IGUKpMXpoVa5l2kEANH6MKTE6NBkNOFiSzeWTouXu4lEJJPBzYzcr08qbXJS186J/lDnSRwgLS0NYWFhUKvVztfmzZsHg8EAs9kMrXbkPSbjAPJp6JBWEIS7nVWcFK2DRiXAYhPRZOx3O1GIgoNbmYVarRbLli1DSUmJ8zWbzYaSkhKsXLnS5c+x2WwwmUzunJr8gDQ7Ea3TIEbv3uwEMJiNWO8YuCh0uJNdzBlF5dU6Hgoy4yPcvrm40oL0WADAGWYWEvkN5yqBWM9v/rOTuBSZKBA4lyF70N+nSjXHuQw55HkSB7j22mtx4cIF2GyDdavPnTuHtLS0UQOFJK/6Dnu/9STQp1YJSHNMMHBVYOhyK7MQAIqKinDPPfdg+fLlWLFiBZ577jn09PSgsLAQALBu3TpkZGSguLgYgD1LaPny5cjOzobJZMKePXvwyiuv4IUXXpD3m5DXSUE+T7IKgcE6hxxwApcvsos5o6i8Ghl2QpZIm5ycaTBCFMVJBx+JaPKcmUaTKDMwMykSBysvc5MTIj8n1Rv3JGAgbXJW19EHm02EyoNMZAoe7sYB7r//fvzmN7/BQw89hO9973s4f/48nnrqKXz/+99X8muEjMk+u6fHhqO2rQ/1HX1YLmfDKGC4HSwsKChAS0sLNm7cCIPBgNzcXOzdu9e5LLGmpgYq1WDCYk9PDx544AHU1dUhPDwcOTk5+OMf/4iCggL5vgX5hKc7IUvSOTsR8IbOKt5+++0ABmcV169f7/LnMLvYv0kZBFMnsROyZG5qNFQCcLnHjJYuE5Jj3F8GRUTykoKFUyexrGgws5DBQiJ/1jCJgEFarL1UgdliQ0u3CSm8hoc0d+MAmZmZePfdd/GDH/wAixcvRkZGBh566CH8+Mc/VuorhJTJPrtnxIcDl1iGIJS5HSwEgPXr148ZGLhyaeETTzwx7o5HFDg83Xpdks7MwqDA7OLgJxUyn8xOyBJ9mBrZSVE439yN0w1GBguJ/MBkr+cAlyETBQKbTRxSs9D9/q5Rq5Aao0d9Rx/q2nsZLCS34gAAsHLlSnzyySdebhWNZuhO6J7gqkDyKFhIoWmyy5YGg4WsWRjImF0c/GrbpJqF8hQznp8eg/PN3TjTaMSNOcmyfCYReW5wd9RJBAuT7cHC6ss9sFht0KjdKoNNRD5wuccMs8UGQYDHgb7MhHDUd/Shtq0Py6bL3EAi8hqpBIG0s7G7Mpz7DTBYGKp4Z0cucxZE9zSV2fE+g7EfFqttgqPJn61fvx7V1dUwmUw4dOgQ8vLynL/bv38/du7c6fzzE088gfPnz6Ovrw9tbW04ePAgA4V+Ts7MQmB43UIKfFu3bkVWVhb0ej3y8vJw+PDhMY/929/+huXLlyMuLg6RkZHIzc3FK6+84sPW0pVEcXKZRpK0GD30YSoMWEXnpkhE5F+kjKDkaB20Gs8e+wZ3ROYmJ0SBwmYT0ehI0PE0s5CrAonBQnLZZB8ukqJ0CFMLsNpENHexXh2RP+rqH0BH7wAAGYOF6fZg4emGTlk+j5Sze/duFBUVYdOmTTh69CiWLFmCNWvWoLm5edTjExIS8Mgjj6C0tBQnTpxAYWEhCgsL8e677/q45STp7BtAr9kKYHLLkFUqATMTHUuRm1m3kMgfNcgwMSBtciKtOiAi/9fabYLZaoNKAFI9zCqWVh/Ut/dBFEU5m0cBgsFCconFaoPB6CiQ7OENh0olIDWWm5wQ+TPpYSA+IgxROnkqVcxzZBZWt/WizxGkoMD07LPP4r777kNhYSHmz5+Pbdu2ISIiAjt27Bj1+NWrV+MrX/kK5s2bh+zsbDz00ENYvHgxDhw44OOWk0QqVJ4YpYU+TD2pz5KWInOTEyL/VC9DfdKpjoBBXQczC4kCRZ2j76fG6D0uE5Iea+/7PWYrjH0W2dpGgYPBQnJJc5cJVpuIMLWA5Gidx58jDToNnaxbSOSP5F6CDACJUTpMidRCFIHzzV2yfS75ltlsxpEjR5Cfn+98TaVSIT8/H6WlpRO+XxRFlJSUoKKiAtdff/2Yx5lMJhiNxmE/JB85liBLspMiATBYSOSv5Ojv0v0AMwuJAoccG5mFa9WYEqkFwMmCUMVgIblEutlIjdVDpRI8/hzWPiDyb7VtjmBhvHzBQgCYkxINAKgwMFgYqFpbW2G1Wp2bGUlSUlJgMBjGfF9nZyeioqKg1Wpx66234vnnn8dNN9005vHFxcWIjY11/mRmZsr2HUieBwjJTMeOyBe5IzKRX5Kjv0uZhQ0dfbDauBSRKBDIsZHZ0Pdzg9LQxGAhuUSOmifAYIFVBguJ/JO0RHFqgjw7IUvmptqDheeaGCwMNdHR0SgrK8Onn36KJ598EkVFRdi/f/+Yx2/YsAGdnZ3On9raWt81NgQ4HyCYWUgU9BqcGxx43t9TYvQIUwuw2ERnSSIi8m9yTQxKqwLrucFRSJKnIBUFvTrnw8Xkso2YWUjk36TMwmkyLkMGBoOFZ5lZGLASExOhVqvR1NQ07PWmpiakpqaO+T6VSoVZs2YBAHJzc1FeXo7i4mKsXr161ON1Oh10Os/LXdD4nMsSJ5ltAMC5wUl77wDaesxIcCxXIiL/IMdkv1olID0uHNWXe1HX1ivLRAMReVd9x+T2GpA4Nznhs3tIYmYhuWTwZsOz3ZQkUrCwnqnMRH7JWbPQS8uQmVkYuLRaLZYtW4aSkhLnazabDSUlJVi5cqXLn2Oz2WAymbzRRHKBnMuQw7Vq54MIswuJ/Euf2YrLPWYAkw8YOHdEbmfAgCgQyFWfeDDRh8/uoYjBQnKJXA8XGcwsJPJboig6C5jLucEJAMxJsWcgNRlN6Og1y/rZ5DtFRUXYvn07Xn75ZZSXl+P+++9HT08PCgsLAQDr1q3Dhg0bnMcXFxfj/fffx8WLF1FeXo5f/vKXeOWVV/Dv//7vSn2FkCfnBicAMNOxFPkig4UhY+vWrcjKyoJer0deXh4OHz7s0vt27doFQRBw++23e7eBBABo6LT39UitGjHhk1tM5twRmUsRiQKC3M/udXx2D0lchkwukWYTJrtsKS3WnpnY2TeAbpMFUTr+J0jkLy73mNE3YIUgDNYXlUu0PgwZceGo7+hDhaELeTOnyPr55BsFBQVoaWnBxo0bYTAYkJubi7179zo3PampqYFKNTgP2dPTgwceeAB1dXUIDw9HTk4O/vjHP6KgoECprxDS+gesaO2WJ9NIkp0UhY/Ot6KSm5yEhN27d6OoqAjbtm1DXl4ennvuOaxZswYVFRVITk4e831VVVX40Y9+hFWrVvmwtaFtaLBAEDzfnBDgjshEgaTbZEFn3wCAyd/PM9EntDGzkFwiDRBpsZN7uIjWhyFGbw8QNnLQIfIrUm3SlGg9dBq17J/PTU6Cw/r161FdXQ2TyYRDhw4hLy/P+bv9+/dj586dzj8/8cQTOH/+PPr6+tDW1oaDBw8yUKgg6VoeoVUjLiJMls/MTrZnDVc2M7MwFDz77LO47777UFhYiPnz52Pbtm2IiIjAjh07xnyP1WrFXXfdhcceewwzZ870YWtDm5wlB5hZSBQ4pL4frdcgWj+5a72UKNTSZUL/gHXSbaPAwmAhTcjYP4AukwWAPNlGg3ULGSwk8ifSQ0CmzDshS6S6hRUMFhIpYugS5MlmGkmkHZEvcBly0DObzThy5Ajy8/Odr6lUKuTn56O0tHTM9z3++ONITk7GvffeO+E5TCYTjEbjsB/yTL0MOyFLpjpqFtaxZiGR35NjYyNJfEQY9GH2kJGhk3ULQw2DhTShRsfNRnxEGCK0k182LA1cjRxwiPyK9BAwVebNTSQ5jszCCu6ITKSI+nb5Mo0ks5Pt/bqmrRe9Zotsn0v+p7W1FVar1Vl2QJKSkgKDwTDqew4cOICXXnoJ27dvd+kcxcXFiI2Ndf5kZmZOut2hqlGmzQkBIDNeunfvw4DVNunPIyLvaZBpJ2QAEATB+TlM9Ak9DBbShORcxjD0c+o5O0nkV6TMwqmTrE06FmdmoaELoih65RxENDZntoGMfTwpWofEKB1EkRMBNFxXVxfuvvtubN++HYmJiS69Z8OGDejs7HT+1NbWermVwUva4ESO+/ekaB10GhVs4mASARH5Jz67k1y4uwRNqF6meoWSdBZKJfJLg5mF3gkWZidHQq0SYOy3oMloQmqsvJuoENH46mTeCVkyLy0aH503obyxC0unxcv62eQ/EhMToVar0dTUNOz1pqYmpKamjji+srISVVVVWLt2rfM1m82elabRaFBRUYHs7Oxh79HpdNDpdF5ofeiRsovkuH8XBAEZ8eG42NKDuvZeTJvinRUIRDR5cgcLnTVL+ewecphZSBNqkHEZAzCY0cBUZiL/4u1lyDqNGjMS7fXNzhpYh4rI1+SsYzTU/LQYAEB5I/t1MNNqtVi2bBlKSkqcr9lsNpSUlGDlypUjjs/JycHJkydRVlbm/Lnttttw4403oqysjEuMvUgURdn7e6bj3qCWm5wQ+bV6Z7BQnmd36bmAmYWhh5mFNCG5ZyekoCODhUT+QxRFry9DBoC5KdG40NyNc01dWD032WvnIaKR6r2wDBkA5jmChWcYLAx6RUVFuOeee7B8+XKsWLECzz33HHp6elBYWAgAWLduHTIyMlBcXAy9Xo+FCxcOe39cXBwAjHid5NXWY4bJYoMgACmx8mRqDu6IzPt3In8mZwkCYHDCgbuhhx4GC2lCDZ3y7aY29HMMnf2w2kSoVfLsyEhEnrvcY0b/gP3BQq6SA6OZkxKNd042osLAnVOJfMlqE507Gcq5wQkwGCw822iEzSZCxet60CooKEBLSws2btwIg8GA3Nxc7N2717npSU1NDVQqLlxSmrQEOSlKB51GLctnZiY4MgvbGDAg8lc2L1zrp3JVYMhisJAmJHdmYXK0HhqVAItNREsX65YR+QMpUyA1Rg+txnsPenOlHZGbmIFE5EstXSYMWO0TdCnR8taEm5kUCa1ahR6zFbXtvZg+JVLWzyf/sn79eqxfv37U3+3fv3/c9+7cuVP+BtEIznrjMk4MMLOQyP+1dtuv9SoBsl3rM5y7offDYrVBo+aEUKjgvzSNa3gmgjxBPbVKcAYI6zs4O0nkD3yxBBkYDBaeb+qG1cYdkYl8RQoepMboZb/RD1OrnH37ZH2nrJ9NRO5r7JS33jjAmoVEgcAb1/rkaD3C1AKsNhFNXSZZPpMCA4OFNK6WLhMsjqXCydHy3XA4t2B3LJMgImV5e3MTybSECOjDVDBZbKi+3OPVcxHRoHovbW4iWTotDgBwtLrDK59PRK5zrgqSsayINJnYZDTBZLHK9rlEJB/nLugyXuvVKsH57F7HMgQhhcFCGtfQ2Qk5awtOdQw4Dax9QOQXfJVZqFYJmJ3sWIps6PLquYhoUIOXNjeRLJseDwA4UtPulc8nItdJAQM565MmRGoRobXXP+SuqET+Se7yYZKMONYtDEUMFtK4BpcxyDvgODMLebNB5Bdq26TMQu8GC4HBpchnGSwk8hnpeitXSZErXTXNHiw8Xd+JPjOzjoiUNLgbqnz9XRAE1i0k8nP1Hd651rPvhyaPgoVbt25FVlYW9Ho98vLycPjw4TGP3b59O1atWoX4+HjEx8cjPz9/3OPJvzQ4CyTLO+CkM7OQyK9ImYWZXl6GDAA5qcwsJPI1Z2ZhnHf6+NT4cCRH62CxiThR1+GVcxCRa7yVXcS6heROHGCoXbt2QRAE3H777d5tYIjzVqKPdO/ARJ/Q4nawcPfu3SgqKsKmTZtw9OhRLFmyBGvWrEFzc/Oox+/fvx933nkn9u3bh9LSUmRmZuLmm29GfX39pBtP3ueNZQzA4DIopjITKU8URZ/VLAQGMwvPNTFYSOQr3so2kAiC4FyKfOhSm1fOQUQTM1tsaHZsQiD3/Tuzi0Kbu3EASVVVFX70ox9h1apVPmpp6HI+u8tYrxQY0ve5OWlIcTtY+Oyzz+K+++5DYWEh5s+fj23btiEiIgI7duwY9fhXX30VDzzwAHJzc5GTk4MXX3wRNpsNJSUlk248eV+91+oeSLsh82YjEDG7OLi0dpthstigEuDcqdybpGBh1eUe9A9wuSKRL0jXW2+WGlg1OwkA8MG5Fq+dg4jG12TshygCWo0KUyK1sn52ZoIjs5CbHIQkd+MAAGC1WnHXXXfhsccew8yZM33Y2tDktZqF8SwhForcChaazWYcOXIE+fn5gx+gUiE/Px+lpaUufUZvby8GBgaQkJAw5jEmkwlGo3HYDyljMJXZO8uQu/otMPYPyPrZ5F3MLg4+0hLk1Bg9tBrvl7JNitIhPiIMNhE439Tt9fMRhTpj/wC6+i0A5H+AGGr1XHuw8FhNO9p7zF47DxGNbXAnZD0EQb7NCQFmFoYyT+MAjz/+OJKTk3Hvvfe6dB7GATzXP2DFZce111s1Cxs6+mGzibJ+9pXON3Xh+386hoL/LcXLB6u8fj4am1tPha2trbBarUhJSRn2ekpKCgwGg0uf8eMf/xjp6enDBporFRcXIzY21vmTmZnpTjNJRt5ahhyh1SA+IsxxDt5wBBJmFwcfXy5BBuzLFQc3OeFNIJG3SdfZ+IgwRGg1XjtPelw4clKjYROBfRXjL0sjIu8Y3NxE/okB6T6hjjULQ44ncYADBw7gpZdewvbt210+D+MAnpOu9RFaNWLDw2T97NQYPdQqAWarDS3dJlk/e6jyRiO+8tuDeOt4Aw5dasOmt07j0b+f8tr5aHw+3Q158+bN2LVrF9544w3o9WNHuzds2IDOzk7nT21trQ9bSZI+sxVtjtmJNJnrHgDc5CQQ+SK7mDOKvjcYLPT+TsiSnNQYANzkhMgXBndC9n4fX7MgFQDw16N1Xj8XEY3krYl+YHCDk9ZuM3c9p3F1dXXh7rvvxvbt25GYmOjy+xgH8NzQvi93VrFGrUJqjD1+463JArPFhgdfPYpukwXLpsfjhzfNgSAArx6qwTsnGr1yThqfW8HCxMREqNVqNDU1DXu9qakJqamp4753y5Yt2Lx5M9577z0sXrx43GN1Oh1iYmKG/ZDvSTOTUToNYvTyZyJINzH1joGN/J8vsos5o+h70kXfl8FCKbOwgpucEHnd4E7I3u/jX182FYIAfHzhMuuaESnAWzXLACA2IgzRjmcCZheGFnfjAJWVlaiqqsLatWuh0Wig0Wjwhz/8AW+99RY0Gg0qKytHPQ/jAJ7zZlYxMFi30FtlCP50uAYXW3uQGKXD9nXL8b0vzMaDq2cBADbvLYfZYvPKeWlsbgULtVotli1bNmz5oLSccOXKlWO+7xe/+AV+/vOfY+/evVi+fLnnrSWfahiyc6LcsxPA4EMLC6WGDleyizmj6Hu+XoYMDAkWMrOQyOvqvBg8uFJmQgSum2XPIvnfD0d/GCQi7xlas9AbBpci8/49lLgbB8jJycHJkydRVlbm/Lnttttw4403oqysjMkAXjA4Meitvu+9YKHZYsPz/7oAAHgofzYSHJszPXjjLCRF61Db1oc3y1jv3tfcXoZcVFSE7du34+WXX0Z5eTnuv/9+9PT0oLCwEACwbt06bNiwwXn8008/jUcffRQ7duxAVlYWDAYDDAYDurtZ1N7fNXpxGQMwGCzkMuTA4YvsYs4o+p4SmYVzUuzBwuYuEzdCCDDcDT3wSEuTfNXH199ozwTYdbgWp+o7fXJOIrLz5jJkAMh0jCO1zCwMOe7EAfR6PRYuXDjsJy4uDtHR0Vi4cCG0Wnl36qbBZ2pvlA8DgKlx3gsWvnfGgNZuE5Kjdfjm1YOB5HCtGv/ftTMA2Jcjk2+5HSwsKCjAli1bsHHjRuTm5qKsrAx79+51LkusqalBY+PgmvIXXngBZrMZX//615GWlub82bJli3zfgryi3suZCKxZGHiYXRx8RFFUJLMwSqdBZoJ9DDjL7MKAwd3QA1O946HeF8uQASBv5hTcsjAVFpuI//jDZ9zIiMiHvLkMGWBmYShzNw5AvuXtiQKp79d74dl912H7SrKCqzMRph4eovrG8qkIUws4XtvBCUgf86gQ3fr167F+/fpRf7d///5hf66qqvLkFOQH6r28jEHa0t0bAw55T1FREe655x4sX74cK1aswHPPPTdiVjEjIwPFxcUA7NnFGzduxGuvvebMLgaAqKgoREVFKfY9yK612wyTxQaVAKR6qa+PZW5KDGrb+lBhMGJl9hSfnps8M3Q3dADYtm0b3nnnHezYsQMPP/zwiONfffXVYX9+8cUX8de//hUlJSVYt26dT9pM3p/8G03xVxehwtCFi609+OKvP8LyrARkJ0UiJUaPtFg9spOisHRaPNQq+cucEIWqzr4BdJksAAbvs+UmTfR5qyZpV/8A+gasSI727T0JucadOMCVdu7cKX+DyKl+SAkxbxisWShv369r78WBC60QBOAby0cuT58SpcPN81PxzslGvH2iEQszYmU9P41N/l0rKGhINwGZCd7JNpIGnCZjPwasthGzCOSfCgoK0NLSgo0bN8JgMCA3N3fErKJKNfhvOTS7eKhNmzbhZz/7mS+bTqOQLvipMXpoNb7tgzmp0fhneRM3OQkQ0m7oQ0uNyL0bOmDfEd1kMjn/zB3RJ8dssaG5y/73meHDUgNxEVr85bsr8fDfTuL9M004fKkNhy+1DTsmOVqH/8qfgztXZHqlNjJRqJGu6VMitYjQeucxT9oRufqyvAEDi9WG4n+cxcsHq2CxiViRlYBf35nrtSWVRMHEZhOd+wBkemmlkFTKpL69DzabCJVMk33vnraXt7o6K2HMuMMXF6XhnZON2HOyET/+t7m8Z/ARBgtpTN5empgYqYNWrYLZakOTsd+nSyBpcphdHDxqHJMCU700KTAeaZMTLkMODOPthn727FmXPmOi3dAB+47ojz322KTaSoMaO/sgioBOo8KUSN/WiJri2NGw5nIvPrl0GQ0dfWgy9qOxsx/HajrQ3GXCT944icOXLuOZO5Zw0pBokmrbHPfuXrymZyfbV4Vcau2RNWDwyBunsPuzwU3tDle14e6XDuNvD1yDGH2YLOcgClbNXSaYrTaoVQLSvLYqMBxqlQCTYxJSrhVJ756yrzr7twVj179fPTcJ+jAVatp6cbrByOxCH+FdGY1qwGpDY6c0O+GdGT2VSkCatBSZdU+IFCFlBkxXMFh4ztAFm030+fnJt1zZDR3gjuhyk4IHmQkRis3ET5sSgW8sz8R/5c9B8VcXY2fhCnz6SD4e+eI8aFQC3ixrwMa/n4Yochwgmgwps9Bb9+7SZ4epBfQNWNFo7JflM/9xshG7P6uFSgC2fusqfPT/bkRarB4XmrvxzN4KWc5BFMykvp8Wq4fGSxNvYWqVM7uw6nKPLJ/Z0mXCp9X2VQdrFo4dLIzUabB6TjIA4L3TBlnOTRNjsJBG1djRD5sjEyEpWue18zh3RO5ksJBICc5g4RTfBwtnJEYiTC2gx2xl7dIA4Ivd0AHuiC43KXt4mgITAuPRalS47/qZ+O1dV0EQgD8drsFbxxuUbhZRQPPFhmUatQrTp0QCACqbuyf9eWaLffkxANy/Ohu3Lk5DZkIEfnnHEgDAq4eqcalVnsAEUbCSdief6uVyI1Lfr5YpWFhS3gRRBBZlxE64Cdvnc+zBwg/Ot8pybpoYg4U0Kml2IiM+3KuZCIM7IsszM0lE7qlps1/spzku/r4UplYhO8m+nKmCS5H9HndDD0xSsNCbmUaTcfOCVHz/87MBAI//3xl09g4o3CKiwDVYb9y7/T07yREsbJl8sHD3Z7WoaetFcrQOD944y/n6NbMS8fmcZNhE4H8/qJz0eYiCmXMVgZfLemU5kguqZKpZ+sG5FgBA/ryUCY4EVs1JBACcqOtAe49ZlvPT+BgspFHVOpcxeHfAkYKFcu+qRESuUXIZMmDf5AQANzkJEEVFRdi+fTtefvlllJeX4/777x+xG/rQDVCefvppPProo9ixY4dzN3SDwYDu7sk/YJJrnNdzP8ssHOrBG2dhdnIULveY8buPGBQg8lSdlzc4kEgTfZMNFtpsInYcuAQAeGB19ohNWR5YnQ0A+NvRegYHiMZR56NrvZyZhVabiIOVlwEMBgLHkxYbjjkpURBF4MAFZhf6AoOFNCpngWQvZyJIy6Kk8xGR7/SZrc5dUpVYhgwAc1PtS0y5yUlgKCgowJYtW7Bx40bk5uairKxsxG7ojY2NzuOH7oaelpbm/NmyZYtSXyHk1PrpMuShtBoV/nvNXADA7z+uQhuDAkRuE0XRZ0sRncHC5skFDD4414JLrT2I1mnw9eWZI36/PCsB89NiYLba8H8nWKaAaCy+enZ3Zha2Tj7R52R9Jzr7BhCt12CxixuWXD87CQDwoSMjkbyLwUIala9mJ6SHl+o21iIh8jVpeWKMXoO4CN/ukiqZl2bPLDzT0KnI+cl969evR3V1NUwmEw4dOoS8vDzn7/bv34+dO3c6/1xVVQVRFEf8/OxnP/N9w0OUs2ahQhMCrrppfgoWpMeg12zF7k+5qQ2Ru9p6zOg1WwHYywh500yZliG/drgGAHDH8kxE6TSjHvP1ZVMBAK8fqZvUuYiCWV2H7zMLJ7sp2YHz9oDfNdlTXN6U5fo59mDhxxdauSmaDzBYSKOqbffN7ISUzdTQ0Y8Bq82r5yKi4aQlBNMVqFcoWZBun0m82NqDXrNFsXYQBaPOvgF0OGoAentZ4mQJgoBvX5MFAPjjJ9Wwcod0IrdIS5BTYnTQadRePddMR2Zhc5cJXf2e1Rnt7B3A/opmAEDB1SOzCiVfzk2HRiXgRF0nLsiwoQpRsLFYbc76/96+1mcmhEMQgB6zFa3dk1sF8JFjo5LrHNmCrlieFQ+NSkBDZ79zzCPvYbCQRuUskOzlAScpSgedRgWrTUQDd0Ml8il/yDhKitYhJUYHUQTKG42KtYMoGEnX8imRWkSOkbXjT9YuSUdcRBjqO/qw72yz0s0hCii+qjcOALHhYUiK1gEALrZ4tjpo7+lGDFhFzE2JxlxH/eLRTInS4dpZ9npm750xeHQuomDW2NkPq02EVq1CsqNfeotOo0Z6rD2ZqGYSKwN7TBYcrWkHAKyaNXG9QkmEVoOFjiXLhy+1eXx+cg2DhTRC/8BgHTNvpzKrVIJzKbIUuCAi31B6cxPJQkd24al6BguJ5OSrkiJy0Yep8dWl9iWHb5bVK9waosDi3A3VR/19liO70NMNyv5eZq9BeFtu+oTH3jTfXhf3/TNNHp2LKJhJGXYZ8eFQqQSvny8rcfJ1Cw9fasOAVcTU+HC366bnzUxwfgZ5F4OFNEK9I8MvQqtGfESY18/HYCGRMqodfU6pzU0kCzKkYCHrFhLJSbquBkqwEABuX2oPHPyzvAk9JpYmIHJVnY82N5HMS7NvUObJqoBmYz9KL9p3Qb1tievBwrLaDjR39bt9PqJg5quNjSRS+aJLrZ5nFjqXIM9KhCC4F+DMm+EIFlYxWOhtDBbSCNLsRGZ8hNud1xPSQ0zNZQYLiXypxlGzcFqCcjULAWBhuv2B41QDMwuJ5OQsNZDgmwcIOSzKiMWMxEj0D9iYRUTkhtoh9+++MN9x7T7jwbX77RONEEVg6bQ4lyYzUmL0WDI1FqIIlJSzRAHRUM5ndx9NDEq7oU+mhujBSnuw8Fo3liBLlk1PgCDYg5XNRk4eeBODhTSCtOmBrwYcKauJmYVEvmOx2pw3F/6SWXi+qQsmi1XRthAFkxrHssRpAZRZKAiCM9PoreMNCreGKHDU+Pj+fb4js/BMo9HtXUn/7ujbX3Yhq1CSP8+eXfhBRYtb5yIKds6+76OJgjkp9mDh+WbPShBc7jbhrMH+3pXZU9x+f2x4GOal2sefQ1yK7FUMFtIIUkrxjETfDDjSQ0w1MwuJfKaxsx8WmwitRoXUGL2ibUmP1SM+IgwWm4hzBu50SCQXX21WJrcvLkoDABy40Mpd0olcMGC1OTMLZyT6ZrXArOQoaNUqdPVb3NqVtPpyD47XdkAlALcudj1YuGqOfcfUjytbYbHa3G4vUbC65HiG9tWz++xk+4ZEVZd7Yba43xelAN/clGgkRnm2IcuKGaxb6AsMFtIIVY5goVSPwNukYGFtW6/bM5NE5BlpUmBaQoRPiiGPRxAE585mpxpYt5BIDgNWmzNjf0aSsqUG3DUnJQqZCeEwW2zOukZENLa69j5YbSLCw9RIifHubqgSrUaF2Y4Mo9NuLEX+P0dW4bWzEp07KrtiUUYsYsPD0NVvwQnWOCZykp7ds3w0UZASo0O0TgOrTfSobqG0BNmTrELJ57jJiU8wWEgjVDlnJ3wz4EjLJbpMFnT0DvjknEShrrLFnsGX7SdBhAXp3OSESE61bb2DwYNoZbOH3SUIgnPJ4T9Zt5BoQoMT/b6pNy4ZuhTZFaIoOndBXuvGEmQAUKsEXOMILhzgJAIRAKC9x4zOPvvz83Qf1SAXBAGzJrEU+WClfXOjyQQLl2fZg4UVTV3o6DV7/Dk0PgYLaRiL1eZctuSr2Qn9kFnQatYtJPIJKVg401GkWGkLuMkJkawuDck0UDp72BM3OYKF/zrbDKuNqw78ydatW5GVlQW9Xo+8vDwcPnx4zGO3b9+OVatWIT4+HvHx8cjPzx/3ePLMYAkh304ASpucuDrRd9bQhfPN3dCqVVizINXt8103274ZwkfnWbeQCAAuOeoVpsXqEa5V++y8s5MdwcIm98oHNRn7cbGlB4IAfG6G58HCxCgdZjrGu8+q2j3+HBofg4U0TH1HHyw2ETqNCmk+rGMmzYRIm6sQkXddbLH3tWw/CRZKy5DPNhpZi4hIBlIfn+kn2cPuunpGAqL1GlzuMaOslg8C/mL37t0oKirCpk2bcPToUSxZsgRr1qxBc/PoO9Tu378fd955J/bt24fS0lJkZmbi5ptvRn19vY9bHtyqLvt2GaLkqmnxAIAj1e2wuRDUl7IKV89NQmx4mNvnWzXLXrfwWE0Huk2sZ0rkXILso/JhEqluobs7Ipc6sgoXpsciNsL9MWCoqx3ZhZ9WcSmytzBYSMNcGrKMwZeZCNJMqPRwQ0TeNZhZ6B+BhOkJEYjWa2Cy2FDR5NnuakQ06KLjej7Tx8EDuYSpVbjesaHBh+e45NBfPPvss7jvvvtQWFiI+fPnY9u2bYiIiMCOHTtGPf7VV1/FAw88gNzcXOTk5ODFF1+EzWZDSUnJqMebTCYYjcZhPzQxZybxFN9uZjQ/PQbhYWp09g3gQsv4QQObTcRbZfYg8VeWZnh0vmlTIjAtIQIWm4hPHEEHolDm63qFEqleqbv37HLUK5RcPYPBQm9jsJCGUWp2QgpYVE5wo0FEk9fVP4AmowkAkJ3oH5mFKpWA3Mw4APaMASKanEut/jUh4InrHUsOP+SSQ79gNptx5MgR5OfnO19TqVTIz89HaWmpS5/R29uLgYEBJCQkjPr74uJixMbGOn8yMzNlaXuwc2YW+vj+PUytwtJpcQAmfmA/XNWGhs5+ROs1uDEn2eNzSkuRP67kJAKRtBOyrycKclLtJQgutnSjz2x1+X2lFydfr1CywpFZeLK+E/0DrreBXMdgIQ3j681NJNJSSGYWEnmflIGQGKWb9BIAOS11BAvLajsUbQdRMJCupzP8ZELAE9fNtmcWHq/tQCc3QFNca2srrFYrUlJShr2ekpICg8Hg0mf8+Mc/Rnp6+rCA41AbNmxAZ2en86e2tnbS7Q52ZosN9e19AHx//w4MbjQwUd2wN4/ZswpvWZgKfZjntdWuzXYECy8wWEikVGZhSowOiVE62ETXNziqbetFbVsfNCrBuYR4MjITwpEcrcOAVWSigZcwWEjDXFJowMl2FEm92NrtUs0TIvKcvy1BluQ6shOO1bA+GdFkdJssaO6yZw8rETyQS0ZcOLKTImETB5cuUeDavHkzdu3ahTfeeAN6/eh1sXU6HWJiYob90Phq2nphE4FIrRpJ0Tqfn//qLHvdwtLKyxDF0e/h+wes2HOyEQBwu4dLkCVSRtK5pm40d/VP6rOIApkois5goa+v9YIgYFGGexscSdfxxVNjEaXTyNIGLkX2LgYLaRilljFkxocjTC2gf8CGRiMv/ETeVNnsX5ubSHIz7Q8clS09zCIimoQqZ/aw1qNNBPzJKkd24YfnGSxUWmJiItRqNZqamoa93tTUhNTU8Xe23bJlCzZv3oz33nsPixcv9mYzQ06Vs954JATB9zufX52VgPAwNQzGfpxuGD3DaM/JRhj7LUiP1SNvEjugAkBCpBYLHLswl7JuIYWwyz1mdJksEARgWoJvlyEDwCLH5oQnXQwW7q+wlxSR6hHLYQU3OfEqj4KFW7duRVZWFvR6PfLy8nD48OExjz19+jS+9rWvISsrC4Ig4LnnnvO0reRl/QNW1LbZlyFnJ/s2WKhRqzDdEaCsdHNXJSJyz0VHLbNsP8ssTIjUOmuuHK/rULYxRAFMyh4O5KxCyQ3OTU5axsxaIt/QarVYtmzZsM1JpM1KVq5cOeb7fvGLX+DnP/859u7di+XLl/uiqSFF2lhEWqXja/owNVY56giWlI++K/Yrn1QDAL6VNw1qGTZQvHaW/XwHOIngc+7EAbZv345Vq1YhPj4e8fHxyM/PH/d4co+0E3FGXPiklvZ7aqEjWOhKZuGA1ebsr6vnel6z9ErScuaj1e2wWG2yfS7ZuR0s3L17N4qKirBp0yYcPXoUS5YswZo1a9DcPPrFobe3FzNnzsTmzZsnnHUkZVW2dMMmAnERYUiK8v0yBmnHRm5y4t84WRD4zjUp+2AxHm5yQjR50gOEv2UPeyJvZgLC1ALqO/qcdZVJOUVFRdi+fTtefvlllJeX4/7770dPTw8KCwsBAOvWrcOGDRucxz/99NN49NFHsWPHDmRlZcFgMMBgMKC7m/d6cjnn2I10joLX9Px59jqW/zjVOCKof6KuA8dqOhCmFvCNq+XZsOYax1Lkg+MsfSb5uRsH2L9/P+68807s27cPpaWlyMzMxM0334z6+noftzw4nZf6fkq0IudfNNUeLDzf3I1es2XcY49Ut6PLZEFCpBaLHUFGOcxNjUa0XoMesxXlje7tzEwTcztY+Oyzz+K+++5DYWEh5s+fj23btiEiIgI7duwY9firr74azzzzDL75zW9Cp/N9AIpcJz1czE6OUmQZg7NuITc58VucLAh8/QNWZ23SnFRlbi7Gs3SafSlyWS3rFhJ56qzBfsM81w/7uLsitBosn27PHPiIuyIrrqCgAFu2bMHGjRuRm5uLsrIy7N2717npSU1NDRobG53Hv/DCCzCbzfj617+OtLQ058+WLVuU+gpBx3n/nqJcsPDmBSnQalQ4a+gasSTxf0rOAwDWLk5HcvTotSrdtWLG4CRCNScRfMbdOMCrr76KBx54ALm5ucjJycGLL77ozEamyTuvcN9PjdEjPVYPq03E0eqOcY91LkGenQiVDNnFErVKwLLp9meHw1yKLDu3goVmsxlHjhwZtoOZSqVCfn4+SktLZWuUyWSC0Wgc9kPeJ81MzkpW5uFCyoBgZqH/4mRB4LvQ3A2rTURseBhSY+S5aZeTM7OwtoPZAkQekq7nwRAsBIBVc+xLDj88x2ChP1i/fj2qq6thMplw6NAh5OXlOX+3f/9+7Ny50/nnqqoqiKI44udnP/uZ7xsehGw2EeebpICBcv09LkKLLy60Twq/dOCS8/XDl9rwz/JmqARg/ednyXa+CK3GObl4gLsi+4QccYDe3l4MDAwgIWHsnXAZB3CddK2frdCzuyAIWOHYYOTwpfHrh+6vsCeWyLkEWSItRf70EoOFcnMrWNja2gqr1eqcPZSkpKTAYDDI1qji4mLExsY6fzIz5UlZp/FJNxtzFJqdkHZmvcCahX6JkwXBoWJIxpESGcQTmZcWA61GhY7eAWcGJPkXliLwb71mC2oc9YfnKhg8kNP1jk1OSisvw2xhTSIiSX1HH/oGrNCqVZiuwAYHQ31n1UwAwN/LGvBpVRvaesz479ePAwAKrs7ETJnLIlybbZ9E4E7pviFHHODHP/4x0tPThz1LXIlxANcp/ewOAHkz7SUBPhknUHeptQdnDV1QqwRnHWI5SQHLz6rbmGggM7/cDXnDhg3o7Ox0/tTW1irdpJDgTGVWaHZCqrfQ3GVCe49ZkTbQ2DhZEBwqHLOQ8/w040irUWGJowYKdzbzPyxF4P/ON3VDFIHEKB2mKFB/2Bvmp8VgSqQWPWYrjtawRAGR5Hyz/Zo+MykSGrWyj3ULM2LxtaumAgC+veMwbvn1h6i+3Iv0WD0evmWe7Oe7brY9SFFaeRk2GwME/m7z5s3YtWsX3njjDej1Y69sYRzANZe7TbjseF6epWC90jxHoK6spmPMuoXvnGgAYN+YKD5SK3sbFk+NhVajQmu3mYkGMnPrqpKYmAi1Wo2mpqZhrzc1Ncn6EKDT6RATEzPsh7yrf8CK6sv2zqXU7ESUTuPc9r3cwGyyUMWbBO8qb7T3rbmp/juufs4xS3noIoOF/oalCPxfhXMJcuBvbiJRqQTnbqtcikw0SNqwTMlgwVCPf3kBrs6KR4/ZiiajCemxevzh3hWIDQ+T/VyLp8YhUqtGe+8AzjTyucHbJhMH2LJlCzZv3oz33nsPixcvHvdYxgFcIyX5TI0PR4RWo1g7ZiRGIjMhHGarDR+eGz3L9+0T9jq2X1qU5pU26DRqJhp4iVvBQq1Wi2XLlg0rSioVKV25cqXsjSPfudjSA5sIxOg1SIpW7oFO2nCBuxn5H04WBAdpGXJOmn9mFgJA3gzHkoaL3OXQn7AUQWCQ+rhSuyN6y/WOpUsfcpMTIqfBZYj+0d8jdRr86b7P4feFV+OFu67C+0U3eK0Wepha5VwC+THrFnqdp3GAX/ziF/j5z3+OvXv3Yvny5b5oakhQeidkiSAIuHm+/TnwvTMjV5qdbujEWUMXwtQC1izw3goTqW7h4UtcfSAnt/PVi4qKsH37drz88ssoLy/H/fffj56eHhQWFgIA1q1bhw0bNjiPN5vNKCsrQ1lZGcxmM+rr61FWVoYLFy7I9y1o0qRso5zUGEXrmM1LsweGznKG0O9wsiDwtfWY0dxlAqD8zcV4rpoeB41KQENnP+ra+5RuDjmwFEFgcE4I+GmpAU+tctQtPFVvRItjHCMKdVJGnT9d0zVqFW6cm4xbFqUhUufdjKdrsh3BwsrxN1cgebgbB3j66afx6KOPYseOHcjKyoLBYIDBYEB3N+vTT1a5H00MSkHAf55pQv+AddjvXimtBgD828I0xEbIn2EsudqxHJqZhfJyO1hYUFCALVu2YOPGjcjNzUVZWRn27t3rfHioqalBY2Oj8/iGhgYsXboUS5cuRWNjI7Zs2YKlS5fiO9/5jnzfgibtVEMnAGBBhrJZXPMc2U5chuyfOFkQ2E7UdQCwLxmI8vIN/GREaDVY7FhO8MlFPgCEGpYi8Jwois7ruTT5FiySonWY7/hOBy4wu5Cof8DqzC5a5LhmhprrHOUJPr3UBpPFOsHRNFnuxgFeeOEFmM1mfP3rX0daWprzZ8uWLUp9haBxut5+rV+UoXzfXzY9Hhlx4TD2W/DOicF//5YuE94sqwcA3P256V5vgyAANW29aDL2e/VcocSjp8X169dj/fr1o/5u//79w/6clZXFZWQB4HS9PTi3MF3ZAUd6uDnX1A2L1aZ4sWYarqCgAC0tLdi4cSMMBgNyc3NH3CSoVIP/ZtJkgWTLli3YsmULbrjhhhFjBXnfiTr7jcXiAHio+NzMKTha04FDl9pwx3JmlvkDX5YiYH1Dz9S196GjdwBatQpzgyyzEABumJuEM41GfHiuFV9ZOlXp5hAp6lxTFyw2EfERYUiPHXvDiGA2NyUaiVFatHabcaymw1nzmLzHnThAVVWV9xsUggasNmdm4UKFE30AQK0S8K28aXjm3Qr874eV+HJuOjRqFf6n5Dz6B2xYkhmHq7PivdqGGH0Y5qXG4EyjEZ9WteFLi9O9er5QwUgMwWYTcdqRibBQ4dmJzPgIRGrVMFts3M3IT61fvx7V1dUwmUw4dOgQ8vLynL/bv38/du7c6fyzNFlw5Q8DhcqQMgsXT41TtB2ukOoQMbPQf7AUgf877ujjOWnR0GnUyjbGC653LEX+6HwLdz+lkHdKmujPiFW0hJCSBEHAymx7duFB1i2kEHG+qRtmiw3R+sHNQZV2V940xEeE4VxTN375/jn89UgdXvnEvgT5x/821ydjlBSQ/PQSlyLLhcFCQtXlHvSYrdCHqZCdFKloW1QqwZldeNKRXk1EkyeKIo47MguXBEBm4bLp8VCrBNS196G2rVfp5pADSxH4t5N1/rMsyRuWTY9HpFaN1m4zdz+lkHfKTyb6lXbdLNYtpNDiLB+WruxeA0PFRWjxyK3zAQAv7K/ED/9yHADw7WuycI0joO9tUt3CQwwWyobBQsKpBvsN97y0GL9Y9pubGQcAOFbToWg7iIKJwdiPli4T1CoBCxQuN+CKKJ0GSx1jAXc/9R+sW+zfAqnUgCe0GhVWOjY04LhAoU6qWaZ0CSGlSYGIstoOdPUPKNwaIu/z177/9WVT8bO18xGt10CnUeHe62bgkVvn+ez8K2dOgSAAZw1daGbdQln4b4V78hl/G3Byp8UBsF/0iUgex2vt/XxOSjTCtYGxPHH13CR8Vt2O/RUtuCvPu4WRyXWsW+yfbDYRp5wFz+OUbYwXXT8nCf8sb8aH51rwwOpZSjeHSBH+VrNMSZkJEZiWEIGatl4cvtSGL8xLUbpJRF4lJfr4Y1bxt6+dgXUrs2ATRZ8nIU2J0mFRRixO1HXiw/Ot+Poy1jaeLOXTyEhxR2vaAfjPTmpLp9nrDZQ3Gkdsv05EnjlWa+/ngbAEWbJ6bjIAex0i7nJINL6Lrd3oMlmg06gwOyVK6eZ4zQ1z7HULP6tqR2cvs4goNJ1uMMJssSEuIsxvapYp6dpZ9uzCjy9wKTIFN5PFOjgx6Kf39CqVoNhqReke4YNzXH0gBwYLQ5zJYnXWMbs6K0Hh1tilx+qRFK2DZUiWBBFNzmFH/Q5/6eeumJ8Wg8QoHXrMVnxW1a50c4j82uFL9j6ydFocwvygpIi3TJ8SiTkpUbDYRPyromniNxAFoc+q7Nf05dPj/aZmmZKucwQL91U0M5udgtqpeiNMFhsSIrWYmajsXgP+SAoWfnS+BVZuhDZpwXs3SS45Vd8Js8WGKZFaZE3xj5lJQRCctcpYt5Bo8nrNFufGB3kzAydYqFIJzov+/opmhVtD5N8OX7Jn1KyYMUXhlnjfmgWpAIB3TzFYSKFJmkBbNj1wrunedP2cRGjVKlxq7cH55m6lm0PkNUeq7RMFyzhRMKrczDhE6zXo6B3A8boOpZsT8BgsDHGDNxv+NeBIS5E/reJuRkSTdaS6HRabiIy4cEyN949JAVetnmsPFu6r4HICorGIoujc/S9vRvAHD6Rg4QfnWliuhEKOKIr4rFpaLRCvcGv8Q7Q+DNfNtmcXvnvKoHBriLznU8ez+/Lp7Puj0ahVWOUYC/bz2WHSGCwMcdKA429LEz/nyH765OJlphATTdLhAA4iXD8nCWFqAReau3G+qUvp5hD5pbr2PjR29kOjErDUsUlYMFuQHoOMuHD0DVjxIesSUYiputyL1m4ztGqVX25woJQ1C+wbm+w9zWAhBSdRFHG02hEs9LNnd39yo6Pm+XscCyaNwcIQNmC14dBFadmSfw04izJiEa3TwNhvwekG1i0kmoyPzrcCAD43M/CWJ8aGh2HVbHt24TsnGxVuDZF/Olhp7+OLpsYiQqtRuDXeJwgCbpYCA8wiohBz4Lw9QJ47LQ76MLXCrfEf+fNSoBLsm7/UtvUq3Rwi2Z01dOFyjxnhYeqQ3wV9PDfNT4FGJeCsoQuVLSxLMBkMFoawstoOdJksiI8I87uZSY1a5aytdrCSO5sReepyt8lZs+MGx5LeQPPFRWkAgD0MFhKNSlpqI9X4DAVfWmwfF949bUCv2aJwa4h854Nz9smBUOrvrpgSpXOulOIkAgUjaYffz81MgE7DiYKxxEVonTuk/4PPDpPCYGEIk5buXDc7CWqV/9QrlFyTbe/kBxxZUUTkvg/OtUAU7cv2UmL0SjfHIzfNT0GYWsC5Ji5FJrrSgNXmzB6Wlt6EgqumxWP6lAj0mK147zQ3OqHQYLbYUOrIJL5+NoOFV/rSknQAwF+P1nFXZAo60rM7Jwom9sVF9trGe05y4mAyGCwMYdKAc72jCKi/kbKgDl26DGP/gMKtIQpM/zpr30U4kIMIseFhzoeiN8vqFW4NkX/5rKod3SYLpkRqscjPVgl4kyAI+OrSqQDsgQGiUHC0ph09ZiumRGqxIJ3LEK902+J0aNUqnDV04XSDUenmEMmmx2Rxbvx5PYOFE7ppfirUKgFnGo1cijwJDBaGqMbOPhyvs9cC9NcBJzspCtlJkRiwitjnCHgQkev6B6zOvnNjTuAGCwHgq1fZgwKvH6mDxWpTuDVE/uNdRwHvG+YmQeWHqwS86StLMwAAH19oRWNnn8KtIfI+qb9fPyf0+rsrYiPCcJOjnunrRziJQMHjg3MtGLCKyEwIx4zESKWb4/cSIrVY7Yhx/PnTWoVbE7gYLAxRUkru1Vnxfr00cc0CewoxlxgRua+kvBk9ZiumxofjqgDfIfWm+SmYEqlFk9GEfRXc/ZQIAKw20bnxj1TDL5RMmxKBvBkJsInAq5/UKN0cIq+y2URn7V6pli+N9PVl9snFv5fVo3/AqnBriOTx9okGAPa+LwicKHBFwdWZAOyrD8wWJhp4gsHCECUNOLf6+c3GzY5g4b6KZvSYWMCcyB1/dyzZvW1JesDfWGg1KnzN8QDw2qFqhVtD5B8+uXgZLV0mxEWE4bpZ/rlKwNu+fU0WAOC1wzUMDFBQ+6y6HU1GE6J1Glw/xz9LCPmDVbMSkR6rR3vvgPM+iCiQ9ZgszrJCX1qUrnBrAseNOclIitahtduMknImHnmCwcIQVNXag2M1HRAE4BY/DxYumRqLGYmR6DVb8c4J7mZE5KqWLhP2VdhvLG7LDY4biztXTIMgAPsqWlBh4EYnRFKtvlsWpkGrCc1bupvmpyAjLhxtPWYGBiiovXHM/t/3TQtSuBPqODRqFQqvnQEAePGjS9zohALee2cM6B+wYfqUCCzMYK1SV4WpVfjGcnuiwfaPLnIs8EBo3lmGuFcdWTnXz07y6yXIgL2A+TeW21OId33KJUZErnrtUA0GrCKumhaHnNTguLGYkRiJf3NkG2/7oFLh1hApq7XbhLeP2yfR7nDcDIcijVqFe66ZDgDYuq8SA6xpSkHI2D+ANx3BwjuWZSrcGv9XsCITUToNzjd3OzOyiALVK6X2Z/evXTU14FcK+do912RBq1HhaE0HDl9qU7o5AYfBwhDTZ7biz5/ZMxHWrZyucGtc87VlGVCrBByt6cCp+k6lm0Pk98wWG/7omBT4tmN2PVg8sHoWAOCt4w24yN3NKIS9dqgGZqsNSzLjcNW0eKWbo6h//9x0JEZpUdPWiz9/xkLmFHz+eqQOfQNWzEmJwudmJijdHL8Xow/DXXnTAADPvFsBq40ZRRSYTtZ14mhNB8LUAr65ghMF7kqO1uMORxmj3+y7oHBrAg+DhSHmT4dr0Nk3gKnx4Vg9NzB2R02O1mOto3D78/86r3BriPzfa4eq0dJlQkqMDrcsTFW6ObJaNDUWn89JhtUm4sl3ypVuDpEijP0D+P3HlwAAhY6afaEsQqvBgzfaJxJ+/c/zMPYPKNwiIvmYLFa8+JG9v9/9uenMLHLR/auzEaPX4Kyhy1mygSjQSCtpblmYhuRo/14R6K/+8/pshKkFfHS+Ff86y9qF7mCwMIR0myzY6oioP3jjLKhVgXOz8eCNsyAIwLunm5hdSDSObpMFz//L3s+//4XZCFMH3zD/ky/Og0YloORsM/ZxeRGFoO0fXkR77wBmJkWG5C7Io/lW3jRkTYlAc5cJT//jrNLNIZLNq5/UoL6jD6kxetyxnJlFroqL0GL95+2TCE//4ywud5sUbhGRe07Vd+Kdk40QBHvwmzwzbUoE/j/HSqufv10Ok4Wbobkq+J4iaUzPl5zH5R4zZiRGOtNxA8XslGh8abF9k4afvnmKywmIxvD0P846+/k3gvShYlZyFAqvzQIA/L+/nkArHwAohFS2dON3H14EAPy/NTnQBOGEgCd0GjWe+uoiAMCrh2rwwbkWhVtENHnNXf34H8eqmu9/YTb0YdzYxB33XJOFnNRoXO4x45E3TnGDAwoYVpuIx/7vNADgy0vSMS8tOOqPK2X952chKVqHS6092MwJRZfxDjNEfFbVht99ZH+4+Omt8wLy4eKRL85DlE6DstoOvOj4LkQ0aN/ZZrzyib1W4c+/vDAoswolRTfNxZyUKLR0mfBfu8pgtnBTAwp+JosVRX8+DpPFhlWzE7FmQYrSTfIr12QnOuuUff9Px1DV2qNwi4g8J4oifvK3k+joHcCC9JiQ3sjIUzqNGlvuWAKNSsDe0wb8dj83R6PAsOPAJXxa1Y5IrRo/vHmu0s0JeNH6MDz9NfuE4u8/rsKek40KtygwBO+TJDnVXO7Fd/94BKJo30XpC/MC8+EiNVaPDV/MAQA8vfcs9lVw+SGR5FR9J9a/dhQAcM/K6bhudqLCLfKucK0az995FfRhKhy40Iof/LmMu6BSULPZRPz3X07geG0HovUa/OLri1m7bBQb185HbmYcOvsGcNeLhxgwpIC15b0K/LO8GWFqAb/8xpKgngD0poUZsdi0dj4A+2Ynf3RMqhL5q30Vzdi815799sit85GZEKFwi4LD53NS8J3r7MuR/2tXGVcguIBXnSBX3mjEN39XitZuM+anxeCxLy9QukmT8q0V03DHsqmwicB//uEI3jnBWQGifRXNuPN3n6DHbMU12VPwyK3zlW6ST8xNjcb/3r0cYWoB75xoxL+/eAjNXf1KN4tIdr1mCx587SjeOt4AjUrAb++6Cmmx4Uo3yy/pNGr87u5lmJkYifqOPnzthYOcXKSAYrWJeGpPObbus2fBPfmVRchJ5RLEybh7ZRbuW2UPEvz0zVN44u0z6B9g3TLyP3tONuK7rxyB1Sbiq1dl4E7ugCyrh2/JwS0LU2G22nDvzk+x8+NLLE8wDo+ChVu3bkVWVhb0ej3y8vJw+PDhcY//y1/+gpycHOj1eixatAh79uzxqLHkus6+ATz3z3P48taP0dDZj5lJkfh94dWI0mmUbtqkCIKAJ76yEGsWpMBsteHB147iB7vLUHO5V+mmhRSOAf6hqrUHP/zzcRT+/lN0mSxYMSMB2+5eBq0mdOaBbpiThG3/vgyRWjUOXWrDF7Z8gG0fVKKzl7uhegv7v+9YrDb83/EG3PyrD/GPUwZo1Sr8qiAXq2YnKd00v5Yco8fu/1yJBekxuNxjRuHvP8UDrx7hBmky4RjgHaIo4vClNnzthYPOuqQP35ITtPWHfe0nX5yH/8qfDQB48cAlfPHXH+GNY3VcleAm9n/vqG3rRdGfy/DAq0dhstiQPy8Fm7/KFQRy06hV+PU3l2LtknRYbCJ+9n9ncMe2Unx8oZVBw1EIopt/K7t378a6deuwbds25OXl4bnnnsNf/vIXVFRUIDk5ecTxBw8exPXXX4/i4mJ86UtfwmuvvYann34aR48excKFC106p9FoRGxsLDo7OxETw5m1K1ltIjp6zai63IOzhi4cvHAZ+yqa0Wu2z5itnpuEXxcsRWxEmMItlY/FasMz71Xgdx9ehCgCKgHImzEFq+cmYVFGLKYnRiIpShdSQRN3edqvfD0GsP/bHyD6Bqxo7TLjQksXzjQYsb+iBUdq2iGN4PesnI6f3DoPOk1oFj8/39SFoj8fx0lHMEAfpsK12YlYmT0Fc1OjMTMpClMitSwO7xAo/X8ybQ00NpsIY/8A6tr7cK6pC59Vt+OfZ5rQ3GXfwCcjLhy/KsjFihkJCrc0cPQPWLH5H2fxcmmVc6ycmxKNa2ZNwVXT4jEjMRKZCRGI0WtC7oEsUMaAYO3/0nW9yWhChaELJ+s7UFLejLOGLgBAtE6DJ76yEF/OzVC4pcHnvdMGPPLmKbQ4xtbY8DB8PicZV02Px/y0aKTHhSMxShfUy74Dpf9Ppq3+zGyxoa3HjAvN3TjT2IkPz7XiYGUrbCIgCMB9q2bix/+WA7UqtK5LviSKIn7/cRWeebcCfY4s42kJEbhhThKWTY9HdlIUMhPCEaMPgyoI/x1c7VduBwvz8vJw9dVX4ze/+Q0AwGazITMzE9/73vfw8MMPjzi+oKAAPT09ePvtt52vfe5zn0Nubi62bds26jlMJhNMpsHdLY1GIzIzM8f9Mts+qMQ/zzRB+jJDv9bga8P/PPTFkceIVx4y4v2j/dWN+/4r3je8HRMfc+VniyLQ1W+BsX8Ao/0r5qRG48EbZ+FLi9OC9ia4rLYDv3yvAh+dbx3197HhYQgPU0MXpoJeo0aYRoAAAYIAOP9GBAGC/X+crwmO1wJVRnw4fv3NpeMe4+nF19tjgCf9v89sxd0vHRqz/4zo+xP0e1f7vCiO8p4x2jBRHx+rDQNWG9p7B8bcwGP13CR87/OzsWx6/Ki/DyVWm4i/Hq3DjgOXnA9cVwoPUyM2PAxajQoatQCt2v6/apVqRJ+/ctgcbUy4cmz1l3HjuzdkI3/+2PVp/bX/A56NAX/+tBZ//qx2WP8ar/+P1+/G6/djjhUYv8+P199FEegxWdDZNwDbKNfy+Igw3HNNFr6zambArw5QSoWhC8//6zzeO90E8yhZRCoBiNJpEK0Pg06jgkolQKMSoBIEqFX2HzmfFeS6JxvrU8K1arxyb9647/XXMcCT/g8Ad790yDlJfmX/H6vvj3bNH6vfDxsjhhwzUX+/sq8D41/XtRoVvnbVVDz0hdlIjdWP+X1pcoz9A3iltBq//7gKrd2mEb8XHGNCeJga+jA1wsPUUKsczw8CoJKeExz/qxL869nhG8sz8Y2rx85I9df+D3g2Brx/pgn/+0HlOP1v8A9jjQGjXc+B0fvwWGPAWPcg0p9sNvvqv26TZdTvcf2cJPwgfzaWTuM9va80dvZh674LeONoPXrMI0sTqAR7PCFCq0GYWoBGrYJGJSBMbb9XmOjZAXDt+WGs4zz1yr15CNeOnSDh6hjg1l2n2WzGkSNHsGHDBudrKpUK+fn5KC0tHfU9paWlKCoqGvbamjVr8Oabb455nuLiYjz22GPuNA3Vl3vwWXW7W+8JNqkxesxNjcaSqbH4/LwULJkaG7RBQkluZhxeuTcPtW292HvKgGO17ThVb0RjZx8GrCI6+wbQ2Rd6yxHbe81e+VxfjAGe9H+rKIZE/9eqVZiZFInZKdFYMSMBX8hJRnoc65ZJ1CoB31ieiTuWTcWZRiM+PNeKozXtqGzpRs3lXlhs9kyOvhCoUzTaw89k+fM9QH1HX9CMAXERYZiTEo35aTFYPTcJK7OnhGzGsFzmpkbjN9+6Ch29ZnxwrgWfVrXhdIMRtW29aO02wyYCxn4LjP2jP8AFmmgvBZX99R4AAI5Wt4/6oOfvdBoVZqdEYW5KDK6bPQU3zk1GXIRW6WYFvRh9GB68cRa+e0M2jlS344NzzThVb8T5pi40d5lgsYno6regK0DHBG9scufP9wAtXaaAuwdQCUDWlEjMTonC1VkJyJ+XgqzESKWbFXLSYsPxxO2L8PAt83DwQis+vtCK8sYuVLZ043KP/f6gvXcA7QFW3sgm05Jqt+4mWltbYbVakZIyPFshJSUFZ8+eHfU9BoNh1OMNBsOY59mwYcOwgUWaURjPXXnTccMce/qzFB8bGiaTgmaDmWNDfycdP/yNrrx/aDx5xHmHnsPxhyuPGRrMG/m7kR905THReg3iIrSIDQ8L6nT5iWQmROC+62c6/2yziejoG0Bbjwn9AzaYLFb0D9jsGQXi4OzQaDPE9r4lTwdTSqSXHhR8MQZ40v/1GhW2/fsyAGP3sRF9d4J+P1GfH62/T9TPJ+rjo71PoxIQGx6G+EgtIrXqoJ8AkIMgCFiQHosF6bHO10RRhLHfgo5eM4x9FpitNlisNgxYRQzYbLBah8wiX/F5o2WqD/5u2J9k+w6TNfS7y8Wf7wHWLknDvDT7zOjQfjRe/x+v74/X76+8NI/2u9H6/Hj9PVKnQVx4GGIjwhgY9KK4CC2+nJsxbGln/4AVxr4BGPst6Oq3Z3tZRRFWm/3HJoqwWEXZerd8JZHG/iC1yjv3g/56DwAAz31zKaw2ccz+P1bfH+2aP7TfD7suj/H6RP39ymOl63pCpBYRvK4rSq0SsGJGwrDyDjabiMs9Zhj7B9A/YH9+6B+wwmIThzwrDD5H2Ia85i9mJUfJ/pn+fA+wanbisOeAsa7no937j3nsBNf9scaA4fcPw19XCQJiwsMQHxEWtMtbA1WUToObF6Ti5gWpztdMFis6HYHCXrMFFpuIAasNFqsIi82GKxcqjLrqdJRzjT5UyDt+6GQqxeaX61l0Oh10Op1b71mYEYuFGfI/HFFgUqkEJERqkRDJ2dlA40n/16hV+LeFqRMfSCFLEOwPZ7HhwVO7NVh5MgbMSo7GrORoL7WIgpnescwwOThKYQU8T/o/ANw0TtkFIneoVAKSonVIinb/v0OaPE/GgMyECGQmRHipRRSqdBo1kmPUSI4J3ZIQboUcExMToVar0dTUNOz1pqYmpKaO/qCemprq1vFE5L84BhCFLvZ/otDGMYAodLH/E4Uet4KFWq0Wy5YtQ0lJifM1m82GkpISrFy5ctT3rFy5ctjxAPD++++PeTwR+S+OAUShi/2fKLRxDCAKXez/RCFIdNOuXbtEnU4n7ty5Uzxz5oz4H//xH2JcXJxoMBhEURTFu+++W3z44Yedx3/88ceiRqMRt2zZIpaXl4ubNm0Sw8LCxJMnT7p8zs7OThGA2NnZ6W5ziWgMnvYrX48B7P9E8guU/j+ZthLR2AJlDGD/J5JfoPT/ybSViMbmar9yu2ZhQUEBWlpasHHjRhgMBuTm5mLv3r3O4qU1NTVQDSmsfM011+C1117DT3/6U/zkJz/B7Nmz8eabb2LhwoXuBDQB2AucEpE8pP4kulmQ2ddjAPs/kfwCpf8PbSPHACL5BMoYwP5PJL9A6f9D28gxgEg+ro4BgujuKKGAuro6l3ZCIyL31dbWYurUqUo3Y0zs/0Te4+/9H+AYQORN/j4GsP8TeY+/93+AYwCRN000BgREsNBms6GhoQHR0dHOrcmvJG2rXltbi5iYwNzSjt9BeYHefsD17yCKIrq6upCenj5sFtDfuNL/gdD6t/NXgd5+IPC/Q7D1fyB0xoBAbz8Q+N8h0NsPBN8YECr9Hwj87xDo7QcC/zsEW/8HQmcMCPT2A4H/HQK9/YD8Y4Dby5CVoFKpXJ71iImJCdh/XAm/g/ICvf2Aa98hNjbWR63xnDv9Hwidfzt/FujtBwL/OwRL/wdCbwwI9PYDgf8dAr39QPCMAaHW/4HA/w6B3n4g8L9DsPR/IPTGgEBvPxD43yHQ2w/INwb491QCERERERERERER+QyDhURERERERERERAQgiIKFOp0OmzZtgk6nU7opHuN3UF6gtx8Iju/giWD43oH+HQK9/UDgf4dAb/9kBPp3D/T2A4H/HQK9/UBwfAdPBMP3DvTvEOjtBwL/OwR6+ycj0L97oLcfCPzvEOjtB+T/DgGxwQkRERERERERERF5X9BkFhIREREREREREdHkMFhIREREREREREREABgsJCIiIiIiIiIiIgcGC4mIiIiIiIiIiAgAg4VERERERERERETkEBTBwieffBLXXHMNIiIiEBcXN+oxNTU1uPXWWxEREYHk5GT893//NywWi28b6oasrCwIgjDsZ/PmzUo3a1xbt25FVlYW9Ho98vLycPjwYaWb5LKf/exnI/6+c3JylG7WuD788EOsXbsW6enpEAQBb7755rDfi6KIjRs3Ii0tDeHh4cjPz8f58+eVaayXcQzwD4E6BrD/Bzb2f/8QqP0f4BgQ6DgG+AeOAb7D/j8oGPs/EHhjAPu/b/lqDAiKYKHZbMYdd9yB+++/f9TfW61W3HrrrTCbzTh48CBefvll7Ny5Exs3bvRxS93z+OOPo7Gx0fnzve99T+kmjWn37t0oKirCpk2bcPToUSxZsgRr1qxBc3Oz0k1z2YIFC4b9fR84cEDpJo2rp6cHS5YswdatW0f9/S9+8Qv8z//8D7Zt24ZDhw4hMjISa9asQX9/v49b6n0cA5QX6GMA+3/gYv9XXqD3f4BjQCDjGKA8jgG+xf4/KFj7PxA4YwD7v+/5bAwQg8jvf/97MTY2dsTre/bsEVUqlWgwGJyvvfDCC2JMTIxoMpl82ELXTZ8+XfzVr36ldDNctmLFCvHBBx90/tlqtYrp6elicXGxgq1y3aZNm8QlS5Yo3QyPARDfeOMN559tNpuYmpoqPvPMM87XOjo6RJ1OJ/7pT39SoIW+wTFAOYE8BrD/Bwf2f+UEcv8XRY4BwYJjgHI4BiiH/d8umPq/KAbWGMD+ryxvjgFBkVk4kdLSUixatAgpKSnO19asWQOj0YjTp08r2LLxbd68GVOmTMHSpUvxzDPP+G26tNlsxpEjR5Cfn+98TaVSIT8/H6WlpQq2zD3nz59Heno6Zs6cibvuugs1NTVKN8ljly5dgsFgGPZvEhsbi7y8vID6N5ELxwDvCoYxgP0/eLH/e1cw9H+AY0Aw4xjgXRwD/Av7/3CB2v+BwBgD2P/9j5xjgEbuxvkjg8EwbIAA4PyzwWBQokkT+v73v4+rrroKCQkJOHjwIDZs2IDGxkY8++yzSjdthNbWVlit1lH/js+ePatQq9yTl5eHnTt3Yu7cuWhsbMRjjz2GVatW4dSpU4iOjla6eW6T/rse7d/EX/+b9yaOAd4V6GMA+39wY//3rkDv/wDHgGDHMcC7OAb4F/b/4QKx/wOBMwaw//sfOccAv80sfPjhh0cUmrzyJ1D+A5S4852KioqwevVqLF68GN/97nfxy1/+Es8//zxMJpPC3yI43XLLLbjjjjuwePFirFmzBnv27EFHRwf+/Oc/K920kMUxgGOAr7D/+x/2f/Z/X+IY4H84BnAM8CWOAf4lGPs/wDHAX7H/j81vMwt/+MMf4tvf/va4x8ycOdOlz0pNTR2xI09TU5Pzd74yme+Ul5cHi8WCqqoqzJ071wut81xiYiLUarXz71TS1NTk079fOcXFxWHOnDm4cOGC0k3xiPT33tTUhLS0NOfrTU1NyM3NVahV7uEYMBzHAN9h/1ce+/9w7P++xTFAeRwDhuMY4FuBPAaw/w/nL/0fCM4xgP3f/8g5BvhtsDApKQlJSUmyfNbKlSvx5JNPorm5GcnJyQCA999/HzExMZg/f74s53DFZL5TWVkZVCqVs/3+RKvVYtmyZSgpKcHtt98OALDZbCgpKcH69euVbZyHuru7UVlZibvvvlvppnhkxowZSE1NRUlJiXNQMBqNOHTo0Ji7hfkbjgHDcQzwHfZ/5bH/D8f+71scA5THMWA4jgG+FchjAPv/cP7S/4HgHAPY//2PrGOAXLuwKKm6ulo8duyY+Nhjj4lRUVHisWPHxGPHjoldXV2iKIqixWIRFy5cKN58881iWVmZuHfvXjEpKUncsGGDwi0f3cGDB8Vf/epXYllZmVhZWSn+8Y9/FJOSksR169Yp3bQx7dq1S9TpdOLOnTvFM2fOiP/xH/8hxsXFDdt5yp/98Ic/FPfv3y9eunRJ/Pjjj8X8/HwxMTFRbG5uVrppY+rq6nL+tw5AfPbZZ8Vjx46J1dXVoiiK4ubNm8W4uDjx73//u3jixAnxy1/+sjhjxgyxr69P4ZbLj2OA8gJ5DGD/D2zs/8oL5P4vihwDAh3HAOVxDPAt9v9Bwdb/RTHwxgD2f9/z1RgQFMHCe+65RwQw4mffvn3OY6qqqsRbbrlFDA8PFxMTE8Uf/vCH4sDAgHKNHseRI0fEvLw8MTY2VtTr9eK8efPEp556Suzv71e6aeN6/vnnxWnTpolarVZcsWKF+MknnyjdJJcVFBSIaWlpolarFTMyMsSCggLxwoULSjdrXPv27Rv1v/t77rlHFEX7tumPPvqomJKSIup0OvELX/iCWFFRoWyjvYRjgH8I1DGA/T+wsf/7h0Dt/6LIMSDQcQzwDxwDfIf9f1Cw9X9RDMwxgP3ft3w1BgiiKIru5SISERERERERERFRMPLb3ZCJiIiIiIiIiIjItxgsJCIiIiIiIiIiIgAMFhIREREREREREZEDg4VEREREREREREQEgMFCIiIiIiIiIiIicmCwkIiIiIiIiIiIiAAwWEhEREREREREREQODBYSERERERERERERAAYLiYiIiIiIiIiIyIHBQiIiIiIiIiIiIgLAYCERERERERERERE5/P+UQla20ppaKQAAAABJRU5ErkJggg==", 296 | "text/plain": [ 297 | "
" 298 | ] 299 | }, 300 | "metadata": {}, 301 | "output_type": "display_data" 302 | } 303 | ], 304 | "source": [ 305 | "x = jnp.linspace(-10,10,1000).reshape(-1,1)\n", 306 | "\n", 307 | "plt.figure(figsize=(16,2))\n", 308 | "n_steps = 5\n", 309 | "for i,t in enumerate(jnp.linspace(1,0,n_steps)):\n", 310 | " plt.subplot(100+10*n_steps+i+1)\n", 311 | " plt.plot(x,jnp.exp(-jax.vmap(model.E,in_axes=(None,0,None))(model.params,x,jnp.array([t]))))\n", 312 | " #plt.yticks([])" 313 | ] 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": "310", 319 | "language": "python", 320 | "name": "python3" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 3 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython3", 332 | "version": "3.10.12" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 2 337 | } 338 | --------------------------------------------------------------------------------