├── .gitignore ├── GOLF ├── GOLF_actor.py ├── GOLF_trainer.py ├── __init__.py ├── eval.py ├── experience_saver.py ├── make_policies.py ├── make_saver.py ├── optim │ ├── __init__.py │ ├── lbfgs.py │ └── lion_pytorch.py ├── replay_buffer.py └── utils.py ├── LICENSE ├── README.md ├── checkpoints ├── GOLF-10k │ ├── NNP_checkpoint_actor │ └── config.json ├── GOLF-1k │ ├── NNP_checkpoint_actor │ └── config.json ├── baseline-NNP │ ├── NNP_checkpoint_actor │ └── config.json ├── traj-100k │ ├── NNP_checkpoint_actor │ └── config.json ├── traj-10k │ ├── NNP_checkpoint_actor │ └── config.json └── traj-500k │ ├── NNP_checkpoint_actor │ └── config.json ├── env ├── dft.py ├── dft_worker.py ├── host_names.txt ├── make_envs.py ├── moldynamics_env.py ├── molecules_xyz │ ├── aspirin.xyz │ ├── azobenzene.xyz │ ├── benzene.xyz │ ├── ethanol.xyz │ ├── malonaldehyde.xyz │ ├── naphthalene.xyz │ ├── paracetamol.xyz │ ├── salicylic_acid.xyz │ ├── toluene.xyz │ └── uracil.xyz ├── wrappers.py └── xyz2mol.py ├── evaluate_batch_dft.py ├── main.py ├── read_evaluation_metrics.py ├── requirements.txt ├── scripts ├── babysit_dft.sh ├── setup_dft_workers.sh ├── setup_gpu_env.sh └── training │ ├── run_training_GOLF_10k.sh │ ├── run_training_GOLF_1k.sh │ ├── run_training_baseline.sh │ ├── run_training_trajectories-100k.sh │ ├── run_training_trajectories-10k.sh │ └── run_training_trajectories-500k.sh ├── test_dft_workers.py └── utils ├── arguments.py ├── logging.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all .ipynb 2 | *.ipynb 3 | **/.ipynb_checkpoints 4 | 5 | # Ingore Jmol --> might remove 6 | **/Jmol 7 | 8 | # Ignore psi4 stuff 9 | **/timer.dat 10 | 11 | # Ignore all __pycache__ and .DS_Store 12 | **/__pycache__ 13 | *.DS_Store 14 | 15 | # Ignore logs notebooks and env data 16 | results 17 | data 18 | trajectories 19 | notebooks 20 | evaluate_action_norms 21 | env/data 22 | 23 | # Ignore different hostname files 24 | **/host_names* 25 | 26 | # Ignore resulting DBs and evaluation metrics 27 | **/evaluation_config.json 28 | **/evaluation_metrics.json 29 | **/results.db 30 | 31 | # Ignore models 32 | models 33 | 34 | # Ignore figures 35 | *.html 36 | *.png 37 | *.pdf 38 | 39 | # Ignore all trajectories 40 | *.traj 41 | 42 | # Ignore saved torch tensors 43 | *.pt 44 | 45 | -------------------------------------------------------------------------------- /GOLF/GOLF_actor.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import collections 3 | 4 | import numpy as np 5 | import schnetpack as spk 6 | import torch 7 | import torch.nn as nn 8 | from schnetpack import properties 9 | from torch.linalg import vector_norm 10 | 11 | from GOLF import DEVICE 12 | from GOLF.utils import ( 13 | get_conformation_lr_scheduler, 14 | get_atoms_indices_range, 15 | unpad_state, 16 | _atoms_collate_fn, 17 | ) 18 | from utils.utils import ignore_extra_args 19 | from GOLF.optim import lbfgs 20 | 21 | KCALMOL_2_HARTREE = 627.5 22 | 23 | EPS = 1e-8 24 | 25 | backbones = { 26 | "schnet": ignore_extra_args(spk.representation.SchNet), 27 | "painn": ignore_extra_args(spk.representation.PaiNN), 28 | } 29 | 30 | 31 | class Actor(nn.Module): 32 | def __init__( 33 | self, 34 | backbone, 35 | backbone_args, 36 | do_postprocessing=False, 37 | action_norm_limit=None, 38 | ): 39 | super(Actor, self).__init__() 40 | self.action_norm_limit = action_norm_limit 41 | 42 | representation = backbones[backbone](**backbone_args) 43 | output_modules = [ 44 | spk.atomistic.Atomwise( 45 | n_in=representation.n_atom_basis, 46 | n_out=1, 47 | output_key="energy", 48 | ), 49 | spk.atomistic.Forces(energy_key="energy", force_key="anti_gradient"), 50 | ] 51 | 52 | if do_postprocessing: 53 | postprocessors = [ 54 | spk.transform.AddOffsets(property="energy", add_mean=True) 55 | ] 56 | else: 57 | postprocessors = None 58 | self.model = spk.model.NeuralNetworkPotential( 59 | representation=representation, 60 | input_modules=[spk.atomistic.PairwiseDistances()], 61 | postprocessors=postprocessors, 62 | output_modules=output_modules, 63 | ) 64 | 65 | self.last_energy = None 66 | self.last_forces = None 67 | 68 | def _limit_action_norm(self, actions, n_atoms): 69 | if self.action_norm_limit is None: 70 | return actions 71 | 72 | coefficient = torch.ones( 73 | size=(actions.size(0), 1), dtype=torch.float32, device=actions.device 74 | ) 75 | for molecule_id in range(n_atoms.size(0) - 1): 76 | max_norm = ( 77 | vector_norm( 78 | actions[n_atoms[molecule_id] : n_atoms[molecule_id + 1]], 79 | dim=-1, 80 | keepdims=True, 81 | ) 82 | .max(dim=1, keepdims=True) 83 | .values 84 | ) 85 | max_norm = torch.maximum( 86 | max_norm, torch.full_like(max_norm, fill_value=EPS, dtype=torch.float32) 87 | ) 88 | coefficient[n_atoms[molecule_id] : n_atoms[molecule_id + 1]] = ( 89 | torch.minimum( 90 | self.action_norm_limit / max_norm, 91 | torch.ones_like(max_norm, dtype=torch.float32), 92 | ) 93 | ) 94 | 95 | return actions * coefficient 96 | 97 | def _save_last_output(self, output): 98 | energy = output["energy"].detach() 99 | forces = output["anti_gradient"].detach() 100 | self.last_energy = energy 101 | self.last_forces = forces 102 | 103 | def _get_last_output(self): 104 | if self.last_energy is None or self.last_forces is None: 105 | raise ValueError("Last output has not been set yet!") 106 | return self.last_energy, self.last_forces 107 | 108 | def forward(self, state_dict, active_optimizers_ids=None, train=False): 109 | output = self.model(state_dict) 110 | self._save_last_output(output) 111 | if train: 112 | return output 113 | forces = output["anti_gradient"].detach() 114 | forces = self._limit_action_norm(forces, get_atoms_indices_range(state_dict)) 115 | 116 | return {"action": forces, "energy": output["energy"]} 117 | 118 | 119 | class RdkitActor(nn.Module): 120 | def __init__(self, env): 121 | super().__init__() 122 | self.env = env 123 | 124 | def forward(self, state_dict, active_optimizers_ids=None, train=False): 125 | if active_optimizers_ids is None: 126 | opt_ids = list(range(self.env.n_parallel)) 127 | else: 128 | opt_ids = active_optimizers_ids 129 | 130 | # print(opt_ids) 131 | 132 | # Update atoms inside env 133 | current_coordinates = [ 134 | self.env.unwrapped.atoms[idx].get_positions() for idx in opt_ids 135 | ] 136 | # print("current coordinates: ") 137 | # for coord in current_coordinates: 138 | # print(coord.shape) 139 | 140 | new_coordinates = torch.split( 141 | state_dict[properties.R].detach().cpu(), 142 | state_dict[properties.n_atoms].tolist(), 143 | ) 144 | assert len(new_coordinates) == len(opt_ids) 145 | new_coordinates = [ 146 | np.float64(new_coordinates[i].numpy()) for i in range(len(opt_ids)) 147 | ] 148 | # print("new coordinates") 149 | # for coord in new_coordinates: 150 | # print(coord.shape) 151 | # print("mol size inside the env") 152 | # for mol in self.env.rdkit_oracle.molecules: 153 | # print(mol.GetNumAtoms()) 154 | # Update coordinates inside env 155 | self.env.rdkit_oracle.update_coordinates(new_coordinates, indices=opt_ids) 156 | _, energies, forces = self.env.rdkit_oracle.calculate_energies_forces( 157 | indices=opt_ids 158 | ) 159 | 160 | # Restore original coordinates 161 | self.env.rdkit_oracle.update_coordinates(current_coordinates, indices=opt_ids) 162 | 163 | # Forces in (kcal/mol)/angstrom. Transform into hartree/angstrom. 164 | forces = torch.cat( 165 | [torch.tensor(force / KCALMOL_2_HARTREE) for force in forces] 166 | ) 167 | 168 | return {"anti_gradient": forces, "energy": torch.tensor(energies)} 169 | 170 | 171 | class ConformationOptimizer(nn.Module): 172 | def __init__( 173 | self, 174 | n_parallel, 175 | actor, 176 | lr_scheduler, 177 | t_max, 178 | optimizer, 179 | optimizer_kwargs, 180 | ): 181 | super().__init__() 182 | self.n_parallel = n_parallel 183 | self.lr_scheduler = get_conformation_lr_scheduler( 184 | lr_scheduler, optimizer_kwargs["lr"], t_max 185 | ) 186 | self.optimizer = optimizer 187 | self.optimizer_kwargs = optimizer_kwargs 188 | self.optimizer_list = [None] * n_parallel 189 | self.states = [None] * n_parallel 190 | self.actor = actor 191 | 192 | def reset(self, initial_states, indices=None): 193 | if indices is None: 194 | indices = torch.arange(self.n_parallel) 195 | unpad_initial_states = unpad_state(initial_states) 196 | for i, idx in enumerate(indices): 197 | self.states[idx] = { 198 | k: v.detach().clone().to(DEVICE) 199 | for k, v in unpad_initial_states[i].items() 200 | } 201 | self.states[idx][properties.R].requires_grad_(True) 202 | self.optimizer_list[idx] = self.optimizer( 203 | [self.states[idx][properties.R]], **self.optimizer_kwargs 204 | ) 205 | 206 | def act(self, t): 207 | # Update learning rate 208 | lrs = self.lr_scheduler.get(t) 209 | for idx, optim in enumerate(self.optimizer_list): 210 | for g in optim.param_groups: 211 | g["lr"] = lrs[idx] 212 | 213 | # Save current positions 214 | prev_positions = [ 215 | self.states[idx][properties.R].detach().clone() 216 | for idx in range(self.n_parallel) 217 | ] 218 | energy = torch.zeros(self.n_parallel) 219 | 220 | for optim in self.optimizer_list: 221 | optim.zero_grad() 222 | 223 | # Compute forces 224 | states = { 225 | key: value.to(DEVICE) 226 | for key, value in _atoms_collate_fn(self.states).items() 227 | } 228 | output = self.actor(states, train=True) 229 | energy = output["energy"] 230 | gradients = torch.split( 231 | -output["anti_gradient"].detach(), 232 | states[properties.n_atoms].tolist(), 233 | ) 234 | 235 | # Update all molecules' geometry 236 | for i, optim in enumerate(self.optimizer_list): 237 | self.states[i][properties.R].grad = gradients[i].to(DEVICE) 238 | optim.step() 239 | 240 | # Done always False 241 | done = [torch.tensor([False]) for _ in self.optimizer_list] 242 | 243 | # Calculate action based on saved positions and resulting geometries 244 | actions = [ 245 | self.states[idx][properties.R].detach().clone() - prev_positions[idx] 246 | for idx in range(self.n_parallel) 247 | ] 248 | is_finite_action = [ 249 | torch.isfinite(action).all().unsqueeze(dim=0) for action in actions 250 | ] 251 | return { 252 | "action": torch.cat(actions, dim=0), 253 | "energy": energy.detach(), 254 | "done": torch.cat(done, dim=0), 255 | "n_iter": torch.ones_like(energy), 256 | "is_finite_action": torch.cat(is_finite_action), 257 | "anti_gradient": output["anti_gradient"].detach(), 258 | } 259 | 260 | def select_action(self, t): 261 | output = self.act(t) 262 | return {key: value.cpu().numpy() for key, value in output.items()} 263 | 264 | 265 | class AsyncLBFGS: 266 | def __init__( 267 | self, 268 | state: dict, 269 | policy2optimizer_queue: asyncio.Queue, 270 | optimizer2policy_queue: asyncio.Queue, 271 | optimizer_kwargs: dict, 272 | grad_threshold: float, 273 | ): 274 | super().__init__() 275 | self.state = state 276 | self.state[properties.R].requires_grad_(True) 277 | self.policy2optimizer_queue = policy2optimizer_queue 278 | self.optimizer2policy_queue = optimizer2policy_queue 279 | self.optimizer = lbfgs.LBFGS([self.state[properties.R]], **optimizer_kwargs) 280 | self.energy = None 281 | self.anti_gradient = None 282 | self.n_iter = None 283 | self.grad_threshold = grad_threshold 284 | 285 | async def closure(self): 286 | self.optimizer.zero_grad() 287 | await self.optimizer2policy_queue.put(self.state) 288 | anti_gradient, energy = await self.policy2optimizer_queue.get() 289 | self.state[properties.R].grad = -anti_gradient.type( 290 | self.state[properties.R].dtype 291 | ) 292 | # Energy and anti-gradient before step 293 | if self.n_iter == 0: 294 | self.anti_gradient = anti_gradient 295 | self.energy = energy 296 | 297 | self.n_iter += 1 298 | return energy 299 | 300 | async def step(self): 301 | self.n_iter = 0 302 | previous_position = self.state[properties.R].detach().clone() 303 | await self.optimizer.step(self.closure) 304 | await self.optimizer2policy_queue.put(None) 305 | done = torch.unsqueeze( 306 | self.optimizer._gather_flat_grad().abs().max() <= self.grad_threshold, dim=0 307 | ) 308 | action = self.state[properties.R].detach().clone() - previous_position 309 | is_finite_action = torch.isfinite(action).all().unsqueeze(dim=0) 310 | return { 311 | "action": action, 312 | "energy": self.energy, 313 | "done": done, 314 | "is_finite_action": is_finite_action, 315 | "n_iter": torch.tensor([self.n_iter]), 316 | "anti_gradient": self.anti_gradient, 317 | } 318 | 319 | 320 | class LBFGSConformationOptimizer(nn.Module): 321 | def __init__( 322 | self, 323 | n_parallel, 324 | actor, 325 | optimizer_kwargs, 326 | grad_threshold=1e-5, 327 | lbfgs_device="cuda", 328 | ): 329 | super().__init__() 330 | self.n_parallel = n_parallel 331 | self.grad_threshold = grad_threshold 332 | self.optimizer_kwargs = optimizer_kwargs 333 | self.actor = actor 334 | self.lbfgs_device = torch.device(lbfgs_device) 335 | self.loop = asyncio.new_event_loop() 336 | self.policy2optimizer_queues = None 337 | self.optimizer2policy_queues = None 338 | self.loop.run_until_complete(self.set_queues()) 339 | self.conformation_optimizers = [None] * self.n_parallel 340 | 341 | async def set_queues(self): 342 | self.policy2optimizer_queues = [ 343 | asyncio.Queue(maxsize=1) for _ in range(self.n_parallel) 344 | ] 345 | self.optimizer2policy_queues = [ 346 | asyncio.Queue(maxsize=1) for _ in range(self.n_parallel) 347 | ] 348 | 349 | def reset(self, initial_states, indices=None): 350 | if indices is None: 351 | indices = torch.arange(self.n_parallel) 352 | unpad_initial_states = unpad_state(initial_states) 353 | torch.set_grad_enabled(True) 354 | for i, idx in enumerate(indices): 355 | state = { 356 | key: value.detach().clone().to(self.lbfgs_device) 357 | for key, value in unpad_initial_states[i].items() 358 | } 359 | self.conformation_optimizers[idx] = AsyncLBFGS( 360 | state, 361 | self.policy2optimizer_queues[idx], 362 | self.optimizer2policy_queues[idx], 363 | self.optimizer_kwargs, 364 | self.grad_threshold, 365 | ) 366 | 367 | async def _act_task(self): 368 | conformation_optimizers_ids = set(range(self.n_parallel)) 369 | while True: 370 | individual_states = {} 371 | stopped_optimizers_ids = set() 372 | for conformation_optimizer_id in conformation_optimizers_ids: 373 | individual_state = await self.optimizer2policy_queues[ 374 | conformation_optimizer_id 375 | ].get() 376 | if individual_state is None: 377 | stopped_optimizers_ids.add(conformation_optimizer_id) 378 | continue 379 | 380 | individual_states[conformation_optimizer_id] = individual_state 381 | 382 | conformation_optimizers_ids -= stopped_optimizers_ids 383 | if len(individual_states) == 0: 384 | break 385 | 386 | states = _atoms_collate_fn(list(individual_states.values())) 387 | torch.set_grad_enabled(True) 388 | states = {key: value.to(DEVICE) for key, value in states.items()} 389 | output = self.actor( 390 | state_dict=states, 391 | active_optimizers_ids=list(conformation_optimizers_ids), 392 | train=True, 393 | ) 394 | anti_gradients = torch.split( 395 | output["anti_gradient"].detach().to(self.lbfgs_device), 396 | states[properties.n_atoms].tolist(), 397 | ) 398 | energies = output["energy"].detach().to(self.lbfgs_device).view(-1, 1) 399 | for i, optimizer_idx in enumerate(individual_states.keys()): 400 | await self.policy2optimizer_queues[optimizer_idx].put( 401 | (anti_gradients[i], energies[i]) 402 | ) 403 | 404 | async def _act_async(self): 405 | tasks = [ 406 | conformation_optimizer.step() 407 | for conformation_optimizer in self.conformation_optimizers 408 | ] 409 | tasks.append(self._act_task()) 410 | task_results = await asyncio.gather(*tasks) 411 | 412 | result = collections.defaultdict(list) 413 | for task in task_results[:-1]: 414 | for key, value in task.items(): 415 | result[key].append(value) 416 | 417 | for key, value in result.items(): 418 | result[key] = torch.cat(value, dim=0) 419 | 420 | return result 421 | 422 | def act(self, t): 423 | return self.loop.run_until_complete(self._act_async()) 424 | 425 | def select_action(self, t): 426 | output = self.act(t) 427 | return {key: value.cpu().numpy() for key, value in output.items()} 428 | -------------------------------------------------------------------------------- /GOLF/GOLF_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from schnetpack import properties 3 | from schnetpack.nn import scatter_add 4 | from torch.nn.functional import mse_loss 5 | 6 | from GOLF import DEVICE 7 | from GOLF.utils import calculate_gradient_norm, get_lr_scheduler, get_optimizer_class 8 | 9 | 10 | class GOLF(object): 11 | def __init__( 12 | self, 13 | policy, 14 | lr, 15 | batch_size=256, 16 | clip_value=None, 17 | lr_scheduler=None, 18 | energy_loss_coef=0.01, 19 | force_loss_coef=0.99, 20 | load_model=None, 21 | total_steps=0, 22 | utd_ratio=1, 23 | optimizer_name="adam", 24 | ): 25 | self.actor = policy.actor 26 | self.optimizer = get_optimizer_class(optimizer_name)( 27 | self.actor.parameters(), lr=lr 28 | ) 29 | if load_model: 30 | self.load(load_model) 31 | last_epoch = int(load_model.split("/")[-1].split("_")[-1]) * utd_ratio 32 | else: 33 | last_epoch = -1 34 | 35 | self.use_lr_scheduler = lr_scheduler is not None 36 | if self.use_lr_scheduler: 37 | lr_kwargs = { 38 | "gamma": 0.1, 39 | "total_steps": total_steps, 40 | "final_div_factor": 1e3, 41 | "last_epoch": last_epoch, 42 | } 43 | lr_kwargs["initial_lr"] = lr 44 | self.lr_scheduler = get_lr_scheduler( 45 | lr_scheduler, self.optimizer, **lr_kwargs 46 | ) 47 | 48 | self.batch_size = batch_size 49 | self.clip_value = clip_value 50 | self.energy_loss_coef = energy_loss_coef 51 | self.force_loss_coef = force_loss_coef 52 | self.total_it = 0 53 | 54 | def update(self, replay_buffer): 55 | metrics = dict() 56 | 57 | # Train model 58 | state, force, energy = replay_buffer.sample(self.batch_size) 59 | output = self.actor(state, train=True) 60 | predicted_energy = output["energy"] 61 | predicted_force = output["anti_gradient"] 62 | n_atoms = state[properties.n_atoms] 63 | 64 | energy_loss = mse_loss(predicted_energy, energy.squeeze(1)) 65 | force_loss = torch.sum( 66 | scatter_add( 67 | mse_loss(predicted_force, force, reduction="none").mean(-1), 68 | state[properties.idx_m], 69 | dim_size=n_atoms.size(0), 70 | ) 71 | / n_atoms 72 | ) / n_atoms.size(0) 73 | loss = self.force_loss_coef * force_loss + self.energy_loss_coef * energy_loss 74 | 75 | if not torch.all(torch.isfinite(loss)): 76 | print(f"Non finite values in loss") 77 | return metrics 78 | 79 | self.optimizer.zero_grad() 80 | loss.backward() 81 | 82 | if self.clip_value is not None: 83 | grad_norm = torch.nn.utils.clip_grad_norm_( 84 | self.actor.parameters(), self.clip_value 85 | ) 86 | else: 87 | grad_norm = calculate_gradient_norm(self.actor) 88 | if not torch.all(torch.isfinite(grad_norm)): 89 | print("Non finite values in GD grad_norm") 90 | return metrics 91 | 92 | self.optimizer.step() 93 | # Update lr 94 | if self.use_lr_scheduler: 95 | self.lr_scheduler.step() 96 | 97 | self.total_it += 1 98 | metrics["train/loss"] = loss.item() 99 | metrics["train/energy_loss"] = energy_loss.item() 100 | metrics["train/force_loss"] = force_loss.item() 101 | metrics["train/energy_loss_contrib"] = ( 102 | energy_loss.item() * self.energy_loss_coef 103 | ) 104 | metrics["train/force_loss_contrib"] = force_loss.item() * self.force_loss_coef 105 | metrics["grad_norm"] = grad_norm.item() 106 | if self.use_lr_scheduler: 107 | metrics["lr"] = self.lr_scheduler.get_last_lr()[0] 108 | return metrics 109 | 110 | def eval(self, replay_buffer): 111 | metrics = dict() 112 | 113 | # Evaluate on test dataset 114 | eval_state, eval_force, eval_energy = replay_buffer.sample_eval(self.batch_size) 115 | output = self.actor(eval_state, train=True) 116 | predicted_energy = output["energy"] 117 | predicted_force = output["anti_gradient"] 118 | n_atoms = eval_state[properties.n_atoms] 119 | 120 | with torch.no_grad(): 121 | eval_energy_loss = mse_loss(predicted_energy, eval_energy.squeeze(1)) 122 | eval_force_loss = torch.sum( 123 | scatter_add( 124 | mse_loss(predicted_force, eval_force, reduction="none").mean(-1), 125 | eval_state[properties.idx_m], 126 | dim_size=n_atoms.size(0), 127 | ) 128 | / n_atoms 129 | ) / n_atoms.size(0) 130 | eval_loss = ( 131 | self.force_loss_coef * eval_force_loss 132 | + self.energy_loss_coef * eval_energy_loss 133 | ) 134 | 135 | metrics["eval/loss"] = eval_loss.item() 136 | metrics["eval/energy_loss"] = eval_energy_loss.item() 137 | metrics["eval/force_loss"] = eval_force_loss.item() 138 | metrics["eval/energy_loss_contrib"] = ( 139 | eval_energy_loss.item() * self.energy_loss_coef 140 | ) 141 | metrics["eval/force_loss_contrib"] = ( 142 | eval_force_loss.item() * self.force_loss_coef 143 | ) 144 | return metrics 145 | 146 | def save(self, filename): 147 | self.light_save(filename) 148 | torch.save(self.optimizer.state_dict(), f"{filename}_optimizer") 149 | 150 | def light_save(self, filename): 151 | torch.save(self.actor.state_dict(), f"{filename}_actor") 152 | 153 | def load(self, filename): 154 | self.light_load(filename) 155 | self.optimizer.load_state_dict( 156 | torch.load(f"{filename}_optimizer", map_location=DEVICE) 157 | ) 158 | self.optimizer.param_groups[0]["capturable"] = True 159 | 160 | def light_load(self, filename): 161 | self.actor.load_state_dict(torch.load(f"{filename}_actor", map_location=DEVICE)) 162 | -------------------------------------------------------------------------------- /GOLF/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | -------------------------------------------------------------------------------- /GOLF/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | from collections import defaultdict 5 | 6 | from GOLF import DEVICE 7 | from GOLF.utils import recollate_batch 8 | 9 | CONVERGENCE_THRESHOLD = 1e-5 10 | 11 | 12 | def run_policy(env, actor, fixed_atoms, smiles, max_timestamps, eval_termination_mode): 13 | teminate_episode_condition = False 14 | delta_energy = 0 15 | t = 0 16 | 17 | # Reset initial state in actor 18 | state = env.set_initial_positions( 19 | fixed_atoms, smiles, energy_list=[None], force_list=[None] 20 | ) 21 | actor.reset({k: v.to(DEVICE) for k, v in state.items()}) 22 | 23 | # Get initial final energies in case of an optimization failure 24 | initial_energy = env.get_energies() 25 | 26 | while not teminate_episode_condition: 27 | select_action_result = actor.select_action([t]) 28 | action = select_action_result["action"] 29 | done = select_action_result["done"] 30 | state, reward, _, info = env.step(action) 31 | state = {k: v.to(DEVICE) for k, v in state.items()} 32 | delta_energy += reward[0] 33 | t += 1 34 | if eval_termination_mode == "grad_norm": 35 | teminate_episode_condition = done[0] 36 | elif eval_termination_mode == "negative_reward": 37 | teminate_episode_condition = reward[0] < 0 38 | # Terminate if max len is reached 39 | teminate_episode_condition = teminate_episode_condition or t >= max_timestamps 40 | 41 | if delta_energy < 0: 42 | final_energy = initial_energy[0] 43 | delta_energy = 0 44 | # Reset env to initial state 45 | state = env.set_initial_positions( 46 | fixed_atoms, smiles, energy_list=[None], force_list=[None] 47 | ) 48 | else: 49 | final_energy = info["final_energy"][0] 50 | 51 | return delta_energy, final_energy, t 52 | 53 | 54 | def rdkit_minimize_until_convergence(env, fixed_atoms, smiles, max_its=0): 55 | M_init = 1000 56 | env.set_initial_positions( 57 | fixed_atoms, smiles, energy_list=[None], force_list=[None], max_its=max_its 58 | ) 59 | initial_energy = env.initial_energy["rdkit"][0] 60 | not_converged, final_energy, _ = env.minimize_rdkit(idx=0, max_its=M_init) 61 | while not_converged: 62 | M_init *= 2 63 | not_converged, final_energy, _ = env.minimize_rdkit(idx=0, max_its=M_init) 64 | if M_init > 5000: 65 | print("Minimization did not converge!") 66 | return initial_energy, final_energy 67 | return initial_energy, final_energy 68 | 69 | 70 | def eval_policy_dft(actor, env, eval_episodes=10): 71 | start = time.perf_counter() 72 | max_timestamps = env.unwrapped.TL 73 | result = defaultdict(list) 74 | episode_returns = np.zeros(env.n_parallel) 75 | dft_pct_of_minimized_energy = [] 76 | 77 | # Reset env and actor 78 | state = env.reset() 79 | actor.reset(state) 80 | actor.eval() 81 | 82 | # Calculate optimal delta energy DFT 83 | optimized_delta_energy_dft = np.zeros(env.n_parallel) 84 | for i in range(env.n_parallel): 85 | optimized_delta_energy_dft[i] = ( 86 | env.unwrapped.energy[i] - env.unwrapped.optimal_energy[i] 87 | ) 88 | 89 | # Calculate optimal delta energy Rdkit 90 | initial_energy_rdkit = env.rdkit_oracle.initial_energies 91 | 92 | # First save all the data 93 | molecules = [molecule.copy() for molecule in env.unwrapped.atoms] 94 | smiles = env.unwrapped.smiles.copy() 95 | dft_initial_energies = env.unwrapped.energy.copy() 96 | dft_forces = env.unwrapped.force.copy() 97 | 98 | # Get optimial rdkit energy and calculate delta 99 | _, optimal_energy_rdkit, _ = env.rdkit_oracle.calculate_energies_forces( 100 | max_its=5000 101 | ) 102 | optimized_delta_energy_rdkit = initial_energy_rdkit - optimal_energy_rdkit 103 | 104 | # Reset the environment again 105 | env.set_initial_positions(molecules, smiles, dft_initial_energies, dft_forces) 106 | 107 | while len(result["eval/dft_delta_energy"]) < eval_episodes: 108 | episode_timesteps = env.unwrapped.get_env_step() 109 | # TODO incorporate actor dones into DFT evaluation 110 | select_action_result = actor.select_action(episode_timesteps) 111 | action = select_action_result["action"] 112 | # actor_dones = select_action_result["done"] 113 | 114 | # Obser reward and next obs 115 | state, rdkit_rewards, _, info = env.step(action) 116 | dones = [(t + 1) > max_timestamps for t in episode_timesteps] 117 | episode_returns += rdkit_rewards 118 | 119 | if "calculate_dft_energy_env_ids" in info: 120 | dft_pct_of_minimized_energy.extend( 121 | optimized_delta_energy_dft[ 122 | info["calculate_dft_energy_env_ids"] 123 | ].tolist()[: env.n_parallel - len(dft_pct_of_minimized_energy)] 124 | ) 125 | 126 | # If task queue is full wait for all tasks to finish 127 | if env.dft_oracle.task_queue_full_flag: 128 | _, _, _, episode_total_delta_energies = env.dft_oracle.get_data(eval=True) 129 | # Log total delta energy and pct of optimized energy 130 | result["eval/dft_delta_energy"].extend( 131 | episode_total_delta_energies.tolist() 132 | ) 133 | dft_pct_of_minimized_energy = episode_total_delta_energies / np.array( 134 | dft_pct_of_minimized_energy 135 | ) 136 | result["eval/dft_pct_of_minimized_energy"].extend( 137 | dft_pct_of_minimized_energy.tolist() 138 | ) 139 | 140 | # All trajectories terminate at the same time 141 | if np.all(dones): 142 | rdkit_pct_of_minimized_energy = ( 143 | episode_returns / optimized_delta_energy_rdkit 144 | ) 145 | rdkit_delta_energy = episode_returns 146 | final_energies = info["final_energy"] 147 | 148 | # Optimization failure 149 | optimization_failure_mask = episode_returns < 0 150 | rdkit_pct_of_minimized_energy[optimization_failure_mask] = 0.0 151 | rdkit_delta_energy[optimization_failure_mask] = 0.0 152 | final_energies[optimization_failure_mask] = initial_energy_rdkit[ 153 | optimization_failure_mask 154 | ] 155 | 156 | # Log results 157 | result["eval/rdkit_pct_of_minimized_energy"].extend( 158 | rdkit_pct_of_minimized_energy.tolist() 159 | ) 160 | result["eval/rdkit_delta_energy"].extend(rdkit_delta_energy.tolist()) 161 | result["eval/final_energy"].extend(final_energies.tolist()) 162 | result["eval/episode_len"].extend(episode_timesteps) 163 | 164 | # Reset episode returns 165 | episode_returns = np.zeros(env.n_parallel) 166 | 167 | # Reset env and actor 168 | state = env.reset() 169 | actor.reset(state) 170 | actor.eval() 171 | 172 | # Update optimal delta energy DFT 173 | for i in range(env.n_parallel): 174 | optimized_delta_energy_dft[i] = ( 175 | env.unwrapped.energy[i] - env.unwrapped.optimal_energy[i] 176 | ) 177 | 178 | # Update optimal delta energy Rdkit 179 | initial_energy_rdkit = env.rdkit_oracle.initial_energies 180 | 181 | # First save all the data 182 | molecules = [molecule.copy() for molecule in env.unwrapped.atoms] 183 | smiles = env.unwrapped.smiles.copy() 184 | dft_initial_energies = env.unwrapped.energy.copy() 185 | dft_forces = env.unwrapped.force.copy() 186 | 187 | # Get optimial rdkit energy and calculate delta 188 | _, optimal_energy_rdkit, _ = env.rdkit_oracle.calculate_energies_forces( 189 | max_its=5000 190 | ) 191 | optimized_delta_energy_rdkit = initial_energy_rdkit - optimal_energy_rdkit 192 | 193 | # Reset the environment again 194 | env.set_initial_positions( 195 | molecules, smiles, dft_initial_energies, dft_forces 196 | ) 197 | 198 | actor.train() 199 | result = {k: np.array(v).mean() for k, v in result.items()} 200 | print( 201 | "Full Evaluation time: {:.3f}, results: {}".format( 202 | time.perf_counter() - start, result 203 | ) 204 | ) 205 | return result 206 | 207 | 208 | def eval_policy_rdkit( 209 | actor, 210 | env, 211 | eval_episodes=10, 212 | eval_termination_mode=False, 213 | ): 214 | assert env.n_parallel == 1, "Eval env is supposed to have n_parallel=1." 215 | 216 | max_timestamps = env.unwrapped.TL 217 | result = defaultdict(lambda: 0.0) 218 | for _ in range(eval_episodes): 219 | env.reset() 220 | if hasattr(env.unwrapped, "smiles"): 221 | smiles = env.unwrapped.smiles.copy() 222 | else: 223 | smiles = [None] 224 | fixed_atoms = env.unwrapped.atoms.copy() 225 | 226 | # Evaluate policy in eval mode 227 | actor.eval() 228 | eval_delta_energy, eval_final_energy, eval_episode_len = run_policy( 229 | env, actor, fixed_atoms, smiles, max_timestamps, eval_termination_mode 230 | ) 231 | result["eval/delta_energy"] += eval_delta_energy 232 | result["eval/final_energy"] += eval_final_energy 233 | result["eval/episode_len"] += eval_episode_len 234 | 235 | # Compute minimal energy of the molecule 236 | initial_energy, final_energy = rdkit_minimize_until_convergence( 237 | env, fixed_atoms, smiles, max_its=0 238 | ) 239 | pct = (initial_energy - eval_final_energy) / (initial_energy - final_energy) 240 | result["eval/pct_of_minimized_energy"] += pct 241 | if pct > 1.0 or pct < -100: 242 | print( 243 | "Strange conformation encountered: pct={:.3f} \nSmiles: {} \ 244 | \n Coords: \n{}".format( 245 | result["eval/pct_of_minimized_energy"], 246 | smiles, 247 | fixed_atoms[0].get_positions(), 248 | ) 249 | ) 250 | 251 | # Switch actor back to training mode 252 | actor.train() 253 | 254 | result = {k: v / eval_episodes for k, v in result.items()} 255 | return result 256 | -------------------------------------------------------------------------------- /GOLF/experience_saver.py: -------------------------------------------------------------------------------- 1 | from schnetpack.data.loader import _atoms_collate_fn 2 | 3 | from GOLF.utils import unpad_state 4 | 5 | 6 | class BaseSaver: 7 | def __init__(self, env, replay_buffer): 8 | self.env = env 9 | self.replay_buffer = replay_buffer 10 | 11 | def get_forces(self, indices=None): 12 | return self.env.get_forces(indices=indices) 13 | 14 | def save(self, states, envs_to_store): 15 | if len(envs_to_store) > 0: 16 | energies = self.env.get_energies(indices=envs_to_store) 17 | forces = self.env.get_forces(indices=envs_to_store) 18 | state_list = unpad_state(states) 19 | state_list = [state_list[i] for i in envs_to_store] 20 | self.replay_buffer.add(_atoms_collate_fn(state_list), forces, energies) 21 | 22 | 23 | class RewardThresholdSaver(BaseSaver): 24 | def __init__(self, env, replay_buffer, reward_threshold): 25 | super().__init__(env, replay_buffer) 26 | self.reward_threshold = reward_threshold 27 | 28 | def __call__(self, states, rewards, _): 29 | # Only store states with reward > REWARD_THRESHOLD 30 | envs_to_store = [ 31 | i for i, reward in enumerate(rewards) if reward > self.reward_threshold 32 | ] 33 | super().save(states, envs_to_store) 34 | 35 | 36 | class LastConformationSaver(BaseSaver): 37 | def __init__(self, env, replay_buffer, reward_threshold): 38 | super().__init__(env, replay_buffer) 39 | self.reward_threshold = reward_threshold 40 | 41 | def __call__(self, states, rewards, dones): 42 | # Save last states of trajectories (and with reward > REWARD_THRESHOLD) 43 | envs_to_store = [ 44 | i 45 | for i, (done, reward) in enumerate(zip(dones, rewards)) 46 | if done and reward > self.reward_threshold 47 | ] 48 | super().save(states, envs_to_store) 49 | -------------------------------------------------------------------------------- /GOLF/make_policies.py: -------------------------------------------------------------------------------- 1 | import schnetpack 2 | from torch.optim import SGD, Adam 3 | 4 | from GOLF import DEVICE 5 | from GOLF.GOLF_actor import ( 6 | Actor, 7 | RdkitActor, 8 | ConformationOptimizer, 9 | LBFGSConformationOptimizer, 10 | ) 11 | from GOLF.utils import get_cutoff_by_string, get_radial_basis_by_string 12 | from GOLF.optim.lion_pytorch import Lion 13 | from utils.utils import ignore_extra_args 14 | 15 | actors = { 16 | "GOLF": ignore_extra_args(Actor), 17 | "rdkit": ignore_extra_args(RdkitActor), 18 | } 19 | 20 | 21 | def make_policies(env, eval_env, args): 22 | # Backbone args 23 | backbone_args = { 24 | "n_interactions": args.n_interactions, 25 | "n_atom_basis": args.n_atom_basis, 26 | "radial_basis": get_radial_basis_by_string(args.radial_basis_type)( 27 | n_rbf=args.n_rbf, cutoff=args.cutoff 28 | ), 29 | "cutoff_fn": get_cutoff_by_string("cosine")(args.cutoff), 30 | } 31 | 32 | # Actor args 33 | actor_args = { 34 | "env": env, 35 | "backbone": args.backbone, 36 | "backbone_args": backbone_args, 37 | "do_postprocessing": args.do_postprocessing, 38 | "action_norm_limit": args.action_norm_limit, 39 | } 40 | actor = actors[args.actor](**actor_args) 41 | 42 | policy_args = { 43 | "n_parallel": args.n_parallel, 44 | "lr_scheduler": args.conf_opt_lr_scheduler, 45 | "t_max": args.timelimit_train, 46 | } 47 | 48 | if args.conformation_optimizer == "LBFGS": 49 | policy_args.update( 50 | { 51 | "grad_threshold": args.grad_threshold, 52 | "lbfgs_device": args.lbfgs_device, 53 | "optimizer_kwargs": { 54 | "lr": 1, 55 | "max_iter": args.max_iter, 56 | }, 57 | } 58 | ) 59 | elif args.conformation_optimizer == "GD": 60 | policy_args.update( 61 | { 62 | "optimizer": SGD, 63 | "optimizer_kwargs": { 64 | "lr": args.conf_opt_lr, 65 | "momentum": args.momentum, 66 | }, 67 | } 68 | ) 69 | elif args.conformation_optimizer == "Lion": 70 | policy_args.update( 71 | { 72 | "optimizer": Lion, 73 | "optimizer_kwargs": { 74 | "lr": args.conf_opt_lr, 75 | "betas": (args.lion_beta1, args.lion_beta2), 76 | }, 77 | } 78 | ) 79 | elif args.conformation_optimizer == "Adam": 80 | policy_args.update( 81 | {"optimizer": Adam, "optimizer_kwargs": {"lr": args.conf_opt_lr}} 82 | ) 83 | else: 84 | raise NotImplemented("Unknowm policy type: {}!".format(args.policy)) 85 | 86 | if args.conformation_optimizer == "LBFGS": 87 | policy = ignore_extra_args(LBFGSConformationOptimizer)( 88 | actor=actor, **policy_args 89 | ).to(DEVICE) 90 | else: 91 | policy = ignore_extra_args(ConformationOptimizer)( 92 | actor=actor, **policy_args 93 | ).to(DEVICE) 94 | 95 | # Initialize eval policy 96 | if args.reward == "rdkit": 97 | n_parallel_eval = 1 98 | else: 99 | n_parallel_eval = args.n_eval_runs 100 | 101 | # Update arguments and initialize new actor 102 | policy_args["n_parallel"] = n_parallel_eval 103 | 104 | actor_args.update({"env": eval_env}) 105 | eval_actor = actors[args.actor](**actor_args) 106 | 107 | if args.conformation_optimizer == "LBFGS": 108 | eval_policy = ignore_extra_args(LBFGSConformationOptimizer)( 109 | actor=eval_actor, **policy_args 110 | ).to(DEVICE) 111 | else: 112 | eval_policy = ignore_extra_args(ConformationOptimizer)( 113 | actor=eval_actor, **policy_args 114 | ).to(DEVICE) 115 | 116 | return policy, eval_policy 117 | -------------------------------------------------------------------------------- /GOLF/make_saver.py: -------------------------------------------------------------------------------- 1 | from GOLF.experience_saver import ( 2 | RewardThresholdSaver, 3 | LastConformationSaver, 4 | ) 5 | from utils.utils import ignore_extra_args 6 | 7 | 8 | savers = { 9 | "reward_threshold": ignore_extra_args(RewardThresholdSaver), 10 | "last": ignore_extra_args(LastConformationSaver), 11 | } 12 | 13 | 14 | def make_saver(args, env, replay_buffer, actor, reward_thresh): 15 | if args.reward == "dft": 16 | thresh = reward_thresh / 627.5 17 | else: 18 | thresh = reward_thresh 19 | 20 | return savers[args.experience_saver]( 21 | env=env, 22 | replay_buffer=replay_buffer, 23 | reward_threshold=thresh, 24 | actor=actor, 25 | ) 26 | -------------------------------------------------------------------------------- /GOLF/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIRI-Institute/GOLF/deaaab858b9034b9d7a8359d9cdc61618f484334/GOLF/optim/__init__.py -------------------------------------------------------------------------------- /GOLF/optim/lbfgs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import reduce 3 | from torch.optim import Optimizer 4 | 5 | __all__ = ['LBFGS'] 6 | 7 | def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): 8 | # ported from https://github.com/torch/optim/blob/master/polyinterp.lua 9 | # Compute bounds of interpolation area 10 | if bounds is not None: 11 | xmin_bound, xmax_bound = bounds 12 | else: 13 | xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) 14 | 15 | # Code for most common case: cubic interpolation of 2 points 16 | # w/ function and derivative values for both 17 | # Solution in this case (where x2 is the farthest point): 18 | # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); 19 | # d2 = sqrt(d1^2 - g1*g2); 20 | # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); 21 | # t_new = min(max(min_pos,xmin_bound),xmax_bound); 22 | d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) 23 | d2_square = d1**2 - g1 * g2 24 | if d2_square >= 0: 25 | d2 = d2_square.sqrt() 26 | if x1 <= x2: 27 | min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) 28 | else: 29 | min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) 30 | return min(max(min_pos, xmin_bound), xmax_bound) 31 | else: 32 | return (xmin_bound + xmax_bound) / 2. 33 | 34 | 35 | def _strong_wolfe(obj_func, 36 | x, 37 | t, 38 | d, 39 | f, 40 | g, 41 | gtd, 42 | c1=1e-4, 43 | c2=0.9, 44 | tolerance_change=1e-9, 45 | max_ls=25): 46 | # ported from https://github.com/torch/optim/blob/master/lswolfe.lua 47 | d_norm = d.abs().max() 48 | g = g.clone(memory_format=torch.contiguous_format) 49 | # evaluate objective and gradient using initial step 50 | f_new, g_new = obj_func(x, t, d) 51 | ls_func_evals = 1 52 | gtd_new = g_new.dot(d) 53 | 54 | # bracket an interval containing a point satisfying the Wolfe criteria 55 | t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd 56 | done = False 57 | ls_iter = 0 58 | while ls_iter < max_ls: 59 | # check conditions 60 | if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): 61 | bracket = [t_prev, t] 62 | bracket_f = [f_prev, f_new] 63 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 64 | bracket_gtd = [gtd_prev, gtd_new] 65 | break 66 | 67 | if abs(gtd_new) <= -c2 * gtd: 68 | bracket = [t] 69 | bracket_f = [f_new] 70 | bracket_g = [g_new] 71 | done = True 72 | break 73 | 74 | if gtd_new >= 0: 75 | bracket = [t_prev, t] 76 | bracket_f = [f_prev, f_new] 77 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 78 | bracket_gtd = [gtd_prev, gtd_new] 79 | break 80 | 81 | # interpolate 82 | min_step = t + 0.01 * (t - t_prev) 83 | max_step = t * 10 84 | tmp = t 85 | t = _cubic_interpolate( 86 | t_prev, 87 | f_prev, 88 | gtd_prev, 89 | t, 90 | f_new, 91 | gtd_new, 92 | bounds=(min_step, max_step)) 93 | 94 | # next step 95 | t_prev = tmp 96 | f_prev = f_new 97 | g_prev = g_new.clone(memory_format=torch.contiguous_format) 98 | gtd_prev = gtd_new 99 | f_new, g_new = obj_func(x, t, d) 100 | ls_func_evals += 1 101 | gtd_new = g_new.dot(d) 102 | ls_iter += 1 103 | 104 | # reached max number of iterations? 105 | if ls_iter == max_ls: 106 | bracket = [0, t] 107 | bracket_f = [f, f_new] 108 | bracket_g = [g, g_new] 109 | 110 | # zoom phase: we now have a point satisfying the criteria, or 111 | # a bracket around it. We refine the bracket until we find the 112 | # exact point satisfying the criteria 113 | insuf_progress = False 114 | # find high and low points in bracket 115 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) 116 | while not done and ls_iter < max_ls: 117 | # line-search bracket is so small 118 | if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: 119 | break 120 | 121 | # compute new trial value 122 | t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0], 123 | bracket[1], bracket_f[1], bracket_gtd[1]) 124 | 125 | # test that we are making sufficient progress: 126 | # in case `t` is so close to boundary, we mark that we are making 127 | # insufficient progress, and if 128 | # + we have made insufficient progress in the last step, or 129 | # + `t` is at one of the boundary, 130 | # we will move `t` to a position which is `0.1 * len(bracket)` 131 | # away from the nearest boundary point. 132 | eps = 0.1 * (max(bracket) - min(bracket)) 133 | if min(max(bracket) - t, t - min(bracket)) < eps: 134 | # interpolation close to boundary 135 | if insuf_progress or t >= max(bracket) or t <= min(bracket): 136 | # evaluate at 0.1 away from boundary 137 | if abs(t - max(bracket)) < abs(t - min(bracket)): 138 | t = max(bracket) - eps 139 | else: 140 | t = min(bracket) + eps 141 | insuf_progress = False 142 | else: 143 | insuf_progress = True 144 | else: 145 | insuf_progress = False 146 | 147 | # Evaluate new point 148 | f_new, g_new = obj_func(x, t, d) 149 | ls_func_evals += 1 150 | gtd_new = g_new.dot(d) 151 | ls_iter += 1 152 | 153 | if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: 154 | # Armijo condition not satisfied or not lower than lowest point 155 | bracket[high_pos] = t 156 | bracket_f[high_pos] = f_new 157 | bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) 158 | bracket_gtd[high_pos] = gtd_new 159 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) 160 | else: 161 | if abs(gtd_new) <= -c2 * gtd: 162 | # Wolfe conditions satisfied 163 | done = True 164 | elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: 165 | # old high becomes new low 166 | bracket[high_pos] = bracket[low_pos] 167 | bracket_f[high_pos] = bracket_f[low_pos] 168 | bracket_g[high_pos] = bracket_g[low_pos] 169 | bracket_gtd[high_pos] = bracket_gtd[low_pos] 170 | 171 | # new point becomes new low 172 | bracket[low_pos] = t 173 | bracket_f[low_pos] = f_new 174 | bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) 175 | bracket_gtd[low_pos] = gtd_new 176 | 177 | # return stuff 178 | t = bracket[low_pos] 179 | f_new = bracket_f[low_pos] 180 | g_new = bracket_g[low_pos] 181 | return f_new, g_new, t, ls_func_evals 182 | 183 | 184 | class LBFGS(Optimizer): 185 | """Implements L-BFGS algorithm, heavily inspired by `minFunc 186 | `_. 187 | 188 | .. warning:: 189 | This optimizer doesn't support per-parameter options and parameter 190 | groups (there can be only one). 191 | 192 | .. warning:: 193 | Right now all parameters have to be on a single device. This will be 194 | improved in the future. 195 | 196 | .. note:: 197 | This is a very memory intensive optimizer (it requires additional 198 | ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory 199 | try reducing the history size, or use a different algorithm. 200 | 201 | Args: 202 | lr (float): learning rate (default: 1) 203 | max_iter (int): maximal number of iterations per optimization step 204 | (default: 20) 205 | max_eval (int): maximal number of function evaluations per optimization 206 | step (default: max_iter * 1.25). 207 | tolerance_grad (float): termination tolerance on first order optimality 208 | (default: 1e-5). 209 | tolerance_change (float): termination tolerance on function 210 | value/parameter changes (default: 1e-9). 211 | history_size (int): update history size (default: 100). 212 | line_search_fn (str): either 'strong_wolfe' or None (default: None). 213 | """ 214 | 215 | def __init__(self, 216 | params, 217 | lr=1, 218 | max_iter=20, 219 | max_eval=None, 220 | tolerance_grad=1e-7, 221 | tolerance_change=1e-9, 222 | history_size=100, 223 | line_search_fn=None): 224 | if max_eval is None: 225 | max_eval = max_iter * 5 // 4 226 | defaults = dict( 227 | lr=lr, 228 | max_iter=max_iter, 229 | max_eval=max_eval, 230 | tolerance_grad=tolerance_grad, 231 | tolerance_change=tolerance_change, 232 | history_size=history_size, 233 | line_search_fn=line_search_fn) 234 | super().__init__(params, defaults) 235 | 236 | if len(self.param_groups) != 1: 237 | raise ValueError("LBFGS doesn't support per-parameter options " 238 | "(parameter groups)") 239 | 240 | self._params = self.param_groups[0]['params'] 241 | self._numel_cache = None 242 | 243 | def _numel(self): 244 | if self._numel_cache is None: 245 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) 246 | return self._numel_cache 247 | 248 | def _gather_flat_grad(self): 249 | views = [] 250 | for p in self._params: 251 | if p.grad is None: 252 | view = p.new(p.numel()).zero_() 253 | elif p.grad.is_sparse: 254 | view = p.grad.to_dense().view(-1) 255 | else: 256 | view = p.grad.view(-1) 257 | views.append(view) 258 | return torch.cat(views, 0) 259 | 260 | def _add_grad(self, step_size, update): 261 | offset = 0 262 | for p in self._params: 263 | numel = p.numel() 264 | # view as to avoid deprecated pointwise semantics 265 | p.add_(update[offset:offset + numel].view_as(p), alpha=step_size) 266 | offset += numel 267 | assert offset == self._numel() 268 | 269 | def _clone_param(self): 270 | return [p.clone(memory_format=torch.contiguous_format) for p in self._params] 271 | 272 | def _set_param(self, params_data): 273 | for p, pdata in zip(self._params, params_data): 274 | p.copy_(pdata) 275 | 276 | def _directional_evaluate(self, closure, x, t, d): 277 | self._add_grad(t, d) 278 | loss = float(closure()) 279 | flat_grad = self._gather_flat_grad() 280 | self._set_param(x) 281 | return loss, flat_grad 282 | 283 | async def step(self, closure): 284 | """Performs a single optimization step. 285 | 286 | Args: 287 | closure (Callable): A closure that reevaluates the model 288 | and returns the loss. 289 | """ 290 | assert len(self.param_groups) == 1 291 | 292 | is_grad_enabled = torch.is_grad_enabled() 293 | torch.set_grad_enabled(False) 294 | 295 | # Make sure the closure is always called with grad enabled 296 | closure = torch.enable_grad()(closure) 297 | 298 | group = self.param_groups[0] 299 | lr = group['lr'] 300 | max_iter = group['max_iter'] 301 | max_eval = group['max_eval'] 302 | tolerance_grad = group['tolerance_grad'] 303 | tolerance_change = group['tolerance_change'] 304 | line_search_fn = group['line_search_fn'] 305 | history_size = group['history_size'] 306 | 307 | # NOTE: LBFGS has only global state, but we register it as state for 308 | # the first param, because this helps with casting in load_state_dict 309 | state = self.state[self._params[0]] 310 | state.setdefault('func_evals', 0) 311 | state.setdefault('n_iter', 0) 312 | 313 | # evaluate initial f(x) and df/dx 314 | 315 | # ++++++++ Async code block start ++++++++ 316 | with torch.enable_grad(): 317 | orig_loss = await closure() 318 | torch.set_grad_enabled(False) 319 | # ++++++++ Async code block end ++++++++ 320 | 321 | loss = float(orig_loss) 322 | current_evals = 1 323 | state['func_evals'] += 1 324 | 325 | flat_grad = self._gather_flat_grad() 326 | opt_cond = flat_grad.abs().max() <= tolerance_grad 327 | 328 | # optimal condition 329 | if opt_cond: 330 | return orig_loss 331 | 332 | # tensors cached in state (for tracing) 333 | d = state.get('d') 334 | t = state.get('t') 335 | old_dirs = state.get('old_dirs') 336 | old_stps = state.get('old_stps') 337 | ro = state.get('ro') 338 | H_diag = state.get('H_diag') 339 | prev_flat_grad = state.get('prev_flat_grad') 340 | prev_loss = state.get('prev_loss') 341 | 342 | n_iter = 0 343 | # optimize for a max of max_iter iterations 344 | while n_iter < max_iter: 345 | # keep track of nb of iterations 346 | n_iter += 1 347 | state['n_iter'] += 1 348 | 349 | ############################################################ 350 | # compute gradient descent direction 351 | ############################################################ 352 | if state['n_iter'] == 1: 353 | d = flat_grad.neg() 354 | old_dirs = [] 355 | old_stps = [] 356 | ro = [] 357 | H_diag = 1 358 | else: 359 | # do lbfgs update (update memory) 360 | y = flat_grad.sub(prev_flat_grad) 361 | s = d.mul(t) 362 | ys = y.dot(s) # y*s 363 | if ys > 1e-10: 364 | # updating memory 365 | if len(old_dirs) == history_size: 366 | # shift history by one (limited-memory) 367 | old_dirs.pop(0) 368 | old_stps.pop(0) 369 | ro.pop(0) 370 | 371 | # store new direction/step 372 | old_dirs.append(y) 373 | old_stps.append(s) 374 | ro.append(1. / ys) 375 | 376 | # update scale of initial Hessian approximation 377 | H_diag = ys / y.dot(y) # (y*y) 378 | 379 | # compute the approximate (L-BFGS) inverse Hessian 380 | # multiplied by the gradient 381 | num_old = len(old_dirs) 382 | 383 | if 'al' not in state: 384 | state['al'] = [None] * history_size 385 | al = state['al'] 386 | 387 | # iteration in L-BFGS loop collapsed to use just one buffer 388 | q = flat_grad.neg() 389 | for i in range(num_old - 1, -1, -1): 390 | al[i] = old_stps[i].dot(q) * ro[i] 391 | q.add_(old_dirs[i], alpha=-al[i]) 392 | 393 | # multiply by initial Hessian 394 | # r/d is the final direction 395 | d = r = torch.mul(q, H_diag) 396 | for i in range(num_old): 397 | be_i = old_dirs[i].dot(r) * ro[i] 398 | r.add_(old_stps[i], alpha=al[i] - be_i) 399 | 400 | if prev_flat_grad is None: 401 | prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) 402 | else: 403 | prev_flat_grad.copy_(flat_grad) 404 | prev_loss = loss 405 | 406 | ############################################################ 407 | # compute step length 408 | ############################################################ 409 | # reset initial guess for step size 410 | if state['n_iter'] == 1: 411 | t = min(1., 1. / flat_grad.abs().sum()) * lr 412 | else: 413 | t = lr 414 | 415 | # directional derivative 416 | gtd = flat_grad.dot(d) # g * d 417 | 418 | # directional derivative is below tolerance 419 | if gtd > -tolerance_change: 420 | break 421 | 422 | # optional line search: user function 423 | ls_func_evals = 0 424 | if line_search_fn is not None: 425 | raise NotImplementedError('Linear search has not been implemented!') 426 | # perform line search, using user function 427 | # if line_search_fn != "strong_wolfe": 428 | # raise RuntimeError("only 'strong_wolfe' is supported") 429 | # else: 430 | # x_init = self._clone_param() 431 | # 432 | # def obj_func(x, t, d): 433 | # return self._directional_evaluate(closure, x, t, d) 434 | # 435 | # loss, flat_grad, t, ls_func_evals = _strong_wolfe( 436 | # obj_func, x_init, t, d, loss, flat_grad, gtd) 437 | # self._add_grad(t, d) 438 | # opt_cond = flat_grad.abs().max() <= tolerance_grad 439 | else: 440 | # no line search, simply move with fixed-step 441 | self._add_grad(t, d) 442 | if n_iter != max_iter: 443 | # re-evaluate function only if not in last iteration 444 | # the reason we do this: in a stochastic setting, 445 | # no use to re-evaluate that function here 446 | 447 | # ++++++++ Async code block start ++++++++ 448 | with torch.enable_grad(): 449 | loss = float(await closure()) 450 | torch.set_grad_enabled(False) 451 | # ++++++++ Async code block end ++++++++ 452 | 453 | flat_grad = self._gather_flat_grad() 454 | opt_cond = flat_grad.abs().max() <= tolerance_grad 455 | ls_func_evals = 1 456 | 457 | # update func eval 458 | current_evals += ls_func_evals 459 | state['func_evals'] += ls_func_evals 460 | 461 | ############################################################ 462 | # check conditions 463 | ############################################################ 464 | if n_iter == max_iter: 465 | break 466 | 467 | if current_evals >= max_eval: 468 | break 469 | 470 | # optimal condition 471 | if opt_cond: 472 | break 473 | 474 | # lack of progress 475 | if d.mul(t).abs().max() <= tolerance_change: 476 | break 477 | 478 | if abs(loss - prev_loss) < tolerance_change: 479 | break 480 | 481 | state['d'] = d 482 | state['t'] = t 483 | state['old_dirs'] = old_dirs 484 | state['old_stps'] = old_stps 485 | state['ro'] = ro 486 | state['H_diag'] = H_diag 487 | state['prev_flat_grad'] = prev_flat_grad 488 | state['prev_loss'] = prev_loss 489 | 490 | torch.set_grad_enabled(is_grad_enabled) 491 | 492 | return orig_loss 493 | -------------------------------------------------------------------------------- /GOLF/optim/lion_pytorch.py: -------------------------------------------------------------------------------- 1 | # Copy-pasted from https://github.com/google/automl/blob/master/lion/lion_pytorch.py 2 | 3 | # Copyright 2023 Google Research. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """PyTorch implementation of the Lion optimizer.""" 18 | import torch 19 | from torch.optim.optimizer import Optimizer 20 | 21 | 22 | class Lion(Optimizer): 23 | r"""Implements Lion algorithm.""" 24 | 25 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): 26 | """Initialize the hyperparameters. 27 | Args: 28 | params (iterable): iterable of parameters to optimize or dicts defining 29 | parameter groups 30 | lr (float, optional): learning rate (default: 1e-4) 31 | betas (Tuple[float, float], optional): coefficients used for computing 32 | running averages of gradient and its square (default: (0.9, 0.99)) 33 | weight_decay (float, optional): weight decay coefficient (default: 0) 34 | """ 35 | 36 | if not 0.0 <= lr: 37 | raise ValueError("Invalid learning rate: {}".format(lr)) 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 42 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 43 | super().__init__(params, defaults) 44 | 45 | @torch.no_grad() 46 | def step(self, closure=None): 47 | """Performs a single optimization step. 48 | Args: 49 | closure (callable, optional): A closure that reevaluates the model 50 | and returns the loss. 51 | Returns: 52 | the loss. 53 | """ 54 | loss = None 55 | if closure is not None: 56 | with torch.enable_grad(): 57 | loss = closure() 58 | 59 | for group in self.param_groups: 60 | for p in group["params"]: 61 | if p.grad is None: 62 | continue 63 | 64 | # Perform stepweight decay 65 | p.data.mul_(1 - group["lr"] * group["weight_decay"]) 66 | 67 | grad = p.grad 68 | state = self.state[p] 69 | # State initialization 70 | if len(state) == 0: 71 | # Exponential moving average of gradient values 72 | state["exp_avg"] = torch.zeros_like(p) 73 | 74 | exp_avg = state["exp_avg"] 75 | beta1, beta2 = group["betas"] 76 | 77 | # Weight update 78 | update = exp_avg * beta1 + grad * (1 - beta1) 79 | p.add_(torch.sign(update), alpha=-group["lr"]) 80 | # Decay the momentum running average coefficient 81 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 82 | 83 | return loss 84 | -------------------------------------------------------------------------------- /GOLF/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from schnetpack import properties 5 | from schnetpack.data.loader import _atoms_collate_fn 6 | from schnetpack.nn import scatter_add 7 | 8 | from GOLF.utils import unpad_state 9 | from env.moldynamics_env import env_fn 10 | from env.wrappers import EnergyWrapper 11 | 12 | 13 | NORM_THRESHOLD = 10.5 14 | 15 | 16 | class ReplayBuffer(object): 17 | def __init__( 18 | self, 19 | device, 20 | max_size, 21 | max_total_conformations, 22 | atomrefs=None, 23 | initial_RB=None, 24 | eval_RB=None, 25 | initial_conf_pct=0.0, 26 | ): 27 | self.device = device 28 | self.max_size = max_size 29 | self.max_total_conformations = max_total_conformations 30 | self.initial_RB = initial_RB 31 | self.eval_RB = eval_RB 32 | self.ptr = 0 33 | self.size = 0 34 | self.replay_buffer_full = False 35 | 36 | if self.initial_RB: 37 | self.initial_conf_pct = initial_conf_pct 38 | else: 39 | self.initial_conf_pct = 0.0 40 | 41 | self.states = [None] * self.max_size 42 | self.energy = torch.empty((max_size, 1), dtype=torch.float32) 43 | self.forces = [None] * self.max_size 44 | 45 | if atomrefs: 46 | self.atomrefs = torch.tensor(atomrefs, device=device) 47 | else: 48 | self.atomrefs = None 49 | 50 | def add(self, states, forces, energies): 51 | energies = torch.tensor(energies, dtype=torch.float32) 52 | force_norms = np.array([np.linalg.norm(force) for force in forces]) 53 | individual_states = unpad_state(states) 54 | # Exclude conformations with forces that have a high norm 55 | # from the replay buffer 56 | for i in np.where(force_norms < NORM_THRESHOLD)[0]: 57 | self.states[self.ptr] = individual_states[i] 58 | self.energy[self.ptr] = energies[i] 59 | self.forces[self.ptr] = torch.tensor(forces[i], dtype=torch.float32) 60 | self.ptr = (self.ptr + 1) % self.max_size 61 | self.size = self.size + 1 62 | 63 | self.replay_buffer_full = self.size >= self.max_total_conformations 64 | 65 | def sample(self, batch_size): 66 | new_samples_batch_size = int(batch_size * (1 - self.initial_conf_pct)) 67 | states, forces, energy = self.sample_wo_collate(new_samples_batch_size) 68 | 69 | if self.initial_RB and self.initial_conf_pct: 70 | initial_conf_batch_size = batch_size - new_samples_batch_size 71 | init_states, init_forces, init_energy = self.initial_RB.sample_wo_collate( 72 | initial_conf_batch_size 73 | ) 74 | states = states + init_states 75 | forces = forces + init_forces 76 | energy = torch.cat((energy, init_energy), dim=0) 77 | 78 | state_batch = { 79 | key: value.to(self.device) 80 | for key, value in _atoms_collate_fn(states).items() 81 | } 82 | forces = torch.cat(forces).to(self.device) 83 | energy = energy.to(self.device) 84 | 85 | if self.atomrefs is not None: 86 | # Get system index 87 | idx_m = state_batch[properties.idx_m] 88 | 89 | # Get num molecules in the batch 90 | max_m = int(idx_m[-1]) + 1 91 | 92 | # Get atomization energy for each molecule in the batch 93 | atomization_energy = scatter_add( 94 | self.atomrefs[state_batch[properties.Z]], idx_m, dim_size=max_m 95 | ).unsqueeze(-1) 96 | energy -= atomization_energy 97 | 98 | return state_batch, forces, energy 99 | 100 | def sample_eval(self, batch_size): 101 | states, forces, energy = self.eval_RB.sample_wo_collate(batch_size) 102 | state_batch = { 103 | key: value.to(self.device) 104 | for key, value in _atoms_collate_fn(states).items() 105 | } 106 | forces = torch.cat(forces).to(self.device) 107 | energy = energy.to(self.device) 108 | 109 | if self.atomrefs is not None: 110 | # Get system index 111 | idx_m = state_batch[properties.idx_m] 112 | 113 | # Get num molecules in the batch 114 | max_m = int(idx_m[-1]) + 1 115 | 116 | # Get atomization energy for each molecule in the batch 117 | atomization_energy = scatter_add( 118 | self.atomrefs[state_batch[properties.Z]], idx_m, dim_size=max_m 119 | ).unsqueeze(-1) 120 | energy -= atomization_energy 121 | 122 | return state_batch, forces, energy 123 | 124 | def sample_wo_collate(self, batch_size): 125 | ind = np.random.choice(min(self.size, self.max_size), batch_size, replace=False) 126 | states = [self.states[i] for i in ind] 127 | forces = [self.forces[i] for i in ind] 128 | energy = self.energy[ind] 129 | return states, forces, energy 130 | 131 | 132 | def fill_initial_replay_buffer( 133 | device, db_path, timelimit, num_initial_conformations, atomrefs=None 134 | ): 135 | # Env kwargs 136 | env_kwargs = { 137 | "db_path": db_path, 138 | "n_parallel": 1, 139 | "timelimit": timelimit, 140 | "sample_initial_conformations": False, 141 | "num_initial_conformations": num_initial_conformations, 142 | } 143 | # Initialize env 144 | env = env_fn(**env_kwargs) 145 | if num_initial_conformations == -1: 146 | total_confs = env.get_db_length() 147 | else: 148 | total_confs = num_initial_conformations 149 | 150 | initial_replay_buffer = ReplayBuffer( 151 | device, 152 | max_size=total_confs, 153 | max_total_conformations=total_confs, 154 | atomrefs=atomrefs, 155 | ) 156 | 157 | # Fill up the replay buffer 158 | for _ in range(total_confs): 159 | state = env.reset() 160 | # Save initial state in replay buffer 161 | energies = np.array([env.energy]) 162 | forces = [np.array(force) for force in env.force] 163 | initial_replay_buffer.add(state, forces, energies) 164 | 165 | return initial_replay_buffer 166 | -------------------------------------------------------------------------------- /GOLF/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import schnetpack.nn as snn 3 | import torch 4 | from schnetpack import properties 5 | 6 | from GOLF import DEVICE 7 | from GOLF.optim.lion_pytorch import Lion 8 | 9 | 10 | class LRCosineAnnealing: 11 | def __init__(self, lr, lr_min=1e-5, t_max=1000): 12 | self.lr = lr 13 | self.lr_min = lr_min 14 | self.t_max = t_max 15 | 16 | def get(self, t): 17 | return torch.FloatTensor( 18 | [ 19 | self.lr_min 20 | + 0.5 21 | * (self.lr - self.lr_min) 22 | * (1 + np.cos(min(t_, self.t_max) * np.pi / self.t_max)) 23 | for t_ in t 24 | ] 25 | ).to(DEVICE) 26 | 27 | 28 | class LRConstant: 29 | def __init__(self, lr): 30 | self.lr = lr 31 | 32 | def get(self, t): 33 | return torch.FloatTensor([self.lr for t_ in t]).to(DEVICE) 34 | 35 | 36 | def get_conformation_lr_scheduler(lr_scheduler_type, lr, t_max): 37 | if lr_scheduler_type == "Constant": 38 | return LRConstant(lr) 39 | elif lr_scheduler_type == "CosineAnnealing": 40 | return LRCosineAnnealing(lr, t_max=t_max) 41 | else: 42 | raise ValueError( 43 | "Unknown conformation LR scheduler type: {}".format(lr_scheduler_type) 44 | ) 45 | 46 | 47 | def get_lr_scheduler(scheduler_type, optimizer, **kwargs): 48 | if scheduler_type == "OneCycleLR": 49 | return torch.optim.lr_scheduler.OneCycleLR( 50 | optimizer, 51 | max_lr=10 * kwargs["initial_lr"], 52 | final_div_factor=kwargs["final_div_factor"], 53 | total_steps=kwargs["total_steps"], 54 | last_epoch=kwargs["last_epoch"], 55 | ) 56 | elif scheduler_type == "CosineAnnealing": 57 | return torch.optim.lr_scheduler.CosineAnnealingLR( 58 | optimizer, 59 | T_max=kwargs["total_steps"], 60 | eta_min=kwargs["initial_lr"] / kwargs["final_div_factor"], 61 | last_epoch=kwargs["last_epoch"], 62 | ) 63 | elif scheduler_type == "StepLR": 64 | return torch.optim.lr_scheduler.StepLR( 65 | optimizer, step_size=kwargs["total_steps"] // 3, gamma=kwargs["gamma"] 66 | ) 67 | else: 68 | raise ValueError("Unknown LR scheduler type: {}".format(scheduler_type)) 69 | 70 | 71 | def get_optimizer_class(optimizer_name): 72 | if optimizer_name == "adam": 73 | return torch.optim.Adam 74 | elif optimizer_name == "lion": 75 | return Lion 76 | else: 77 | raise ValueError(f"Unknown optimizer: {optimizer_name}") 78 | 79 | 80 | def recollate_batch(state_batch, indices, new_state_batch): 81 | # Transform state_batch and new_state_batch to lists. 82 | individual_states = unpad_state(state_batch) 83 | new_individual_states = unpad_state(new_state_batch) 84 | 85 | # Replaces some states with new ones and collates them into batch. 86 | for new_idx, idx in enumerate(indices): 87 | individual_states[idx] = new_individual_states[new_idx] 88 | return {k: v.to(DEVICE) for k, v in _atoms_collate_fn(individual_states).items()} 89 | 90 | 91 | def calculate_atoms_in_cutoff(state): 92 | n_atoms = state[properties.n_atoms] 93 | atoms_indices_range = get_atoms_indices_range(state) 94 | indices, counts = torch.unique( 95 | state[properties.idx_i], sorted=False, return_counts=True 96 | ) 97 | n_atoms_expanded = torch.ones_like(indices) 98 | for molecule_id in range(n_atoms.size(0)): 99 | molecule_indices = (atoms_indices_range[molecule_id] <= indices) & ( 100 | indices < atoms_indices_range[molecule_id + 1] 101 | ) 102 | n_atoms_expanded[molecule_indices] = n_atoms[molecule_id] 103 | 104 | return torch.sum(counts / (n_atoms_expanded * n_atoms.size(0))) 105 | 106 | 107 | def calculate_molecule_metrics(state, next_state): 108 | n_atoms = state[properties.n_atoms] 109 | atoms_indices_range = get_atoms_indices_range(state) 110 | assert ( 111 | state[properties.idx_m].size(0) == atoms_indices_range[-1].item() 112 | ), "Assume that all atoms are listed in _idx_m property!" 113 | 114 | min_r, avg_r, max_r = 0, 0, 0 115 | rij = torch.linalg.norm(state[properties.Rij], dim=-1) 116 | for molecule_id in range(n_atoms.size(0)): 117 | molecule_indices = ( 118 | atoms_indices_range[molecule_id] <= state[properties.idx_i] 119 | ) & (state[properties.idx_j] < atoms_indices_range[molecule_id + 1]) 120 | current_molecule_r = rij[molecule_indices] 121 | min_r += current_molecule_r.min() 122 | avg_r += current_molecule_r.mean() 123 | max_r += current_molecule_r.max() 124 | 125 | min_r, avg_r, max_r = ( 126 | min_r / n_atoms.size(0), 127 | avg_r / n_atoms.size(0), 128 | max_r / n_atoms.size(0), 129 | ) 130 | avg_atoms_in_cutoff_before = calculate_atoms_in_cutoff(state) 131 | avg_atoms_in_cutoff_after = calculate_atoms_in_cutoff(next_state) 132 | 133 | metrics = { 134 | "Molecule/min_interatomic_dist": min_r.item(), 135 | "Molecule/avg_interatomic_dist": avg_r.item(), 136 | "Molecule/max_interatomic_dist": max_r.item(), 137 | "Molecule/avg_atoms_inside_cutoff_state": avg_atoms_in_cutoff_before.item(), 138 | "Molecule/avg_atoms_inside_cutoff_next_state": avg_atoms_in_cutoff_after.item(), 139 | } 140 | 141 | return metrics 142 | 143 | 144 | def calculate_gradient_norm(model): 145 | total_norm = 0.0 146 | params = [p for p in model.parameters() if p.grad is not None and p.requires_grad] 147 | for p in params: 148 | param_norm = p.grad.detach().data.norm(2) 149 | total_norm += param_norm**2 150 | total_norm = total_norm ** (0.5) 151 | return total_norm 152 | 153 | 154 | def calculate_action_norm(actions, cumsum_numbers_atoms): 155 | actions_norm = np.linalg.norm(actions, axis=1) 156 | mean_norm = 0 157 | for idx in range(len(cumsum_numbers_atoms) - 1): 158 | mean_norm += actions_norm[ 159 | cumsum_numbers_atoms[idx] : cumsum_numbers_atoms[idx + 1] 160 | ].mean() 161 | 162 | return mean_norm / (len(cumsum_numbers_atoms) - 1) 163 | 164 | 165 | def get_cutoff_by_string(cutoff_type): 166 | if cutoff_type == "cosine": 167 | return snn.cutoff.CosineCutoff 168 | 169 | raise ValueError(f"Unexpected cutoff type:{cutoff_type}") 170 | 171 | 172 | def get_radial_basis_by_string(radial_basis_type): 173 | if radial_basis_type == "Bessel": 174 | return snn.BesselRBF 175 | elif radial_basis_type == "Gaussian": 176 | return snn.GaussianRBF 177 | 178 | raise ValueError(f"Unexpected radial basis type:{radial_basis_type}") 179 | 180 | 181 | def get_atoms_indices_range(states): 182 | return torch.nn.functional.pad( 183 | torch.cumsum(states[properties.n_atoms], dim=0), pad=(1, 0) 184 | ) 185 | 186 | 187 | def unpad_state(states): 188 | individual_states = [] 189 | n_molecules = states[properties.n_atoms].size(0) 190 | n_atoms = get_atoms_indices_range(states) 191 | for i in range(n_molecules): 192 | state = { 193 | properties.n_atoms: torch.unsqueeze(n_atoms[i + 1] - n_atoms[i], dim=0) 194 | .clone() 195 | .cpu() 196 | } 197 | for key in (properties.Z, properties.position): 198 | state[key] = states[key][n_atoms[i] : n_atoms[i + 1]].clone().cpu() 199 | 200 | for key in (properties.cell, properties.pbc, properties.idx): 201 | state[key] = states[key][i].unsqueeze(0).clone().cpu() 202 | 203 | assert ( 204 | states[properties.idx_m].size(0) == n_atoms[-1].item() 205 | ), "Assume that all atoms are listed in _idx_m property!" 206 | molecule_indices = (n_atoms[i] <= states[properties.idx_i]) & ( 207 | states[properties.idx_i] < n_atoms[i + 1] 208 | ) 209 | for key in (properties.lidx_i, properties.lidx_j, properties.offsets): 210 | state[key] = states[key][molecule_indices].clone().cpu() 211 | 212 | for key in (properties.idx_i, properties.idx_j): 213 | state[key] = state[f"{key}_local"].clone().cpu() 214 | 215 | state[properties.idx_m] = torch.zeros_like(state[properties.Z]) 216 | 217 | individual_states.append(state) 218 | 219 | return individual_states 220 | 221 | 222 | def _atoms_collate_fn(batch): 223 | """ 224 | Build batch from systems and properties & apply padding 225 | 226 | Args: 227 | examples (list): 228 | 229 | Returns: 230 | dict[str->torch.Tensor]: mini-batch of atomistic systems 231 | """ 232 | elem = batch[0] 233 | idx_keys = {properties.idx_i, properties.idx_j, properties.idx_i_triples} 234 | # Atom triple indices must be treated separately 235 | idx_triple_keys = {properties.idx_j_triples, properties.idx_k_triples} 236 | 237 | coll_batch = {} 238 | for key in elem: 239 | if (key not in idx_keys) and (key not in idx_triple_keys): 240 | coll_batch[key] = torch.cat([d[key] for d in batch], 0) 241 | elif key in idx_keys: 242 | coll_batch[key + "_local"] = torch.cat([d[key] for d in batch], 0) 243 | 244 | seg_m = torch.cumsum(coll_batch[properties.n_atoms], dim=0) 245 | seg_m = torch.cat( 246 | [torch.zeros((1,), dtype=seg_m.dtype, device=seg_m.device), seg_m], dim=0 247 | ) 248 | idx_m = torch.repeat_interleave( 249 | torch.arange(len(batch), device=seg_m.device), 250 | repeats=coll_batch[properties.n_atoms], 251 | dim=0, 252 | ) 253 | coll_batch[properties.idx_m] = idx_m 254 | 255 | for key in idx_keys: 256 | if key in elem.keys(): 257 | coll_batch[key] = torch.cat( 258 | [d[key] + off for d, off in zip(batch, seg_m)], 0 259 | ) 260 | 261 | # Shift the indices for the atom triples 262 | for key in idx_triple_keys: 263 | if key in elem.keys(): 264 | indices = [] 265 | offset = 0 266 | for idx, d in enumerate(batch): 267 | indices.append(d[key] + offset) 268 | offset += d[properties.idx_j].shape[0] 269 | coll_batch[key] = torch.cat(indices, 0) 270 | 271 | return coll_batch 272 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AIRI - Artificial Intelligence Research Institute 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (ICLR2024 Poster) Gradual Optimization Learning for Conformational Energy Minimization 2 | 3 |

4 | Code style: black ICLR poster page Openreview Paper URL 5 |

6 | 7 | This repository is the official implementation of the paper: 8 | > Tsypin, A., Ugadiarov, L. A., Khrabrov, K., Telepov, A., Rumiantsev, E., Skrynnik, A., ... & Kadurin, A. (2023, October).
9 | > **Gradual Optimization Learning for Conformational Energy Minimization.**
10 | > In The Twelfth International Conference on Learning Representations. 11 | 12 | **Experiments and results on the [SPICE](https://www.nature.com/articles/s41597-022-01882-6) dataset can be found in the "GOLF-SPICE" branch.** 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 |
Model$\overline{\text{pct}}_T (\%) \uparrow$$\text{pct}_{\text{div}} (\%) \downarrow$$\overline{E^{\text{res}}}_T\tiny{\text{(kc/mol)}}\downarrow$$\text{pct}_{\text{success}} (\%) \uparrow$$\text{COV} (\%) \uparrow$ $\text{MAT} (\text{Å}) \downarrow $
RDKit$84.92 \pm 10.6$$\mathbf{0.05}$$5.5$$4.1$$62.24$$0.509$
Torsional Diffusion$25.63 \pm 21.4$$46.9$$33.8$$0.0$$11.3$$1.333$
ConfOpt$36.48 \pm 23.0$$84.5$$27.9$$0.2$$19.88$$1.05$
Uni-Mol+$62.20 \pm 17.2$$2.8$$18.6$$0.2$$68.79$$0.407$
$f^{\text{baseline}}$$76.8 \pm 22.4$$7.5$$8.6$$8.2$$65.22$$0.482$
$f^{\text{rdkit}}$$93.09 \pm 11.9$$3.8$$2.8$$35.4$$71.6$$0.426$
$f^{\text{traj-10k}}$$95.3 \pm 7.3$ $4.5$$2.0$$37.0$$70.55$$0.440$
$f^{\text{traj-100k}}$$96.3 \pm 9.8$$2.9$$1.5$$52.7$$71.43$$0.432$
$f^{\text{traj-500k}}$$98.4 \pm 9.2$$1.8$$\mathbf{0.5}$$73.4$$72.15$$0.442$
$f^{\text{GOLF-1k}}$$98.5 \pm 5.3$$3.6$$1.1$$62.9$$76.54$$\mathbf{0.349}$
$f^{\text{GOLF-10k}}$$\mathbf{99.4 \pm 5.2}$$2.4$$\mathbf{0.5}$$\mathbf{77.3}$$\mathbf{76.84}$$0.355$
128 | 129 | ## Training the NNP baseline 130 | 1. Set up environment on the GPU machine. 131 | ``` 132 | # On the GPU machine 133 | ./scripts/setup_gpu_env.sh 134 | conda activate GOLF_schnetpack 135 | pip install -r requirements.txt 136 | ``` 137 | 2. Download training dataset $\mathcal{D}_0$ and evaluation dataset $\mathcal{D}\_{\text{test}}$ 138 | ``` 139 | mkdir data && cd data 140 | wget https://sc.link/FpEvS -O D-0.db 141 | wget https://sc.link/W6RUA -O D-test.db 142 | cd ../ 143 | ``` 144 | 4. Train baseline PaiNN model 145 | ``` 146 | cd scripts/train 147 | ./run_training_baseline.sh 148 | ``` 149 | Running this script will create a folder in the specified `log_dir` directory (we use "./results" in our configs and scripts). The name of the folder is specified by the `exp_name` hyperparameter. The folder will contain checkpoints, a metrics file and a config file with hyperparameters. 150 | 151 | ## Training the NNP on optimization trajectories 152 | 1. Set up environment on the GPU machine like in [the first section](#training-the-nnp-baseline) 153 | 2. Download optimization trajectories datasets. 154 | ``` 155 | cd data 156 | wget https://sc.link/ZQRiV -O D-traj-10k.db 157 | wget https://sc.link/Z0ebo -O D-traj-100k.db 158 | wget https://sc.link/hj1JX -O D-traj-500k.db 159 | cd ../ 160 | ``` 161 | 3. Train PaiNN. 162 | ``` 163 | cd scripts/train 164 | ./run_training_trajectories-10k.sh 165 | ./run_training_trajectories-100k.sh 166 | ./run_training_trajectories-500k.sh 167 | ``` 168 | 169 | ## Training NNPs with GOLF 170 | 171 | ### Distributed Gradient Calculation with Psi4 172 | To speed up the training, we parallelize DFT computations using several CPU-rich machines. The training of the NNP takes place on the parent machine with a GPU. 173 | 1. Set up environment on the GPU machine like in [the first section](#training-the-nnp-baseline) 174 | 1. Log in to CPU-rich machines. They must be accessible via `ssh`. 175 | 2. Set up environments on CPU-rich machines. 176 | ``` 177 | # On CPU rich machines 178 | git clone https://github.com/AIRI-Institute/GOLF 179 | cd GOLF/scripts 180 | ./setup_dft_workers.sh 181 | ``` 182 | Here, `n_ports` denotes number of workers on a CPU-rich machine, and `ports_range_begin` denotes the starting port numbers for workers. Workers calculate energies and forces using `psi4` for newly generated conformations. For example, `./setup_host.sh 24 20000` will launch a total of 48 workers listening to ports `20000, ... , 20023`. You can change the `ports_range_begin` in `env/dft.py`. 183 | 184 | By default we assume that each worker uses 4 CPU-cores (can be changed in `env/dft_worker.py`, line 22) which means that `n_ports` must be less or equal to `total_cpu_number / 4`. 185 | 4. Add ip addresses of CPU rich machines to a text file. We use `env/host_names.txt`. 186 | 187 | ### Training with GOLF 188 | Train PaiNN with GOLF. 189 | ``` 190 | cd scripts/train 191 | ./run_training_GOLF-10k.sh 192 | ``` 193 | 194 | ## Evaluating NNPs 195 | The evaluation can be done with or without `psi4` energy estimation for NNP-optimization trajectories. The argument 'eval_early_stop_steps' controls for which conformations in the optimization trajectory to evaluate energy/forces with `psi4`. For example, setting `eval_early_stop_steps` to an empty list will result in no additional `psi4` energy estimations, and setting it to `[1 2 3 5 8 13 21 30 50 75 100]` will result in 13 additional energy evaluations for each conformation in evaluation dataset. Note that in order to compute the $\overline{pct}_T$, optimal energies obtained with the genuine oracle $\mathcal{O}$ must be available. In our work, `psi4.optimize` with spherical representation of the molecule was used (approximately 30 steps until convergence). 196 | 197 | In this repo, we provide NNPs pre-trained on different datasets and with GOLF in the `checkpoints` directory: 198 | - $f^{\text{baseline}}$ (`checkpoints/baseline-NNP/NNP_checkpoint`) 199 | - $f^{\text{traj-10k}}$ (`checkpoints/traj-10k/NNP_checkpoint`) 200 | - $f^{\text{traj-100k}}$ (`checkpoints/traj-100k/NNP_checkpoint`) 201 | - $f^{\text{traj-500k}}$ (`checkpoints/trak-500k/NNP_checkpoint`) 202 | - $f^{\text{GOLF-1k}}$ (`checkpoints/GOLF-1k/NNP_checkpoint`) 203 | - $f^{\text{GOLF-10k}}$ (`checkpoints/GOLF-10k/NNP_checkpoint`) 204 | 205 | For example, to evaluate GOLF-10k and additionally calculate `psi4` energies/forces along the optimization trajectory, run: 206 | ``` 207 | python evaluate_batch_dft.py --checkpoint_path checkpoints/GOLF-10k --agent_path NNP_checkpoint_actor --n_parallel 240 --n_threads 24 --conf_number -1 --host_file_path env/host_names.txt --eval_db_path data/GOLF_test.db --timelimit 100 --terminate_on_negative_reward False --reward dft --minimize_on_every_step False --eval_early_stop_steps 1 2 3 5 8 13 21 30 50 75 100 208 | ``` 209 | Make sure that `n_threads` is equal to the number of workers on each CPU-rich machine. Setting `n_threads` to a larger number will result in optimization failures. If you wish to only evaluate the last state in each optimization trajectory, set `timelimit` and `eval_early_stop_steps` to the same number: `--timelimit T --eval_early_stop_steps T`. 210 | 211 | After the evaluation is finished, an `evaluation_metrics.json` file with per-step metrics will be created. Each record in `evaluation_metrics.json` describes optimization statistics for a single conformation and contains such metrics as: forces/energies MSE, percentage of optimized energy, predicted and ground-truth energies, etc. The final NNP-optimized conformations are stored in `results.db` database. 212 | 213 | ## Citation 214 | To cite this work, please use: 215 | ``` 216 | @inproceedings{tsypin2023gradual, 217 | title={Gradual Optimization Learning for Conformational Energy Minimization}, 218 | author={Tsypin, Artem and Ugadiarov, Leonid Anatolievich and Khrabrov, Kuzma and Telepov, Alexander and Rumiantsev, Egor and Skrynnik, Alexey and Panov, Aleksandr and Vetrov, Dmitry P and Tutubalina, Elena and Kadurin, Artur}, 219 | booktitle={The Twelfth International Conference on Learning Representations}, 220 | year={2023} 221 | } 222 | ``` 223 | -------------------------------------------------------------------------------- /checkpoints/GOLF-10k/NNP_checkpoint_actor: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIRI-Institute/GOLF/deaaab858b9034b9d7a8359d9cdc61618f484334/checkpoints/GOLF-10k/NNP_checkpoint_actor -------------------------------------------------------------------------------- /checkpoints/GOLF-10k/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_parallel": 120, 3 | "n_threads": 24, 4 | "db_path": "../../data/GOLF_train.db", 5 | "eval_db_path": "../../data/GOLF_test.db", 6 | "num_initial_conformations": -1, 7 | "sample_initial_conformations": true, 8 | "timelimit_train": 100, 9 | "timelimit_eval": 50, 10 | "terminate_on_negative_reward": true, 11 | "max_num_negative_rewards": 1, 12 | "reward": "dft", 13 | "minimize_on_every_step": true, 14 | "backbone": "painn", 15 | "n_interactions": 3, 16 | "cutoff": 5.0, 17 | "n_rbf": 50, 18 | "n_atom_basis": 128, 19 | "actor": "GOLF", 20 | "conformation_optimizer": "LBFGS", 21 | "conf_opt_lr": 1.0, 22 | "conf_opt_lr_scheduler": "Constant", 23 | "experience_saver": "reward_threshold", 24 | "store_only_initial_conformations": false, 25 | "max_iter": 5, 26 | "lbfgs_device": "cpu", 27 | "momentum": 0.0, 28 | "lion_beta1": 0.9, 29 | "lion_beta2": 0.99, 30 | "batch_size": 64, 31 | "lr": 0.0001, 32 | "optimizer": "adam", 33 | "lr_scheduler": "CosineAnnealing", 34 | "clip_value": "1.0", 35 | "energy_loss_coef": 0.01, 36 | "force_loss_coef": 0.99, 37 | "initial_conf_pct": 0.1, 38 | "max_oracle_steps": 10000, 39 | "replay_buffer_size": 1000000, 40 | "utd_ratio": 50, 41 | "subtract_atomization_energy": true, 42 | "action_norm_limit": 1.0, 43 | "eval_freq": 120, 44 | "n_eval_runs": 64, 45 | "eval_termination_mode": "fixed_length", 46 | "grad_threshold": 1e-05, 47 | "exp_name": "GOLF-10k", 48 | "host_file_path": "../../env/host_names.txt", 49 | "seed": 235361, 50 | "full_checkpoint_freq": 600, 51 | "light_checkpoint_freq": 1200, 52 | "save_checkpoints": true, 53 | "load_baseline": "../../checkpoints/baseline-NNP/NNP_checkpoint", 54 | "load_model": null, 55 | "log_dir": "../../results", 56 | "run_id": "run-0", 57 | "env": "GOLF_train" 58 | } -------------------------------------------------------------------------------- /checkpoints/GOLF-1k/NNP_checkpoint_actor: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIRI-Institute/GOLF/deaaab858b9034b9d7a8359d9cdc61618f484334/checkpoints/GOLF-1k/NNP_checkpoint_actor -------------------------------------------------------------------------------- /checkpoints/GOLF-1k/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_parallel": 48, 3 | "n_threads": 24, 4 | "db_path": "../../data/GOLF_train.db", 5 | "eval_db_path": "../../data/GOLF_test.db", 6 | "num_initial_conformations": -1, 7 | "sample_initial_conformations": true, 8 | "timelimit_train": 100, 9 | "timelimit_eval": 50, 10 | "terminate_on_negative_reward": true, 11 | "max_num_negative_rewards": 1, 12 | "reward": "dft", 13 | "minimize_on_every_step": true, 14 | "backbone": "painn", 15 | "n_interactions": 3, 16 | "cutoff": 5.0, 17 | "n_rbf": 50, 18 | "n_atom_basis": 128, 19 | "actor": "GOLF", 20 | "conformation_optimizer": "LBFGS", 21 | "conf_opt_lr": 1.0, 22 | "conf_opt_lr_scheduler": "Constant", 23 | "experience_saver": "reward_threshold", 24 | "store_only_initial_conformations": false, 25 | "max_iter": 5, 26 | "lbfgs_device": "cpu", 27 | "momentum": 0.0, 28 | "lion_beta1": 0.9, 29 | "lion_beta2": 0.99, 30 | "batch_size": 64, 31 | "lr": 0.0001, 32 | "optimizer": "adam", 33 | "lr_scheduler": "CosineAnnealing", 34 | "clip_value": "1.0", 35 | "energy_loss_coef": 0.01, 36 | "force_loss_coef": 0.99, 37 | "initial_conf_pct": 0.1, 38 | "max_oracle_steps": 1000, 39 | "replay_buffer_size": 1000000, 40 | "utd_ratio": 500, 41 | "subtract_atomization_energy": true, 42 | "action_norm_limit": 1.0, 43 | "eval_freq": 48, 44 | "n_eval_runs": 48, 45 | "eval_termination_mode": "fixed_length", 46 | "grad_threshold": 1e-05, 47 | "exp_name": "GOLF-1k", 48 | "host_file_path": "../../env/host_names.txt", 49 | "seed": 958582, 50 | "full_checkpoint_freq": 96, 51 | "light_checkpoint_freq": 192, 52 | "save_checkpoints": true, 53 | "load_baseline": "../../checkpoints/baseline-NNP/NNP_checkpoint", 54 | "load_model": null, 55 | "log_dir": "../../results", 56 | "run_id": "run-0", 57 | "env": "GOLF_train" 58 | } -------------------------------------------------------------------------------- /checkpoints/baseline-NNP/NNP_checkpoint_actor: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIRI-Institute/GOLF/deaaab858b9034b9d7a8359d9cdc61618f484334/checkpoints/baseline-NNP/NNP_checkpoint_actor -------------------------------------------------------------------------------- /checkpoints/baseline-NNP/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_parallel": 240, 3 | "n_threads": 24, 4 | "db_path": "../../data/GOLF_train.db", 5 | "eval_db_path": "../../data/GOLF_test.db", 6 | "num_initial_conformations": -1, 7 | "sample_initial_conformations": true, 8 | "timelimit_train": 1, 9 | "timelimit_eval": 50, 10 | "terminate_on_negative_reward": true, 11 | "max_num_negative_rewards": 1, 12 | "reward": "dft", 13 | "minimize_on_every_step": true, 14 | "backbone": "painn", 15 | "n_interactions": 3, 16 | "cutoff": 5.0, 17 | "n_rbf": 50, 18 | "n_atom_basis": 128, 19 | "actor": "GOLF", 20 | "conformation_optimizer": "LBFGS", 21 | "conf_opt_lr": 1.0, 22 | "conf_opt_lr_scheduler": "Constant", 23 | "experience_saver": "reward_threshold", 24 | "store_only_initial_conformations": true, 25 | "max_iter": 5, 26 | "lbfgs_device": "cpu", 27 | "momentum": 0.0, 28 | "lion_beta1": 0.9, 29 | "lion_beta2": 0.99, 30 | "batch_size": 64, 31 | "lr": 0.0001, 32 | "optimizer": "adam", 33 | "lr_scheduler": "CosineAnnealing", 34 | "clip_value": "1.0", 35 | "energy_loss_coef": 0.01, 36 | "force_loss_coef": 0.99, 37 | "initial_conf_pct": 1.0, 38 | "max_oracle_steps": 100000, 39 | "replay_buffer_size": 1000000, 40 | "utd_ratio": 5, 41 | "subtract_atomization_energy": true, 42 | "action_norm_limit": 1.0, 43 | "eval_freq": 1200, 44 | "n_eval_runs": 64, 45 | "eval_termination_mode": "fixed_length", 46 | "grad_threshold": 1e-05, 47 | "exp_name": "baseline-NNP", 48 | "host_file_path": null, 49 | "seed": 402079, 50 | "full_checkpoint_freq": 10000, 51 | "light_checkpoint_freq": 50000, 52 | "save_checkpoints": true, 53 | "load_baseline": null, 54 | "load_model": null, 55 | "log_dir": "../../results", 56 | "run_id": "run-0", 57 | "env": "GOLF_train" 58 | } -------------------------------------------------------------------------------- /checkpoints/traj-100k/NNP_checkpoint_actor: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIRI-Institute/GOLF/deaaab858b9034b9d7a8359d9cdc61618f484334/checkpoints/traj-100k/NNP_checkpoint_actor -------------------------------------------------------------------------------- /checkpoints/traj-100k/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_parallel": 240, 3 | "n_threads": 24, 4 | "db_path": "../../data/traj-100k.db", 5 | "eval_db_path": "../../GOLF_test.db", 6 | "num_initial_conformations": -1, 7 | "sample_initial_conformations": true, 8 | "timelimit_train": 1, 9 | "timelimit_eval": 50, 10 | "terminate_on_negative_reward": true, 11 | "max_num_negative_rewards": 1, 12 | "reward": "dft", 13 | "minimize_on_every_step": true, 14 | "backbone": "painn", 15 | "n_interactions": 3, 16 | "cutoff": 5.0, 17 | "n_rbf": 50, 18 | "n_atom_basis": 128, 19 | "actor": "GOLF", 20 | "conformation_optimizer": "LBFGS", 21 | "conf_opt_lr": 1.0, 22 | "conf_opt_lr_scheduler": "Constant", 23 | "experience_saver": "reward_threshold", 24 | "store_only_initial_conformations": true, 25 | "max_iter": 5, 26 | "lbfgs_device": "cpu", 27 | "momentum": 0.0, 28 | "lion_beta1": 0.9, 29 | "lion_beta2": 0.99, 30 | "batch_size": 64, 31 | "lr": 0.0001, 32 | "optimizer": "adam", 33 | "lr_scheduler": "CosineAnnealing", 34 | "clip_value": "1.0", 35 | "energy_loss_coef": 0.01, 36 | "force_loss_coef": 0.99, 37 | "initial_conf_pct": 1.0, 38 | "max_oracle_steps": 100000, 39 | "replay_buffer_size": 1000000, 40 | "utd_ratio": 5, 41 | "subtract_atomization_energy": true, 42 | "action_norm_limit": 1.0, 43 | "eval_freq": 1200, 44 | "n_eval_runs": 64, 45 | "eval_termination_mode": "fixed_length", 46 | "grad_threshold": 1e-05, 47 | "exp_name": "traj-100k", 48 | "host_file_path": null, 49 | "seed": 797164, 50 | "full_checkpoint_freq": 10000, 51 | "light_checkpoint_freq": 50000, 52 | "save_checkpoints": true, 53 | "load_baseline": "../../checkpoints/baseline-NNP/NNP_checkpoint", 54 | "load_model": null, 55 | "log_dir": "../../results", 56 | "run_id": "run-0", 57 | "env": "traj-100k" 58 | } -------------------------------------------------------------------------------- /checkpoints/traj-10k/NNP_checkpoint_actor: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIRI-Institute/GOLF/deaaab858b9034b9d7a8359d9cdc61618f484334/checkpoints/traj-10k/NNP_checkpoint_actor -------------------------------------------------------------------------------- /checkpoints/traj-10k/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_parallel": 240, 3 | "n_threads": 24, 4 | "db_path": "../../data/traj-10k.db", 5 | "eval_db_path": "../../data/GOLF_test.db", 6 | "num_initial_conformations": -1, 7 | "sample_initial_conformations": true, 8 | "timelimit_train": 1, 9 | "timelimit_eval": 50, 10 | "terminate_on_negative_reward": true, 11 | "max_num_negative_rewards": 1, 12 | "reward": "dft", 13 | "minimize_on_every_step": true, 14 | "backbone": "painn", 15 | "n_interactions": 3, 16 | "cutoff": 5.0, 17 | "n_rbf": 50, 18 | "n_atom_basis": 128, 19 | "actor": "GOLF", 20 | "conformation_optimizer": "LBFGS", 21 | "conf_opt_lr": 1.0, 22 | "conf_opt_lr_scheduler": "Constant", 23 | "experience_saver": "reward_threshold", 24 | "store_only_initial_conformations": true, 25 | "max_iter": 5, 26 | "lbfgs_device": "cpu", 27 | "momentum": 0.0, 28 | "lion_beta1": 0.9, 29 | "lion_beta2": 0.99, 30 | "batch_size": 64, 31 | "lr": 0.0001, 32 | "optimizer": "adam", 33 | "lr_scheduler": "CosineAnnealing", 34 | "clip_value": "1.0", 35 | "energy_loss_coef": 0.01, 36 | "force_loss_coef": 0.99, 37 | "initial_conf_pct": 1.0, 38 | "max_oracle_steps": 100000, 39 | "replay_buffer_size": 1000000, 40 | "utd_ratio": 5, 41 | "subtract_atomization_energy": true, 42 | "action_norm_limit": 1.0, 43 | "eval_freq": 1200, 44 | "n_eval_runs": 64, 45 | "eval_termination_mode": "fixed_length", 46 | "grad_threshold": 1e-05, 47 | "exp_name": "traj-10k", 48 | "host_file_path": null, 49 | "seed": 729972, 50 | "full_checkpoint_freq": 10000, 51 | "light_checkpoint_freq": 50000, 52 | "save_checkpoints": true, 53 | "load_baseline": "../../checkpoints/baseline-NNP/NNP_checkpoint", 54 | "load_model": null, 55 | "log_dir": "../../results", 56 | "run_id": "run-0", 57 | "env": "traj-10k" 58 | } -------------------------------------------------------------------------------- /checkpoints/traj-500k/NNP_checkpoint_actor: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIRI-Institute/GOLF/deaaab858b9034b9d7a8359d9cdc61618f484334/checkpoints/traj-500k/NNP_checkpoint_actor -------------------------------------------------------------------------------- /checkpoints/traj-500k/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_parallel": 240, 3 | "n_threads": 24, 4 | "db_path": "../../data/traj-500k.db", 5 | "eval_db_path": "../../data/GOLF_test.db", 6 | "num_initial_conformations": -1, 7 | "sample_initial_conformations": true, 8 | "timelimit_train": 1, 9 | "timelimit_eval": 50, 10 | "terminate_on_negative_reward": true, 11 | "max_num_negative_rewards": 1, 12 | "reward": "dft", 13 | "minimize_on_every_step": true, 14 | "backbone": "painn", 15 | "n_interactions": 3, 16 | "cutoff": 5.0, 17 | "n_rbf": 50, 18 | "n_atom_basis": 128, 19 | "actor": "GOLF", 20 | "conformation_optimizer": "LBFGS", 21 | "conf_opt_lr": 1.0, 22 | "conf_opt_lr_scheduler": "Constant", 23 | "experience_saver": "reward_threshold", 24 | "store_only_initial_conformations": true, 25 | "max_iter": 5, 26 | "lbfgs_device": "cpu", 27 | "momentum": 0.0, 28 | "lion_beta1": 0.9, 29 | "lion_beta2": 0.99, 30 | "batch_size": 64, 31 | "lr": 0.0001, 32 | "optimizer": "adam", 33 | "lr_scheduler": "CosineAnnealing", 34 | "clip_value": "1.0", 35 | "energy_loss_coef": 0.01, 36 | "force_loss_coef": 0.99, 37 | "initial_conf_pct": 1.0, 38 | "max_oracle_steps": 200000, 39 | "replay_buffer_size": 1000000, 40 | "utd_ratio": 5, 41 | "subtract_atomization_energy": true, 42 | "action_norm_limit": 1.0, 43 | "eval_freq": 1200, 44 | "n_eval_runs": 64, 45 | "eval_termination_mode": "fixed_length", 46 | "grad_threshold": 1e-05, 47 | "exp_name": "traj-500k", 48 | "host_file_path": null, 49 | "seed": 720226, 50 | "full_checkpoint_freq": 10000, 51 | "light_checkpoint_freq": 50000, 52 | "save_checkpoints": true, 53 | "load_baseline": "../../checkpoints/baseline-NNP/NNP_checkpoint", 54 | "load_model": null, 55 | "log_dir": "../../results", 56 | "run_id": "run-0", 57 | "env": "traj-500k" 58 | } -------------------------------------------------------------------------------- /env/dft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import socket 4 | import struct 5 | import subprocess 6 | import tempfile 7 | import traceback 8 | from datetime import datetime 9 | 10 | PORT_RANGE_BEGIN_TRAIN = 20000 11 | PORT_RANGE_BEGIN_EVAL = 30000 12 | HOSTS = [ 13 | "192.168.19.21", 14 | "192.168.19.22", 15 | "192.168.19.23", 16 | "192.168.19.24", 17 | "192.168.19.25", 18 | "192.168.19.26", 19 | "192.168.19.27", 20 | "192.168.19.28", 21 | "192.168.19.29", 22 | "192.168.19.30", 23 | ] 24 | 25 | 26 | def recvall(sock, count): 27 | buf = b"" 28 | while count: 29 | newbuf = sock.recv(count) 30 | if not newbuf: 31 | return newbuf 32 | buf += newbuf 33 | count -= len(newbuf) 34 | return buf 35 | 36 | 37 | def send_one_message(sock, data): 38 | length = len(data) 39 | sock.sendall(struct.pack("!I", length)) 40 | sock.sendall(data) 41 | 42 | 43 | def recv_one_message(sock): 44 | buf = recvall(sock, 4) 45 | if not buf: 46 | return buf 47 | (length,) = struct.unpack("!I", buf) 48 | return recvall(sock, length) 49 | 50 | 51 | def log(conformation_id, message, path, logging): 52 | if logging: 53 | with open(path, "a") as file_obj: 54 | print( 55 | f"{get_time()} - conformation_id={conformation_id} {message}", 56 | file=file_obj, 57 | ) 58 | 59 | 60 | def calculate_dft_energy_tcp_client(task, host, port, logging=False): 61 | path = f"client_{host}_{port}.out" 62 | conformation_id, step, ase_atoms = task 63 | try: 64 | log(conformation_id, "going to connect", path, logging) 65 | sock = socket.socket() 66 | sock.connect((host, port)) 67 | log(conformation_id, "connected", path, logging) 68 | 69 | ase_atoms = ase_atoms.todict() 70 | 71 | task = (ase_atoms, step, conformation_id) 72 | task = pickle.dumps(task) 73 | send_one_message(sock, task) 74 | log(conformation_id, "send one message", path, logging) 75 | result = recv_one_message(sock) 76 | log(conformation_id, "received response", path, logging) 77 | idx, not_converged, energy, force = pickle.loads(result) 78 | assert conformation_id == idx 79 | 80 | return conformation_id, step, energy, force 81 | except Exception as e: 82 | description = traceback.format_exc() 83 | log(conformation_id, description, path, logging) 84 | return conformation_id, step, None, None 85 | 86 | 87 | def get_dft_server_destinations(n_workers, host_file_path=None): 88 | if host_file_path: 89 | with open(host_file_path, "r") as f: 90 | hosts = f.readlines() 91 | else: 92 | hosts = HOSTS 93 | port_range_begin = PORT_RANGE_BEGIN_TRAIN 94 | destinations = [] 95 | for host in hosts: 96 | for port in range(port_range_begin, port_range_begin + n_workers): 97 | destinations.append((host, port)) 98 | 99 | return destinations 100 | 101 | 102 | # Get correct hostname 103 | def get_ip(): 104 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 105 | s.settimeout(0) 106 | try: 107 | # doesn't even have to be reachable 108 | s.connect(("10.254.254.254", 1)) 109 | IP = s.getsockname()[0] 110 | except Exception: 111 | IP = "127.0.0.1" 112 | finally: 113 | s.close() 114 | return IP 115 | 116 | 117 | def get_time(): 118 | return datetime.now().strftime("%H:%M:%S") 119 | 120 | 121 | if __name__ == "__main__": 122 | import sys 123 | 124 | host = get_ip() 125 | num_threads = sys.argv[1] 126 | port = int(sys.argv[2]) 127 | 128 | if len(sys.argv) >= 4: 129 | timeout_seconds = sys.argv[3] 130 | else: 131 | timeout_seconds = 600 132 | 133 | if len(sys.argv) >= 5: 134 | dft_script_path = sys.argv[4] 135 | else: 136 | dir_path = os.path.dirname(os.path.abspath(sys.argv[0])) 137 | dft_script_path = os.path.join(dir_path, "dft_worker.py") 138 | 139 | server_socket = socket.socket() 140 | server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 141 | server_socket.bind((host, port)) 142 | 143 | server_socket.listen(1) 144 | 145 | print(port, "accept") 146 | conn, address = server_socket.accept() 147 | total_processed = 0 148 | 149 | while True: 150 | data = recv_one_message(conn) 151 | while len(data) == 0: 152 | print(port, "connection lost, accept") 153 | conn, address = server_socket.accept() 154 | data = recv_one_message(conn) 155 | 156 | print( 157 | f"{get_time()} -", port, "recv", len(data), "bytes from", address, end=" " 158 | ) 159 | task = pickle.loads(data) 160 | conformation_id = task[2] 161 | print("conformation_id", conformation_id, flush=True) 162 | 163 | with tempfile.NamedTemporaryFile(mode="wb", delete=False) as file_obj: 164 | pickle.dump(task, file_obj) 165 | task_path = file_obj.name 166 | 167 | with tempfile.NamedTemporaryFile(mode="wb", delete=False) as file_obj: 168 | result_path = file_obj.name 169 | 170 | result = conformation_id, True, None, None 171 | try: 172 | completed_process = subprocess.run( 173 | [ 174 | sys.executable, 175 | dft_script_path, 176 | "--task_path", 177 | task_path, 178 | "--result_path", 179 | result_path, 180 | "--num_threads", 181 | num_threads, 182 | ], 183 | stdout=subprocess.PIPE, 184 | stderr=subprocess.STDOUT, 185 | text=True, 186 | timeout=timeout_seconds, 187 | ) 188 | 189 | print( 190 | f"returncode={completed_process.returncode}\nWorker stdout:\n{completed_process.stdout}", 191 | flush=True, 192 | ) 193 | if completed_process.returncode == 0: 194 | with open(result_path, "rb") as file_obj: 195 | result = pickle.load(file_obj) 196 | except subprocess.TimeoutExpired as e: 197 | print(e) 198 | if e.stdout is not None: 199 | print(f'Worker stdout:\n{e.stdout.decode("utf-8")}', flush=True) 200 | if e.stderr is not None: 201 | print(f'Worker stderr:\n{e.stderr.decode("utf-8")}', flush=True) 202 | 203 | os.remove(task_path) 204 | os.remove(result_path) 205 | 206 | result = pickle.dumps(result) 207 | 208 | total_processed += 1 209 | print( 210 | f"{get_time()} -", 211 | port, 212 | "going to send", 213 | len(result), 214 | "bytes to", 215 | address, 216 | flush=True, 217 | ) 218 | send_one_message(conn, result) 219 | print( 220 | f"{get_time()} - data sent, total processed:{total_processed}", 221 | flush=True, 222 | ) 223 | -------------------------------------------------------------------------------- /env/dft_worker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import traceback 4 | 5 | import ase 6 | import ase.io 7 | import time 8 | import psi4 9 | import numpy as np 10 | import pickle 11 | 12 | from psi4 import SCFConvergenceError 13 | from psi4.driver.p4util.exceptions import OptimizationConvergenceError 14 | 15 | PSI4_BOHR2ANGSTROM = 0.52917720859 16 | 17 | # os.environ['PSI_SCRATCH'] = "/dev/shm/tmp" 18 | # psi4.set_options({ "CACHELEVEL": 0 }) 19 | 20 | psi4.set_memory("16 GB") 21 | psi4.core.IOManager.shared_object().set_default_path("/dev/shm/") 22 | psi4.core.set_output_file("/dev/null") 23 | 24 | HEADER = "units ang \n nocom \n noreorient \n" 25 | FUNCTIONAL_STRING = "wb97x-d/def2-svp" 26 | 27 | 28 | def read_xyz_file_block(file, look_for_charge=True): 29 | """ """ 30 | 31 | atomic_symbols = [] 32 | xyz_coordinates = [] 33 | charge = 0 34 | title = "" 35 | 36 | line = file.readline() 37 | if not line: 38 | return None 39 | num_atoms = int(line) 40 | line = file.readline() 41 | if "charge=" in line: 42 | charge = int(line.split("=")[1]) 43 | 44 | for _ in range(num_atoms): 45 | line = file.readline() 46 | atomic_symbol, x, y, z = line.split()[:4] 47 | atomic_symbols.append(atomic_symbol) 48 | xyz_coordinates.append([float(x), float(y), float(z)]) 49 | 50 | return atomic_symbols, xyz_coordinates, charge 51 | 52 | 53 | def read_xyz_file(filename, look_for_charge=True): 54 | """ """ 55 | mol_data = [] 56 | with open(filename) as xyz_file: 57 | while xyz_file: 58 | current_data = read_xyz_file_block(xyz_file, look_for_charge) 59 | if current_data: 60 | mol_data.append(current_data) 61 | else: 62 | break 63 | 64 | return mol_data 65 | 66 | 67 | def xyz2psi4mol(atoms, coordinates): 68 | molecule_string = HEADER + "\n".join( 69 | [ 70 | " ".join( 71 | [ 72 | atom, 73 | ] 74 | + list(map(str, x)) 75 | ) 76 | for atom, x in zip(atoms, coordinates) 77 | ] 78 | ) 79 | mol = psi4.geometry(molecule_string) 80 | return mol 81 | 82 | 83 | def atoms2psi4mol(atoms): 84 | atomic_numbers = [str(atom) for atom in atoms.get_atomic_numbers().tolist()] 85 | coordinates = atoms.get_positions().tolist() 86 | return xyz2psi4mol(atomic_numbers, coordinates) 87 | 88 | 89 | def get_dft_forces_energy(mol): 90 | # Energy in Hartrees, force in Hatrees/Angstrom 91 | try: 92 | gradient, wfn = psi4.driver.gradient( 93 | FUNCTIONAL_STRING, **{"molecule": mol, "return_wfn": True} 94 | ) 95 | energy = wfn.energy() 96 | forces = -np.array(gradient) / PSI4_BOHR2ANGSTROM 97 | return energy, forces 98 | except SCFConvergenceError as e: 99 | # Set energy to some threshold if SOSCF does not converge 100 | description = traceback.format_exc() 101 | print(f"DFT optimization did not converge!\n{description}", file=sys.stderr) 102 | return None, None 103 | finally: 104 | psi4.core.clean() 105 | 106 | 107 | def update_ase_atoms_positions(atoms, positions): 108 | atoms.set_positions(positions) 109 | 110 | 111 | def update_psi4_geometry(molecule, positions): 112 | psi4matrix = psi4.core.Matrix.from_array(positions / PSI4_BOHR2ANGSTROM) 113 | molecule.set_geometry(psi4matrix) 114 | molecule.update_geometry() 115 | 116 | 117 | def calculate_dft_energy_item(task): 118 | # Get molecule from the queue 119 | ase_atoms, _, idx = task 120 | 121 | ase_atoms = ase.Atoms.fromdict(ase_atoms) 122 | 123 | print("task", idx) 124 | 125 | t1 = time.time() 126 | 127 | molecule = atoms2psi4mol(ase_atoms) 128 | 129 | # Perform DFT minimization 130 | # Energy in Hartree 131 | # Calculate DFT energy 132 | 133 | energy, gradient = get_dft_forces_energy(molecule) 134 | not_converged = energy is None 135 | 136 | t = time.time() - t1 137 | 138 | print("time", t) 139 | 140 | return idx, not_converged, energy, gradient 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--task_path", type=str, required=True) 146 | parser.add_argument("--result_path", type=str, required=True) 147 | parser.add_argument("--num_threads", type=int, required=True) 148 | args = parser.parse_args() 149 | 150 | # set number of threads per worker 151 | psi4.core.set_num_threads(args.num_threads) 152 | 153 | with open(args.task_path, "rb") as file_obj: 154 | task = pickle.load(file_obj) 155 | 156 | result = calculate_dft_energy_item(task) 157 | 158 | with open(args.result_path, "wb") as file_obj: 159 | pickle.dump(result, file_obj) 160 | -------------------------------------------------------------------------------- /env/host_names.txt: -------------------------------------------------------------------------------- 1 | 192.168.19.101 -------------------------------------------------------------------------------- /env/make_envs.py: -------------------------------------------------------------------------------- 1 | from .moldynamics_env import env_fn 2 | from .wrappers import EnergyWrapper 3 | 4 | 5 | def make_envs(args): 6 | # Env kwargs 7 | env_kwargs = { 8 | "db_path": args.db_path, 9 | "n_parallel": args.n_parallel, 10 | "timelimit": args.timelimit_train, 11 | "sample_initial_conformations": True, 12 | "num_initial_conformations": args.num_initial_conformations, 13 | } 14 | 15 | # Reward wrapper kwargs 16 | reward_wrapper_kwargs = { 17 | "dft": args.reward == "dft", 18 | "n_workers": args.n_workers, 19 | "minimize_on_every_step": args.minimize_on_every_step, 20 | "evaluation": False, 21 | "terminate_on_negative_reward": args.terminate_on_negative_reward, 22 | "max_num_negative_rewards": args.max_num_negative_rewards, 23 | "host_file_path": args.host_file_path, 24 | } 25 | 26 | # Initialize env 27 | env = env_fn(**env_kwargs) 28 | env = EnergyWrapper(env, **reward_wrapper_kwargs) 29 | 30 | # Update kwargs for eval_env 31 | if args.eval_db_path != "": 32 | env_kwargs["db_path"] = args.eval_db_path 33 | env_kwargs.update( 34 | { 35 | "sample_initial_conformations": args.sample_initial_conformations, 36 | "timelimit": args.timelimit_eval, 37 | } 38 | ) 39 | 40 | if args.reward == "rdkit": 41 | env_kwargs["n_parallel"] = 1 42 | else: 43 | env_kwargs["n_parallel"] = args.n_eval_runs 44 | reward_wrapper_kwargs["minimize_on_every_step"] = False 45 | reward_wrapper_kwargs["evaluation"] = True 46 | 47 | # Initialize eval env 48 | eval_env = env_fn(**env_kwargs) 49 | eval_env = EnergyWrapper(eval_env, **reward_wrapper_kwargs) 50 | 51 | return env, eval_env 52 | -------------------------------------------------------------------------------- /env/moldynamics_env.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from sqlite3 import DatabaseError 4 | 5 | import backoff 6 | import gymnasium as gym 7 | import numpy as np 8 | import torch 9 | from ase.db import connect 10 | from schnetpack.interfaces import AtomsConverter 11 | from schnetpack.transform import ASENeighborList 12 | 13 | np.seterr(all="ignore") 14 | warnings.filterwarnings("ignore", category=DeprecationWarning) 15 | 16 | 17 | # For backoff exceptions 18 | def on_giveup(details): 19 | print( 20 | "Giving Up after {} tries. Time elapsed: {:.3f} :(".format( 21 | details["tries"], details["elapsed"] 22 | ) 23 | ) 24 | 25 | 26 | class MolecularDynamics(gym.Env): 27 | metadata = {"render_modes": ["human"], "name": "md_v0"} 28 | DISTANCE_THRESH = 0.7 29 | 30 | def __init__( 31 | self, 32 | db_path, 33 | converter, 34 | n_parallel=1, 35 | timelimit=10, 36 | sample_initial_conformations=True, 37 | num_initial_conformations=50000, 38 | ): 39 | self.db_path = db_path 40 | self.converter = converter 41 | self.n_parallel = n_parallel 42 | self.TL = timelimit 43 | self.sample_initial_conformations = sample_initial_conformations 44 | 45 | self.db_len = self.get_db_length() 46 | self.atoms = None 47 | self.mean_energy = 0.0 48 | self.std_energy = 1.0 49 | self.initial_molecule_conformations = [] 50 | self.initial_conformations_ids = [] 51 | 52 | # Store random subset of molecules DB 53 | self.get_initial_molecule_conformations(num_initial_conformations) 54 | self.conformation_idx = 0 55 | 56 | # Initialize lists 57 | self.atoms = [None] * self.n_parallel 58 | self.smiles = [None] * self.n_parallel 59 | self.energy = [None] * self.n_parallel 60 | self.optimal_energy = [None] * self.n_parallel 61 | self.force = [None] * self.n_parallel 62 | self.env_steps = [None] * self.n_parallel 63 | self.atoms_ids = [None] * self.n_parallel 64 | 65 | self.total_num_bad_pairs_before = 0 66 | self.total_num_bad_pairs_after = 0 67 | 68 | def step(self, actions): 69 | # Get number of atoms in each molecule 70 | cumsum_numbers_atoms = self.get_atoms_num_cumsum() 71 | 72 | obs = [] 73 | rewards = [None] * self.n_parallel 74 | dones = [None] * self.n_parallel 75 | info = {} 76 | 77 | for idx in range(self.n_parallel): 78 | # Unpad action 79 | self.atoms[idx].set_positions( 80 | self.atoms[idx].get_positions() 81 | + actions[cumsum_numbers_atoms[idx] : cumsum_numbers_atoms[idx + 1]] 82 | ) 83 | 84 | # Check if there are atoms too close to each other in the molecule 85 | ( 86 | self.atoms[idx], 87 | num_bad_pairs_before, 88 | num_bad_pairs_after, 89 | ) = self.process_molecule(self.atoms[idx]) 90 | self.total_num_bad_pairs_before += num_bad_pairs_before 91 | self.total_num_bad_pairs_after += num_bad_pairs_after 92 | 93 | # Terminate the episode if TL is reached 94 | self.env_steps[idx] += 1 95 | dones[idx] = self.env_steps[idx] >= self.TL 96 | 97 | # Add info about bad pairs 98 | info["total_bad_pairs_before_processing"] = self.total_num_bad_pairs_before 99 | info["total_bad_pairs_after_processing"] = self.total_num_bad_pairs_after 100 | 101 | # Collate observations into a batch 102 | obs = self.converter(self.atoms) 103 | 104 | return obs, rewards, dones, info 105 | 106 | def reset(self, indices=None, increment_conf_idx=True): 107 | # If indices is not provided reset all molecules 108 | if indices is None: 109 | indices = np.arange(self.n_parallel) 110 | 111 | # If sample_initial_conformations iterate over all initial conformations sequentially 112 | if self.sample_initial_conformations: 113 | db_indices = np.random.choice( 114 | len(self.initial_molecule_conformations), len(indices), replace=False 115 | ) 116 | else: 117 | start_conf_idx = self.conformation_idx % len( 118 | self.initial_molecule_conformations 119 | ) 120 | db_indices = np.mod( 121 | np.arange(start_conf_idx, start_conf_idx + len(indices)), 122 | len(self.initial_conformations_ids), 123 | ).astype(np.int64) 124 | if increment_conf_idx: 125 | self.conformation_idx += len(indices) 126 | 127 | rows = [self.initial_molecule_conformations[db_idx] for db_idx in db_indices] 128 | 129 | for idx, row, atom_id in zip( 130 | indices, rows, self.initial_conformations_ids[db_indices] 131 | ): 132 | # Copy to avoid changing the atoms object inplace 133 | self.atoms[idx] = row.toatoms().copy() 134 | self.atoms_ids[idx] = int(atom_id) 135 | 136 | # Check if row has Smiles 137 | if hasattr(row, "smiles"): 138 | self.smiles[idx] = row.smiles 139 | 140 | # Energy and optimal_energy in Hartrees. 141 | # Energies and optimal energies can sometimes be stored in different formats. 142 | if hasattr(row.data, "energy"): 143 | if isinstance(row.data["energy"], list): 144 | assert len(row.data["energy"]) == 1 145 | self.energy[idx] = row.data["energy"][0] 146 | elif isinstance(row.data["energy"], np.ndarray): 147 | assert len(row.data["energy"]) == 1 148 | self.energy[idx] = row.data["energy"].item() 149 | else: 150 | self.energy[idx] = row.data["energy"] 151 | # In case the database is the result of optimization 152 | elif hasattr(row.data, "initial_energy"): 153 | self.energy[idx] = row.data["initial_energy"] 154 | 155 | if hasattr(row.data, "optimal_energy"): 156 | if isinstance(row.data["optimal_energy"], list): 157 | assert len(row.data["optimal_energy"]) == 1 158 | self.optimal_energy[idx] = row.data["optimal_energy"][0] 159 | elif isinstance(row.data["optimal_energy"], np.ndarray): 160 | assert len(row.data["optimal_energy"]) == 1 161 | self.optimal_energy[idx] = row.data["optimal_energy"].item() 162 | else: 163 | self.optimal_energy[idx] = row.data["optimal_energy"] 164 | 165 | # forces in Hartees/ Angstrom 166 | if hasattr(row.data, "forces"): 167 | self.force[idx] = row.data["forces"] 168 | elif hasattr(row.data, "final_forces"): 169 | # In case the database is the result of optimization 170 | self.force[idx] = row.data["final_forces"] 171 | 172 | # Reset env_steps 173 | self.env_steps[idx] = 0 174 | 175 | # Collate observations into a batch 176 | obs = self.converter([self.atoms[idx] for idx in indices]) 177 | 178 | return obs 179 | 180 | def update_timelimit(self, new_timelimit): 181 | self.TL = new_timelimit 182 | 183 | def get_db_length(self): 184 | with connect(self.db_path) as conn: 185 | db_len = len(conn) 186 | return db_len 187 | 188 | def get_env_step(self): 189 | return self.env_steps 190 | 191 | def get_initial_molecule_conformations(self, num_initial_conformations): 192 | if num_initial_conformations == -1 or num_initial_conformations == self.db_len: 193 | self.initial_conformations_ids = np.arange(1, self.db_len + 1) 194 | else: 195 | self.initial_conformations_ids = np.random.choice( 196 | np.arange(1, self.db_len + 1), 197 | min(self.db_len, num_initial_conformations), 198 | replace=False, 199 | ) 200 | self.initial_molecule_conformations = [] 201 | for idx in self.initial_conformations_ids: 202 | row = self.get_molecule(int(idx)) 203 | self.initial_molecule_conformations.append(row) 204 | 205 | # Makes sqllite3 database compatible with NFS storages 206 | @backoff.on_exception( 207 | backoff.expo, 208 | exception=DatabaseError, 209 | jitter=backoff.full_jitter, 210 | max_tries=10, 211 | on_giveup=on_giveup, 212 | ) 213 | def get_molecule(self, idx): 214 | with connect(self.db_path) as conn: 215 | return conn.get(idx) 216 | 217 | def get_atoms_num_cumsum(self): 218 | atoms_num_cumsum = [0] 219 | for atom in self.atoms: 220 | atoms_num_cumsum.append( 221 | atoms_num_cumsum[-1] + len(atom.get_atomic_numbers()) 222 | ) 223 | 224 | return atoms_num_cumsum 225 | 226 | def get_bad_pairs_indices(self, positions): 227 | dir_ij = positions[None, :, :] - positions[:, None, :] 228 | r_ij = np.linalg.norm(dir_ij, axis=2) 229 | 230 | # Set diagonal elements to a large positive number 231 | r_ij[np.diag_indices_from(r_ij)] = 10.0 232 | dir_ij /= r_ij[..., None] 233 | 234 | # Set lower triangle matrix of r_ij to a large positive number 235 | # to avoid finding dublicate pairs 236 | r_ij[np.tri(r_ij.shape[0], k=-1).astype(bool)] = 10.0 237 | 238 | return np.argwhere(r_ij < MolecularDynamics.DISTANCE_THRESH), dir_ij, r_ij 239 | 240 | def process_molecule(self, atoms): 241 | new_atoms = atoms.copy() 242 | positions = new_atoms.get_positions() 243 | 244 | # Detect atoms too close to each other 245 | bad_indices_before, dir_ij, r_ij = self.get_bad_pairs_indices(positions) 246 | 247 | # Move atoms apart. At the moment we assume 248 | # that r_ij < THRESH is a rare event that affects at most 249 | # one pair of atoms so moving them apart should not cause 250 | # any other r_kl to become < THRESH. 251 | for i, j in bad_indices_before: 252 | coef = MolecularDynamics.DISTANCE_THRESH + 0.05 - r_ij[i, j] 253 | positions[i] -= 0.5 * coef * dir_ij[i, j] 254 | positions[j] -= 0.5 * coef * dir_ij[j, i] 255 | new_atoms.set_positions(positions) 256 | 257 | # Check if assumption does not hold 258 | bad_indices_after, _, _ = self.get_bad_pairs_indices(positions) 259 | 260 | return new_atoms, len(bad_indices_before), len(bad_indices_after) 261 | 262 | def seed(self, seed=None): 263 | if seed is None: 264 | seed = np.random.randint(0, 1000000) 265 | np.random.seed(seed) 266 | return seed 267 | 268 | 269 | def env_fn(**kwargs): 270 | """ 271 | To support the AEC API, the raw_env() function just uses the from_parallel 272 | function to convert from a ParallelEnv to an AEC env 273 | """ 274 | converter = AtomsConverter( 275 | neighbor_list=ASENeighborList(cutoff=math.inf), 276 | dtype=torch.float32, 277 | device=torch.device("cpu"), 278 | ) 279 | env = MolecularDynamics(converter=converter, **kwargs) 280 | return env 281 | -------------------------------------------------------------------------------- /env/molecules_xyz/aspirin.xyz: -------------------------------------------------------------------------------- 1 | 21 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | C 2.23930000 -0.37910000 0.26300000 4 | C 0.84240000 1.92310000 -0.42490000 5 | C 2.87090000 0.84560000 0.27220000 6 | C 2.17510000 1.99350000 -0.07030000 7 | C -3.48380000 0.49530000 -0.08960000 8 | C 0.89100000 -0.46470000 -0.09390000 9 | C 0.19080000 0.69910000 -0.44020000 10 | O -0.96330000 -1.84250000 -0.41850000 11 | O -1.65310000 0.88890000 1.34060000 12 | O 0.88570000 -2.88830000 0.22670000 13 | C 0.20900000 -1.77200000 -0.10690000 14 | C -2.01850000 0.68530000 0.20710000 15 | O -1.11890000 0.62850000 -0.78860000 16 | H 0.39620000 -3.72190000 0.20350000 17 | H 2.78670000 -1.27190000 0.52680000 18 | H 0.30690000 2.82240000 -0.69110000 19 | H 3.91300000 0.91080000 0.54820000 20 | H 2.67810000 2.94920000 -0.06040000 21 | H -3.73600000 -0.56230000 -0.01200000 22 | H -4.07630000 1.06370000 0.62730000 23 | H -3.69880000 0.84710000 -1.09860000 24 | -------------------------------------------------------------------------------- /env/molecules_xyz/azobenzene.xyz: -------------------------------------------------------------------------------- 1 | 24 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | N 0.31271200 -0.18912500 -0.65225700 4 | N -0.42515200 0.02781900 0.34833600 5 | C -1.81381900 0.08970200 0.29497300 6 | C -2.50320500 1.26282100 0.67990700 7 | C -3.86265500 1.27736200 0.42399500 8 | C -4.49319300 0.11096200 0.05874300 9 | C -3.72723200 -0.94662300 -0.39618200 10 | C -2.33699100 -0.92411100 -0.40497800 11 | C 1.78223700 -0.09516400 -0.28325200 12 | C 2.65201600 -0.83846900 -1.11253200 13 | C 3.98565700 -0.86814700 -0.92576500 14 | C 4.47985100 -0.11713100 0.19169100 15 | C 3.65412100 0.62306900 1.07781600 16 | C 2.28892300 0.62248100 0.79376500 17 | H -1.82272900 1.92833200 0.98749000 18 | H -4.37697900 2.17199700 0.38762300 19 | H -5.57877000 -0.00035700 -0.05419900 20 | H -4.26480500 -1.74474600 -0.93648500 21 | H -1.80907600 -1.76846700 -0.87420400 22 | H 2.34289600 -1.39224700 -1.95058800 23 | H 4.61128600 -1.31671000 -1.78697400 24 | H 5.59969500 0.01113400 0.25177700 25 | H 4.06963200 1.11739800 1.96829700 26 | H 1.53171200 0.89070400 1.48633400 27 | -------------------------------------------------------------------------------- /env/molecules_xyz/benzene.xyz: -------------------------------------------------------------------------------- 1 | 12 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | C 0.00000000 1.39700000 0.00000000 4 | C 1.20980000 0.69850000 0.00000000 5 | C 1.20980000 -0.69850000 0.00000000 6 | C 0.00000000 -1.39700000 0.00000000 7 | C -1.20980000 -0.69850000 0.00000000 8 | C -1.20980000 0.69850000 0.00000000 9 | H 0.00000000 2.48100000 0.00000000 10 | H 2.14860000 1.24050000 0.00000000 11 | H 2.14860000 -1.24050000 0.00000000 12 | H 0.00000000 -2.48100000 0.00000000 13 | H -2.14860000 -1.24050000 0.00000000 14 | H -2.14860000 1.24050000 0.00000000 15 | -------------------------------------------------------------------------------- /env/molecules_xyz/ethanol.xyz: -------------------------------------------------------------------------------- 1 | 9 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | C 0.00720000 -0.56870000 0.00000000 4 | C -1.28540000 0.24990000 0.00000000 5 | O 1.13040000 0.31470000 0.00000000 6 | H 0.03920000 -1.19720000 0.89000000 7 | H 0.03920000 -1.19720000 -0.89000000 8 | H -1.31750000 0.87840000 0.89000000 9 | H -1.31750000 0.87840000 -0.89000000 10 | H -2.14220000 -0.42390000 0.00000000 11 | H 1.98570000 -0.13650000 0.00000000 12 | -------------------------------------------------------------------------------- /env/molecules_xyz/malonaldehyde.xyz: -------------------------------------------------------------------------------- 1 | 9 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | C 1.18649313 0.22185741 -0.38326844 4 | C 0.00529427 -0.63807785 -0.01059477 5 | C -1.17405044 0.22537018 0.37848979 6 | O -2.24159238 0.09785438 -0.17762684 7 | O 2.23128021 0.08948184 0.17978190 8 | H 0.97124340 1.00128230 -1.16123847 9 | H -0.23318273 -1.37893404 -0.80123256 10 | H 0.30917643 -1.25580996 0.89258006 11 | H -1.09648246 0.93475153 1.21885293 12 | -------------------------------------------------------------------------------- /env/molecules_xyz/naphthalene.xyz: -------------------------------------------------------------------------------- 1 | 18 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | C -1.22480000 -1.39760000 -0.00020000 4 | C -2.39160000 -0.69680000 -0.00010000 5 | C -2.39160000 0.69680000 -0.00020000 6 | C -1.22480000 1.39760000 0.00010000 7 | C 0.00000000 0.70870000 0.00040000 8 | C 1.22480000 1.39760000 0.00010000 9 | C 2.39160000 0.69680000 0.00000000 10 | C 2.39160000 -0.69680000 -0.00020000 11 | C 1.22480000 -1.39760000 0.00000000 12 | C 0.00000000 -0.70870000 0.00010000 13 | H -1.24020000 -2.47750000 0.00400000 14 | H -3.33150000 -1.22870000 -0.00030000 15 | H -3.33150000 1.22870000 -0.00080000 16 | H -1.24020000 2.47750000 -0.00030000 17 | H 1.24030000 2.47750000 -0.00040000 18 | H 3.33150000 1.22870000 -0.00020000 19 | H 3.33150000 -1.22870000 -0.00090000 20 | H 1.24020000 -2.47750000 -0.00050000 21 | -------------------------------------------------------------------------------- /env/molecules_xyz/paracetamol.xyz: -------------------------------------------------------------------------------- 1 | 20 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | C 0.96980000 -0.06870000 0.00887000 4 | C 2.48362000 -0.08513000 0.01528000 5 | O 2.88509000 1.00000000 0.42119000 6 | N 3.07607000 -1.21916000 -0.37930000 7 | C 4.43238000 -1.18856000 -0.35955000 8 | C 5.15464000 -1.90601000 0.59688000 9 | C 6.55668000 -1.80118000 0.69213000 10 | C 7.34416000 -1.00307000 -0.15866000 11 | O 8.61018000 -0.87323000 -0.02662000 12 | C 6.59491000 -0.36801000 -1.16037000 13 | C 5.19403000 -0.45388000 -1.26867000 14 | H 0.62206000 0.88664000 0.34249000 15 | H 0.60206000 -0.83074000 0.66382000 16 | H 0.61534000 -0.25039000 -0.98423000 17 | H 2.56725000 -2.01325000 -0.66171000 18 | H 4.64998000 -2.52699000 1.24859000 19 | H 7.03462000 -2.34193000 1.42982000 20 | H 9.05786000 -1.32281000 0.66289000 21 | H 7.10148000 0.19634000 -1.86032000 22 | H 4.71823000 0.03925000 -2.04033000 23 | -------------------------------------------------------------------------------- /env/molecules_xyz/salicylic_acid.xyz: -------------------------------------------------------------------------------- 1 | 16 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | C 1.63470000 -0.23020000 0.00380000 4 | C 0.16090000 -0.25710000 0.00380000 5 | C -0.56640000 0.94170000 0.01640000 6 | O 0.08300000 2.13260000 0.02860000 7 | C -1.95260000 0.90420000 0.01580000 8 | C -2.61230000 -0.30890000 0.00280000 9 | C -1.89770000 -1.49540000 -0.00920000 10 | C -0.51970000 -1.47730000 -0.01460000 11 | O 2.22290000 0.83290000 0.01470000 12 | O 2.32960000 -1.38470000 -0.00270000 13 | H 0.27110000 2.48700000 -0.85120000 14 | H -2.51700000 1.82500000 0.02580000 15 | H -3.69210000 -0.33260000 0.00240000 16 | H -2.42280000 -2.43910000 -0.01850000 17 | H 0.03330000 -2.40490000 -0.02400000 18 | H 3.29420000 -1.31650000 -0.00250000 19 | -------------------------------------------------------------------------------- /env/molecules_xyz/toluene.xyz: -------------------------------------------------------------------------------- 1 | 15 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | C 2.43290000 0.00000000 0.00100000 4 | C 0.92590000 0.00000000 0.00100000 5 | C 0.23470000 1.19720000 -0.00060000 6 | C -1.14760000 1.19720000 -0.00060000 7 | C -1.83880000 0.00000000 0.00070000 8 | C -1.14760000 -1.19720000 0.00140000 9 | C 0.23470000 -1.19710000 -0.00300000 10 | H 2.79620000 0.00140000 1.02860000 11 | H 2.79620000 -0.89070000 -0.51160000 12 | H 2.79620000 0.88930000 -0.51410000 13 | H 0.77470000 2.13250000 -0.00160000 14 | H -1.68760000 2.13250000 -0.00110000 15 | H -2.91880000 0.00000000 0.00150000 16 | H -1.68760000 -2.13250000 0.00270000 17 | H 0.77470000 -2.13250000 -0.00250000 18 | -------------------------------------------------------------------------------- /env/molecules_xyz/uracil.xyz: -------------------------------------------------------------------------------- 1 | 12 2 | Properties=species:S:1:pos:R:3 pbc="F F F" 3 | C 1.63200000 0.29510000 -0.06370000 4 | C 1.44620000 -1.03750000 0.04140000 5 | N 0.15790000 -1.58810000 0.11580000 6 | C -1.00000000 -0.75480000 0.08230000 7 | N -0.81390000 0.64620000 -0.02800000 8 | C 0.49660000 1.22110000 -0.10470000 9 | O -2.11590000 -1.25560000 0.14840000 10 | O 0.54710000 2.43640000 -0.19770000 11 | H 2.28200000 -1.74920000 0.07380000 12 | H 0.04330000 -2.57330000 0.19270000 13 | H -1.62410000 1.22990000 -0.05140000 14 | H 2.63630000 0.72860000 -0.12190000 15 | -------------------------------------------------------------------------------- /env/wrappers.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import math 3 | import multiprocessing as mp 4 | 5 | import gymnasium as gym 6 | import numpy as np 7 | import torch 8 | from rdkit.Chem import AddHs, AllChem, Conformer, MolFromSmiles 9 | from schnetpack.data.loader import _atoms_collate_fn 10 | from schnetpack.interfaces import AtomsConverter 11 | from schnetpack.transform import ASENeighborList 12 | 13 | from .dft import calculate_dft_energy_tcp_client, get_dft_server_destinations 14 | from .dft_worker import update_ase_atoms_positions 15 | from .xyz2mol import get_rdkit_energy, get_rdkit_force, set_coordinates 16 | 17 | RDKIT_ENERGY_THRESH = 300 18 | KCALMOL2HARTREE = 627.5 19 | 20 | 21 | class BaseOracle: 22 | def __init__(self, n_parallel, update_coordinates_fn): 23 | self.n_parallel = n_parallel 24 | self.update_coordinates_fn = update_coordinates_fn 25 | 26 | self.initial_energies = np.zeros(self.n_parallel) 27 | self.forces = [None] * self.n_parallel 28 | self.molecules = [None] * self.n_parallel 29 | 30 | def get_energies(self, indices): 31 | if indices is None: 32 | indices = np.arange(self.n_parallel) 33 | return self.initial_energies[indices] 34 | 35 | def get_forces(self, indices): 36 | if indices is None: 37 | indices = np.arange(self.n_parallel) 38 | return [np.array(self.forces[i]) for i in indices] 39 | 40 | def update_coordinates(self, positions, indices=None): 41 | if indices is None: 42 | indices = np.arange(self.n_parallel) 43 | assert len(positions) == len( 44 | indices 45 | ), f"Not enough values to update all molecules! Expected {self.n_parallel} but got {len(positions)}" 46 | 47 | # Update current molecules 48 | for i, position in zip(indices, positions): 49 | self.update_coordinates_fn(self.molecules[i], position) 50 | 51 | def update_forces(self, forces, indices=None): 52 | if indices is None: 53 | indices = np.arange(self.n_parallel) 54 | assert len(forces) == len(indices) 55 | for i, force in zip(indices, forces): 56 | self.forces[i] = force 57 | 58 | 59 | class RdkitOracle(BaseOracle): 60 | def __init__(self, n_parallel, update_coordinates_fn): 61 | super().__init__(n_parallel, update_coordinates_fn) 62 | 63 | def calculate_energies_forces(self, max_its=0, indices=None): 64 | if indices is None: 65 | indices = np.arange(self.n_parallel) 66 | 67 | not_converged, energies, forces = ( 68 | np.zeros(len(indices)), 69 | np.zeros(len(indices)), 70 | [None] * len(indices), 71 | ) 72 | 73 | for i, idx in enumerate(indices): 74 | # Perform rdkit minimization 75 | try: 76 | ff = AllChem.MMFFGetMoleculeForceField( 77 | self.molecules[idx], 78 | AllChem.MMFFGetMoleculeProperties(self.molecules[idx]), 79 | confId=0, 80 | ) 81 | ff.Initialize() 82 | not_converged[i] = ff.Minimize(maxIts=max_its) 83 | except Exception as e: 84 | print("Bad SMILES! Unable to minimize.") 85 | energies[i] = get_rdkit_energy(self.molecules[idx]) 86 | forces[i] = get_rdkit_force(self.molecules[idx]) 87 | 88 | return not_converged, energies, forces 89 | 90 | def get_rewards(self, new_energies, indices=None): 91 | if indices is None: 92 | indices = np.arange(self.n_parallel) 93 | assert len(new_energies) == len(indices) 94 | 95 | new_energies_ = np.copy(self.initial_energies) 96 | new_energies_[indices] = new_energies 97 | 98 | rewards = self.initial_energies - new_energies_ 99 | self.initial_energies = new_energies_ 100 | return rewards 101 | 102 | def initialize_molecules(self, indices, smiles_list, molecules, max_its=0): 103 | for i, smiles, molecule in zip(indices, smiles_list, molecules): 104 | # Calculate initial rdkit energy 105 | if smiles is not None: 106 | # Initialize molecule from Smiles 107 | self.molecules[i] = MolFromSmiles(smiles) 108 | self.molecules[i] = AddHs(self.molecules[i]) 109 | # Add random conformer 110 | self.molecules[i].AddConformer( 111 | Conformer(len(molecule.get_atomic_numbers())) 112 | ) 113 | else: 114 | raise ValueError( 115 | "Unknown molecule type {}".format(str(molecule.symbols)) 116 | ) 117 | self.update_coordinates_fn(self.molecules[i], molecule.get_positions()) 118 | _, initial_energies, forces = self.calculate_energies_forces(max_its, indices) 119 | 120 | # Set initial energies for new molecules 121 | self.initial_energies[indices] = initial_energies 122 | 123 | # Set forces 124 | for i, force in zip(indices, forces): 125 | self.forces[i] = force 126 | 127 | 128 | class DFTOracle(BaseOracle): 129 | def __init__( 130 | self, 131 | n_parallel, 132 | update_coordinates_fn, 133 | n_workers, 134 | converter, 135 | host_file_path, 136 | ): 137 | super().__init__(n_parallel, update_coordinates_fn) 138 | 139 | self.previous_molecules = [None] * self.n_parallel 140 | self.dft_server_destinations = get_dft_server_destinations( 141 | n_workers, host_file_path 142 | ) 143 | method = "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn" 144 | self.executors = [ 145 | concurrent.futures.ProcessPoolExecutor( 146 | max_workers=1, mp_context=mp.get_context(method) 147 | ) 148 | for _ in range(len(self.dft_server_destinations)) 149 | ] 150 | self.converter = converter 151 | self.tasks = {} 152 | self.number_processed_conformations = 0 153 | 154 | self.task_queue_full_flag = False 155 | 156 | def update_coordinates(self, positions, indices=None): 157 | if indices is None: 158 | indices = np.arange(self.n_parallel) 159 | 160 | # Update previous molecules 161 | for i in indices: 162 | self.previous_molecules[i] = self.molecules[i].copy() 163 | 164 | # Update current molecules 165 | super().update_coordinates(positions, indices) 166 | 167 | def close_executors(self): 168 | for executor in self.executors: 169 | executor.shutdown(wait=False, cancel_futures=True) 170 | 171 | def get_data(self, eval=False): 172 | assert self.task_queue_full_flag, "The task queue has not filled up yet!" 173 | 174 | # Wait for all computations to finish 175 | results = self.wait_tasks(eval=eval) 176 | if eval: 177 | assert len(results) == self.n_parallel 178 | else: 179 | assert len(results) > 0 180 | 181 | _, energies, forces, obs, initial_energies = zip(*results) 182 | 183 | energies = np.array(energies) 184 | forces = [np.array(force) for force in forces] 185 | obs = _atoms_collate_fn(obs) 186 | episode_total_delta_energies = np.array(initial_energies) - energies 187 | 188 | self.task_queue_full_flag = False 189 | return obs, energies, forces, episode_total_delta_energies 190 | 191 | def initialize_molecules(self, indices, molecules, initial_energies, forces): 192 | no_initial_energy_indices = [] 193 | for i, molecule, initial_energy, force in zip( 194 | indices, molecules, initial_energies, forces 195 | ): 196 | self.molecules[i] = molecule.copy() 197 | self.previous_molecules[i] = molecule.copy() 198 | if initial_energy is not None: 199 | self.initial_energies[i] = initial_energy 200 | self.forces[i] = force 201 | else: 202 | no_initial_energy_indices.append(i) 203 | 204 | # Calculate initial DFT energy and forces if it's not provided 205 | if no_initial_energy_indices: 206 | self.submit_tasks(no_initial_energy_indices) 207 | # Make sure there were no unfinished tasks 208 | assert len(self.tasks) == len(indices) 209 | 210 | results = self.wait_tasks() 211 | assert len(results) == len(indices) 212 | 213 | # Update initial energies and forces 214 | for result in results: 215 | i, energy, force = result[:3] 216 | self.initial_energies[i] = energy 217 | self.forces[i] = force 218 | 219 | def submit_tasks(self, indices): 220 | self.submitted_indices = indices 221 | for i in indices: 222 | # Replace early_stop_steps with 0 223 | new_task = (i, 0, self.previous_molecules[i].copy()) 224 | 225 | # Select worker and submit task 226 | worker_id = self.number_processed_conformations % len( 227 | self.dft_server_destinations 228 | ) 229 | host, port = self.dft_server_destinations[worker_id] 230 | future = self.executors[worker_id].submit( 231 | calculate_dft_energy_tcp_client, 232 | new_task, 233 | host, 234 | port, 235 | False, 236 | ) 237 | 238 | # Store information about conformation 239 | self.tasks[self.number_processed_conformations] = { 240 | "future": future, 241 | "initial_energy": self.initial_energies[i], 242 | "obs": self.converter(self.previous_molecules[i]), 243 | } 244 | self.number_processed_conformations += 1 245 | 246 | # Check if the task queue is full 247 | if len(self.tasks) >= self.n_parallel: 248 | self.task_queue_full_flag = True 249 | break 250 | 251 | def wait_tasks(self, eval=False): 252 | results = [] 253 | 254 | done_task_ids = [] 255 | # Wait for all active tasks to finish 256 | for key, task in self.tasks.items(): 257 | done_task_ids.append(key) 258 | future = task["future"] 259 | obs = task["obs"] 260 | initial_energy = task["initial_energy"] 261 | i, _, energy, force = future.result() 262 | if energy is None: 263 | print( 264 | f"DFT did not converged for {self.molecules[i].symbols}, id: {i}", 265 | flush=True, 266 | ) 267 | # If eval is True return initial_energy to correctly detect optimization failures, 268 | # else skip the conformation to avoid adding incorrect forces to RB. 269 | if eval: 270 | energy = initial_energy 271 | else: 272 | continue 273 | results.append((i, energy, force, obs, initial_energy)) 274 | 275 | if not eval: 276 | print(f"Total conformations added: {len(results)}") 277 | 278 | # Delete all finished tasks 279 | for key in done_task_ids: 280 | del self.tasks[key] 281 | 282 | return results 283 | 284 | 285 | class EnergyWrapper(gym.Wrapper): 286 | def __init__( 287 | self, 288 | env, 289 | dft=False, 290 | n_workers=1, 291 | minimize_on_every_step=False, 292 | minimize_on_done=True, 293 | evaluation=False, 294 | terminate_on_negative_reward=False, 295 | max_num_negative_rewards=1, 296 | host_file_path=None, 297 | ): 298 | # Set arguments 299 | self.dft = dft 300 | self.n_workers = n_workers 301 | self.minimize_on_every_step = minimize_on_every_step 302 | self.minimize_on_done = minimize_on_done 303 | self.evaluation = evaluation 304 | self.terminate_on_negative_reward = terminate_on_negative_reward 305 | self.max_num_negative_rewards = max_num_negative_rewards 306 | 307 | # Initialize environemnt 308 | super().__init__(env) 309 | self.n_parallel = self.env.n_parallel 310 | 311 | # Initialize rdkit oracle 312 | self.rdkit_oracle = RdkitOracle( 313 | n_parallel=self.n_parallel, update_coordinates_fn=set_coordinates 314 | ) 315 | 316 | # Initialize DFT oracle 317 | converter = AtomsConverter( 318 | neighbor_list=ASENeighborList(cutoff=math.inf), 319 | dtype=torch.float32, 320 | device=torch.device("cpu"), 321 | ) 322 | self.dft_oracle = DFTOracle( 323 | n_parallel=self.n_parallel, 324 | update_coordinates_fn=update_ase_atoms_positions, 325 | n_workers=self.n_workers, 326 | converter=converter, 327 | host_file_path=host_file_path, 328 | ) 329 | 330 | self.negative_rewards_counter = np.zeros(self.n_parallel) 331 | 332 | def step(self, actions): 333 | obs, env_rewards, dones, info = super().step(actions) 334 | dones = np.stack(dones) 335 | 336 | # Put rewards from the environment into info 337 | info = dict(info, **{"env_reward": env_rewards}) 338 | 339 | # Rdkit rewards 340 | new_positions = [molecule.get_positions() for molecule in self.env.atoms] 341 | self.rdkit_oracle.update_coordinates(new_positions) 342 | if self.dft: 343 | self.dft_oracle.update_coordinates(new_positions) 344 | 345 | # Calculate rdkit energies and forces 346 | calc_rdkit_energy_indices = np.where( 347 | self.minimize_on_every_step | (self.minimize_on_done & dones) 348 | )[0] 349 | 350 | _, new_energies, forces = self.rdkit_oracle.calculate_energies_forces( 351 | indices=calc_rdkit_energy_indices 352 | ) 353 | 354 | # Update current energies and forces. Calculate reward 355 | self.rdkit_oracle.update_forces(forces, calc_rdkit_energy_indices) 356 | rdkit_rewards = self.rdkit_oracle.get_rewards( 357 | new_energies, calc_rdkit_energy_indices 358 | ) 359 | 360 | # When agent encounters 'max_num_negative_rewards' terminate the episode 361 | self.negative_rewards_counter[rdkit_rewards < 0] += 1 362 | dones[self.negative_rewards_counter >= self.max_num_negative_rewards] = True 363 | 364 | # Log final energies of molecules 365 | info["final_energy"] = self.rdkit_oracle.initial_energies 366 | 367 | # DFT rewards 368 | if self.dft: 369 | # Conformations whose energy w.r.t. to Rdkit's MMFF is higher than 370 | # RDKIT_ENERGY_THRESH are highly improbable and likely to cause 371 | # an error in DFT calculation and/or significantly 372 | # slow them down. To mitigate this we propose to replace the DFT reward 373 | # in such states with the Rdkit reward, as they are strongly correlated in such states. 374 | rdkit_energy_thresh_exceeded = ( 375 | self.rdkit_oracle.initial_energies >= RDKIT_ENERGY_THRESH 376 | ) 377 | 378 | # Calculate energy and forces with DFT only for terminal states. 379 | # Skip conformations with energy higher than RDKIT_ENERGY_THRESH 380 | calculate_dft_energy_env_ids = np.where( 381 | self.minimize_on_done & dones & ~rdkit_energy_thresh_exceeded 382 | )[0] 383 | if len(calculate_dft_energy_env_ids) > 0: 384 | info["calculate_dft_energy_env_ids"] = calculate_dft_energy_env_ids 385 | self.dft_oracle.submit_tasks(calculate_dft_energy_env_ids) 386 | 387 | return obs, rdkit_rewards, dones, info 388 | 389 | def reset(self, indices=None): 390 | obs = self.env.reset(indices=indices) 391 | if indices is None: 392 | indices = np.arange(self.n_parallel) 393 | 394 | # Reset negative rewards counter 395 | self.negative_rewards_counter[indices] = 0 396 | 397 | # Get sizes of molecules 398 | smiles_list = [self.env.smiles[i] for i in indices] 399 | molecules = [self.env.atoms[i].copy() for i in indices] 400 | self.rdkit_oracle.initialize_molecules(indices, smiles_list, molecules) 401 | 402 | if self.dft: 403 | dft_initial_energies = [self.env.energy[i] for i in indices] 404 | dft_forces = [self.env.force[i] for i in indices] 405 | self.dft_oracle.initialize_molecules( 406 | indices, molecules, dft_initial_energies, dft_forces 407 | ) 408 | 409 | return obs 410 | 411 | def set_initial_positions( 412 | self, molecules, smiles_list, energy_list, force_list, max_its=0 413 | ): 414 | super().reset(increment_conf_idx=False) 415 | indices = np.arange(self.n_parallel) 416 | 417 | # Reset negative rewards counter 418 | self.negative_rewards_counter.fill(0.0) 419 | 420 | obs_list = [] 421 | # Set molecules and get observation 422 | for i, molecule in enumerate(molecules): 423 | self.env.atoms[i] = molecule.copy() 424 | obs_list.append(self.env.converter(molecule)) 425 | 426 | self.rdkit_oracle.initialize_molecules(indices, smiles_list, molecules, max_its) 427 | 428 | if self.dft: 429 | self.dft_oracle.initialize_molecules( 430 | indices, molecules, energy_list, force_list 431 | ) 432 | 433 | obs = _atoms_collate_fn(obs_list) 434 | return obs 435 | 436 | def update_timelimit(self, tl): 437 | return self.env.update_timelimit(tl) 438 | 439 | def get_forces(self, indices=None): 440 | if self.dft: 441 | return self.dft_oracle.get_forces(indices) 442 | else: 443 | return self.rdkit_oracle.get_forces(indices) 444 | 445 | def get_energies(self, indices=None): 446 | if self.dft: 447 | return self.dft_oracle.get_energies(indices) 448 | else: 449 | return self.rdkit_oracle.get_energies(indices) 450 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import datetime 3 | import glob 4 | import math 5 | import os 6 | import pickle 7 | import random 8 | import time 9 | from pathlib import Path 10 | 11 | from ase.db import connect 12 | import numpy as np 13 | import torch 14 | 15 | from GOLF import DEVICE 16 | from GOLF.GOLF_trainer import GOLF 17 | from GOLF.eval import eval_policy_dft, eval_policy_rdkit 18 | from GOLF.make_policies import make_policies 19 | from GOLF.make_saver import make_saver 20 | from GOLF.replay_buffer import ReplayBuffer, fill_initial_replay_buffer 21 | from GOLF.utils import calculate_action_norm, recollate_batch 22 | from env.make_envs import make_envs 23 | from utils.arguments import get_args 24 | from utils.logging import Logger 25 | from utils.utils import ignore_extra_args 26 | 27 | eval_function = { 28 | "rdkit": ignore_extra_args(eval_policy_rdkit), 29 | "dft": ignore_extra_args(eval_policy_dft), 30 | } 31 | 32 | REWARD_THRESHOLD = -100 33 | 34 | 35 | def main(args, experiment_folder): 36 | # Set env name 37 | args.env = args.db_path.split("/")[-1].split(".")[0] 38 | 39 | # Initialize logger 40 | logger = Logger(experiment_folder, args) 41 | 42 | # Initialize envs 43 | env, eval_env = make_envs(args) 44 | 45 | # Initialize replay buffer. 46 | atomrefs = None 47 | initial_replay_buffer = None 48 | # First, initialize a RB with initial conformations 49 | if args.reward == "dft": 50 | # Read atomization energy from the database 51 | with connect(args.db_path) as conn: 52 | if "atomrefs" in conn.metadata and args.subtract_atomization_energy: 53 | atomrefs = conn.metadata["atomrefs"]["energy"] 54 | if args.subtract_atomization_energy: 55 | assert ( 56 | atomrefs is not None 57 | ), "Per-atom energies are not provided but args.subtract_atomization_energy is True." 58 | 59 | if args.load_model and not args.store_only_initial_conformations: 60 | replay_buffer = pickle.load(open(f"{args.load_model}_replay", "rb")) 61 | # For compatability 62 | if not hasattr(replay_buffer, "max_total_conformations"): 63 | replay_buffer.max_total_conformations = args.max_oracle_steps 64 | else: 65 | # Initialize a fixed replay buffer with conformations from the database 66 | print("Filling replay buffer with initial conformations...") 67 | initial_replay_buffer = fill_initial_replay_buffer( 68 | device=DEVICE, 69 | db_path=args.db_path, 70 | timelimit=args.timelimit_train, 71 | num_initial_conformations=args.num_initial_conformations, 72 | atomrefs=atomrefs, 73 | ) 74 | print(f"Done! RB size: {initial_replay_buffer.size}") 75 | print("Filling evaluation buffer with conformations...") 76 | eval_replay_buffer = fill_initial_replay_buffer( 77 | device=DEVICE, 78 | db_path=args.eval_db_path, 79 | timelimit=args.timelimit_train, 80 | num_initial_conformations=args.eval_num_initial_conformations, 81 | atomrefs=atomrefs, 82 | ) 83 | print(f"Done! Eval RB size: {eval_replay_buffer.size}") 84 | replay_buffer = ReplayBuffer( 85 | device=DEVICE, 86 | max_size=args.replay_buffer_size, 87 | max_total_conformations=args.max_oracle_steps, 88 | atomrefs=atomrefs, 89 | initial_RB=initial_replay_buffer, 90 | eval_RB=eval_replay_buffer, 91 | initial_conf_pct=args.initial_conf_pct, 92 | ) 93 | 94 | if args.load_model and args.store_only_initial_conformations: 95 | # Explicitly set RB size to the checkpoint iteration 96 | replay_buffer.size = int(args.load_model.split("/")[-1].split("_")[-1]) 97 | 98 | # Inititalize policy and eval policy 99 | policy, eval_policy = make_policies(env, eval_env, args) 100 | 101 | # Initialize experience saver 102 | experience_saver = make_saver( 103 | args, 104 | env=env, 105 | replay_buffer=replay_buffer, 106 | actor=policy.actor, 107 | reward_thresh=REWARD_THRESHOLD, 108 | ) 109 | 110 | # Initialize trainer 111 | trainer = GOLF( 112 | policy=policy, 113 | lr=args.lr, 114 | batch_size=args.batch_size, 115 | clip_value=args.clip_value, 116 | lr_scheduler=args.lr_scheduler, 117 | energy_loss_coef=args.energy_loss_coef, 118 | force_loss_coef=args.force_loss_coef, 119 | load_model=args.load_model, 120 | total_steps=args.max_oracle_steps * args.utd_ratio, 121 | utd_ratio=args.utd_ratio, 122 | optimizer_name=args.optimizer, 123 | ) 124 | if args.load_baseline: 125 | trainer.light_load(args.load_baseline) 126 | 127 | if not args.store_only_initial_conformations: 128 | state = env.reset() 129 | # Set initial states in Policy 130 | policy.reset(state) 131 | 132 | episode_returns = np.zeros(args.n_parallel) 133 | 134 | policy.train() 135 | 136 | # Set training flag to False (for dft reward only) 137 | train_model_flag = False 138 | 139 | # Set evaluation/save flags to False (for dft reward only). 140 | # Flag is set to True every time new experience is added 141 | # to replay buffer. This is done to avoid multiple 142 | # evaluations of the same model 143 | eval_model_flag = False 144 | full_save_flag = False 145 | light_save_flag = False 146 | 147 | # Train until the number of conformations in 148 | # replay buffer is less than max_oracle_steps 149 | while not replay_buffer.replay_buffer_full: 150 | start = time.perf_counter() 151 | update_condition = replay_buffer.size >= args.batch_size 152 | 153 | if not args.store_only_initial_conformations: 154 | # Get current timesteps 155 | episode_timesteps = env.unwrapped.get_env_step() 156 | # Select next action 157 | actions = policy.act(episode_timesteps)["action"].cpu().numpy() 158 | print("policy.act() time: {:.4f}".format(time.perf_counter() - start)) 159 | 160 | # If action contains non finites then reset everything and continue 161 | if not np.isfinite(actions).all(): 162 | state = env.reset() 163 | policy.reset(state) 164 | episode_returns = np.zeros(args.n_parallel) 165 | continue 166 | 167 | next_state, rewards, dones, info = env.step(actions) 168 | episode_returns += rewards 169 | 170 | if args.reward == "dft": 171 | # If task queue is full wait for all tasks to finish and store data to RB 172 | if env.dft_oracle.task_queue_full_flag: 173 | ( 174 | states, 175 | energies, 176 | forces, 177 | episode_total_delta_energies, 178 | ) = env.dft_oracle.get_data() 179 | replay_buffer.add(states, forces, energies) 180 | 181 | logger.update_dft_return_statistics(episode_total_delta_energies) 182 | 183 | # After new data has been added to replay buffer reset all flags 184 | train_model_flag = True 185 | eval_model_flag = True 186 | full_save_flag = True 187 | light_save_flag = True 188 | else: 189 | experience_saver(next_state, rewards, dones) 190 | else: 191 | dones = np.stack([True for _ in range(args.n_parallel)]) 192 | 193 | if ( 194 | update_condition and (args.reward != "dft" or train_model_flag) 195 | ) or args.store_only_initial_conformations: 196 | # Train agent after collecting sufficient data 197 | for _ in range(args.n_parallel * args.utd_ratio): 198 | step_metrics = trainer.update(replay_buffer) 199 | 200 | # Calculate evaluation metrics 201 | step_metrics.update(trainer.eval(replay_buffer)) 202 | 203 | # Reset flag 204 | train_model_flag = False 205 | 206 | # Increase replay buffer size without adding any data 207 | # so that the training is done for the correct amount of steps 208 | if args.store_only_initial_conformations: 209 | replay_buffer.size += args.n_parallel 210 | replay_buffer.replay_buffer_full = ( 211 | replay_buffer.size >= replay_buffer.max_total_conformations 212 | ) 213 | else: 214 | step_metrics = dict() 215 | 216 | step_metrics["Timestamp"] = str(datetime.datetime.now()) 217 | step_metrics["RB_size"] = min(replay_buffer.size, replay_buffer.max_size) 218 | 219 | if not args.store_only_initial_conformations: 220 | step_metrics["Action_norm"] = calculate_action_norm( 221 | actions, env.get_atoms_num_cumsum() 222 | ).item() 223 | # Calculate average number of pairs of atoms too close together 224 | # in env before and after processing 225 | step_metrics["Molecule/num_bad_pairs_before"] = info[ 226 | "total_bad_pairs_before_processing" 227 | ] 228 | step_metrics["Molecule/num_bad_pairs_after"] = info[ 229 | "total_bad_pairs_after_processing" 230 | ] 231 | # Update training statistics 232 | for i, done in enumerate(dones): 233 | if done: 234 | logger.update_evaluation_statistics( 235 | episode_timesteps[i] + 1, 236 | episode_returns[i].item(), 237 | info["final_energy"][i], 238 | ) 239 | episode_returns[i] = 0 240 | 241 | # If the episode is terminated 242 | if not args.store_only_initial_conformations: 243 | envs_to_reset = [i for i, done in enumerate(dones) if done] 244 | 245 | # Recollate state_batch after resets. 246 | # Execute only if at least one env has reset. 247 | if len(envs_to_reset) > 0: 248 | reset_states = env.reset(indices=envs_to_reset) 249 | # Reset initial states in policy 250 | policy.reset(reset_states, indices=envs_to_reset) 251 | 252 | # Print iteration time 253 | print("Full iteration time: {:.4f}".format(time.perf_counter() - start)) 254 | 255 | # Evaluate episode 256 | if ( 257 | args.reward != "dft" 258 | or eval_model_flag 259 | or args.store_only_initial_conformations 260 | ) and ( 261 | (replay_buffer.size // args.n_parallel) 262 | % math.ceil(args.eval_freq / float(args.n_parallel)) 263 | == 0 264 | or replay_buffer.size == 0 265 | ): 266 | print(f"Evaluation at step {replay_buffer.size}...") 267 | # Update eval policy 268 | eval_policy.actor = copy.deepcopy(policy.actor) 269 | step_metrics["Total_timesteps"] = replay_buffer.size 270 | step_metrics["Total_training_steps"] = replay_buffer.size * args.utd_ratio 271 | step_metrics["FPS"] = args.n_parallel / (time.perf_counter() - start) 272 | if not args.store_only_initial_conformations or args.reward == "rdkit": 273 | step_metrics.update( 274 | eval_function[args.reward]( 275 | actor=eval_policy, 276 | env=eval_env, 277 | eval_episodes=args.n_eval_runs, 278 | eval_termination_mode=args.eval_termination_mode, 279 | ) 280 | ) 281 | logger.log(step_metrics) 282 | 283 | # Prevent evaluations until new data is added to replay buffer 284 | eval_model_flag = False 285 | 286 | # Save checkpoints 287 | if ( 288 | ( 289 | args.reward != "dft" 290 | or full_save_flag 291 | or args.store_only_initial_conformations 292 | ) 293 | and (replay_buffer.size // args.n_parallel) 294 | % (args.full_checkpoint_freq // args.n_parallel) 295 | == 0 296 | and args.save_checkpoints 297 | ): 298 | # Remove previous checkpoint 299 | old_checkpoint_files = glob.glob(f"{experiment_folder}/full_cp_iter*") 300 | for cp_file in old_checkpoint_files: 301 | os.remove(cp_file) 302 | 303 | # Save new checkpoint 304 | save_t = replay_buffer.size 305 | trainer_save_name = f"{experiment_folder}/full_cp_iter_{save_t}" 306 | trainer.save(trainer_save_name) 307 | 308 | # Do not save the RB if no new data is generated 309 | if not args.store_only_initial_conformations: 310 | with open( 311 | f"{experiment_folder}/full_cp_iter_{save_t}_replay", "wb" 312 | ) as outF: 313 | pickle.dump(replay_buffer, outF) 314 | 315 | # Prevent checkpoint saving until new data is added to replay buffer 316 | full_save_flag = False 317 | 318 | if ( 319 | ( 320 | args.reward != "dft" 321 | or light_save_flag 322 | or args.store_only_initial_conformations 323 | ) 324 | and (replay_buffer.size // args.n_parallel) 325 | % (args.light_checkpoint_freq // args.n_parallel) 326 | == 0 327 | and args.save_checkpoints 328 | ): 329 | save_t = replay_buffer.size 330 | trainer_save_name = f"{experiment_folder}/light_cp_iter_{save_t}" 331 | trainer.light_save(trainer_save_name) 332 | 333 | # Prevent checkpoint saving until new data is added to replay buffer 334 | light_save_flag = False 335 | 336 | # Save final model 337 | # Remove previous checkpoint 338 | old_checkpoint_files = glob.glob(f"{experiment_folder}/full_cp_iter*") 339 | for cp_file in old_checkpoint_files: 340 | os.remove(cp_file) 341 | 342 | # Save new checkpoint 343 | save_t = replay_buffer.size 344 | trainer_save_name = f"{experiment_folder}/full_cp_iter_{save_t}" 345 | trainer.save(trainer_save_name) 346 | 347 | # Do not save the RB if no new data is generated 348 | if not args.store_only_initial_conformations: 349 | with open(f"{experiment_folder}/full_cp_iter_{save_t}_replay", "wb") as outF: 350 | pickle.dump(replay_buffer, outF) 351 | 352 | 353 | if __name__ == "__main__": 354 | args = get_args() 355 | 356 | log_dir = Path(args.log_dir) 357 | if args.seed is None: 358 | args.seed = random.randint(0, 1000000) 359 | 360 | # Check hyperparameters 361 | if args.store_only_initial_conformations: 362 | assert args.timelimit_train == 1 363 | assert args.initial_conf_pct == 1.0 364 | 365 | if args.reward == "rdkit": 366 | assert not args.subtract_atomization_energy 367 | 368 | # Seed everything 369 | torch.manual_seed(args.seed) 370 | np.random.seed(args.seed) 371 | random.seed(args.seed) 372 | # args.git_sha = get_current_gitsha() 373 | 374 | start_time = datetime.datetime.strftime( 375 | datetime.datetime.now(), "%Y_%m_%d_%H_%M_%S" 376 | ) 377 | if args.load_model is not None: 378 | assert os.path.exists(f"{args.load_model}_actor"), "Checkpoint not found!" 379 | exp_folder = log_dir / args.load_model.split("/")[-2] 380 | else: 381 | exp_folder = log_dir / f"{args.exp_name}_{start_time}_{args.seed}" 382 | 383 | main(args, exp_folder) 384 | -------------------------------------------------------------------------------- /read_evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import json 4 | import math 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import wandb 9 | 10 | 11 | def mean_std(values, min_threshold=-math.inf, max_threshold=math.inf): 12 | values = np.asarray(values) 13 | mask = (min_threshold <= values) & (values <= max_threshold) 14 | values = values[mask] 15 | 16 | return values.mean(), values.std() 17 | 18 | 19 | def outliers_fraction(values, min_threshold=-math.inf, max_threshold=math.inf): 20 | values = np.asarray(values) 21 | min_mask = values < min_threshold 22 | min_mean = values[min_mask].mean() 23 | max_mask = max_threshold < values 24 | max_mean = values[max_mask].mean() 25 | 26 | return (min_mask.mean(), min_mean), (max_mask.mean(), max_mean) 27 | 28 | 29 | def read_metrics(path): 30 | metrics = {} 31 | with open(path, 'r') as file_obj: 32 | for line in file_obj: 33 | line = line.strip() 34 | if line == '': 35 | continue 36 | 37 | record = json.loads(line) 38 | metrics[record['conformation_id']] = record 39 | 40 | return metrics 41 | 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument( 46 | "--path", 47 | type=str, 48 | required=True, 49 | ) 50 | args = parser.parse_args() 51 | 52 | metrics = read_metrics(args.path) 53 | pct_prefix = 'pct_of_minimized_energy@step:' 54 | energy_rdkit_prefix = 'rdkit_energy@step:' 55 | n_iter_prefix = 'n_iter@step:' 56 | energy_mse_prefix = 'energy_mse@step:' 57 | force_mse_prefix = 'force_mse@step:' 58 | not_finite_action_step = 'not_finite_action_step' 59 | rdkit_initial_energy = 'rdkit_initial_energy' 60 | rdkit_final_energy = 'rdkit_final_energy' 61 | 62 | stats = [] 63 | for conformation_id, record in metrics.items(): 64 | pct = {} 65 | n_iter = {} 66 | energy_rdkit = {} 67 | energy_mse = {} 68 | force_mse = {} 69 | invalid_step = record[not_finite_action_step] 70 | for key, value in record.items(): 71 | if key.startswith(pct_prefix): 72 | step = int(key[len(pct_prefix):]) 73 | pct[step] = value 74 | elif key.startswith(energy_mse_prefix): 75 | step = int(key[len(energy_mse_prefix):]) 76 | energy_mse[step] = value 77 | elif key.startswith(force_mse_prefix): 78 | step = int(key[len(force_mse_prefix):]) 79 | force_mse[step] = value 80 | elif key.startswith(n_iter_prefix): 81 | step = int(key[len(n_iter_prefix):]) 82 | n_iter[step] = value 83 | elif key.startswith(energy_rdkit_prefix): 84 | step = int(key[len(energy_rdkit_prefix):]) 85 | energy_rdkit[step] = value 86 | 87 | for step in sorted(pct.keys()): 88 | if 0 <= invalid_step <= step: 89 | continue 90 | stats.append( 91 | {'conformation_id': conformation_id, 'step': step, 'n_iter': n_iter[step], 92 | 'energy_rdkit': energy_rdkit[step], 'pct': pct[step], 'energy_mse': energy_mse[step], 93 | 'force_mse': force_mse[step]} 94 | ) 95 | 96 | print(f'{"conformation_id":15} {"step":6} {"n_iter":6} {"energy_rdkit":20} {"pct":20} {"energy_mse":20} {"force_mse":20}') 97 | for record in stats: 98 | conformation_id = record['conformation_id'] 99 | step = record['step'] 100 | n_iter = record['n_iter'] 101 | pct = record['pct'] 102 | energy_mse = record['energy_mse'] 103 | force_mse = record['force_mse'] 104 | energy_rdkit = record['energy_rdkit'] 105 | print(f'{conformation_id:<15} {step:<6} {n_iter:<6} {energy_rdkit:<20} {pct:<20} {energy_mse:<20} {force_mse:<20}') 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ase==3.23.0 2 | h5py==3.11.0 3 | tqdm==4.66.3 4 | PyYAML==6.0.1 5 | gymnasium==0.29.1 6 | schnetpack==2.0.4 7 | backoff==1.11.1 8 | wandb==0.17.1 -------------------------------------------------------------------------------- /scripts/babysit_dft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit immediately if a command exits with a non-zero status. 4 | set -ex 5 | 6 | # Function to clean up resources 7 | cleanup() { 8 | echo "Stopping workers..." 9 | pkill -P $$ 10 | 11 | echo "Cleaning shared memory..." 12 | rm -f /dev/shm/psi* /dev/shm/null* /dev/shm/dfh* 13 | } 14 | 15 | # Trap EXIT signal to clean up resources 16 | trap cleanup EXIT 17 | 18 | # Parameters 19 | NUM_THREADS=$1 20 | NUM_WORKERS=$2 21 | START_PORT=$3 22 | END_PORT=$(($START_PORT + $NUM_WORKERS - 1)) 23 | 24 | # Clean up any leftover shared memory files 25 | rm -f /dev/shm/psi* /dev/shm/null* /dev/shm/dfh* 26 | 27 | # Launch workers 28 | for PORT in $(seq $START_PORT $END_PORT); do 29 | python ../env/dft.py $NUM_THREADS $PORT &> worker_$PORT.out & 30 | done 31 | 32 | # Wait for all background jobs to finish 33 | wait 34 | -------------------------------------------------------------------------------- /scripts/setup_dft_workers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit immediately if a command exits with a non-zero status. 4 | set -e 5 | 6 | # Function to print usage 7 | print_usage() { 8 | echo "Usage: $0 [-r|--relaunch] " 9 | } 10 | 11 | # Parse arguments 12 | RELAUNCH_ONLY=0 13 | if [[ "$1" == "-r" || "$1" == "--relaunch" ]]; then 14 | RELAUNCH_ONLY=1 15 | shift 16 | fi 17 | 18 | # Ensure the correct number of arguments are provided 19 | if [[ $# -ne 3 ]]; then 20 | print_usage 21 | exit 1 22 | fi 23 | 24 | # Parameters 25 | NUM_THREADS=$1 26 | NUM_WORKERS=$2 27 | START_PORT=$3 28 | 29 | # Function to check if the -r flag is in the wrong place 30 | check_argument_order() { 31 | for arg in "$@"; do 32 | if [[ "$arg" == "-r" || "$arg" == "--relaunch" ]]; then 33 | echo "Error: The -r or --relaunch flag should be the first argument." 34 | print_usage 35 | exit 1 36 | fi 37 | done 38 | } 39 | 40 | check_argument_order "$NUM_THREADS" "$NUM_WORKERS" "$START_PORT" 41 | 42 | # Function to install Mamba 43 | install_mamba() { 44 | echo "Installing Mamba..." 45 | mkdir -p ~/mamba3 46 | wget -q https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh -O ~/mamba3/mamba.sh 47 | bash ~/mamba3/mamba.sh -b -u -p ~/mamba3 48 | rm -rf ~/mamba3/mamba.sh 49 | ~/mamba3/bin/conda init bash 50 | } 51 | 52 | # Function to setup the environment 53 | setup_environment() { 54 | echo "Setting up the environment..." 55 | source $CONDA_PREFIX/etc/profile.d/conda.sh # Initialize Conda/Mamba environment 56 | conda create -y -n golf_dft_env python=3.10 # Create the environment with Python 3.10 57 | conda activate golf_dft_env # Activate the newly created environment 58 | conda install -y numpy=1 psi4 -c conda-forge # Install psi4 using Mamba 59 | conda install -y ase -c conda-forge # Install ase using Mamba 60 | } 61 | 62 | # Function to activate the environment 63 | activate_environment() { 64 | echo "Activating the environment..." 65 | source $CONDA_PREFIX/etc/profile.d/conda.sh # Initialize Conda/Mamba environment 66 | conda activate golf_dft_env # Activate the environment 67 | } 68 | 69 | # Function to launch workers 70 | launch_workers() { 71 | echo "Launching workers..." 72 | ./babysit_dft.sh $NUM_THREADS $NUM_WORKERS $START_PORT 73 | } 74 | 75 | # Check if Mamba is already installed 76 | if ! command -v mamba &> /dev/null; then 77 | install_mamba 78 | else 79 | echo "Mamba is already installed." 80 | fi 81 | 82 | # Main script execution 83 | if [[ $RELAUNCH_ONLY -eq 0 ]]; then 84 | setup_environment 85 | else 86 | activate_environment 87 | fi 88 | launch_workers -------------------------------------------------------------------------------- /scripts/setup_gpu_env.sh: -------------------------------------------------------------------------------- 1 | conda create -y -n GOLF_schnetpack python=3.12 2 | conda install -y -n GOLF_schnetpack pytorch pytorch-cuda=12.1 -c pytorch -c nvidia 3 | conda install -y -n GOLF_schnetpack lightning -c conda-forge 4 | conda install -y -n GOLF_schnetpack psi4 -c conda-forge 5 | conda install -y -n GOLF_schnetpack rdkit -c conda-forge -------------------------------------------------------------------------------- /scripts/training/run_training_GOLF_10k.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash -ex 2 | 3 | cuda=$1 4 | 5 | CUDA_VISIBLE_DEVICES=$cuda \ 6 | python ../../main.py --n_parallel 120 \ 7 | --n_threads 24 \ 8 | --db_path ../../data/D-0.db \ 9 | --eval_db_path ../../data/D-test.db \ 10 | --num_initial_conformations -1 \ 11 | --sample_initial_conformations True \ 12 | --timelimit_train 100 \ 13 | --timelimit_eval 50 \ 14 | --terminate_on_negative_reward True \ 15 | --max_num_negative_rewards 1 \ 16 | --reward dft \ 17 | --minimize_on_every_step True \ 18 | --backbone painn \ 19 | --n_interactions 3 \ 20 | --cutoff 5.0 \ 21 | --n_rbf 50 \ 22 | --n_atom_basis 128 \ 23 | --actor GOLF \ 24 | --experience_saver reward_threshold \ 25 | --store_only_initial_conformations False \ 26 | --conformation_optimizer LBFGS \ 27 | --conf_opt_lr 1.0 \ 28 | --conf_opt_lr_scheduler Constant \ 29 | --max_iter 5 \ 30 | --lbfgs_device cpu \ 31 | --momentum 0.0 \ 32 | --lion_beta1 0.9 \ 33 | --lion_beta2 0.99 \ 34 | --batch_size 64 \ 35 | --lr 1e-4 \ 36 | --lr_scheduler CosineAnnealing \ 37 | --optimizer adam \ 38 | --clip_value 1.0 \ 39 | --energy_loss_coef 0.01 \ 40 | --force_loss_coef 0.99 \ 41 | --replay_buffer_size 1000000 \ 42 | --initial_conf_pct 0.1 \ 43 | --max_oracle_steps 10000 \ 44 | --utd_ratio 50 \ 45 | --subtract_atomization_energy True \ 46 | --action_norm_limit 1.0 \ 47 | --eval_freq 120 \ 48 | --n_eval_runs 64 \ 49 | --eval_termination_mode fixed_length \ 50 | --exp_name GOLF-10k \ 51 | --host_file_path ../../env/host_names.txt \ 52 | --full_checkpoint_freq 600 \ 53 | --light_checkpoint_freq 1200 \ 54 | --save_checkpoints True \ 55 | --load_baseline ../../checkpoints/baseline-NNP/NNP_checkpoint \ 56 | --log_dir ../../results \ 57 | -------------------------------------------------------------------------------- /scripts/training/run_training_GOLF_1k.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash -ex 2 | 3 | cuda=$1 4 | 5 | CUDA_VISIBLE_DEVICES=$cuda \ 6 | python ../../main.py --n_parallel 48 \ 7 | --n_threads 24 \ 8 | --db_path ../../data/D-0.db \ 9 | --eval_db_path ../../data/D-test.db \ 10 | --num_initial_conformations -1 \ 11 | --sample_initial_conformations True \ 12 | --timelimit_train 100 \ 13 | --timelimit_eval 50 \ 14 | --terminate_on_negative_reward True \ 15 | --max_num_negative_rewards 1 \ 16 | --reward dft \ 17 | --minimize_on_every_step True \ 18 | --backbone painn \ 19 | --n_interactions 3 \ 20 | --cutoff 5.0 \ 21 | --n_rbf 50 \ 22 | --n_atom_basis 128 \ 23 | --actor GOLF \ 24 | --experience_saver reward_threshold \ 25 | --store_only_initial_conformations False \ 26 | --conformation_optimizer LBFGS \ 27 | --conf_opt_lr 1.0 \ 28 | --conf_opt_lr_scheduler Constant \ 29 | --max_iter 5 \ 30 | --lbfgs_device cpu \ 31 | --momentum 0.0 \ 32 | --lion_beta1 0.9 \ 33 | --lion_beta2 0.99 \ 34 | --batch_size 64 \ 35 | --lr 1e-4 \ 36 | --lr_scheduler CosineAnnealing \ 37 | --optimizer adam \ 38 | --clip_value 1.0 \ 39 | --energy_loss_coef 0.01 \ 40 | --force_loss_coef 0.99 \ 41 | --replay_buffer_size 1000000 \ 42 | --initial_conf_pct 0.1 \ 43 | --max_oracle_steps 1000 \ 44 | --utd_ratio 500 \ 45 | --subtract_atomization_energy True \ 46 | --action_norm_limit 1.0 \ 47 | --eval_freq 48 \ 48 | --n_eval_runs 48 \ 49 | --eval_termination_mode fixed_length \ 50 | --exp_name GOLF-1k \ 51 | --host_file_path ../../env/host_names.txt \ 52 | --full_checkpoint_freq 96 \ 53 | --light_checkpoint_freq 192 \ 54 | --save_checkpoints True \ 55 | --load_baseline ../../checkpoints/baseline-NNP/NNP_checkpoint \ 56 | --log_dir ../../results \ 57 | -------------------------------------------------------------------------------- /scripts/training/run_training_baseline.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash -ex 2 | 3 | cuda=$1 4 | 5 | CUDA_VISIBLE_DEVICES=$cuda \ 6 | python ../../main.py --n_parallel 240 \ 7 | --n_threads 24 \ 8 | --db_path ../../data/D-0.db \ 9 | --eval_db_path ../../data/D-test.db \ 10 | --num_initial_conformations -1 \ 11 | --sample_initial_conformations True \ 12 | --timelimit_train 1 \ 13 | --timelimit_eval 50 \ 14 | --terminate_on_negative_reward True \ 15 | --max_num_negative_rewards 1 \ 16 | --reward dft \ 17 | --minimize_on_every_step True \ 18 | --backbone painn \ 19 | --n_interactions 3 \ 20 | --cutoff 5.0 \ 21 | --n_rbf 50 \ 22 | --n_atom_basis 128 \ 23 | --actor GOLF \ 24 | --experience_saver reward_threshold \ 25 | --store_only_initial_conformations True \ 26 | --conformation_optimizer LBFGS \ 27 | --conf_opt_lr 1.0 \ 28 | --conf_opt_lr_scheduler Constant \ 29 | --max_iter 5 \ 30 | --lbfgs_device cpu \ 31 | --momentum 0.0 \ 32 | --lion_beta1 0.9 \ 33 | --lion_beta2 0.99 \ 34 | --batch_size 64 \ 35 | --lr 1e-4 \ 36 | --lr_scheduler CosineAnnealing \ 37 | --optimizer adam \ 38 | --clip_value 1.0 \ 39 | --energy_loss_coef 0.01 \ 40 | --force_loss_coef 0.99 \ 41 | --replay_buffer_size 1000000 \ 42 | --initial_conf_pct 1.0 \ 43 | --max_oracle_steps 200000 \ 44 | --utd_ratio 5 \ 45 | --subtract_atomization_energy True \ 46 | --action_norm_limit 1.0 \ 47 | --eval_freq 1200 \ 48 | --n_eval_runs 64 \ 49 | --eval_termination_mode fixed_length \ 50 | --exp_name baseline-NNP \ 51 | --full_checkpoint_freq 10000 \ 52 | --light_checkpoint_freq 50000 \ 53 | --save_checkpoints True \ 54 | --log_dir ../../results \ 55 | --project_name GOLF-pyg-baseline \ 56 | -------------------------------------------------------------------------------- /scripts/training/run_training_trajectories-100k.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash -ex 2 | 3 | cuda=$1 4 | 5 | CUDA_VISIBLE_DEVICES=$cuda \ 6 | python ../../main.py --n_parallel 240 \ 7 | --n_threads 24 \ 8 | --db_path ../../data/D-traj-100k.db \ 9 | --eval_db_path ../../D-test.db \ 10 | --num_initial_conformations -1 \ 11 | --sample_initial_conformations True \ 12 | --timelimit_train 1 \ 13 | --timelimit_eval 50 \ 14 | --terminate_on_negative_reward True \ 15 | --max_num_negative_rewards 1 \ 16 | --reward dft \ 17 | --minimize_on_every_step True \ 18 | --backbone painn \ 19 | --n_interactions 3 \ 20 | --cutoff 5.0 \ 21 | --n_rbf 50 \ 22 | --n_atom_basis 128 \ 23 | --actor GOLF \ 24 | --experience_saver reward_threshold \ 25 | --store_only_initial_conformations True \ 26 | --conformation_optimizer LBFGS \ 27 | --conf_opt_lr 1.0 \ 28 | --conf_opt_lr_scheduler Constant \ 29 | --max_iter 5 \ 30 | --lbfgs_device cpu \ 31 | --momentum 0.0 \ 32 | --lion_beta1 0.9 \ 33 | --lion_beta2 0.99 \ 34 | --batch_size 64 \ 35 | --lr 1e-4 \ 36 | --lr_scheduler CosineAnnealing \ 37 | --optimizer adam \ 38 | --clip_value 1.0 \ 39 | --energy_loss_coef 0.01 \ 40 | --force_loss_coef 0.99 \ 41 | --replay_buffer_size 1000000 \ 42 | --initial_conf_pct 1.0 \ 43 | --max_oracle_steps 100000 \ 44 | --utd_ratio 5 \ 45 | --subtract_atomization_energy True \ 46 | --action_norm_limit 1.0 \ 47 | --eval_freq 1200 \ 48 | --n_eval_runs 64 \ 49 | --eval_termination_mode fixed_length \ 50 | --exp_name traj-100k \ 51 | --full_checkpoint_freq 10000 \ 52 | --light_checkpoint_freq 50000 \ 53 | --save_checkpoints True \ 54 | --load_baseline ../../checkpoints/baseline-NNP/NNP_checkpoint \ 55 | --log_dir ../../results \ 56 | -------------------------------------------------------------------------------- /scripts/training/run_training_trajectories-10k.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash -ex 2 | 3 | cuda=$1 4 | 5 | CUDA_VISIBLE_DEVICES=$cuda \ 6 | python ../../main.py --n_parallel 240 \ 7 | --n_threads 24 \ 8 | --db_path ../../data/D-traj-10k.db \ 9 | --eval_db_path ../../data/D-test.db \ 10 | --num_initial_conformations -1 \ 11 | --sample_initial_conformations True \ 12 | --timelimit_train 1 \ 13 | --timelimit_eval 50 \ 14 | --terminate_on_negative_reward True \ 15 | --max_num_negative_rewards 1 \ 16 | --reward dft \ 17 | --minimize_on_every_step True \ 18 | --backbone painn \ 19 | --n_interactions 3 \ 20 | --cutoff 5.0 \ 21 | --n_rbf 50 \ 22 | --n_atom_basis 128 \ 23 | --actor GOLF \ 24 | --experience_saver reward_threshold \ 25 | --store_only_initial_conformations True \ 26 | --conformation_optimizer LBFGS \ 27 | --conf_opt_lr 1.0 \ 28 | --conf_opt_lr_scheduler Constant \ 29 | --max_iter 5 \ 30 | --lbfgs_device cpu \ 31 | --momentum 0.0 \ 32 | --lion_beta1 0.9 \ 33 | --lion_beta2 0.99 \ 34 | --batch_size 64 \ 35 | --lr 1e-4 \ 36 | --lr_scheduler CosineAnnealing \ 37 | --optimizer adam \ 38 | --clip_value 1.0 \ 39 | --energy_loss_coef 0.01 \ 40 | --force_loss_coef 0.99 \ 41 | --replay_buffer_size 1000000 \ 42 | --initial_conf_pct 1.0 \ 43 | --max_oracle_steps 100000 \ 44 | --utd_ratio 5 \ 45 | --subtract_atomization_energy True \ 46 | --action_norm_limit 1.0 \ 47 | --eval_freq 1200 \ 48 | --n_eval_runs 64 \ 49 | --eval_termination_mode fixed_length \ 50 | --exp_name traj-10k \ 51 | --full_checkpoint_freq 10000 \ 52 | --light_checkpoint_freq 50000 \ 53 | --save_checkpoints True \ 54 | --load_baseline ../../checkpoints/baseline-NNP/NNP_checkpoint \ 55 | --log_dir ../../results \ 56 | -------------------------------------------------------------------------------- /scripts/training/run_training_trajectories-500k.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash -ex 2 | 3 | cuda=$1 4 | 5 | CUDA_VISIBLE_DEVICES=$cuda \ 6 | python ../../main.py --n_parallel 240 \ 7 | --n_threads 24 \ 8 | --db_path ../../data/D-traj-500k.db \ 9 | --eval_db_path ../../data/D-test.db \ 10 | --num_initial_conformations -1 \ 11 | --sample_initial_conformations True \ 12 | --timelimit_train 1 \ 13 | --timelimit_eval 50 \ 14 | --terminate_on_negative_reward True \ 15 | --max_num_negative_rewards 1 \ 16 | --reward dft \ 17 | --minimize_on_every_step True \ 18 | --backbone painn \ 19 | --n_interactions 3 \ 20 | --cutoff 5.0 \ 21 | --n_rbf 50 \ 22 | --n_atom_basis 128 \ 23 | --actor GOLF \ 24 | --experience_saver reward_threshold \ 25 | --store_only_initial_conformations True \ 26 | --conformation_optimizer LBFGS \ 27 | --conf_opt_lr 1.0 \ 28 | --conf_opt_lr_scheduler Constant \ 29 | --max_iter 5 \ 30 | --lbfgs_device cpu \ 31 | --momentum 0.0 \ 32 | --lion_beta1 0.9 \ 33 | --lion_beta2 0.99 \ 34 | --batch_size 64 \ 35 | --lr 1e-4 \ 36 | --lr_scheduler CosineAnnealing \ 37 | --optimizer adam \ 38 | --clip_value 1.0 \ 39 | --energy_loss_coef 0.01 \ 40 | --force_loss_coef 0.99 \ 41 | --replay_buffer_size 1000000 \ 42 | --initial_conf_pct 1.0 \ 43 | --max_oracle_steps 200000 \ 44 | --utd_ratio 5 \ 45 | --subtract_atomization_energy True \ 46 | --action_norm_limit 1.0 \ 47 | --eval_freq 1200 \ 48 | --n_eval_runs 64 \ 49 | --eval_termination_mode fixed_length \ 50 | --exp_name traj-500k \ 51 | --full_checkpoint_freq 10000 \ 52 | --light_checkpoint_freq 50000 \ 53 | --save_checkpoints True \ 54 | --load_baseline ../../checkpoints/baseline-NNP/NNP_checkpoint \ 55 | --log_dir ../../results \ 56 | -------------------------------------------------------------------------------- /test_dft_workers.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import multiprocessing as mp 3 | import argparse 4 | import numpy as np 5 | import math 6 | 7 | from ase.db import connect 8 | 9 | from env.dft import calculate_dft_energy_tcp_client, get_dft_server_destinations 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--hostnames", 15 | type=str, 16 | required=True, 17 | help="Path to txt file with ip addresses of CPU-rich machines", 18 | ) 19 | parser.add_argument( 20 | "--db_path", 21 | type=str, 22 | default="data/test_trajectories_initial.db", 23 | help="Path to database. Defaults to optimization evaluation database", 24 | ) 25 | parser.add_argument( 26 | "--num_workers_per_server", 27 | type=int, 28 | required=True, 29 | help="Number of DFT workers per CPU-rich machine", 30 | ) 31 | args = parser.parse_args() 32 | 33 | dft_server_destinations = get_dft_server_destinations(4, args.hostnames) 34 | with connect(args.db_path) as conn: 35 | atoms = conn.get(231).toatoms() 36 | 37 | futures = {} 38 | method = "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn" 39 | executors = [ 40 | concurrent.futures.ProcessPoolExecutor( 41 | max_workers=1, mp_context=mp.get_context(method) 42 | ) 43 | for _ in range(len(dft_server_destinations)) 44 | ] 45 | 46 | print(f"Going to test {len(dft_server_destinations)} workers: ") 47 | for i, (host, port) in enumerate(dft_server_destinations): 48 | print(f"Worker {i}: host {host}, port {port}") 49 | task = (i, 0, atoms.copy()) 50 | worker_id = i % len(dft_server_destinations) 51 | future = executors[worker_id].submit( 52 | calculate_dft_energy_tcp_client, 53 | task, 54 | host, 55 | port, 56 | False, 57 | ) 58 | futures[i] = future 59 | with connect(args.db_path) as conn: 60 | atoms = conn.get(231).toatoms() 61 | 62 | print("Results:") 63 | 64 | while len(futures) > 0: 65 | del_future_ids = [] 66 | for future_id, future in futures.items(): 67 | if not future.done(): 68 | continue 69 | 70 | del_future_ids.append(future_id) 71 | 72 | conformation_id, step, energy, force = future.result() 73 | worker_id = future_id % len(dft_server_destinations) 74 | host, port = dft_server_destinations[worker_id] 75 | if energy is None: 76 | print( 77 | f"Worker {worker_id}: (host={host}, port={port}) returned None for conformation_id={conformation_id}.", 78 | flush=True, 79 | ) 80 | else: 81 | if math.isclose(energy, -899.09231071538): 82 | print(f"Worker {worker_id}: (host={host}, port={port}) OK!") 83 | else: 84 | print( 85 | f"Worker {worker_id}: (host={host}, port={port}). Returned energy={energy} but energy={-899.09231071538} was expected.", 86 | flush=True, 87 | ) 88 | for future_id in del_future_ids: 89 | del futures[future_id] 90 | -------------------------------------------------------------------------------- /utils/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(s): 5 | """helper function used in order to support boolean command line arguments""" 6 | if s.lower() in ("true", "t", "1"): 7 | return True 8 | elif s.lower() in ("false", "f", "0"): 9 | return False 10 | else: 11 | return s 12 | 13 | 14 | def check_positive(value): 15 | int_value = int(value) 16 | if int_value <= 0: 17 | raise argparse.ArgumentTypeError( 18 | f"{int_value} is an invalid positive int value" 19 | ) 20 | return int_value 21 | 22 | 23 | def none_or_str(value): 24 | if value == "None": 25 | return None 26 | return value 27 | 28 | 29 | def get_args(): 30 | parser = argparse.ArgumentParser() 31 | 32 | # Env args 33 | parser.add_argument( 34 | "--n_parallel", 35 | default=1, 36 | type=int, 37 | help="Number of copies of env to run in parallel", 38 | ) 39 | parser.add_argument( 40 | "--n_workers", 41 | default=1, 42 | type=int, 43 | help="Number of parallel DFT workers", 44 | ) 45 | parser.add_argument( 46 | "--db_path", 47 | default="env/data/malonaldehyde.db", 48 | type=str, 49 | help="Path to molecules database for training", 50 | ) 51 | parser.add_argument( 52 | "--eval_db_path", 53 | default="", 54 | type=str, 55 | help="Path to molecules database for evaluation", 56 | ) 57 | parser.add_argument( 58 | "--num_initial_conformations", 59 | default=-1, 60 | type=int, 61 | help="Number of initial molecule conformations to sample from the database. \ 62 | If equals to '-1' sample all conformations from the database.", 63 | ) 64 | parser.add_argument( 65 | "--eval_num_initial_conformations", 66 | default=-1, 67 | type=int, 68 | help="Number of initial molecule conformations to sample from the evaluation database. \ 69 | If equals to '-1' sample all conformations from the database.", 70 | ) 71 | parser.add_argument( 72 | "--sample_initial_conformations", 73 | default=False, 74 | choices=[True, False], 75 | metavar="True|False", 76 | type=str2bool, 77 | help="Sample new conformation for every seed", 78 | ) 79 | 80 | # Episode termination args 81 | parser.add_argument( 82 | "--timelimit_train", default=100, type=int, help="Max episode len on training" 83 | ) 84 | parser.add_argument( 85 | "--timelimit_eval", default=100, type=int, help="Max episode len on evaluation" 86 | ) 87 | parser.add_argument( 88 | "--terminate_on_negative_reward", 89 | default=True, 90 | choices=[True, False], 91 | metavar="True|False", 92 | type=str2bool, 93 | help="Terminate the episode when enough negative rewards are encountered", 94 | ) 95 | parser.add_argument( 96 | "--max_num_negative_rewards", 97 | default=1, 98 | type=check_positive, 99 | help="Max number of negative rewards to terminate the episode", 100 | ) 101 | 102 | # Reward args 103 | parser.add_argument( 104 | "--reward", 105 | choices=["rdkit", "dft"], 106 | default="rdkit", 107 | help="How the energy is calculated", 108 | ) 109 | parser.add_argument( 110 | "--minimize_on_every_step", 111 | default=True, 112 | choices=[True, False], 113 | metavar="True|False", 114 | type=str2bool, 115 | help="Whether to minimize conformation with rdkit on every step", 116 | ) 117 | 118 | # Backbone args 119 | parser.add_argument( 120 | "--backbone", 121 | choices=["schnet", "painn"], 122 | required=True, 123 | help="Type of backbone to use for actor and critic", 124 | ) 125 | parser.add_argument( 126 | "--n_interactions", 127 | default=3, 128 | type=int, 129 | help="Number of interaction blocks for Schnet in actor/critic", 130 | ) 131 | parser.add_argument( 132 | "--cutoff", default=5.0, type=float, help="Cutoff for Schnet in actor/critic" 133 | ) 134 | parser.add_argument( 135 | "--n_rbf", 136 | default=50, 137 | type=int, 138 | help="Number of Gaussians for Schnet in actor/critic", 139 | ) 140 | parser.add_argument( 141 | "--n_atom_basis", 142 | default=128, 143 | type=int, 144 | help="Number of features to describe atomic environments inside backbone", 145 | ) 146 | parser.add_argument( 147 | "--radial_basis_type", 148 | default="Bessel", 149 | choices=["Bessel", "Gaussian"], 150 | help="Radial basis function type", 151 | ) 152 | parser.add_argument( 153 | "do_postprocessing", 154 | default=False, 155 | type=bool, 156 | help="Postprocess energy by subtracting mean", 157 | ) 158 | 159 | # GOLF args 160 | parser.add_argument( 161 | "--actor", 162 | default="GOLF", 163 | type=str, 164 | choices=["GOLF", "rdkit"], 165 | help="Actor type. Rdkit can be used for evaluation only", 166 | ) 167 | parser.add_argument( 168 | "--conformation_optimizer", 169 | default="LBFGS", 170 | type=str, 171 | choices=["GD", "Lion", "LBFGS", "Adam"], 172 | help="Conformation optimizer type", 173 | ) 174 | parser.add_argument( 175 | "--conf_opt_lr", 176 | default=1.0, 177 | type=float, 178 | help="Initial learning rate for conformation optimizer.", 179 | ) 180 | parser.add_argument( 181 | "--conf_opt_lr_scheduler", 182 | choices=["Constant", "CosineAnnealing"], 183 | default="Constant", 184 | help="Conformation optimizer learning rate scheduler type", 185 | ) 186 | parser.add_argument( 187 | "--experience_saver", 188 | default="reward_threshold", 189 | choices=["reward_threshold", "last"], 190 | help="How to save experience to replay buffer", 191 | ) 192 | parser.add_argument( 193 | "--store_only_initial_conformations", 194 | default=False, 195 | choices=[True, False], 196 | metavar="True|False", 197 | type=str2bool, 198 | help="For baseline experiments.", 199 | ) 200 | 201 | # LBFGS args 202 | parser.add_argument( 203 | "--max_iter", 204 | type=int, 205 | default=1, 206 | help="Number of iterations in the inner cycle LBFGS", 207 | ) 208 | parser.add_argument( 209 | "--lbfgs_device", 210 | default="cuda", 211 | type=str, 212 | choices=["cuda", "cpu"], 213 | help="LBFGS device type", 214 | ) 215 | 216 | # GD args 217 | parser.add_argument( 218 | "--momentum", 219 | default=0.0, 220 | type=float, 221 | help="Momentum argument for gradient descent confromation optimizer", 222 | ) 223 | 224 | # Lion args 225 | parser.add_argument( 226 | "--lion_beta1", 227 | default=0.9, 228 | type=float, 229 | help="Beta_1 for Lion conformation optimizer", 230 | ) 231 | parser.add_argument( 232 | "--lion_beta2", 233 | default=0.99, 234 | type=float, 235 | help="Beta_2 for Lion conformation optimizer", 236 | ) 237 | 238 | # Training args 239 | parser.add_argument( 240 | "--batch_size", 241 | default=64, 242 | type=int, 243 | help="Batch size for both actor and critic", 244 | ) 245 | parser.add_argument("--lr", default=3e-4, type=float, help="Actor learning rate") 246 | parser.add_argument( 247 | "--optimizer", 248 | default="adam", 249 | type=str, 250 | choices=["adam", "lion"], 251 | help="Optimizer type", 252 | ) 253 | parser.add_argument( 254 | "--lr_scheduler", 255 | default=None, 256 | type=none_or_str, 257 | choices=[None, "OneCycleLR", "CosineAnnealing", "StepLR"], 258 | help="LR scheduler", 259 | ) 260 | parser.add_argument( 261 | "--clip_value", default=None, help="Clipping value for actor gradients" 262 | ) 263 | parser.add_argument( 264 | "--energy_loss_coef", 265 | default=0.01, 266 | type=float, 267 | help="Weight for the energy part of the backbone loss", 268 | ) 269 | parser.add_argument( 270 | "--force_loss_coef", 271 | default=1.0, 272 | type=float, 273 | help="Weight for the forces part of the backbone loss", 274 | ) 275 | parser.add_argument( 276 | "--initial_conf_pct", 277 | default=0.0, 278 | type=float, 279 | help="Percentage of conformations from the initial database in each batch", 280 | ) 281 | parser.add_argument( 282 | "--max_oracle_steps", 283 | default=1e6, 284 | type=int, 285 | help="Max number of oracle calls", 286 | ) 287 | parser.add_argument( 288 | "--replay_buffer_size", 289 | default=1e5, 290 | type=int, 291 | help="Max capacity of the replay buffer", 292 | ) 293 | parser.add_argument( 294 | "--utd_ratio", 295 | default=1, 296 | type=int, 297 | help="Number of NN updates per each data sample in replay buffer.\ 298 | Total number of training steps = utd_ratio * max_oracle_steps", 299 | ) 300 | parser.add_argument( 301 | "--subtract_atomization_energy", 302 | default=False, 303 | choices=[True, False], 304 | metavar="True|False", 305 | type=str2bool, 306 | help="Subtract atomization energy from the DFT energy for training", 307 | ) 308 | parser.add_argument( 309 | "--action_norm_limit", 310 | default=0.05, 311 | type=float, 312 | help="Upper limit for action norm. Action norms larger get scaled down", 313 | ) 314 | 315 | # Eval args 316 | parser.add_argument( 317 | "--eval_freq", default=1e3, type=int, help="Evaluation frequency" 318 | ) 319 | parser.add_argument( 320 | "--n_eval_runs", default=10, type=int, help="Number of evaluation episodes" 321 | ) 322 | parser.add_argument( 323 | "--eval_termination_mode", 324 | default="fixed_length", 325 | choices=["fixed_length", "grad_norm", "negative_reward"], 326 | help="When to terminate the episode on evaluation", 327 | ) 328 | parser.add_argument( 329 | "--grad_threshold", 330 | default=1e-5, 331 | type=float, 332 | help="Terminates optimization when norm of the gradient is smaller than the threshold", 333 | ) 334 | 335 | # Other args 336 | parser.add_argument( 337 | "--exp_name", required=True, type=str, help="Name of the experiment" 338 | ) 339 | parser.add_argument( 340 | "--host_file_path", 341 | default=None, 342 | type=str, 343 | help="Path to the file with a list of server ip's", 344 | ) 345 | parser.add_argument("--seed", default=None, type=int, help="Random seed") 346 | parser.add_argument( 347 | "--full_checkpoint_freq", 348 | type=int, 349 | default=10000, 350 | help="How often full checkpoints are saved.\ 351 | Note that only the most recent full checkpoint is available", 352 | ) 353 | parser.add_argument( 354 | "--light_checkpoint_freq", 355 | type=int, 356 | default=10000, 357 | help="How often light checkpoints are saved", 358 | ) 359 | parser.add_argument( 360 | "--save_checkpoints", 361 | default=False, 362 | choices=[True, False], 363 | metavar="True|False", 364 | type=str2bool, 365 | help="Save light and full checkpoints", 366 | ) 367 | parser.add_argument( 368 | "--load_baseline", 369 | type=str, 370 | default=None, 371 | help="Checkpoint for the actor. Does not restore replay buffer", 372 | ) 373 | parser.add_argument( 374 | "--load_model", 375 | type=str, 376 | default=None, 377 | help="Full checkpoint path (conformation optimizer and replay buffer)", 378 | ) 379 | parser.add_argument("--log_dir", default=".", help="Directory where runs are saved") 380 | parser.add_argument( 381 | "--project_name", required=True, type=str, help="Wandb project name" 382 | ) 383 | args = parser.parse_args() 384 | 385 | return args 386 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import warnings 4 | 5 | import numpy as np 6 | import os 7 | 8 | from collections import deque 9 | 10 | try: 11 | import wandb 12 | except ImportError: 13 | pass 14 | 15 | 16 | class Logger: 17 | def __init__(self, experiment_folder, config): 18 | # If training is restarted log to the same directory 19 | if os.path.exists(experiment_folder) and config.load_model is None: 20 | raise Exception("Experiment folder exists, apparent seed conflict!") 21 | if config.load_model is None: 22 | os.makedirs(experiment_folder) 23 | 24 | # If training is restarted truncate metrics file 25 | # to the last checkpoint 26 | self.metrics_file = experiment_folder / "metrics.json" 27 | if config.load_model: 28 | # Load config file from checkpoint to correctly estimate eval_freq 29 | with open(experiment_folder / "config.json", "r") as old_config_file: 30 | old_config = json.load(old_config_file) 31 | with open(self.metrics_file, "rb") as f: 32 | lines = f.readlines() 33 | true_eval_freq = old_config["n_parallel"] * ( 34 | old_config["eval_freq"] // old_config["n_parallel"] 35 | ) 36 | checkpoint_iter = ( 37 | int(config.load_model.split("/")[-1].split("_")[-1]) // true_eval_freq 38 | ) 39 | N = len(lines) - checkpoint_iter 40 | with open(self.metrics_file, "wb") as f: 41 | if N > 0: 42 | f.writelines(lines[:-N]) 43 | elif N == 0: 44 | f.writelines(lines) 45 | else: 46 | warnings.warn( 47 | "Checkpoint iteration is older that the latest record in 'metrics.json'." 48 | ) 49 | f.writelines(lines) 50 | else: 51 | self.metrics_file.touch() 52 | 53 | with open(experiment_folder / "config.json", "w") as config_file: 54 | json.dump(config.__dict__, config_file) 55 | 56 | if config.__dict__["reward"] == "dft": 57 | self._keep_n_episodes = config.__dict__["n_parallel"] 58 | else: 59 | self._keep_n_episodes = 10 60 | self.exploration_episode_lengths = deque(maxlen=self._keep_n_episodes) 61 | self.exploration_episode_rdkit_returns = deque(maxlen=self._keep_n_episodes) 62 | self.exploration_episode_dft_returns = deque(maxlen=self._keep_n_episodes) 63 | self.exploration_episode_final_energy = deque(maxlen=self._keep_n_episodes) 64 | self.exploration_episode_number = 0 65 | 66 | self.use_wandb = "wandb" in sys.modules # and os.environ.get("WANDB_API_KEY") 67 | if self.use_wandb: 68 | wandb.init(project=config.project_name, save_code=True, config=config) 69 | else: 70 | warnings.warn("Could not configure wandb access.") 71 | 72 | def log(self, metrics): 73 | metrics["Exploration episodes number"] = self.exploration_episode_number 74 | for name, d in zip( 75 | [ 76 | "episode length", 77 | "episode rdkit return", 78 | "episode dft return", 79 | "episode final energy", 80 | ], 81 | [ 82 | self.exploration_episode_lengths, 83 | self.exploration_episode_rdkit_returns, 84 | self.exploration_episode_dft_returns, 85 | self.exploration_episode_final_energy, 86 | ], 87 | ): 88 | if len(d) == 0: 89 | mean = 0.0 90 | std = 0.0 91 | else: 92 | mean = np.mean(d) 93 | std = np.std(d) 94 | metrics[f"Exploration {name}, mean"] = mean 95 | metrics[f"Exploration {name}, std"] = std 96 | with open(self.metrics_file, "a") as out_metrics: 97 | json.dump(metrics, out_metrics) 98 | out_metrics.write("\n") 99 | 100 | if self.use_wandb: 101 | wandb.log(metrics) 102 | 103 | def update_evaluation_statistics( 104 | self, 105 | episode_length, 106 | episode_rdkit_return, 107 | episode_final_energy, 108 | ): 109 | self.exploration_episode_number += 1 110 | self.exploration_episode_lengths.append(episode_length) 111 | self.exploration_episode_rdkit_returns.append(episode_rdkit_return) 112 | self.exploration_episode_final_energy.append(episode_final_energy) 113 | 114 | def update_dft_return_statistics(self, episode_dft_return): 115 | for val in episode_dft_return: 116 | self.exploration_episode_dft_returns.append(val) 117 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | def ignore_extra_args(foo): 4 | def indifferent_foo(**kwargs): 5 | signature = inspect.signature(foo) 6 | expected_keys = [p.name for p in signature.parameters.values() 7 | if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD] 8 | filtered_kwargs = {k: kwargs[k] for k in kwargs if k in expected_keys} 9 | return foo(**filtered_kwargs) 10 | return indifferent_foo --------------------------------------------------------------------------------