├── quantum_control_rl_server ├── __init__.py ├── version_helper.py ├── dynamic_episode_driver_sim_env.py ├── remote_env_tools.py ├── tf_env_wrappers.py ├── h5log.py ├── tf_env.py └── PPO.py ├── 20231212.h5 ├── examples ├── pi_pulse │ ├── 20231213.h5 │ ├── train_summaries │ │ └── events.out.tfevents.1702502804.DN-GKP-MSMT.21904.8.v2 │ ├── pi_pulse_sim_function.py │ ├── pi_pulse_sim_test.py │ ├── parse_pi_pulse_data.py │ ├── pi_pulse_client.py │ └── pi_pulse_training_server.py └── pi_pulse_oct_style │ ├── 20231213.h5 │ ├── train_summaries │ └── events.out.tfevents.1702505333.DN-GKP-MSMT.24468.7.v2 │ ├── pi_pulse_oct_style_sim_function.py │ ├── pi_pulse_oct_style_sim_test.py │ ├── pi_pulse_oct_style_client.py │ ├── parse_pi_pulse_oct_style_data.py │ └── pi_pulse_oct_style_training_server.py ├── train_summaries └── events.out.tfevents.1702427509.DN-GKP-MSMT.15044.8.v2 ├── requirements.txt ├── setup.py ├── .gitignore ├── README.md ├── qcrl-server-tf240.yml └── LICENSE /quantum_control_rl_server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /20231212.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbrock89/quantum_control_rl_server/HEAD/20231212.h5 -------------------------------------------------------------------------------- /examples/pi_pulse/20231213.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbrock89/quantum_control_rl_server/HEAD/examples/pi_pulse/20231213.h5 -------------------------------------------------------------------------------- /examples/pi_pulse_oct_style/20231213.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbrock89/quantum_control_rl_server/HEAD/examples/pi_pulse_oct_style/20231213.h5 -------------------------------------------------------------------------------- /train_summaries/events.out.tfevents.1702427509.DN-GKP-MSMT.15044.8.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbrock89/quantum_control_rl_server/HEAD/train_summaries/events.out.tfevents.1702427509.DN-GKP-MSMT.15044.8.v2 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.4.3 2 | numpy==1.18.5 3 | qutip==4.7.1 4 | tensorboard==2.11.2 5 | tensorboard_plugin_profile==2.3.0 6 | tensorflow==2.4.0 7 | tensorflow-probability==0.11.0 8 | tf-agents==0.6.0 9 | scipy==1.4.1 10 | -------------------------------------------------------------------------------- /examples/pi_pulse/train_summaries/events.out.tfevents.1702502804.DN-GKP-MSMT.21904.8.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbrock89/quantum_control_rl_server/HEAD/examples/pi_pulse/train_summaries/events.out.tfevents.1702502804.DN-GKP-MSMT.21904.8.v2 -------------------------------------------------------------------------------- /examples/pi_pulse_oct_style/train_summaries/events.out.tfevents.1702505333.DN-GKP-MSMT.24468.7.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbrock89/quantum_control_rl_server/HEAD/examples/pi_pulse_oct_style/train_summaries/events.out.tfevents.1702505333.DN-GKP-MSMT.24468.7.v2 -------------------------------------------------------------------------------- /quantum_control_rl_server/version_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compatibility shim for some versioning 3 | 4 | Created on Tue Aug 25 14:04:38 2020 5 | 6 | @author: Henry Liu 7 | """ 8 | from distutils.version import LooseVersion 9 | 10 | import tf_agents 11 | from tf_agents.policies import tf_policy 12 | 13 | if LooseVersion(tf_agents.__version__) >= "0.6": 14 | TFPolicy = tf_policy.TFPolicy 15 | else: 16 | TFPolicy = tf_policy.Base 17 | 18 | 19 | __all__ = ["TFPolicy"] 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Author: Ben Brock 2 | # Created on May 08, 2023 3 | 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="quantum_control_rl_server", 8 | version="1.0", 9 | description="Fork of v-sivak/quantum-control-rl, frozen version where the server (RL agent) and client (experiment or sim) communicate over tcpip", 10 | author="Volodymyr Sivak, Henry Liu, Ben Brock", 11 | author_email="bbrock89@gmail.com", 12 | url="https://github.com/bbrock89/quantum_control_rl_server", 13 | packages = ["quantum_control_rl_server"], 14 | requires=[], 15 | ) 16 | -------------------------------------------------------------------------------- /examples/pi_pulse_oct_style/pi_pulse_oct_style_sim_function.py: -------------------------------------------------------------------------------- 1 | # Author: Ben Brock 2 | # Created on May 03, 2023 3 | 4 | import numpy as np 5 | import qutip as qt 6 | 7 | def pi_pulse_oct_style_sim(pulse_array_real, # real pulse 8 | pulse_array_imag, # imag pulse 9 | N = 5, # number of transmon states 10 | kerr = -0.1, # kerr nonlinearity of transmon in GHz 11 | n_times = 101 # number of time-steps for the qutip simulation 12 | ): 13 | 14 | q = qt.destroy(N) 15 | psi0 = qt.fock(N,0) 16 | 17 | 18 | # frequencies in GHz, times in ns 19 | t_duration = len(pulse_array_real) 20 | ts = np.linspace(-t_duration/2,t_duration/2,n_times) 21 | 22 | H0 = 2*np.pi*kerr*(q.dag()**2)*(q**2) 23 | H1 = q.dag() 24 | H2 = q 25 | 26 | pulse = pulse_array_real + 1j*pulse_array_imag 27 | pulse_func = qt.interpolate.Cubic_Spline(ts[0], ts[-1], pulse) 28 | pulse_conj_func = qt.interpolate.Cubic_Spline(ts[0], ts[-1], pulse.conj()) 29 | 30 | H = [H0,[H1,pulse_func],[H2,pulse_conj_func]] 31 | result = qt.sesolve(H,psi0,tlist=ts) 32 | this_reward = (2*qt.expect(qt.fock_dm(N,1),result.states[-1]))-1 # return reward as pauli measurement 33 | return this_reward 34 | -------------------------------------------------------------------------------- /examples/pi_pulse/pi_pulse_sim_function.py: -------------------------------------------------------------------------------- 1 | # Author: Ben Brock 2 | # Created on May 03, 2023 3 | 4 | import numpy as np 5 | import qutip as qt 6 | 7 | def pi_pulse_sim(amp, # amp of cos^2 pulse in GHz 8 | drag, # dimensionless drag param 9 | detuning, # detuning of pi pulse in GHz 10 | t_duration = 40, # duration of pi pulse in ns 11 | N = 5, # number of transmon states 12 | kerr = -0.1, # kerr nonlinearity of transmon in GHz 13 | n_times = 101 # number of time-steps for the qutip simulation 14 | ): 15 | 16 | q = qt.destroy(N) 17 | psi0 = qt.fock(N,0) 18 | 19 | # frequencies in GHz, times in ns 20 | ts = np.linspace(-t_duration/2,t_duration/2,n_times) 21 | 22 | args = {'amp':amp, 23 | 't_duration':t_duration, 24 | 'drag':drag, 25 | 'kerr':kerr} 26 | 27 | H0 = -2*np.pi*detuning*(q.dag()*q) + 2*np.pi*kerr*(q.dag()**2)*(q**2) 28 | H1 = q.dag() 29 | H2 = q 30 | 31 | def H1_coeff(t,args): 32 | omega_pulse = (np.pi)/args['t_duration'] 33 | pulse = args['amp']*(np.cos(omega_pulse*t)**2) 34 | pulse = pulse + 1j*(args['drag']/(2*args['kerr']))*args['amp']*omega_pulse*np.sin(2*omega_pulse*t) 35 | return pulse 36 | 37 | def H2_coeff(t,args): 38 | omega_pulse = (np.pi)/args['t_duration'] 39 | pulse = args['amp']*(np.cos(omega_pulse*t)**2) 40 | pulse = pulse - 1j*(args['drag']/(2*args['kerr']))*args['amp']*omega_pulse*np.sin(2*omega_pulse*t) 41 | return pulse 42 | 43 | H = [H0,[H1,H1_coeff],[H2,H2_coeff]] 44 | result = qt.sesolve(H,psi0,tlist=ts,args=args) 45 | this_reward = (2*qt.expect(qt.fock_dm(N,1),result.states[-1]))-1 # return reward as pauli measurement 46 | return this_reward 47 | -------------------------------------------------------------------------------- /examples/pi_pulse_oct_style/pi_pulse_oct_style_sim_test.py: -------------------------------------------------------------------------------- 1 | # Author: Ben Brock 2 | # Created on May 03, 2023 3 | #%% 4 | import qutip as qt 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | N = 5 # number of transmon levels 10 | n_times = 101 11 | t_duration = 40 12 | q = qt.destroy(N) 13 | psi0 = qt.fock(N,0) 14 | 15 | # frequencies in GHz, times in ns 16 | kerr = -0.1 17 | ts = np.linspace(-t_duration/2,t_duration/2,n_times) 18 | 19 | omega_pulse = (np.pi)/t_duration 20 | pulse_array_real = ((np.pi)/t_duration)*(np.cos(omega_pulse*ts)**2) 21 | pulse_array_imag = np.zeros_like(pulse_array_real) 22 | 23 | H0 = 2*np.pi*kerr*(q.dag()**2)*(q**2) 24 | H1 = q.dag() 25 | H2 = q 26 | 27 | pulse = pulse_array_real + 1j*pulse_array_imag 28 | pulse_func = qt.interpolate.Cubic_Spline(ts[0], ts[-1], pulse) 29 | pulse_conj_func = qt.interpolate.Cubic_Spline(ts[0], ts[-1], pulse.conj()) 30 | 31 | fig,ax = plt.subplots(1,1) 32 | ax.plot(ts,pulse_func(ts).real*1e3,label='real') 33 | ax.plot(ts,pulse_func(ts).imag*1e3,label='imag') 34 | ax.set_xlabel('t (ns)') 35 | ax.set_ylabel('pulse amp (MHz)') 36 | ax.legend() 37 | plt.show() 38 | 39 | 40 | H = [H0,[H1,pulse_func],[H2,pulse_conj_func]] 41 | 42 | result = qt.sesolve(H,psi0,tlist=ts) 43 | 44 | 45 | probs = np.zeros((N,n_times),dtype=float) 46 | fig,ax = plt.subplots(1,1) 47 | for ii in range(N): 48 | for jj in range(n_times): 49 | probs[ii,jj] = qt.expect(qt.fock_dm(N,ii),result.states[jj]) 50 | 51 | ax.plot(ts,probs[ii],label=str(ii)) 52 | ax.set_xlabel('t (ns)') 53 | ax.set_ylabel('P(n)') 54 | ax.legend() 55 | plt.show() 56 | 57 | fig,ax = plt.subplots(1,1) 58 | ax.plot(ts,probs[2]) 59 | ax.set_xlabel('t (ns)') 60 | ax.set_ylabel('P(2)') 61 | ax.set_yscale('log') 62 | plt.show() 63 | 64 | fig,ax = plt.subplots(1,1) 65 | ax.plot(ts,1-probs[1]) 66 | ax.set_xlabel('t (ns)') 67 | ax.set_ylabel('Infidelity') 68 | ax.set_yscale('log') 69 | plt.show() 70 | 71 | # %% 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # Mac OS X 104 | .DS_Store 105 | 106 | # Visual Studio Code 107 | .vscode 108 | 109 | ################## 110 | # Project Specific 111 | ################## 112 | 113 | # Don't commit benchmark output plots 114 | benchmark/results/ 115 | 116 | # Tensorboard logdir 117 | logdir/ 118 | 119 | # Output zips 120 | *.zip 121 | 122 | .gitconfig 123 | 124 | 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model-Free Quantum Control with Reinforcement Learning 2 | Fork of [v-sivak/quantum-control-rl](https://github.com/v-sivak/quantum-control-rl); in this version the server (RL agent) and client (experiment or sim) communicate over tcpip. This may be a bottleneck for some applications, but it's very convenient for others. 3 | 4 | This code was used in the following publications: 5 | 6 | [**Model-Free Quantum Control with Reinforcement Learning**; Phys. Rev. X 12, 011059 (2022)](https://journals.aps.org/prx/abstract/10.1103/PhysRevX.12.011059) 7 | 8 | [**Real-time quantum error correction beyond break-even**; Nature 616, 50–55 (2023)](https://www.nature.com/articles/s41586-023-05782-6) 9 | 10 | [**High-fidelity, frequency-flexible two-qubit fluxonium gates with a transmon coupler**; arXiv:2304.06087 (2023)](https://arxiv.org/abs/2304.06087) 11 | 12 | ## Requirements 13 | Requires a variety of packages, but for ease of use one should create the conda environment defined in qcrl-server-tf240.yml. See the installation section for more details. The included qcrl-server environment uses tensorflow v2.4.0 and tf-agents v0.6.0, which has been tested to work with CUDA v11.0 and cudnn v8.0.5 for GPU acceleration. Without CUDA set up, this package will still work using the CPU, but this may limit performance depending on the application. 14 | 15 | ## Installation 16 | To install this package, first clone this repository. This package should be used with the conda environment defined in qcrl-server-tf240.yml. To create this environment from the file, open an anaconda cmd prompt, cd into the repo directory, and run: 17 | ```sh 18 | conda env create -f qcrl-server-tf240.yml 19 | ``` 20 | To install this package into this conda environment qcrl-server, first activate the environment using 21 | ```sh 22 | conda activate qcrl-server 23 | ``` 24 | then cd into the repo directory and run: 25 | ```sh 26 | pip install -e . 27 | ``` 28 | 29 | ## CUDA Compatibility 30 | 31 | The qcrl-server conda environment has been tested to work with CUDA v11.0 and cudnn v8.0.5. 32 | 33 | ## Running the examples 34 | 35 | Open two consoles, activate qcrl-server in both, and cd into the directory of the example you want to run in both (pi_pulse or pi_pulse_oct_style). In one console run *_training_server.py, and in the other run *_client.py. 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /examples/pi_pulse/pi_pulse_sim_test.py: -------------------------------------------------------------------------------- 1 | # Author: Ben Brock 2 | # Created on May 03, 2023 3 | #%% 4 | import qutip as qt 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from pi_pulse_sim_function import pi_pulse_sim 9 | 10 | N = 5 # number of transmon levels 11 | n_times = 101 12 | t_duration = 40 13 | q = qt.destroy(N) 14 | psi0 = qt.fock(N,0) 15 | 16 | # frequencies in GHz, times in ns 17 | kerr = -0.1 18 | ts = np.linspace(-t_duration/2,t_duration/2,n_times) 19 | 20 | amp = (np.pi)/t_duration 21 | drag = 0.5 22 | detuning = 0.0 23 | 24 | args = {'amp':amp, 25 | 't_duration':t_duration, 26 | 'drag':drag, 27 | 'kerr':kerr} 28 | 29 | H0 = -2*np.pi*detuning*(q.dag()*q) + 2*np.pi*kerr*(q.dag()**2)*(q**2) 30 | H1 = q 31 | H2 = q.dag() 32 | 33 | e_ops = [q.dag()*q] 34 | 35 | def H1_coeff(t,args): 36 | omega_pulse = (np.pi)/args['t_duration'] 37 | pulse = args['amp']*(np.cos(omega_pulse*t)**2) 38 | pulse = pulse + 1j*(args['drag']/(2*args['kerr']))*args['amp']*omega_pulse*np.sin(2*omega_pulse*t) 39 | return pulse 40 | 41 | def H2_coeff(t,args): 42 | omega_pulse = (np.pi)/args['t_duration'] 43 | pulse = args['amp']*(np.cos(omega_pulse*t)**2) 44 | pulse = pulse - 1j*(args['drag']/(2*args['kerr']))*args['amp']*omega_pulse*np.sin(2*omega_pulse*t) 45 | return pulse 46 | 47 | 48 | fig,ax = plt.subplots(1,1) 49 | ax.plot(ts,np.real(H1_coeff(ts,args))*1e3,label='real') 50 | ax.plot(ts,np.imag(H1_coeff(ts,args))*1e3,label='imag') 51 | ax.set_xlabel('t (ns)') 52 | ax.set_ylabel('pulse amp (MHz)') 53 | ax.legend() 54 | plt.show() 55 | 56 | 57 | H = [H0,[H1,H1_coeff],[H2,H2_coeff]] 58 | 59 | result = qt.sesolve(H,psi0,tlist=ts,args=args) 60 | 61 | 62 | 63 | probs = np.zeros((N,n_times),dtype=float) 64 | fig,ax = plt.subplots(1,1) 65 | for ii in range(N): 66 | for jj in range(n_times): 67 | probs[ii,jj] = qt.expect(qt.fock_dm(N,ii),result.states[jj]) 68 | 69 | ax.plot(ts,probs[ii],label=str(ii)) 70 | ax.set_xlabel('t (ns)') 71 | ax.set_ylabel('P(n)') 72 | ax.legend() 73 | plt.show() 74 | 75 | fig,ax = plt.subplots(1,1) 76 | ax.plot(ts,probs[2]) 77 | ax.set_xlabel('t (ns)') 78 | ax.set_ylabel('P(2)') 79 | ax.set_yscale('log') 80 | plt.show() 81 | 82 | fig,ax = plt.subplots(1,1) 83 | ax.plot(ts,1-probs[1]) 84 | ax.set_xlabel('t (ns)') 85 | ax.set_ylabel('Infidelity') 86 | ax.set_yscale('log') 87 | plt.show() 88 | 89 | # %% 90 | -------------------------------------------------------------------------------- /quantum_control_rl_server/dynamic_episode_driver_sim_env.py: -------------------------------------------------------------------------------- 1 | from tf_agents.drivers import dynamic_episode_driver 2 | from quantum_control_rl_server.version_helper import TFPolicy 3 | from quantum_control_rl_server import tf_env_wrappers as wrappers 4 | from quantum_control_rl_server.tf_env import TFEnvironmentQuantumControl 5 | 6 | 7 | class PolicyPlaceholder(TFPolicy): 8 | pass 9 | 10 | class DynamicEpisodeDriverSimEnv(dynamic_episode_driver.DynamicEpisodeDriver): 11 | """ 12 | This driver is a simple wrapper of the standard DynamicEpisodeDriver from 13 | tf-agents. It initializes a simulated environment from which data will be 14 | collected according to the agent's policy. 15 | 16 | """ 17 | def __init__(self, env_kwargs, reward_kwargs, batch_size, 18 | action_script, action_scale, action_spec, to_learn, 19 | learn_residuals=False, remote=False): 20 | """ 21 | Args: 22 | env_kwargs (dict): optional parameters for training environment. 23 | reward_kwargs (dict): optional parameters for reward function. 24 | batch_size (int): number of episodes collected in parallel. 25 | action_script (str): name of action script. Action wrapper will 26 | select actions from this script if they are not learned. 27 | action_scale (dict, str:float): dictionary mapping action dimensions 28 | to scaling factors. Action wrapper will rescale actions produced 29 | by the agent's neural net policy by these factors. 30 | to_learn (dict, str:bool): dictionary mapping action dimensions to 31 | bool flags. Specifies if the action should be learned or scripted. 32 | learn_residuals (bool): flag to learn residual over the scripted 33 | protocol. If False, will learn actions from scratch. If True, 34 | will learn a residual to be added to scripted protocol. 35 | remote (bool): flag for remote environment to close the connection 36 | to a client upon finishing the training. 37 | """ 38 | self.remote = remote 39 | # Create training env and wrap it 40 | env = TFEnvironmentQuantumControl( 41 | action_spec=action_spec, 42 | batch_size=batch_size, 43 | reward_kwargs=reward_kwargs, 44 | **env_kwargs) 45 | 46 | env = wrappers.ActionWrapper(env, action_script, action_scale, to_learn, 47 | learn_residuals=learn_residuals) 48 | 49 | # create dummy placeholder policy to initialize parent class 50 | dummy_policy = PolicyPlaceholder(env.time_step_spec(), env.action_spec()) 51 | super().__init__(env, dummy_policy, num_episodes=batch_size) 52 | 53 | def setup(self, policy, observers): 54 | """Setup policy and observers for the driver.""" 55 | self._policy = policy 56 | self._observers = observers or [] 57 | 58 | def finish_training(self): 59 | if self.remote: 60 | self.env.server_socket.disconnect_client() 61 | 62 | def observation_spec(self): 63 | return self.env.observation_spec() 64 | 65 | def action_spec(self): 66 | return self.env.action_spec() 67 | 68 | def time_step_spec(self): 69 | return self.env.time_step_spec() 70 | -------------------------------------------------------------------------------- /quantum_control_rl_server/remote_env_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Jan 1 19:22:02 2021 5 | 6 | @author: Vladimir Sivak 7 | """ 8 | import pickle 9 | import socket 10 | import sys 11 | 12 | import logging 13 | logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 14 | 15 | class PickleSocket(socket.socket): 16 | """ This is a simple python socket to send pickled data over TCP/IP.""" 17 | pickle_protocol = 2 # if client is in py2 environment 18 | HEADERSIZE = 10 19 | 20 | def __init__(self): 21 | super(PickleSocket, self).__init__(socket.AF_INET, socket.SOCK_STREAM) 22 | self.py = sys.version[0] 23 | 24 | def recv_data(self, connection): 25 | """ 26 | Returns: 27 | data: decoded python object received from the connection 28 | done (bool): flag that this was the last message 29 | """ 30 | full_msg = b'' 31 | new_msg = True 32 | msg_ended = False 33 | while not msg_ended: 34 | msg = connection.recv(16) 35 | if new_msg: 36 | if msg == b'': return (None, True) 37 | msglen = int(msg[:self.HEADERSIZE]) 38 | logging.info('New msg len: %d' % msglen) 39 | new_msg = False 40 | full_msg += msg 41 | if len(full_msg) - self.HEADERSIZE == msglen: 42 | logging.info('Full msg recieved') 43 | if self.py=='2': 44 | data = pickle.loads(full_msg[self.HEADERSIZE:]) 45 | else: 46 | data = pickle.loads(full_msg[self.HEADERSIZE:], encoding="latin1") 47 | msg_ended = True 48 | return (data, False) 49 | 50 | def send_data(self, data, connection): 51 | msg = pickle.dumps(data, protocol=self.pickle_protocol) 52 | header_str = str(len(msg)).zfill(self.HEADERSIZE) 53 | # there is a slight difference for py2 vs py3 54 | header = header_str if self.py=='2' else bytes(header_str, 'utf-8') 55 | msg = header + msg 56 | connection.send(msg) 57 | 58 | 59 | class Server(PickleSocket): 60 | """ A server for a single client. 61 | 62 | Intended use: server is on the RL agent side, client is on the environment 63 | side. Server sends actions to client, client sends rewards back to server. 64 | """ 65 | def send_data(self, data): 66 | super(Server, self).send_data(data, self.client_socket) 67 | 68 | def recv_data(self): 69 | return super(Server, self).recv_data(self.client_socket) 70 | 71 | def connect_client(self): 72 | self.listen(1) 73 | self.client_socket, self.client_address = self.accept() 74 | logging.warning('Connection with: ' + str(self.client_address)) 75 | 76 | def disconnect_client(self): 77 | self.client_socket.close() 78 | 79 | 80 | class Client(PickleSocket): 81 | """ A simple client. 82 | 83 | Intended use: server is on the RL agent side, client is on the environment 84 | side. Server sends actions to client, client sends rewards back to server. 85 | """ 86 | def send_data(self, data): 87 | super(Client, self).send_data(data, self) 88 | 89 | def recv_data(self): 90 | data, done = super(Client, self).recv_data(self) 91 | if done: self.close() 92 | return data, done 93 | 94 | -------------------------------------------------------------------------------- /examples/pi_pulse/parse_pi_pulse_data.py: -------------------------------------------------------------------------------- 1 | # Author: Ben Brock 2 | # Created on May 03, 2023 3 | 4 | #%% 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import h5py 8 | import os 9 | 10 | root_dir = os.getcwd() 11 | filename = root_dir+r'\20230504.h5' 12 | run = '0' 13 | 14 | f = h5py.File(filename,'r') 15 | 16 | rl_params = {key:item 17 | for key,item in f[run]['rl_params'].attrs.items() 18 | } 19 | action_scale = {key:item 20 | for key,item in f[run]['rl_params']['action_scale'].attrs.items() 21 | } 22 | action_script = {key:item 23 | for key,item in f[run]['rl_params']['action_script'].attrs.items() 24 | } 25 | to_learn = {key:item 26 | for key,item in f[run]['rl_params']['to_learn'].attrs.items() 27 | } 28 | 29 | training_actions = {key:item[()] 30 | for key,item in f[run]['training']['actions'].items() 31 | } 32 | training_rewards = f[run]['training']['rewards'][()] 33 | 34 | evaluation_actions = {key:item[()] 35 | for key,item in f[run]['evaluation']['actions'].items() 36 | } 37 | evaluation_rewards = f[run]['evaluation']['rewards'][()] 38 | 39 | f.close() 40 | 41 | # %% 42 | 43 | epochs = np.arange(rl_params['num_epochs']) 44 | 45 | infidelity = (1-training_rewards)/2.0 46 | mean_infidelity = np.mean(infidelity,axis=1) 47 | stdev_infidelity = np.std(infidelity,axis=1) 48 | min_infidelity = np.amin(infidelity,axis=1) 49 | max_infidelity = np.amax(infidelity,axis=1) 50 | fig,ax = plt.subplots(1,1) 51 | ax.plot(epochs,mean_infidelity) 52 | ax.fill_between(epochs, 53 | min_infidelity, 54 | max_infidelity, 55 | alpha = 0.5) 56 | ax.set_xlabel('Epoch') 57 | ax.set_ylabel('1-P(e)') 58 | ax.set_yscale('log') 59 | plt.show() 60 | 61 | # %% 62 | 63 | 64 | 65 | mean_amp = np.mean(training_actions['amp'],axis=1) 66 | stdev_amp = np.std(training_actions['amp'],axis=1) 67 | min_amp = np.amin(training_actions['amp'],axis=1) 68 | max_amp = np.amax(training_actions['amp'],axis=1) 69 | 70 | mean_drag = np.mean(training_actions['drag'],axis=1) 71 | stdev_drag = np.std(training_actions['drag'],axis=1) 72 | min_drag = np.amin(training_actions['drag'],axis=1) 73 | max_drag = np.amax(training_actions['drag'],axis=1) 74 | 75 | mean_detuning = np.mean(training_actions['detuning'],axis=1)*1e3 76 | stdev_detuning = np.std(training_actions['detuning'],axis=1)*1e3 77 | min_detuning = np.amin(training_actions['detuning'],axis=1)*1e3 78 | max_detuning = np.amax(training_actions['detuning'],axis=1)*1e3 79 | 80 | fig,axarr = plt.subplots(3,1,figsize=(8,8)) 81 | 82 | axarr[0].plot(epochs,mean_amp) 83 | axarr[0].fill_between(epochs, 84 | mean_amp-stdev_amp, 85 | mean_amp+stdev_amp, 86 | alpha=0.5) 87 | axarr[0].set_ylabel('Amp') 88 | 89 | axarr[1].plot(epochs,mean_drag) 90 | axarr[1].fill_between(epochs, 91 | mean_drag-stdev_drag, 92 | mean_drag+stdev_drag, 93 | alpha=0.5) 94 | axarr[1].set_ylabel('Drag') 95 | 96 | axarr[2].plot(epochs,mean_detuning) 97 | axarr[2].fill_between(epochs, 98 | mean_detuning-stdev_detuning, 99 | mean_detuning+stdev_detuning, 100 | alpha=0.5) 101 | axarr[2].set_ylabel('Detuning (MHz)') 102 | axarr[2].set_xlabel('Epoch') 103 | 104 | plt.tight_layout() 105 | plt.show() 106 | # %% 107 | -------------------------------------------------------------------------------- /quantum_control_rl_server/tf_env_wrappers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tf_agents import specs 3 | from tf_agents.utils import common, nest_utils 4 | from tf_agents.environments.tf_wrappers import TFEnvironmentBaseWrapper 5 | 6 | 7 | class ActionWrapper(TFEnvironmentBaseWrapper): 8 | """ 9 | Wrapper produces a dictionary with action components such as 'alpha', 10 | 'beta', 'epsilon', 'phi' as dictionary keys. Some action components are 11 | taken from the action script provided at initialization, and some are 12 | taken from the input action produced by the agent. Parameter 'to_learn' 13 | controls which action components are to be learned. 14 | 15 | """ 16 | def __init__(self, env, action_script, scale, to_learn, 17 | learn_residuals=False): 18 | """ 19 | Args: 20 | env: GKP environmen 21 | action_script: module or class with attributes corresponding to 22 | action components such as 'alpha', 'phi' etc 23 | scale: dictionary of scaling factors for action components 24 | to_learn: dictionary of bool values for action components 25 | learn_residuals (bool): flag to learn residual over the scripted 26 | protocol. If False, will learn actions from scratch. If True, 27 | will learn a residual to be added to scripted protocol. 28 | 29 | """ 30 | super(ActionWrapper, self).__init__(env) 31 | 32 | self.scale = scale 33 | self.to_learn = to_learn 34 | self.learn_residuals = learn_residuals 35 | 36 | # load the script of actions and convert to tensors 37 | self.script = action_script 38 | for a, val in self.script.items(): 39 | self.script[a] = tf.constant(val, dtype=tf.float32) 40 | 41 | self._action_spec = {a : specs.BoundedTensorSpec( 42 | shape = C.shape[1:], dtype=tf.float32, minimum=-1, maximum=1) 43 | for a, C in self.script.items() if self.to_learn[a]} 44 | 45 | def wrap(self, input_action): 46 | """ 47 | Args: 48 | input_action (dict): nested tensor action produced by the neural 49 | net. Dictionary keys are those marked True 50 | in 'to_learn'. 51 | 52 | Returns: 53 | actions (dict): nested tensor action which includes all action 54 | components expected by the GKP class. 55 | 56 | """ 57 | # step counter to follow the script of periodicity 'period' 58 | i = self._env._elapsed_steps % self._env.T 59 | out_shape = nest_utils.get_outer_shape(input_action, self._action_spec) 60 | 61 | action = {} 62 | for a in self.to_learn.keys(): 63 | if not self.to_learn[a]: # if not learning: replicate scripted action 64 | action[a] = common.replicate(self.script[a][i], out_shape) 65 | else: # if learning: rescale input tensor 66 | action[a] = input_action[a]*self.scale[a] 67 | if self.learn_residuals: 68 | action[a] += common.replicate(self.script[a][i], out_shape) 69 | 70 | return action 71 | 72 | def action_spec(self): 73 | return self._action_spec 74 | 75 | def _step(self, action): 76 | """ 77 | Take the nested tensor 'action' produced by the neural net and wrap it 78 | into dictionary format expected by the environment. 79 | 80 | Residual feedback learning trick: multiply the neural net prediction 81 | of 'alpha' by the measurement outcome of the last time step. This 82 | ensures that the Markovian part of the feedback is present, and the 83 | agent can focus its efforts on learning residual part. 84 | 85 | """ 86 | action = self.wrap(action) 87 | return self._env.step(action) 88 | -------------------------------------------------------------------------------- /examples/pi_pulse/pi_pulse_client.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Author: Ben Brock 4 | # Created on May 02, 2023 5 | 6 | #%% 7 | import qutip as qt 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | import os 12 | import sys 13 | # add quantum-control-rl dir to path for subsequent imports 14 | #sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) 15 | 16 | import logging 17 | import time 18 | logger = logging.getLogger('RL') 19 | logger.propagate = False 20 | logger.handlers = [] 21 | stream_handler = logging.StreamHandler(sys.stdout) 22 | logger.addHandler(stream_handler) 23 | logger.setLevel(logging.INFO) 24 | 25 | from quantum_control_rl_server.remote_env_tools import Client 26 | 27 | from pi_pulse_sim_function import pi_pulse_sim 28 | 29 | client_socket = Client() 30 | (host, port) = '127.0.0.1', 5555 # ip address of RL server, here it's hosted locally 31 | client_socket.connect((host, port)) 32 | 33 | # training loop 34 | done = False 35 | while not done: 36 | 37 | # receive action data from the agent (see tf_env -> reward_remote()) 38 | message, done = client_socket.recv_data() 39 | logger.info('Received message from RL agent server.') 40 | logger.info('Time stamp: %f' %time.time()) 41 | 42 | if done: 43 | logger.info('Training finished.') 44 | break 45 | 46 | # parsing message (see tf_env -> reward_remote()) 47 | epoch_type = message['epoch_type'] 48 | 49 | if epoch_type == 'final': 50 | logger.info('Final Epoch') 51 | 52 | locs = message['locs'] 53 | scales = message['scales'] 54 | for key in locs.keys(): 55 | logger.info('locs['+str(key)+']:') 56 | logger.info(locs[key][0]) 57 | logger.info('scales['+str(key)+']:') 58 | logger.info(scales[key][0]) 59 | 60 | # After using the locs and scales, terminate training 61 | done = True 62 | logger.info('Training finished.') 63 | break 64 | 65 | action_batch = message['action_batch'] 66 | batch_size = message['batch_size'] 67 | epoch = message['epoch'] 68 | 69 | # parsing action_batch and reshaping to get rid of nested 70 | # structure [[float_param]] required by tensorflow 71 | amplitudes = action_batch['amp'].reshape([batch_size]) 72 | drags = action_batch['drag'].reshape([batch_size]) 73 | detunings = action_batch['detuning'].reshape([batch_size]) 74 | 75 | logger.info('Start %s epoch %d' %(epoch_type, epoch)) 76 | 77 | # collecting rewards for each policy in the batch 78 | reward_data = np.zeros((batch_size)) 79 | for ii in range(batch_size): 80 | 81 | # evaluating reward for ii'th element of the batch 82 | # - can perform different operations depending on the epoch type 83 | # for example, using more averaging for eval epochs 84 | if epoch_type == 'evaluation': 85 | reward_data[ii] = pi_pulse_sim(amplitudes[ii],drags[ii],detunings[ii]) 86 | elif epoch_type == 'training': 87 | reward_data[ii] = pi_pulse_sim(amplitudes[ii],drags[ii],detunings[ii]) 88 | # elif epoch_type == 'final': 89 | # print('Got to final epoch!') 90 | # print('amplitudes:') 91 | # print(amplitudes) 92 | # print('drags:') 93 | # print(drags) 94 | # print('detunings:') 95 | # print(detunings) 96 | # done = True 97 | # logger.info('Training finished.') 98 | # break 99 | 100 | # Print mean and stdev of reward for monitoring progress 101 | R = np.mean(reward_data) 102 | std_R = np.std(reward_data) 103 | logger.info('Average reward %.3f' %R) 104 | logger.info('STDev reward %.3f' %std_R) 105 | 106 | # send reward data back to server (see tf_env -> reward_remote()) 107 | logger.info('Sending message to RL agent server.') 108 | logger.info('Time stamp: %f' %time.time()) 109 | client_socket.send_data(reward_data) 110 | 111 | 112 | 113 | # %% 114 | -------------------------------------------------------------------------------- /examples/pi_pulse_oct_style/pi_pulse_oct_style_client.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Author: Ben Brock 4 | # Created on May 02, 2023 5 | 6 | #%% 7 | import qutip as qt 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | import os 12 | import sys 13 | # add quantum-control-rl dir to path for subsequent imports 14 | #sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) 15 | 16 | import logging 17 | import time 18 | logger = logging.getLogger('RL') 19 | logger.propagate = False 20 | logger.handlers = [] 21 | stream_handler = logging.StreamHandler(sys.stdout) 22 | logger.addHandler(stream_handler) 23 | logger.setLevel(logging.INFO) 24 | 25 | from quantum_control_rl_server.remote_env_tools import Client 26 | 27 | from pi_pulse_oct_style_sim_function import pi_pulse_oct_style_sim 28 | 29 | client_socket = Client() 30 | (host, port) = '127.0.0.1', 5555 # ip address of RL server, here it's hosted locally 31 | client_socket.connect((host, port)) 32 | 33 | # training loop 34 | done = False 35 | while not done: 36 | 37 | # receive action data from the agent (see tf_env -> reward_remote()) 38 | message, done = client_socket.recv_data() 39 | logger.info('Received message from RL agent server.') 40 | logger.info('Time stamp: %f' %time.time()) 41 | 42 | if done: 43 | logger.info('Training finished.') 44 | break 45 | 46 | # parsing message (see tf_env -> reward_remote()) 47 | epoch_type = message['epoch_type'] 48 | 49 | if epoch_type == 'final': 50 | logger.info('Final Epoch') 51 | 52 | locs = message['locs'] 53 | scales = message['scales'] 54 | for key in locs.keys(): 55 | logger.info('locs['+str(key)+']:') 56 | logger.info(locs[key][0]) 57 | logger.info('scales['+str(key)+']:') 58 | logger.info(scales[key][0]) 59 | 60 | # After using the locs and scales, terminate training 61 | done = True 62 | logger.info('Training finished.') 63 | break 64 | 65 | action_batch = message['action_batch'] 66 | batch_size = message['batch_size'] 67 | epoch = message['epoch'] 68 | 69 | # parsing action_batch and reshaping to get rid of nested 70 | # structure required by tensorflow 71 | # here env.T=1 so the shape is (batch_size,1,pulse_len) 72 | new_shape_list_real = list(action_batch['pulse_array_real'].shape) 73 | new_shape_list_imag = list(action_batch['pulse_array_imag'].shape) 74 | new_shape_list_real.pop(1) 75 | new_shape_list_imag.pop(1) 76 | real_pulses = action_batch['pulse_array_real'].reshape(new_shape_list_real) 77 | imag_pulses = action_batch['pulse_array_imag'].reshape(new_shape_list_imag) 78 | 79 | 80 | logger.info('Start %s epoch %d' %(epoch_type, epoch)) 81 | 82 | 83 | # collecting rewards for each policy in the batch 84 | reward_data = np.zeros((batch_size)) 85 | for ii in range(batch_size): 86 | 87 | # evaluating reward for ii'th element of the batch 88 | # - can perform different operations depending on the epoch type 89 | # for example, using more averaging for eval epochs 90 | if epoch_type == 'evaluation': 91 | reward_data[ii] = pi_pulse_oct_style_sim(real_pulses[ii], 92 | imag_pulses[ii]) 93 | elif epoch_type == 'training': 94 | reward_data[ii] = pi_pulse_oct_style_sim(real_pulses[ii], 95 | imag_pulses[ii]) 96 | 97 | # Print mean and stdev of reward for monitoring progress 98 | R = np.mean(reward_data) 99 | std_R = np.std(reward_data) 100 | logger.info('Average reward %.3f' %R) 101 | logger.info('STDev reward %.3f' %std_R) 102 | 103 | # send reward data back to server (see tf_env -> reward_remote()) 104 | logger.info('Sending message to RL agent server.') 105 | logger.info('Time stamp: %f' %time.time()) 106 | client_socket.send_data(reward_data) 107 | 108 | 109 | 110 | # %% 111 | -------------------------------------------------------------------------------- /examples/pi_pulse_oct_style/parse_pi_pulse_oct_style_data.py: -------------------------------------------------------------------------------- 1 | # Author: Ben Brock 2 | # Created on May 03, 2023 3 | 4 | #%% 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import h5py 8 | import os 9 | 10 | root_dir = os.getcwd() 11 | filename = root_dir+r'\20230504.h5' 12 | run = '0' 13 | 14 | f = h5py.File(filename,'r') 15 | 16 | rl_params = {key:item 17 | for key,item in f[run]['rl_params'].attrs.items() 18 | } 19 | action_scale = {key:item 20 | for key,item in f[run]['rl_params']['action_scale'].attrs.items() 21 | } 22 | action_script = {key:item 23 | for key,item in f[run]['rl_params']['action_script'].attrs.items() 24 | } 25 | to_learn = {key:item 26 | for key,item in f[run]['rl_params']['to_learn'].attrs.items() 27 | } 28 | 29 | training_actions = {key:item[()] 30 | for key,item in f[run]['training']['actions'].items() 31 | } 32 | training_rewards = f[run]['training']['rewards'][()] 33 | 34 | evaluation_actions = {key:item[()] 35 | for key,item in f[run]['evaluation']['actions'].items() 36 | } 37 | evaluation_rewards = f[run]['evaluation']['rewards'][()] 38 | 39 | f.close() 40 | 41 | # %% 42 | 43 | epochs = np.arange(rl_params['num_epochs']) 44 | 45 | infidelity = (1-training_rewards)/2.0 46 | mean_infidelity = np.mean(infidelity,axis=1) 47 | stdev_infidelity = np.std(infidelity,axis=1) 48 | min_infidelity = np.amin(infidelity,axis=1) 49 | max_infidelity = np.amax(infidelity,axis=1) 50 | fig,ax = plt.subplots(1,1) 51 | ax.plot(epochs,mean_infidelity) 52 | ax.fill_between(epochs, 53 | mean_infidelity-min_infidelity, 54 | mean_infidelity+max_infidelity, 55 | alpha = 0.5) 56 | ax.set_xlabel('Epoch') 57 | ax.set_ylabel('Infidelity') 58 | ax.set_yscale('log') 59 | plt.show() 60 | 61 | # %% 62 | 63 | mean_real_pulses = np.mean(training_actions['pulse_array_real'],axis=1) 64 | stdev_real_pulses = np.std(training_actions['pulse_array_real'],axis=1) 65 | min_real_pulses = np.amin(training_actions['pulse_array_real'],axis=1) 66 | max_real_pulses = np.amax(training_actions['pulse_array_real'],axis=1) 67 | 68 | mean_imag_pulses = np.mean(training_actions['pulse_array_imag'],axis=1) 69 | stdev_imag_pulses = np.std(training_actions['pulse_array_imag'],axis=1) 70 | min_imag_pulses = np.amin(training_actions['pulse_array_imag'],axis=1) 71 | max_imag_pulses = np.amax(training_actions['pulse_array_imag'],axis=1) 72 | 73 | #%% 74 | 75 | ts = np.arange(mean_real_pulses.shape[-1]) 76 | n_plots = 50 77 | ind_factor = rl_params['num_epochs']//n_plots 78 | for ii in range(n_plots): 79 | 80 | fig,axarr = plt.subplots(2,1,figsize=(8,8)) 81 | 82 | axarr[0].plot(ts,mean_real_pulses[ind_factor*ii]) 83 | axarr[0].fill_between(ts, 84 | mean_real_pulses[ind_factor*ii]-stdev_real_pulses[ind_factor*ii], 85 | mean_real_pulses[ind_factor*ii]+stdev_real_pulses[ind_factor*ii], 86 | alpha=0.5) 87 | axarr[0].set_title('Epoch = '+str(ind_factor*ii)) 88 | axarr[0].set_ylabel('Real Pulse') 89 | 90 | axarr[1].plot(ts,mean_imag_pulses[ind_factor*ii]) 91 | axarr[1].fill_between(ts, 92 | mean_imag_pulses[ind_factor*ii]-stdev_imag_pulses[ind_factor*ii], 93 | mean_imag_pulses[ind_factor*ii]+stdev_imag_pulses[ind_factor*ii], 94 | alpha=0.5) 95 | axarr[1].set_ylabel('Imag. Pulse') 96 | axarr[1].set_xlabel('Time (ns)') 97 | 98 | plt.tight_layout() 99 | plt.show() 100 | # %% 101 | 102 | fig,axarr = plt.subplots(2,1,figsize=(8,8)) 103 | 104 | axarr[0].plot(ts,mean_real_pulses[-1]) 105 | axarr[0].fill_between(ts, 106 | mean_real_pulses[-1]-stdev_real_pulses[-1], 107 | mean_real_pulses[-1]+stdev_real_pulses[-1], 108 | alpha=0.5) 109 | axarr[0].set_title('Epoch = '+str(1000)) 110 | axarr[0].set_ylabel('Real Pulse') 111 | 112 | axarr[1].plot(ts,mean_imag_pulses[-1]) 113 | axarr[1].fill_between(ts, 114 | mean_imag_pulses[-1]-stdev_imag_pulses[-1], 115 | mean_imag_pulses[-1]+stdev_imag_pulses[-1], 116 | alpha=0.5) 117 | axarr[1].set_ylabel('Imag. Pulse') 118 | axarr[1].set_xlabel('Time (ns)') 119 | 120 | plt.tight_layout() 121 | plt.show() 122 | -------------------------------------------------------------------------------- /qcrl-server-tf240.yml: -------------------------------------------------------------------------------- 1 | name: qcrl-server 2 | channels: 3 | - defaults 4 | dependencies: 5 | - asttokens=2.0.5=pyhd3eb1b0_0 6 | - backcall=0.2.0=pyhd3eb1b0_0 7 | - blas=1.0=mkl 8 | - brotli=1.0.9=h2bbff1b_7 9 | - brotli-bin=1.0.9=h2bbff1b_7 10 | - ca-certificates=2023.01.10=haa95532_0 11 | - certifi=2022.12.7=py38haa95532_0 12 | - colorama=0.4.6=py38haa95532_0 13 | - comm=0.1.2=py38haa95532_0 14 | - cycler=0.11.0=pyhd3eb1b0_0 15 | - debugpy=1.5.1=py38hd77b12b_0 16 | - decorator=5.1.1=pyhd3eb1b0_0 17 | - executing=0.8.3=pyhd3eb1b0_0 18 | - fonttools=4.25.0=pyhd3eb1b0_0 19 | - freetype=2.12.1=ha860e81_0 20 | - giflib=5.2.1=h8cc25b3_3 21 | - glib=2.69.1=h5dc1a3c_2 22 | - gst-plugins-base=1.18.5=h9e645db_0 23 | - gstreamer=1.18.5=hd78058f_0 24 | - icc_rt=2022.1.0=h6049295_2 25 | - icu=58.2=ha925a31_3 26 | - importlib_metadata=6.0.0=hd3eb1b0_0 27 | - intel-openmp=2023.1.0=h59b6b97_46319 28 | - ipykernel=6.19.2=py38hd4e2768_0 29 | - ipython=8.12.0=py38haa95532_0 30 | - jedi=0.18.1=py38haa95532_1 31 | - jpeg=9e=h2bbff1b_1 32 | - jupyter_client=8.1.0=py38haa95532_0 33 | - jupyter_core=5.3.0=py38haa95532_0 34 | - kiwisolver=1.4.4=py38hd77b12b_0 35 | - krb5=1.19.4=h5b6d351_0 36 | - lerc=3.0=hd77b12b_0 37 | - libbrotlicommon=1.0.9=h2bbff1b_7 38 | - libbrotlidec=1.0.9=h2bbff1b_7 39 | - libbrotlienc=1.0.9=h2bbff1b_7 40 | - libclang=14.0.6=default_hb5a9fac_1 41 | - libclang13=14.0.6=default_h8e68704_1 42 | - libdeflate=1.17=h2bbff1b_0 43 | - libffi=3.4.2=hd77b12b_6 44 | - libiconv=1.16=h2bbff1b_2 45 | - libogg=1.3.5=h2bbff1b_1 46 | - libpng=1.6.39=h8cc25b3_0 47 | - libsodium=1.0.18=h62dcd97_0 48 | - libtiff=4.5.0=h6c2663c_2 49 | - libvorbis=1.3.7=he774522_0 50 | - libwebp=1.2.4=hbc33d0d_1 51 | - libwebp-base=1.2.4=h2bbff1b_1 52 | - libxml2=2.10.3=h0ad7f3c_0 53 | - libxslt=1.1.37=h2bbff1b_0 54 | - lz4-c=1.9.4=h2bbff1b_0 55 | - matplotlib=3.4.3=py38haa95532_0 56 | - matplotlib-base=3.4.3=py38h49ac443_0 57 | - matplotlib-inline=0.1.6=py38haa95532_0 58 | - mkl=2020.2=256 59 | - mkl-service=2.3.0=py38h196d8e1_0 60 | - mkl_fft=1.3.0=py38h46781fe_0 61 | - mkl_random=1.1.1=py38h47e9c7a_0 62 | - munkres=1.1.4=py_0 63 | - nest-asyncio=1.5.6=py38haa95532_0 64 | - openssl=1.1.1t=h2bbff1b_0 65 | - parso=0.8.3=pyhd3eb1b0_0 66 | - pcre=8.45=hd77b12b_0 67 | - pickleshare=0.7.5=pyhd3eb1b0_1003 68 | - pillow=9.4.0=py38hd77b12b_0 69 | - pip=23.0.1=py38haa95532_0 70 | - platformdirs=2.5.2=py38haa95532_0 71 | - ply=3.11=py38_0 72 | - prompt-toolkit=3.0.36=py38haa95532_0 73 | - psutil=5.9.0=py38h2bbff1b_0 74 | - pure_eval=0.2.2=pyhd3eb1b0_0 75 | - pygments=2.11.2=pyhd3eb1b0_0 76 | - pyparsing=3.0.9=py38haa95532_0 77 | - pyqt=5.15.7=py38hd77b12b_0 78 | - pyqt5-sip=12.11.0=py38hd77b12b_0 79 | - python=3.8.16=h6244533_3 80 | - python-dateutil=2.8.2=pyhd3eb1b0_0 81 | - pywin32=305=py38h2bbff1b_0 82 | - pyzmq=23.2.0=py38hd77b12b_0 83 | - qt-main=5.15.2=he8e5bd7_8 84 | - qt-webengine=5.15.9=hb9a9bb5_5 85 | - qtwebkit=5.212=h2bbfb41_5 86 | - scipy=1.4.1=py38h9439919_0 87 | - setuptools=66.0.0=py38haa95532_0 88 | - sip=6.6.2=py38hd77b12b_0 89 | - sqlite=3.41.2=h2bbff1b_0 90 | - stack_data=0.2.0=pyhd3eb1b0_0 91 | - tk=8.6.12=h2bbff1b_0 92 | - toml=0.10.2=pyhd3eb1b0_0 93 | - tornado=6.2=py38h2bbff1b_0 94 | - traitlets=5.7.1=py38haa95532_0 95 | - vc=14.2=h21ff451_1 96 | - vs2015_runtime=14.27.29016=h5e58377_2 97 | - wcwidth=0.2.5=pyhd3eb1b0_0 98 | - wheel=0.38.4=py38haa95532_0 99 | - xz=5.2.10=h8cc25b3_1 100 | - zeromq=4.3.4=hd77b12b_0 101 | - zlib=1.2.13=h8cc25b3_0 102 | - zstd=1.5.5=hd43e919_0 103 | - pip: 104 | - absl-py==0.15.0 105 | - astunparse==1.6.3 106 | - cachetools==4.2.4 107 | - charset-normalizer==3.1.0 108 | - cloudpickle==1.3.0 109 | - dm-tree==0.1.8 110 | - flatbuffers==1.12 111 | - gast==0.3.3 112 | - gin-config==0.3.0 113 | - google-auth==1.35.0 114 | - google-auth-oauthlib==0.4.6 115 | - google-pasta==0.2.0 116 | - grpcio==1.32.0 117 | - gviz-api==1.10.0 118 | - h5py==2.10.0 119 | - idna==3.4 120 | - importlib-metadata==6.5.0 121 | - keras-preprocessing==1.1.2 122 | - markdown==3.4.3 123 | - markupsafe==2.1.2 124 | - numpy==1.19.5 125 | - oauthlib==3.2.2 126 | - opt-einsum==3.3.0 127 | - packaging==23.1 128 | - protobuf==3.20.1 129 | - pyasn1==0.5.0 130 | - pyasn1-modules==0.3.0 131 | - qutip==4.7.1 132 | - requests==2.28.2 133 | - requests-oauthlib==1.3.1 134 | - rsa==4.9 135 | - six==1.15.0 136 | - tensorboard==2.11.2 137 | - tensorboard-data-server==0.6.1 138 | - tensorboard-plugin-profile==2.3.0 139 | - tensorboard-plugin-wit==1.8.1 140 | - tensorflow==2.4.0 141 | - tensorflow-estimator==2.4.0 142 | - tensorflow-probability==0.11.0 143 | - termcolor==1.1.0 144 | - tf-agents==0.6.0 145 | - typing-extensions==3.7.4.3 146 | - urllib3==1.26.15 147 | - werkzeug==2.2.3 148 | - wrapt==1.12.1 149 | - zipp==3.15.0 150 | prefix: C:\Users\DN_GKP\miniconda3\envs\tf-agents 151 | -------------------------------------------------------------------------------- /examples/pi_pulse/pi_pulse_training_server.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | import os 4 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]='true' 5 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 6 | 7 | # append parent 'gkp-rl' directory to path 8 | import sys 9 | sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) 10 | 11 | import tensorflow as tf 12 | from tf_agents import specs 13 | from quantum_control_rl_server import PPO 14 | from tf_agents.networks import actor_distribution_network 15 | from quantum_control_rl_server import remote_env_tools as rmt 16 | from quantum_control_rl_server.h5log import h5log 17 | 18 | root_dir = os.getcwd() #r'E:\rl_data\exp_training\pi_pulse' 19 | host_ip = '127.0.0.1' # ip address of RL server, here it's hosted locally 20 | 21 | num_epochs = 100 # total number of training epochs 22 | train_batch_size = 20 # number of batches to send for training epoch 23 | 24 | do_evaluation = True # flag for implementing eval epochs or not 25 | eval_interval = 20 # number of training epochs between eval epochs 26 | eval_batch_size = 1 # number of batches to send for eval epoch 27 | 28 | learn_residuals = True 29 | save_tf_style = False 30 | 31 | # Params for action wrapper 32 | action_script = { 33 | 'amp' : [[0.2]], # shape=[1,1] 34 | 'drag' : [[0.0]], # shape=[1,1] 35 | 'detuning' : [[0.0]] 36 | } 37 | 38 | # specify shapes of actions to be consistent with the objects in action_script 39 | action_spec = { 40 | 'amp' : specs.TensorSpec(shape=[1], dtype=tf.float32), 41 | 'drag' : specs.TensorSpec(shape=[1], dtype=tf.float32), 42 | 'detuning' : specs.TensorSpec(shape=[1], dtype=tf.float32) 43 | } 44 | 45 | # characteristic scale of sigmoid functions used in the neural network, 46 | # and for automatic differentiation of the reward 47 | # optimal point should ideally be within +/- action_scale of the initial vals 48 | action_scale = { 49 | 'amp' : 0.5, 50 | 'drag' : 2.0, 51 | 'detuning' : 0.01 # freq in GHz for the sim 52 | } 53 | 54 | # flags indicating whether actions will be learned or scripted 55 | to_learn = { 56 | 'amp' : True, 57 | 'drag' : True, 58 | 'detuning' : True 59 | } 60 | 61 | rl_params = {'num_epochs' : num_epochs, 62 | 'train_batch_size' : train_batch_size, 63 | 'do_evaluation' : do_evaluation, 64 | 'eval_interval' : eval_interval, 65 | 'eval_batch_size' : eval_batch_size, 66 | 'learn_residuals' : learn_residuals, 67 | 'action_script' : action_script, 68 | #'action_spec' : action_spec, # doesn't play nice with h5 files 69 | 'action_scale': action_scale, 70 | 'to_learn' : to_learn, 71 | 'save_tf_style' : save_tf_style} 72 | 73 | log = h5log(root_dir, rl_params) 74 | 75 | 76 | ############################################################ 77 | # Below code shouldn't require modification for normal use # 78 | ############################################################ 79 | 80 | # Create drivers for data collection 81 | from quantum_control_rl_server import dynamic_episode_driver_sim_env 82 | 83 | server_socket = rmt.Server() 84 | (host, port) = (host_ip, 5555) 85 | server_socket.bind((host, port)) 86 | server_socket.connect_client() 87 | 88 | # Params for environment 89 | env_kwargs = eval_env_kwargs = { 90 | 'T' : 1} 91 | 92 | # Params for reward function 93 | reward_kwargs = { 94 | 'reward_mode' : 'remote', 95 | 'server_socket' : server_socket, 96 | 'epoch_type' : 'training'} 97 | 98 | reward_kwargs_eval = { 99 | 'reward_mode' : 'remote', 100 | 'server_socket' : server_socket, 101 | 'epoch_type' : 'evaluation'} 102 | 103 | collect_driver = dynamic_episode_driver_sim_env.DynamicEpisodeDriverSimEnv( 104 | env_kwargs, reward_kwargs, train_batch_size, action_script, action_scale, 105 | action_spec, to_learn, learn_residuals, remote=True) 106 | 107 | eval_driver = dynamic_episode_driver_sim_env.DynamicEpisodeDriverSimEnv( 108 | eval_env_kwargs, reward_kwargs_eval, eval_batch_size, action_script, action_scale, 109 | action_spec, to_learn, learn_residuals, remote=True) 110 | 111 | PPO.train_eval( 112 | root_dir = root_dir, 113 | random_seed = 0, 114 | num_epochs = num_epochs, 115 | # Params for train 116 | normalize_observations = True, 117 | normalize_rewards = False, 118 | discount_factor = 1.0, 119 | lr = 2.5e-3, 120 | lr_schedule = None, 121 | num_policy_updates = 20, 122 | initial_adaptive_kl_beta = 0.0, 123 | kl_cutoff_factor = 0, 124 | importance_ratio_clipping = 0.1, 125 | value_pred_loss_coef = 0.005, 126 | gradient_clipping = 1.0, 127 | entropy_regularization = 0, 128 | log_prob_clipping = 0.0, 129 | # Params for log, eval, save 130 | eval_interval = eval_interval, 131 | save_interval = 2, 132 | checkpoint_interval = None, 133 | summary_interval = 2, 134 | do_evaluation = do_evaluation, 135 | # Params for data collection 136 | train_batch_size = train_batch_size, 137 | eval_batch_size = eval_batch_size, 138 | collect_driver = collect_driver, 139 | eval_driver = eval_driver, 140 | replay_buffer_capacity = 15000, 141 | # Policy and value networks 142 | ActorNet = actor_distribution_network.ActorDistributionNetwork, 143 | zero_means_kernel_initializer = False, 144 | init_action_stddev = 0.08, 145 | actor_fc_layers = (50,20), 146 | value_fc_layers = (), 147 | use_rnn = False, 148 | actor_lstm_size = (12,), 149 | value_lstm_size = (12,), 150 | h5datalog = log, 151 | save_tf_style = save_tf_style, 152 | rl_params = rl_params) 153 | 154 | # %% 155 | -------------------------------------------------------------------------------- /examples/pi_pulse_oct_style/pi_pulse_oct_style_training_server.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | import os 4 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]='true' 5 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 6 | 7 | # append parent 'gkp-rl' directory to path 8 | import sys 9 | sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) 10 | 11 | import tensorflow as tf 12 | from tf_agents import specs 13 | from quantum_control_rl_server import PPO 14 | from tf_agents.networks import actor_distribution_network 15 | from quantum_control_rl_server import remote_env_tools as rmt 16 | from quantum_control_rl_server.h5log import h5log 17 | import numpy as np 18 | 19 | root_dir = os.getcwd() #r'E:\rl_data\exp_training\pi_pulse' 20 | host_ip = '127.0.0.1' # ip address of RL server, here it's hosted locally 21 | 22 | num_epochs = 100 # total number of training epochs 23 | train_batch_size = 20 # number of batches to send for training epoch 24 | 25 | do_evaluation = True # flag for implementing eval epochs or not 26 | eval_interval = 20 # number of training epochs between eval epochs 27 | eval_batch_size = 1 # number of batches to send for eval epoch 28 | 29 | learn_residuals = True 30 | save_tf_style = False 31 | 32 | # setting up initial pulse 33 | n_array_vals = 40 # number of array values to optimize 34 | t_duration = n_array_vals # 1ns resolution, since time is in ns in sim 35 | amp = 0.5*(np.pi/t_duration) # starting 50% away from ideal pulse 36 | init_pulse_real = list(amp*np.ones(n_array_vals)) # square real pulse 37 | init_pulse_imag = list(np.zeros_like(init_pulse_real)) # no imag part 38 | 39 | amp_scale = (2*np.pi)/t_duration 40 | action_scale_array_real = list(np.ones(n_array_vals,dtype=float)*amp_scale) 41 | action_scale_array_imag = list(np.ones(n_array_vals,dtype=float)*amp_scale) 42 | 43 | # Params for action wrapper 44 | action_script = { 45 | 'pulse_array_real' : [init_pulse_real], # shape=[n_array_vals] 46 | 'pulse_array_imag' : [init_pulse_imag] 47 | } 48 | 49 | # specify shapes of actions to be consistent with the objects in action_script 50 | action_spec = { 51 | 'pulse_array_real' : specs.TensorSpec(shape=[n_array_vals], dtype=tf.float32), 52 | 'pulse_array_imag' : specs.TensorSpec(shape=[n_array_vals], dtype=tf.float32) 53 | } 54 | 55 | # characteristic scale of sigmoid functions used in the neural network, 56 | # and for automatic differentiation of the reward 57 | # optimal point should ideally be within +/- action_scale of the initial vals 58 | action_scale = { 59 | 'pulse_array_real':action_scale_array_real, 60 | 'pulse_array_imag':action_scale_array_imag 61 | } 62 | 63 | # flags indicating whether actions will be learned or scripted 64 | to_learn = { 65 | 'pulse_array_real':True, 66 | 'pulse_array_imag':True 67 | } 68 | 69 | 70 | rl_params = {'num_epochs' : num_epochs, 71 | 'train_batch_size' : train_batch_size, 72 | 'do_evaluation' : do_evaluation, 73 | 'eval_interval' : eval_interval, 74 | 'eval_batch_size' : eval_batch_size, 75 | 'learn_residuals' : learn_residuals, 76 | 'action_script' : action_script, 77 | #'action_spec' : action_spec, # doesn't play nice with h5 files 78 | 'action_scale': action_scale, 79 | 'to_learn' : to_learn, 80 | 'save_tf_style' : save_tf_style} 81 | 82 | log = h5log(root_dir, rl_params) 83 | 84 | 85 | ############################################################ 86 | # Below code shouldn't require modification for normal use # 87 | ############################################################ 88 | 89 | # Create drivers for data collection 90 | from quantum_control_rl_server import dynamic_episode_driver_sim_env 91 | 92 | server_socket = rmt.Server() 93 | (host, port) = (host_ip, 5555) 94 | server_socket.bind((host, port)) 95 | server_socket.connect_client() 96 | 97 | # Params for environment 98 | env_kwargs = eval_env_kwargs = { 99 | 'T' : 1} 100 | 101 | # Params for reward function 102 | reward_kwargs = { 103 | 'reward_mode' : 'remote', 104 | 'server_socket' : server_socket, 105 | 'epoch_type' : 'training'} 106 | 107 | reward_kwargs_eval = { 108 | 'reward_mode' : 'remote', 109 | 'server_socket' : server_socket, 110 | 'epoch_type' : 'evaluation'} 111 | 112 | collect_driver = dynamic_episode_driver_sim_env.DynamicEpisodeDriverSimEnv( 113 | env_kwargs, reward_kwargs, train_batch_size, action_script, action_scale, 114 | action_spec, to_learn, learn_residuals, remote=True) 115 | 116 | eval_driver = dynamic_episode_driver_sim_env.DynamicEpisodeDriverSimEnv( 117 | eval_env_kwargs, reward_kwargs_eval, eval_batch_size, action_script, action_scale, 118 | action_spec, to_learn, learn_residuals, remote=True) 119 | 120 | PPO.train_eval( 121 | root_dir = root_dir, 122 | random_seed = 0, 123 | num_epochs = num_epochs, 124 | # Params for train 125 | normalize_observations = True, 126 | normalize_rewards = False, 127 | discount_factor = 1.0, 128 | lr = 2.5e-3, 129 | lr_schedule = None, 130 | num_policy_updates = 20, 131 | initial_adaptive_kl_beta = 0.0, 132 | kl_cutoff_factor = 0, 133 | importance_ratio_clipping = 0.1, 134 | value_pred_loss_coef = 0.005, 135 | gradient_clipping = 1.0, 136 | entropy_regularization = 0, 137 | log_prob_clipping = 0.0, 138 | # Params for log, eval, save 139 | eval_interval = eval_interval, 140 | save_interval = 2, 141 | checkpoint_interval = None, 142 | summary_interval = 2, 143 | do_evaluation = do_evaluation, 144 | # Params for data collection 145 | train_batch_size = train_batch_size, 146 | eval_batch_size = eval_batch_size, 147 | collect_driver = collect_driver, 148 | eval_driver = eval_driver, 149 | replay_buffer_capacity = 15000, 150 | # Policy and value networks 151 | ActorNet = actor_distribution_network.ActorDistributionNetwork, 152 | zero_means_kernel_initializer = False, 153 | init_action_stddev = 0.08, 154 | actor_fc_layers = (50,20), 155 | value_fc_layers = (), 156 | use_rnn = False, 157 | actor_lstm_size = (12,), 158 | value_lstm_size = (12,), 159 | h5datalog = log, 160 | save_tf_style = save_tf_style, 161 | rl_params = rl_params) 162 | 163 | # %% 164 | -------------------------------------------------------------------------------- /quantum_control_rl_server/h5log.py: -------------------------------------------------------------------------------- 1 | # Author: Ben Brock 2 | # Created on May 03, 2023 3 | 4 | import h5py 5 | import os 6 | import time 7 | import numpy as np 8 | 9 | 10 | def set_attrs(g, kwargs): 11 | # recursive function for storing nested dicts 12 | # bottom layer should be compatible with storing in an h5 group (no custom objects, etc) 13 | for name, value in kwargs.items(): 14 | if isinstance(value, dict): 15 | sub_g = g.create_group(name) 16 | set_attrs(sub_g, value) 17 | else: 18 | g.attrs[name] = value 19 | 20 | class h5log: 21 | 22 | def __init__(self, dir, rl_params={}): 23 | # dir = str, directory where h5 file will be located 24 | # rl_params = dict containing params for training server 25 | 26 | self.dir = dir 27 | if not os.path.isdir(self.dir): 28 | os.mkdir(self.dir) 29 | 30 | self.filename = os.path.join(self.dir,time.strftime('%Y%m%d.h5')) 31 | f = h5py.File(self.filename) 32 | if f.keys(): 33 | keys = [k for k in f.keys() if k.isdigit()] 34 | group_name = str(max(map(int, keys)) + 1) 35 | else: 36 | group_name = '0' 37 | g = f.create_group(group_name) 38 | self.group_name = group_name 39 | 40 | rl_params['training_epochs_finished'] = 0 41 | rl_params['evaluation_epochs_finished'] = 0 42 | rl_params['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S') 43 | 44 | rl_param_group = g.create_group('rl_params') 45 | set_attrs(rl_param_group, rl_params) 46 | 47 | f.close() 48 | 49 | def parse_actions(self, driver): 50 | # expand_dims is used to add a dimension for epoch number 51 | actions = { 52 | action_name : np.expand_dims(np.squeeze(np.array(action_history)[1:]),0) 53 | for action_name, action_history in driver._env.history.items() 54 | } 55 | #print('actions in parse_actions()') 56 | #print(actions) 57 | return actions 58 | 59 | def parse_reward(self, driver): 60 | # expand_dims is used to add a dimension for epoch number 61 | reward = np.expand_dims(driver._env._episode_return.numpy(),axis=0) 62 | return reward 63 | 64 | def parse_policy_distribution(self, collect_driver, time_step, rl_params): 65 | policy_dist_dict = collect_driver._policy.distribution(time_step).info['dist_params'] 66 | locs = {} 67 | scales = {} 68 | for key in policy_dist_dict.keys(): 69 | locs[key] = np.expand_dims(rl_params['action_script'][key].numpy()[0] + (policy_dist_dict[key]['loc'].numpy()[0]*rl_params['action_scale'][key]),0) 70 | scales[key] = np.expand_dims((policy_dist_dict[key]['scale'].numpy()[0]*rl_params['action_scale'][key]),0) 71 | return locs, scales 72 | 73 | def save_driver_data(self, driver, epoch_type): 74 | # saves relevant data from RL episode driver 75 | # (collect_driver for training epochs, eval_driver for evaluation epochs) 76 | # epoch_type = str, 'evaluation' or 'training' 77 | 78 | these_actions = self.parse_actions(driver) 79 | this_reward = self.parse_reward(driver) 80 | 81 | f = h5py.File(self.filename) 82 | g = f[self.group_name] 83 | g['rl_params'].attrs[epoch_type+'_epochs_finished'] += 1 84 | h = g.require_group(epoch_type) # creates subgroup if it doesn't exist, otherwise returns the subgroup 85 | 86 | if 'rewards' not in h.keys(): 87 | h.create_dataset('rewards', 88 | data = this_reward, 89 | maxshape = (None,)+this_reward.shape[1:] 90 | ) 91 | else: 92 | h['rewards'].resize(h['rewards'].shape[0]+1,axis=0) 93 | h['rewards'][-1] = this_reward 94 | 95 | action_group = h.require_group('actions') 96 | for action_name, array in these_actions.items(): 97 | if action_name not in action_group.keys(): 98 | action_group.create_dataset(action_name, 99 | data = array, 100 | maxshape = (None,)+array.shape[1:] 101 | ) 102 | else: 103 | action_group[action_name].resize(action_group[action_name].shape[0]+1,axis=0) 104 | action_group[action_name][-1] = array 105 | 106 | 107 | def save_policy_distribution(self, collect_driver, time_step = None, rl_params = None): 108 | # saves policy distribution from the collect driver 109 | # needs rl_params['action_script'] and rl_params['action_scale'] 110 | # time_step = tensorflow object returned after running the driver each epoch 111 | 112 | these_actions = self.parse_actions(collect_driver) 113 | this_reward = self.parse_reward(collect_driver) 114 | 115 | f = h5py.File(self.filename) 116 | g = f[self.group_name] 117 | h = g.require_group('policy_distribution') # creates subgroup if it doesn't exist, otherwise returns the subgroup 118 | 119 | locs, scales = self.parse_policy_distribution(collect_driver, time_step, rl_params) 120 | 121 | loc_group = h.require_group('locs') 122 | for action_name in locs.keys(): 123 | array = locs[action_name] 124 | if action_name not in loc_group.keys(): 125 | loc_group.create_dataset(action_name, 126 | data = array, 127 | maxshape = (None,)+array.shape[1:] 128 | ) 129 | else: 130 | loc_group[action_name].resize(loc_group[action_name].shape[0]+1,axis=0) 131 | loc_group[action_name][-1] = array 132 | 133 | scale_group = h.require_group('scales') 134 | for action_name in scales.keys(): 135 | array = scales[action_name] 136 | if action_name not in scale_group.keys(): 137 | scale_group.create_dataset(action_name, 138 | data = array, 139 | maxshape = (None,)+array.shape[1:] 140 | ) 141 | else: 142 | scale_group[action_name].resize(scale_group[action_name].shape[0]+1,axis=0) 143 | scale_group[action_name][-1] = array 144 | 145 | f.close() 146 | 147 | -------------------------------------------------------------------------------- /quantum_control_rl_server/tf_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tf_agents import specs 4 | from tf_agents.environments import tf_environment 5 | from tf_agents.trajectories import time_step as ts 6 | from tf_agents.specs import tensor_spec 7 | 8 | 9 | class TFEnvironmentQuantumControl(tf_environment.TFEnvironment): 10 | """ 11 | Custom environment that follows TensorFlow Agents interface and allows to 12 | train a reinforcement learning agent to find quantum control policies. 13 | 14 | This implementation heavily relies on TensorFlow to do fast computations 15 | in parallel on GPU by adding batch dimension to all tensors. The speedup 16 | over all-Qutip implementation is about x100 on NVIDIA RTX 2080Ti. 17 | 18 | This is the base environment class for quantum control problems which 19 | incorporates simulation-independet methods. The QuantumCircuit subclasses 20 | inherit from this base class and from a simulation class. Subclasses 21 | implement 'control_circuit' which is ran at each time step. RL agent's 22 | actions are parametrized according to the sequence of gates applied at 23 | each time step, as defined by 'control_circuit'. 24 | 25 | Environment step() method returns TimeStep tuple whose 'observation' 26 | attribute stores the finite-horizon history of applied actions, measurement 27 | outcomes and state wavefunctions. User needs to define a wrapper for the 28 | environment if some components of this observation are to be discarded. 29 | """ 30 | def __init__( 31 | self, 32 | action_spec={}, 33 | T=4, 34 | batch_size=50, 35 | reward_kwargs={}, 36 | **kwargs): 37 | """ 38 | Args: 39 | T (int, optional): Periodicity of the 'clock' observation. Defaults to 4. 40 | batch_size (int, optional): Vectorized minibatch size. Defaults to 50. 41 | reward_kwargs (dict, optional): optional dictionary of parameters 42 | for the reward function of RL agent. 43 | 44 | """ 45 | # Default simulation parameters 46 | self.T = T 47 | self.batch_size = batch_size 48 | 49 | self.setup_reward(reward_kwargs) 50 | self._epoch = 0 51 | 52 | observation_spec = { 53 | 'clock' : specs.TensorSpec(shape=[self.T], dtype=tf.float32), 54 | 'const' : specs.TensorSpec(shape=[1], dtype=tf.float32)} 55 | time_step_spec = ts.time_step_spec(observation_spec) 56 | 57 | super().__init__(time_step_spec, action_spec, self.batch_size) 58 | 59 | 60 | ### STANDARD 61 | 62 | 63 | def _step(self, action): 64 | """ 65 | Execute one time step in the environment. 66 | 67 | Input: 68 | action -- dictionary of batched actions 69 | 70 | Output: 71 | TimeStep object (see tf-agents docs) 72 | 73 | """ 74 | 75 | # Calculate rewards 76 | self._elapsed_steps += 1 77 | self._episode_ended = (self._elapsed_steps == self.T) 78 | 79 | # Add dummy time dimension to tensors and append them to history 80 | 81 | # don't add batch to history on the last time step, 82 | # since this batch of actions isn't actually used 83 | if not self._current_time_step_.is_last().numpy().all(): 84 | for a in action.keys(): 85 | self.history[a].append(action[a]) 86 | 87 | # Make observations of 'msmt' of horizon H, shape=[batch_size,H] 88 | # measurements are selected with hard-coded attention step. 89 | # Also add clock of period 'T' to observations, shape=[batch_size,T] 90 | observation = {} 91 | C = tf.one_hot([self._elapsed_steps%self.T]*self.batch_size, self.T) 92 | observation['clock'] = C 93 | observation['const'] = tf.ones(shape=[self.batch_size,1]) 94 | 95 | reward = self.calculate_reward(action) 96 | self._episode_return += reward 97 | 98 | if self._episode_ended: 99 | self._epoch += 1 100 | self._current_time_step_ = ts.termination(observation, reward) 101 | else: 102 | self._current_time_step_ = ts.transition(observation, reward) 103 | return self.current_time_step() 104 | 105 | 106 | def _reset(self): 107 | """ 108 | Reset the state of the environment to an initial state. States are 109 | represented as batched tensors. 110 | 111 | Output: 112 | TimeStep object (see tf-agents docs) 113 | 114 | """ 115 | self.info = {} # use to cache some intermediate results 116 | 117 | # Bookkeeping of episode progress 118 | self._episode_ended = False 119 | self._elapsed_steps = 0 120 | self._episode_return = 0 121 | 122 | # Initialize history of horizon H with actions=0 and measurements=1 123 | self.history = tensor_spec.zero_spec_nest( 124 | self.action_spec(), outer_dims=(self.batch_size,)) 125 | for key in self.history.keys(): 126 | self.history[key] = [self.history[key]] 127 | 128 | # Make observation of horizon H 129 | observation = { 130 | 'clock' : tf.one_hot([0]*self.batch_size, self.T), 131 | 'const' : tf.ones(shape=[self.batch_size,1])} 132 | 133 | self._current_time_step_ = ts.restart(observation, self.batch_size) 134 | return self.current_time_step() 135 | 136 | def _current_time_step(self): 137 | return self._current_time_step_ 138 | 139 | 140 | ### REWARD FUNCTIONS 141 | 142 | def setup_reward(self, reward_kwargs): 143 | """Setup the reward function based on reward_kwargs. """ 144 | try: 145 | mode = reward_kwargs.pop('reward_mode') 146 | assert mode in ['zero', 147 | 'remote'] 148 | self.reward_mode = mode 149 | except: 150 | raise ValueError('reward_mode not specified or not supported.') 151 | 152 | if mode == 'remote': 153 | """ 154 | Required reward_kwargs: 155 | reward_mode (str): 'remote' 156 | N_msmt (int): number of measurements per protocol 157 | epoch_type (str): either 'training' or 'evaluation' 158 | server_socket (Socket): socket for communication 159 | """ 160 | self.server_socket = reward_kwargs.pop('server_socket') 161 | self.calculate_reward = \ 162 | lambda x: self.reward_remote(**reward_kwargs) 163 | 164 | 165 | def reward_remote(self, epoch_type): 166 | """ 167 | Send the action sequence to remote environment and receive rewards. 168 | The data received from the remote env should be Pauli measurement 169 | (i.e., range from -1 to 1) 170 | outcomes of shape [batch_size]. 171 | 172 | """ 173 | 174 | 175 | 176 | # return 0 on all intermediate steps of the episode 177 | if self._elapsed_steps != self.T: 178 | return tf.zeros(self.batch_size, dtype=tf.float32) 179 | 180 | action_batch = {} 181 | for a in self.history.keys() - ['msmt']: 182 | # reshape to [batch_size, T, action_dim] 183 | 184 | action_history = np.array(self.history[a][1:]) 185 | action_batch[a] = np.transpose(action_history, 186 | axes=[1,0]+list(range(action_history.ndim)[2:])) 187 | 188 | # send action sequence and metadata to remote client 189 | message = dict(action_batch=action_batch, 190 | batch_size=self.batch_size, 191 | epoch_type=epoch_type, 192 | epoch=self._epoch) 193 | 194 | self.server_socket.send_data(message) 195 | 196 | # receive sigma_z of shape [batch_size] 197 | msmt, done = self.server_socket.recv_data() 198 | msmt = tf.cast(msmt, tf.float32) 199 | #z = np.mean(msmt, axis=0) 200 | 201 | return msmt #tf.cast(z, tf.float32) 202 | 203 | 204 | 205 | 206 | 207 | @property 208 | def batch_size(self): 209 | return self._batch_size 210 | 211 | @batch_size.setter 212 | def batch_size(self, size): 213 | try: 214 | assert size>0 and isinstance(size,int) 215 | self._batch_size = size 216 | except: 217 | raise ValueError('Batch size should be positive integer.') 218 | -------------------------------------------------------------------------------- /quantum_control_rl_server/PPO.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Mar 4 12:24:37 2020 4 | 5 | @author: Vladimir Sivak 6 | """ 7 | import os 8 | import sys 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | from tf_agents.replay_buffers import tf_uniform_replay_buffer 13 | from tf_agents.policies import policy_saver 14 | from tf_agents.networks import actor_distribution_network, value_network 15 | from tf_agents.networks import actor_distribution_rnn_network, value_rnn_network 16 | from tf_agents.utils import common 17 | from tf_agents.eval import metric_utils 18 | from tf_agents.agents.ppo import ppo_agent 19 | from tf_agents.utils import timer 20 | from tf_agents.metrics import tf_metrics 21 | 22 | def train_eval( 23 | root_dir, 24 | random_seed = 0, 25 | num_epochs = 1000000, 26 | # Params for train 27 | normalize_observations = True, 28 | normalize_rewards = True, 29 | discount_factor = 1.0, 30 | lr = 1e-5, 31 | lr_schedule = None, 32 | num_policy_updates = 20, 33 | initial_adaptive_kl_beta = 0.0, 34 | kl_cutoff_factor = 0, 35 | importance_ratio_clipping = 0.2, 36 | value_pred_loss_coef = 0.5, 37 | gradient_clipping = None, 38 | entropy_regularization = 0.0, 39 | log_prob_clipping = 0.0, 40 | # Params for log, eval, save 41 | eval_interval = 100, 42 | save_interval = 1000, 43 | checkpoint_interval = None, 44 | summary_interval = 100, 45 | do_evaluation = True, 46 | # Params for data collection 47 | train_batch_size = 10, 48 | eval_batch_size = 100, 49 | collect_driver = None, 50 | eval_driver = None, 51 | replay_buffer_capacity = 20000, 52 | # Policy and value networks 53 | ActorNet = actor_distribution_network.ActorDistributionNetwork, 54 | zero_means_kernel_initializer = False, 55 | init_action_stddev = 0.35, 56 | actor_fc_layers = (), 57 | value_fc_layers = (), 58 | use_rnn = True, 59 | actor_lstm_size = (12,), 60 | value_lstm_size = (12,), 61 | h5datalog = None, 62 | save_tf_style = True, 63 | rl_params = None, 64 | **kwargs): 65 | """ A simple train and eval for PPO agent. 66 | 67 | Args: 68 | root_dir (str): directory for saving training and evalutaion data 69 | random_seed (int): seed for random number generator 70 | num_epochs (int): number of training epochs. At each epoch a batch 71 | of data is collected according to one stochastic policy, and then 72 | the policy is updated. 73 | normalize_observations (bool): flag for normalization of observations. 74 | Uses StreamingTensorNormalizer which normalizes based on the whole 75 | history of observations. 76 | normalize_rewards (bool): flag for normalization of rewards. 77 | Uses StreamingTensorNormalizer which normalizes based on the whole 78 | history of rewards. 79 | discount_factor (float): rewards discout factor, should be in (0,1] 80 | lr (float): learning rate for Adam optimizer 81 | lr_schedule (callable: int -> float, optional): function to schedule 82 | the learning rate annealing. Takes as argument the int epoch 83 | number and returns float value of the learning rate. 84 | num_policy_updates (int): number of policy gradient steps to do on each 85 | epoch of training. In PPO this is typically >1. 86 | initial_adaptive_kl_beta (float): see tf-agents PPO docs 87 | kl_cutoff_factor (float): see tf-agents PPO docs 88 | importance_ratio_clipping (float): clipping value for importance ratio. 89 | Should demotivate the policy from doing updates that significantly 90 | change the policy. Should be in (0,1] 91 | value_pred_loss_coef (float): weight coefficient for quadratic value 92 | estimation loss. 93 | gradient_clipping (float): gradient clipping coefficient. 94 | entropy_regularization (float): entropy regularization loss coefficient. 95 | log_prob_clipping (float): +/- value for clipping log probs to prevent 96 | inf / NaN values. Default: no clipping. 97 | eval_interval (int): interval between evaluations, counted in epochs. 98 | save_interval (int): interval between savings, counted in epochs. It 99 | updates the log file and saves the deterministic policy. 100 | checkpoint_interval (int): interval between saving checkpoints, counted 101 | in epochs. Overwrites the previous saved one. Defaults to None, 102 | in which case checkpoints are not saved. 103 | summary_interval (int): interval between summary writing, counted in 104 | epochs. tf-agents takes care of summary writing; results can be 105 | later displayed in tensorboard. 106 | do_evaluation (bool): flag to interleave training epochs with 107 | evaluation epochs. 108 | train_batch_size (int): training batch size, collected in parallel. 109 | eval_batch_size (int): batch size for evaluation of the policy. 110 | collect_driver (Driver): driver for training data collection 111 | eval_driver (Driver): driver for evaluation data collection 112 | replay_buffer_capacity (int): How many transition tuples the buffer 113 | can store. The buffer is emptied and re-populated at each epoch. 114 | ActorNet (network.DistributionNetwork): a distribution actor network 115 | to use for training. The default is ActorDistributionNetwork from 116 | tf-agents, but this can also be customized. 117 | zero_means_kernel_initializer (bool): flag to initialize the means 118 | projection network with zeros. If this flag is not set, it will 119 | use default tf-agent random initializer. 120 | init_action_stddev (float): initial stddev of the normal action dist. 121 | actor_fc_layers (tuple): sizes of fully connected layers in actor net. 122 | value_fc_layers (tuple): sizes of fully connected layers in value net. 123 | use_rnn (bool): whether to use LSTM units in the neural net. 124 | actor_lstm_size (tuple): sizes of LSTM layers in actor net. 125 | value_lstm_size (tuple): sizes of LSTM layers in value net. 126 | """ 127 | # -------------------------------------------------------------------- 128 | # -------------------------------------------------------------------- 129 | tf.compat.v1.set_random_seed(random_seed) 130 | 131 | # Setup directories within 'root_dir' 132 | if not os.path.isdir(root_dir): os.mkdir(root_dir) 133 | policy_dir = os.path.join(root_dir, 'policy') 134 | checkpoint_dir = os.path.join(root_dir, 'checkpoint') 135 | logfile = os.path.join(root_dir,'log.hdf5') 136 | train_dir = os.path.join(root_dir, 'train_summaries') 137 | 138 | # Create tf summary writer 139 | train_summary_writer = tf.compat.v2.summary.create_file_writer(train_dir) 140 | train_summary_writer.set_as_default() 141 | summary_interval *= num_policy_updates 142 | global_step = tf.compat.v1.train.get_or_create_global_step() 143 | with tf.compat.v2.summary.record_if( 144 | lambda: tf.math.equal(global_step % summary_interval, 0)): 145 | 146 | # Define action and observation specs 147 | observation_spec = collect_driver.observation_spec() 148 | action_spec = collect_driver.action_spec() 149 | 150 | # Preprocessing: flatten and concatenate observation components 151 | preprocessing_layers = { 152 | obs : tf.keras.layers.Flatten() for obs in observation_spec.keys()} 153 | preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) 154 | 155 | # Define actor network and value network 156 | if use_rnn: 157 | actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( 158 | input_tensor_spec = observation_spec, 159 | output_tensor_spec = action_spec, 160 | preprocessing_layers = preprocessing_layers, 161 | preprocessing_combiner = preprocessing_combiner, 162 | input_fc_layer_params = None, 163 | lstm_size = actor_lstm_size, 164 | output_fc_layer_params = actor_fc_layers) 165 | 166 | value_net = value_rnn_network.ValueRnnNetwork( 167 | input_tensor_spec = observation_spec, 168 | preprocessing_layers = preprocessing_layers, 169 | preprocessing_combiner = preprocessing_combiner, 170 | input_fc_layer_params = None, 171 | lstm_size = value_lstm_size, 172 | output_fc_layer_params = value_fc_layers) 173 | else: 174 | npn = actor_distribution_network._normal_projection_net 175 | 176 | if zero_means_kernel_initializer: 177 | normal_projection_net = lambda specs: npn(specs, 178 | zero_means_kernel_initializer=zero_means_kernel_initializer, 179 | init_action_stddev=init_action_stddev) 180 | else: 181 | normal_projection_net = lambda specs: npn(specs, 182 | init_action_stddev=init_action_stddev) 183 | 184 | actor_net = ActorNet( 185 | input_tensor_spec = observation_spec, 186 | output_tensor_spec = action_spec, 187 | preprocessing_layers = preprocessing_layers, 188 | preprocessing_combiner = preprocessing_combiner, 189 | fc_layer_params = actor_fc_layers, 190 | continuous_projection_net=normal_projection_net) 191 | 192 | value_net = value_network.ValueNetwork( 193 | input_tensor_spec = observation_spec, 194 | preprocessing_layers = preprocessing_layers, 195 | preprocessing_combiner = preprocessing_combiner, 196 | fc_layer_params = value_fc_layers) 197 | 198 | # Create PPO agent 199 | 200 | optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr) 201 | tf_agent = ppo_agent.PPOAgent( 202 | time_step_spec = collect_driver.time_step_spec(), 203 | action_spec = action_spec, 204 | optimizer = optimizer, 205 | actor_net = actor_net, 206 | value_net = value_net, 207 | num_epochs = num_policy_updates, 208 | train_step_counter = global_step, 209 | discount_factor = discount_factor, 210 | normalize_observations = normalize_observations, 211 | normalize_rewards = normalize_rewards, 212 | initial_adaptive_kl_beta = initial_adaptive_kl_beta, 213 | kl_cutoff_factor = kl_cutoff_factor, 214 | importance_ratio_clipping = importance_ratio_clipping, 215 | gradient_clipping = gradient_clipping, 216 | value_pred_loss_coef = value_pred_loss_coef, 217 | entropy_regularization=entropy_regularization, 218 | log_prob_clipping=log_prob_clipping, 219 | debug_summaries = True) 220 | 221 | tf_agent.initialize() 222 | eval_policy = tf_agent.policy 223 | collect_policy = tf_agent.collect_policy 224 | 225 | # Create replay buffer and collection driver 226 | replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( 227 | data_spec=tf_agent.collect_data_spec, 228 | batch_size=train_batch_size, 229 | max_length=replay_buffer_capacity) 230 | 231 | def train_step(): 232 | experience = replay_buffer.gather_all() 233 | return tf_agent.train(experience) 234 | 235 | tf_agent.train = common.function(tf_agent.train) 236 | 237 | avg_return_metric = tf_metrics.AverageReturnMetric( 238 | batch_size=eval_batch_size, buffer_size=eval_batch_size) 239 | 240 | collect_driver.setup(collect_policy, [replay_buffer.add_batch]) 241 | eval_driver.setup(eval_policy, [avg_return_metric]) 242 | 243 | # Create a checkpointer and load the saved agent 244 | train_checkpointer = common.Checkpointer( 245 | ckpt_dir=checkpoint_dir, 246 | max_to_keep=1, 247 | agent=tf_agent, 248 | policy=tf_agent.policy, 249 | replay_buffer=replay_buffer, 250 | global_step=global_step) 251 | 252 | train_checkpointer.initialize_or_restore() 253 | global_step = tf.compat.v1.train.get_global_step() 254 | 255 | # Saver for the deterministic policy 256 | saved_model = policy_saver.PolicySaver( 257 | eval_policy, train_step=global_step) 258 | 259 | # Evaluate policy once before training 260 | if do_evaluation: 261 | eval_driver.run() 262 | if h5datalog is not None: 263 | h5datalog.save_driver_data(eval_driver,'evaluation') 264 | avg_return = avg_return_metric.result().numpy() 265 | avg_return_metric.reset() 266 | log = { 267 | 'returns' : [avg_return], 268 | 'epochs' : [0], 269 | 'policy_steps' : [0], 270 | 'experience_time' : [0.0], 271 | 'train_time' : [0.0] 272 | } 273 | print('-------------------') 274 | print('Epoch 0') 275 | print(' Policy steps: 0') 276 | print(' Experience time: 0.00 mins') 277 | print(' Policy train time: 0.00 mins') 278 | print(' Average return: %.5f' %avg_return) 279 | 280 | if save_tf_style: 281 | # Save initial random policy 282 | path = os.path.join(policy_dir,('0').zfill(6)) 283 | saved_model.save(path) 284 | 285 | # Training loop 286 | train_timer = timer.Timer() 287 | experience_timer = timer.Timer() 288 | 289 | #all_policy_dists = None 290 | for epoch in range(1,num_epochs+1): 291 | # Collect new experience 292 | experience_timer.start() 293 | this_time_step,this_policy_state = collect_driver.run() 294 | experience_timer.stop() 295 | 296 | 297 | if h5datalog is not None: 298 | h5datalog.save_driver_data(collect_driver,'training') # saves actions from the epoch that just ran 299 | h5datalog.save_policy_distribution(collect_driver, time_step = this_time_step, rl_params = rl_params) # saves policy dist from the epoch that just ran 300 | 301 | # Update the policy 302 | train_timer.start() 303 | if lr_schedule: optimizer._lr = lr_schedule(epoch) 304 | train_loss = train_step() 305 | replay_buffer.clear() 306 | train_timer.stop() 307 | 308 | if (epoch % eval_interval == 0) and do_evaluation: 309 | # Evaluate the policy 310 | eval_driver.run() 311 | 312 | if h5datalog is not None: 313 | h5datalog.save_driver_data(eval_driver,'evaluation') 314 | 315 | avg_return = avg_return_metric.result().numpy() 316 | avg_return_metric.reset() 317 | 318 | # Print out and log all metrics 319 | print('-------------------') 320 | print('Epoch %d' %epoch) 321 | print(' Policy steps: %d' %(epoch*num_policy_updates)) 322 | print(' Experience time: %.2f mins' %(experience_timer.value()/60)) 323 | print(' Policy train time: %.2f mins' %(train_timer.value()/60)) 324 | print(' Average return: %.5f' %avg_return) 325 | log['epochs'].append(epoch) 326 | log['policy_steps'].append(epoch*num_policy_updates) 327 | log['returns'].append(avg_return) 328 | log['experience_time'].append(experience_timer.value()) 329 | log['train_time'].append(train_timer.value()) 330 | 331 | 332 | if save_tf_style: 333 | if epoch % save_interval == 0: 334 | # Save deterministic policy 335 | path = os.path.join(policy_dir,('%d' % epoch).zfill(6)) 336 | saved_model.save(path) 337 | 338 | if checkpoint_interval is not None and \ 339 | epoch % checkpoint_interval == 0: 340 | # Save training checkpoint 341 | train_checkpointer.save(global_step) 342 | 343 | # save policy dist after final epoch (post-updating) 344 | if h5datalog is not None: 345 | h5datalog.save_policy_distribution(collect_driver, time_step = this_time_step, rl_params = rl_params) 346 | 347 | # End training by sending the mean (loc) and stdev (scale) of the policy distribution to the client 348 | locs, scales = h5datalog.parse_policy_distribution(collect_driver, time_step = this_time_step, rl_params = rl_params) 349 | final_message = dict( 350 | locs=locs, 351 | scales=scales, 352 | epoch_type='final', 353 | ) 354 | 355 | collect_driver.env.server_socket.send_data(final_message) 356 | 357 | collect_driver.finish_training() 358 | eval_driver.finish_training() 359 | 360 | 361 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. --------------------------------------------------------------------------------