├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── baselines ├── heuristics.py ├── mip.py └── pop.py ├── check_data.py ├── data ├── .gitkeep ├── params.json ├── reset_vm_pm_id_clear2.json └── reset_vm_pm_id_clear_big2.json ├── env_patch.py ├── eval.py ├── experiments └── pretrain │ ├── attn │ └── params.json │ └── mlp │ └── params.json ├── find_mean.py ├── gym-reschdule_combination ├── .gitkeep ├── gym_reschdule_combination.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt ├── gym_reschdule_combination │ ├── .gitkeep │ ├── __init__.py │ └── envs │ │ ├── .gitkeep │ │ ├── __init__.py │ │ ├── vm_rescheduler_env.py │ │ ├── vm_rescheduler_env_heuristic.py │ │ └── vm_rescheduler_env_static.py └── setup.py ├── main-onehead.py ├── main-optim-less.py ├── main-optim.py ├── main.py ├── main_graph.py ├── models ├── __init__.py ├── components │ ├── helpers.py │ ├── multihead.py │ ├── multihead_activation.py │ └── pt_transformer.py ├── gcn_embed.py ├── pm_attn.py ├── pm_attn_graph.py ├── pm_detail_attn.py ├── pm_mlp.py ├── vm_attn.py ├── vm_attn_graph.py ├── vm_lite_sparse_attn.py ├── vm_mlp.py └── vm_sparse_attn.py ├── saved_model_weights └── mlp.ckpt ├── ultimate-attn.py ├── ultra-attn.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | *.pyc 7 | *.pyo 8 | 9 | nohup.out 10 | .DS_Store 11 | *.npy 12 | *.csv 13 | .idea 14 | env 15 | runs 16 | venv 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution guidelines 2 | 3 | First of all, thanks for taking the time to contribute! 4 | 5 | Please refer to the following guidelines to contribute new functionality or bug fixes: 6 | 7 | 1. Use [autopep8](https://github.com/hhatto/autopep8) to format the Python code. 8 | 2. Use [clang-format](https://clang.llvm.org/docs/ClangFormat.html) to format C++ code. Changes to C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). 9 | 3. Add unit tests for any new code you write. 10 | 4. Run unit tests in both CI and GPU environments. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deep Reinforcement Learning-based Virtual Machine Rescheduling 2 | 3 | We are still working on this repository. A more complete and clean version will be provided soon. 4 | 5 | 6 | ### Installation Steps 7 | 8 | 1. Install Anaconda: 9 | 10 | ``` 11 | $ conda create -n rl_vm_scheduling python=3.7 12 | $ conda activate rl_vm_scheduling 13 | ``` 14 | 15 | 2. Install RLlib: 16 | 17 | ``` 18 | $ pip install gym==0.23.1 19 | $ pip install "ray[rllib]" tensorflow torch 20 | $ pip install -e gym-reschdule_combination 21 | ``` 22 | 23 | ### Running Steps 24 | 25 | - Train PPO-based agent 26 | ``` 27 | $ python3 main.py 28 | ``` 29 | - To use pretrained model for VM selection 30 | ``` 31 | $ python3 main.py --track --model [mlp/attn] --pretrain 32 | ``` 33 | - Evaluation 34 | ``` 35 | $ python3 eval.py --restore-name [] --restore-file-name [] --model [mlp/attn] 36 | ``` 37 | 38 | ### Environments 39 | * generalizer-v0: Base environment. Fixed number of VMs. 40 | * generalizer-v1: Dynamic number of VMs. 41 | * graph-v1: Dynamic number of VMs with vm-pm affiliations to support graph models. 42 | -------------------------------------------------------------------------------- /baselines/heuristics.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | from gym_reschdule_combination.envs.vm_rescheduler_env import parse_input 12 | import numpy as np 13 | 14 | 15 | def get_vm_dec_contrib(vm, cpu=16): 16 | # vm从原物理机,释放,减少的碎片 17 | free_cpu = np.array([vm.pm.numas[0].free_cpu, vm.pm.numas[1].free_cpu]) 18 | vm_cpu = np.zeros(2) 19 | for i in vm.deploy_numa: 20 | vm_cpu[i] += vm.cpu * vm.numa_coeff 21 | return sum(free_cpu % cpu) - sum((free_cpu + vm_cpu) % cpu) 22 | 23 | 24 | def get_vm_inc_contrib(pm, numa, vm, cpu=16): 25 | # vm绑定到pm的numa,减少的碎片 26 | free_cpu = np.array([pm.numas[0].free_cpu, pm.numas[1].free_cpu]) 27 | vm_cpu = np.zeros(2) 28 | for i in numa: 29 | vm_cpu[i] += vm.cpu * vm.numa_coeff 30 | if (free_cpu < vm_cpu).any(): 31 | return -1000 32 | return sum(free_cpu % cpu) - sum((free_cpu - vm_cpu) % cpu) 33 | 34 | 35 | def filter(vms): 36 | best_choice = [-1000, None] 37 | for vm in vms: 38 | if vm.cpu > 16: 39 | continue 40 | contrib = get_vm_dec_contrib(vm) 41 | if contrib > best_choice[0]: 42 | best_choice = [contrib, vm] 43 | if best_choice[1] is not None: 44 | return best_choice[1] 45 | return 46 | 47 | 48 | def scorer(pms, vm): 49 | best_choice = [-1000, None, None] 50 | for pm in pms: 51 | if pm is vm.pm: 52 | continue 53 | for numa in [[0], [1]] if not vm.double_numa else [[0, 1]]: 54 | contrib = get_vm_inc_contrib(pm, numa, vm) 55 | if contrib > best_choice[0]: 56 | best_choice = [contrib, pm, numa] 57 | return best_choice[1:] 58 | 59 | 60 | def heuristic_move(instance_json_file, max_migration_num=10): 61 | scheduler = parse_input(instance_json_file) 62 | 63 | pms = list(scheduler.get_all_pms().values()) 64 | vms = list(scheduler.get_all_vms().values()) 65 | print(f"该集群有物理机{len(pms)}台,虚拟机{len(vms)}台") 66 | print(f"碎片治理前,集群碎片率为{scheduler.get_fragment_rate() * 100:.2f}%") 67 | 68 | for step in range(max_migration_num): 69 | move_vm = filter(vms) 70 | if move_vm is None: 71 | print("early stop") 72 | break 73 | frag_dec1 = get_vm_dec_contrib(move_vm) 74 | src_pm = move_vm.pm 75 | # print(src_pm.get_free_cpu_arr()) 76 | src_pm.release_a_vm(move_vm) 77 | # print(src_pm.get_free_cpu_arr()) 78 | 79 | target_pm, target_numa = scorer(pms, move_vm) 80 | frag_dec2 = get_vm_inc_contrib(target_pm, target_numa, move_vm) 81 | if frag_dec1 + frag_dec2 <= 0: 82 | src_pm.add_a_vm(move_vm) 83 | print("early stop") 84 | break 85 | target_pm.add_a_vm(move_vm, target_numa) 86 | print(f"第{step + 1}步,迁出减少碎片{frag_dec1}个,迁入减少碎片{frag_dec2}个") 87 | 88 | print(f"碎片治理后,集群碎片率为{scheduler.get_fragment_rate() * 100:.2f}%") 89 | return 90 | 91 | 92 | if __name__ == '__main__': 93 | json_file = "../data/reset_vm_pm_id_clear_big2.json" 94 | heuristic_move(json_file, max_migration_num=100) 95 | -------------------------------------------------------------------------------- /baselines/mip.py: -------------------------------------------------------------------------------- 1 | from gurobipy import Model, GRB, quicksum 2 | import numpy as np 3 | from gym_reschdule_combination.envs.vm_rescheduler_env import parse_input 4 | import time 5 | 6 | 7 | def mip_move(instance_json_file, max_migration_num=10): 8 | scheduler = parse_input(instance_json_file) 9 | 10 | pms = list(scheduler.get_all_pms().values()) 11 | vms = list(scheduler.get_all_vms().values()) 12 | print(f"该集群有物理机{len(pms)}台,虚拟机{len(vms)}台") 13 | print(f"碎片治理前,集群碎片率为{scheduler.get_fragment_rate() * 100:.2f}%") 14 | 15 | vms_1numa = [vm for vm in vms if not vm.double_numa] 16 | vms_2numa = [vm for vm in vms if vm.double_numa] 17 | 18 | for i, vm in enumerate(vms): 19 | vm.lid = i 20 | for i, pm in enumerate(pms): 21 | pm.lid = i 22 | 23 | num_numa = 2 24 | init_mat2 = np.zeros((len(vms), len(pms))) 25 | init_mat1 = np.zeros((len(vms), len(pms), num_numa)) 26 | for vm in vms: 27 | if vm.double_numa: 28 | init_mat2[vm.lid, vm.pm.lid] = 1 29 | else: 30 | init_mat1[vm.lid, vm.pm.lid, vm.deploy_numa[0]] = 1 31 | 32 | 33 | start_tick = time.time() 34 | m = Model() 35 | x1 = m.addVars([(vm.lid, j, k) for vm in vms_1numa for j in range(len(pms)) for k in range(num_numa)], vtype=GRB.BINARY) 36 | x2 = m.addVars([(vm.lid, j) for vm in vms_2numa for j in range(len(pms))], vtype=GRB.BINARY) 37 | y = m.addVars(len(pms), num_numa, vtype=GRB.INTEGER) # 每台物理机每个numa的剩余可部署的16core虚拟机数量 38 | z = m.addVars(len(pms), vtype=GRB.BINARY) 39 | 40 | for j, pm in enumerate(pms): 41 | for k in range(num_numa): 42 | m.addLConstr( 43 | quicksum([vm.cpu * x1[vm.lid, j, k] for vm in vms_1numa]) + 44 | quicksum([vm.cpu * x2[vm.lid, j] / num_numa for vm in vms_2numa]) + 16 * y[j, k] <= 45 | pm.cpu / num_numa * z[j] 46 | ) 47 | m.addLConstr( 48 | quicksum([vm.mem * x1[vm.lid, j, k] for vm in vms_1numa]) + 49 | quicksum([vm.mem * x2[vm.lid, j] / num_numa for vm in vms_2numa]) <= pm.mem / num_numa * z[j] 50 | ) 51 | for vm in vms: 52 | if vm.double_numa: 53 | m.addLConstr( 54 | quicksum(x2.select(vm.lid, "*")) == 1 55 | ) 56 | else: 57 | m.addLConstr( 58 | quicksum(x1.select(vm.lid, "*")) == 1 59 | ) 60 | 61 | m.addLConstr( 62 | quicksum([1 - x1[vm.lid, vm.pm.lid, vm.deploy_numa[0]] for vm in vms_1numa]) + 63 | quicksum([1 - x2[vm.lid, vm.pm.lid] for vm in vms_2numa]) <= max_migration_num 64 | ) 65 | 66 | # 1. 腾空主机目标 67 | # m.setObjective(quicksum(z)) 68 | # 2. 碎片治理目标(只算16core的碎片率) 69 | total_free = sum([pm.get_free_cpu_arr().sum() for pm in pms]) 70 | m.setObjective( 71 | (total_free - quicksum(y) * 16) / total_free 72 | ) 73 | end_tick = time.time() 74 | print(f"建模用时={end_tick - start_tick:.2f}s") 75 | 76 | m.optimize() 77 | print(f"建模用时={end_tick - start_tick:.2f}s, 求解用时={m.RunTime:.2f}s") 78 | 79 | assert m.status in [GRB.OPTIMAL] 80 | 81 | migration = 0 82 | for i, vm in enumerate(vms): 83 | for j, pm in enumerate(pms): 84 | if vm.double_numa and x2[vm.lid, pm.lid].x > 0.5 and vm.pm is not pm: 85 | vm.pm.release_a_vm(vm) 86 | pm.add_a_vm(vm) 87 | migration += 1 88 | if not vm.double_numa: 89 | for k in range(num_numa): 90 | if x1[vm.lid, pm.lid, k].x > 0.5 and (vm.pm is not pm or vm.deploy_numa[0] != k): 91 | vm.pm.release_a_vm(vm) 92 | vm.deploy_numa = [k] 93 | pm.add_a_vm(vm) 94 | migration += 1 95 | print(f"总迁移次数={migration}") 96 | 97 | print(f"碎片治理后,集群碎片率为{scheduler.get_fragment_rate() * 100:.2f}%") 98 | return 99 | 100 | 101 | if __name__ == '__main__': 102 | # json_file = "../data/reset_vm_pm_id_clear_big2.json" 103 | json_file = "../data/reset_vm_pm_id_clear2.json" 104 | mip_move(json_file, max_migration_num=100) 105 | -------------------------------------------------------------------------------- /baselines/pop.py: -------------------------------------------------------------------------------- 1 | from gurobipy import Model, GRB, quicksum 2 | import numpy as np 3 | from gym_reschdule_combination.envs.vm_rescheduler_env import parse_input 4 | import time 5 | 6 | 7 | def pop_move(instance_json_file, max_migration_num=10, pop=2): 8 | scheduler = parse_input(instance_json_file) 9 | 10 | all_pms = list(scheduler.get_all_pms().values()) 11 | all_vms = list(scheduler.get_all_vms().values()) 12 | 13 | # POP加速:将原问题分解为pop个子问题 14 | num_pm_epoch = len(all_pms) // pop 15 | for epoch in range(pop): 16 | pms = all_pms[num_pm_epoch * epoch: num_pm_epoch * (epoch + 1)] 17 | vms = [] 18 | for pm in pms: 19 | for vm in pm.vms.values(): 20 | vms.append(vm) 21 | 22 | print(f"该集群有物理机{len(pms)}台,虚拟机{len(vms)}台") 23 | print(f"碎片治理前,集群碎片率为{scheduler.get_fragment_rate() * 100:.2f}%") 24 | 25 | vms_1numa = [vm for vm in vms if not vm.double_numa] 26 | vms_2numa = [vm for vm in vms if vm.double_numa] 27 | 28 | for i, vm in enumerate(vms): 29 | vm.lid = i 30 | for i, pm in enumerate(pms): 31 | pm.lid = i 32 | 33 | num_numa = 2 34 | init_mat2 = np.zeros((len(vms), len(pms))) 35 | init_mat1 = np.zeros((len(vms), len(pms), num_numa)) 36 | for vm in vms: 37 | if vm.double_numa: 38 | init_mat2[vm.lid, vm.pm.lid] = 1 39 | else: 40 | init_mat1[vm.lid, vm.pm.lid, vm.deploy_numa[0]] = 1 41 | 42 | 43 | start_tick = time.time() 44 | m = Model() 45 | x1 = m.addVars([(vm.lid, j, k) for vm in vms_1numa for j in range(len(pms)) for k in range(num_numa)], vtype=GRB.BINARY) 46 | x2 = m.addVars([(vm.lid, j) for vm in vms_2numa for j in range(len(pms))], vtype=GRB.BINARY) 47 | y = m.addVars(len(pms), num_numa, vtype=GRB.INTEGER) # 每台物理机每个numa的剩余可部署的16core虚拟机数量 48 | z = m.addVars(len(pms), vtype=GRB.BINARY) 49 | 50 | for j, pm in enumerate(pms): 51 | for k in range(num_numa): 52 | m.addLConstr( 53 | quicksum([vm.cpu * x1[vm.lid, j, k] for vm in vms_1numa]) + 54 | quicksum([vm.cpu * x2[vm.lid, j] / num_numa for vm in vms_2numa]) + 16 * y[j, k] <= 55 | pm.cpu / num_numa * z[j] 56 | ) 57 | m.addLConstr( 58 | quicksum([vm.mem * x1[vm.lid, j, k] for vm in vms_1numa]) + 59 | quicksum([vm.mem * x2[vm.lid, j] / num_numa for vm in vms_2numa]) <= pm.mem / num_numa * z[j] 60 | ) 61 | for vm in vms: 62 | if vm.double_numa: 63 | m.addLConstr( 64 | quicksum(x2.select(vm.lid, "*")) == 1 65 | ) 66 | else: 67 | m.addLConstr( 68 | quicksum(x1.select(vm.lid, "*")) == 1 69 | ) 70 | 71 | m.addLConstr( 72 | quicksum([1 - x1[vm.lid, vm.pm.lid, vm.deploy_numa[0]] for vm in vms_1numa]) + 73 | quicksum([1 - x2[vm.lid, vm.pm.lid] for vm in vms_2numa]) <= max_migration_num // pop 74 | ) 75 | 76 | # 1. 腾空主机目标 77 | # m.setObjective(quicksum(z)) 78 | # 2. 碎片治理目标(只算16core的碎片率) 79 | total_free = sum([pm.get_free_cpu_arr().sum() for pm in pms]) 80 | m.setObjective( 81 | (total_free - quicksum(y) * 16) / total_free 82 | ) 83 | end_tick = time.time() 84 | print(f"建模用时={end_tick - start_tick:.2f}s") 85 | 86 | m.optimize() 87 | print(f"建模用时={end_tick - start_tick:.2f}s, 求解用时={m.RunTime:.2f}s") 88 | 89 | assert m.status in [GRB.OPTIMAL] 90 | 91 | migration = 0 92 | for i, vm in enumerate(vms): 93 | for j, pm in enumerate(pms): 94 | if vm.double_numa and x2[vm.lid, pm.lid].x > 0.5 and vm.pm is not pm: 95 | vm.pm.release_a_vm(vm) 96 | pm.add_a_vm(vm) 97 | migration += 1 98 | if not vm.double_numa: 99 | for k in range(num_numa): 100 | if x1[vm.lid, pm.lid, k].x > 0.5 and (vm.pm is not pm or vm.deploy_numa[0] != k): 101 | vm.pm.release_a_vm(vm) 102 | vm.deploy_numa = [k] 103 | pm.add_a_vm(vm) 104 | migration += 1 105 | print(f"总迁移次数={migration}") 106 | 107 | print(f"碎片治理后,集群碎片率为{scheduler.get_fragment_rate() * 100:.2f}%") 108 | return 109 | 110 | 111 | if __name__ == '__main__': 112 | # json_file = "../data/reset_vm_pm_id_clear_big2.json" 113 | json_file = "../data/reset_vm_pm_id_clear2.json" 114 | pop_move(json_file, max_migration_num=100, pop=2) # pop参数是pop算法进行问题分解后的子问题数量。一般而言,pop越大,速度越快,但是解的质量越差。pop=2就是把原问题分解为两个子问题 115 | -------------------------------------------------------------------------------- /check_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import os 12 | import json 13 | 14 | 15 | if __name__ == "__main__": 16 | num_train = 6000 17 | num_dev = 200 18 | num_test = 200 19 | 20 | train_set = set() 21 | for i in range(num_train): 22 | with open(f'./data/flex_vm_dataset/M/train/flex_vm_{i}.json', 'r', encoding='utf-8') as f: 23 | train_set.add(f.read()) 24 | 25 | print('train len: ', len(train_set)) 26 | 27 | dev_set = set() 28 | for i in range(num_dev): 29 | with open(f'./data/flex_vm_dataset/M/dev/flex_vm_{num_train + i}.json', 'r', encoding='utf-8') as f: 30 | dev_set.add(f.read()) 31 | 32 | print('dev len: ', len(dev_set)) 33 | test_set = set() 34 | for i in range(num_test): 35 | with open(f'./data/flex_vm_dataset/M/test/flex_vm_{num_train + num_dev + i}.json', 'r', encoding='utf-8') as f: 36 | test_set.add(f.read()) 37 | 38 | print('test len: ', len(test_set)) 39 | 40 | print('train + test: ', len(train_set.union(test_set))) 41 | print('train + dev: ', len(train_set.union(dev_set))) 42 | print('train + dev + test: ', len(train_set.union(dev_set).union(test_set))) 43 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DRL-based-VM-Rescheduling/a1232df1bce7851c991229fd7a34871f0685f5f0/data/.gitkeep -------------------------------------------------------------------------------- /data/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "vm_candidates_shuffled_extra.pkl", 3 | "target_save_name": "vm_candidates_shuffled_extra", 4 | "raw_data_path": "./data/raw_data/vm/", 5 | "val_portion": 0.1, 6 | "test_portion": 0.1, 7 | "num_vm": 2089, 8 | "num_pm": 279, 9 | "pm_cov": 8, 10 | "vm_cov": 14, 11 | "eval_freq": 1, 12 | "plot_freq": 1, 13 | "plot_start": 0, 14 | "quantiles": [0.5, 0.1, 0.9], 15 | "missing_value": -100, 16 | "output_categorical": false 17 | } 18 | -------------------------------------------------------------------------------- /env_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | from typing import Optional, Union, List 12 | import numpy as np 13 | import multiprocessing as mp 14 | import time 15 | import sys 16 | from enum import Enum 17 | from copy import deepcopy 18 | from gym import logger 19 | from gym.logger import warn 20 | from gym.vector.vector_env import VectorEnv 21 | from gym.error import ( 22 | AlreadyPendingCallError, 23 | NoAsyncCallError, 24 | ClosedEnvironmentError, 25 | CustomSpaceError, 26 | ) 27 | from gym.vector.utils import ( 28 | create_shared_memory, 29 | create_empty_array, 30 | write_to_shared_memory, 31 | read_from_shared_memory, 32 | concatenate, 33 | iterate, 34 | CloudpickleWrapper, 35 | clear_mpi_env_vars, 36 | ) 37 | 38 | __all__ = ["AsyncVectorEnv_Patch"] 39 | 40 | 41 | class AsyncState(Enum): 42 | DEFAULT = "default" 43 | WAITING_RESET = "reset" 44 | WAITING_STEP = "step" 45 | WAITING_CALL = "call" 46 | 47 | 48 | class AsyncVectorEnv_Patch(VectorEnv): 49 | """Vectorized environment that runs multiple environments in parallel. It 50 | uses `multiprocessing`_ processes, and pipes for communication. 51 | 52 | Parameters 53 | ---------- 54 | env_fns : iterable of callable 55 | Functions that create the environments. 56 | 57 | observation_space : :class:`gym.spaces.Space`, optional 58 | Observation space of a single environment. If ``None``, then the 59 | observation space of the first environment is taken. 60 | 61 | action_space : :class:`gym.spaces.Space`, optional 62 | Action space of a single environment. If ``None``, then the action space 63 | of the first environment is taken. 64 | 65 | shared_memory : bool 66 | If ``True``, then the observations from the worker processes are 67 | communicated back through shared variables. This can improve the 68 | efficiency if the observations are large (e.g. images). 69 | 70 | copy : bool 71 | If ``True``, then the :meth:`~AsyncVectorEnv.reset` and 72 | :meth:`~AsyncVectorEnv.step` methods return a copy of the observations. 73 | 74 | context : str, optional 75 | Context for `multiprocessing`_. If ``None``, then the default context is used. 76 | 77 | daemon : bool 78 | If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they 79 | will quit if the head process quits. However, ``daemon=True`` prevents 80 | subprocesses to spawn children, so for some environments you may want 81 | to have it set to ``False``. 82 | 83 | worker : callable, optional 84 | If set, then use that worker in a subprocess instead of a default one. 85 | Can be useful to override some inner vector env logic, for instance, 86 | how resets on done are handled. 87 | 88 | Warning 89 | ------- 90 | :attr:`worker` is an advanced mode option. It provides a high degree of 91 | flexibility and a high chance to shoot yourself in the foot; thus, 92 | if you are writing your own worker, it is recommended to start from the code 93 | for ``_worker`` (or ``_worker_shared_memory``) method, and add changes. 94 | 95 | Raises 96 | ------ 97 | RuntimeError 98 | If the observation space of some sub-environment does not match 99 | :obj:`observation_space` (or, by default, the observation space of 100 | the first sub-environment). 101 | 102 | ValueError 103 | If :obj:`observation_space` is a custom space (i.e. not a default 104 | space in Gym, such as :class:`~gym.spaces.Box`, :class:`~gym.spaces.Discrete`, 105 | or :class:`~gym.spaces.Dict`) and :obj:`shared_memory` is ``True``. 106 | 107 | Example 108 | ------- 109 | 110 | .. code-block:: 111 | 112 | >>> env = gym.vector.AsyncVectorEnv([ 113 | ... lambda: gym.make("Pendulum-v0", g=9.81), 114 | ... lambda: gym.make("Pendulum-v0", g=1.62) 115 | ... ]) 116 | >>> env.reset() 117 | array([[-0.8286432 , 0.5597771 , 0.90249056], 118 | [-0.85009176, 0.5266346 , 0.60007906]], dtype=float32) 119 | """ 120 | 121 | def __init__( 122 | self, 123 | env_fns, 124 | observation_space=None, 125 | action_space=None, 126 | shared_memory=True, 127 | copy=True, 128 | context=None, 129 | daemon=True, 130 | worker=None, 131 | ): 132 | ctx = mp.get_context(context) 133 | self.env_fns = env_fns 134 | self.shared_memory = shared_memory 135 | self.copy = copy 136 | dummy_env = env_fns[0]() 137 | self.metadata = dummy_env.metadata 138 | 139 | if (observation_space is None) or (action_space is None): 140 | observation_space = observation_space or dummy_env.observation_space 141 | action_space = action_space or dummy_env.action_space 142 | dummy_env.close() 143 | del dummy_env 144 | super().__init__( 145 | num_envs=len(env_fns), 146 | observation_space=observation_space, 147 | action_space=action_space, 148 | ) 149 | 150 | if self.shared_memory: 151 | try: 152 | _obs_buffer = create_shared_memory( 153 | self.single_observation_space, n=self.num_envs, ctx=ctx 154 | ) 155 | self.observations = read_from_shared_memory( 156 | self.single_observation_space, _obs_buffer, n=self.num_envs 157 | ) 158 | except CustomSpaceError: 159 | raise ValueError( 160 | "Using `shared_memory=True` in `AsyncVectorEnv` " 161 | "is incompatible with non-standard Gym observation spaces " 162 | "(i.e. custom spaces inheriting from `gym.Space`), and is " 163 | "only compatible with default Gym spaces (e.g. `Box`, " 164 | "`Tuple`, `Dict`) for batching. Set `shared_memory=False` " 165 | "if you use custom observation spaces." 166 | ) 167 | else: 168 | _obs_buffer = None 169 | self.observations = create_empty_array( 170 | self.single_observation_space, n=self.num_envs, fn=np.zeros 171 | ) 172 | 173 | self.parent_pipes, self.processes = [], [] 174 | self.error_queue = ctx.Queue() 175 | target = _worker_shared_memory if self.shared_memory else _worker 176 | target = worker or target 177 | with clear_mpi_env_vars(): 178 | for idx, env_fn in enumerate(self.env_fns): 179 | parent_pipe, child_pipe = ctx.Pipe() 180 | process = ctx.Process( 181 | target=target, 182 | name=f"Worker<{type(self).__name__}>-{idx}", 183 | args=( 184 | idx, 185 | CloudpickleWrapper(env_fn), 186 | child_pipe, 187 | parent_pipe, 188 | _obs_buffer, 189 | self.error_queue, 190 | ), 191 | ) 192 | 193 | self.parent_pipes.append(parent_pipe) 194 | self.processes.append(process) 195 | 196 | process.daemon = daemon 197 | process.start() 198 | child_pipe.close() 199 | 200 | self._state = AsyncState.DEFAULT 201 | self._check_spaces() 202 | 203 | def seed(self, seed=None): 204 | super().seed(seed=seed) 205 | self._assert_is_running() 206 | if seed is None: 207 | seed = [None for _ in range(self.num_envs)] 208 | if isinstance(seed, int): 209 | seed = [seed + i for i in range(self.num_envs)] 210 | assert len(seed) == self.num_envs 211 | 212 | if self._state != AsyncState.DEFAULT: 213 | raise AlreadyPendingCallError( 214 | f"Calling `seed` while waiting for a pending call to `{self._state.value}` to complete.", 215 | self._state.value, 216 | ) 217 | 218 | for pipe, seed in zip(self.parent_pipes, seed): 219 | pipe.send(("seed", seed)) 220 | _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) 221 | self._raise_if_errors(successes) 222 | 223 | def reset_async( 224 | self, 225 | seed: Optional[Union[int, List[int]]] = None, 226 | return_info: bool = False, 227 | options: Optional[dict] = None, 228 | ): 229 | """Send the calls to :obj:`reset` to each sub-environment. 230 | 231 | Raises 232 | ------ 233 | ClosedEnvironmentError 234 | If the environment was closed (if :meth:`close` was previously called). 235 | 236 | AlreadyPendingCallError 237 | If the environment is already waiting for a pending call to another 238 | method (e.g. :meth:`step_async`). This can be caused by two consecutive 239 | calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in 240 | between. 241 | """ 242 | self._assert_is_running() 243 | 244 | if seed is None: 245 | seed = [None for _ in range(self.num_envs)] 246 | if isinstance(seed, int): 247 | seed = [seed + i for i in range(self.num_envs)] 248 | assert len(seed) == self.num_envs 249 | 250 | if self._state != AsyncState.DEFAULT: 251 | raise AlreadyPendingCallError( 252 | f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete", 253 | self._state.value, 254 | ) 255 | 256 | for pipe, single_seed in zip(self.parent_pipes, seed): 257 | single_kwargs = {} 258 | if single_seed is not None: 259 | single_kwargs["seed"] = single_seed 260 | if return_info: 261 | single_kwargs["return_info"] = return_info 262 | if options is not None: 263 | single_kwargs["options"] = options 264 | 265 | pipe.send(("reset", single_kwargs)) 266 | self._state = AsyncState.WAITING_RESET 267 | 268 | def reset_wait( 269 | self, 270 | timeout=None, 271 | seed: Optional[int] = None, 272 | return_info: bool = False, 273 | options: Optional[dict] = None, 274 | ): 275 | """ 276 | Parameters 277 | ---------- 278 | timeout : int or float, optional 279 | Number of seconds before the call to `reset_wait` times out. If 280 | `None`, the call to `reset_wait` never times out. 281 | seed: ignored 282 | options: ignored 283 | 284 | Returns 285 | ------- 286 | element of :attr:`~VectorEnv.observation_space` 287 | A batch of observations from the vectorized environment. 288 | infos : list of dicts containing metadata 289 | 290 | Raises 291 | ------ 292 | ClosedEnvironmentError 293 | If the environment was closed (if :meth:`close` was previously called). 294 | 295 | NoAsyncCallError 296 | If :meth:`reset_wait` was called without any prior call to 297 | :meth:`reset_async`. 298 | 299 | TimeoutError 300 | If :meth:`reset_wait` timed out. 301 | """ 302 | self._assert_is_running() 303 | if self._state != AsyncState.WAITING_RESET: 304 | raise NoAsyncCallError( 305 | "Calling `reset_wait` without any prior " "call to `reset_async`.", 306 | AsyncState.WAITING_RESET.value, 307 | ) 308 | 309 | if not self._poll(timeout): 310 | self._state = AsyncState.DEFAULT 311 | raise mp.TimeoutError( 312 | f"The call to `reset_wait` has timed out after {timeout} second(s)." 313 | ) 314 | 315 | results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) 316 | self._raise_if_errors(successes) 317 | self._state = AsyncState.DEFAULT 318 | 319 | if return_info: 320 | results, infos = zip(*results) 321 | infos = list(infos) 322 | 323 | if not self.shared_memory: 324 | self.observations = concatenate( 325 | self.single_observation_space, results, self.observations 326 | ) 327 | 328 | return ( 329 | deepcopy(self.observations) if self.copy else self.observations 330 | ), infos 331 | else: 332 | if not self.shared_memory: 333 | self.observations = concatenate( 334 | self.single_observation_space, results, self.observations 335 | ) 336 | 337 | return deepcopy(self.observations) if self.copy else self.observations 338 | 339 | def step_async(self, actions): 340 | """Send the calls to :obj:`step` to each sub-environment. 341 | 342 | Parameters 343 | ---------- 344 | actions : element of :attr:`~VectorEnv.action_space` 345 | Batch of actions. 346 | 347 | Raises 348 | ------ 349 | ClosedEnvironmentError 350 | If the environment was closed (if :meth:`close` was previously called). 351 | 352 | AlreadyPendingCallError 353 | If the environment is already waiting for a pending call to another 354 | method (e.g. :meth:`reset_async`). This can be caused by two consecutive 355 | calls to :meth:`step_async`, with no call to :meth:`step_wait` in 356 | between. 357 | """ 358 | self._assert_is_running() 359 | if self._state != AsyncState.DEFAULT: 360 | raise AlreadyPendingCallError( 361 | f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.", 362 | self._state.value, 363 | ) 364 | 365 | actions = iterate(self.action_space, actions) 366 | for pipe, action in zip(self.parent_pipes, actions): 367 | pipe.send(("step", action)) 368 | self._state = AsyncState.WAITING_STEP 369 | 370 | def step_wait(self, timeout=None): 371 | """Wait for the calls to :obj:`step` in each sub-environment to finish. 372 | 373 | Parameters 374 | ---------- 375 | timeout : int or float, optional 376 | Number of seconds before the call to :meth:`step_wait` times out. If 377 | ``None``, the call to :meth:`step_wait` never times out. 378 | 379 | Returns 380 | ------- 381 | observations : element of :attr:`~VectorEnv.observation_space` 382 | A batch of observations from the vectorized environment. 383 | 384 | rewards : :obj:`np.ndarray`, dtype :obj:`np.float_` 385 | A vector of rewards from the vectorized environment. 386 | 387 | dones : :obj:`np.ndarray`, dtype :obj:`np.bool_` 388 | A vector whose entries indicate whether the episode has ended. 389 | 390 | infos : list of dict 391 | A list of auxiliary diagnostic information dicts from sub-environments. 392 | 393 | Raises 394 | ------ 395 | ClosedEnvironmentError 396 | If the environment was closed (if :meth:`close` was previously called). 397 | 398 | NoAsyncCallError 399 | If :meth:`step_wait` was called without any prior call to 400 | :meth:`step_async`. 401 | 402 | TimeoutError 403 | If :meth:`step_wait` timed out. 404 | """ 405 | self._assert_is_running() 406 | if self._state != AsyncState.WAITING_STEP: 407 | raise NoAsyncCallError( 408 | "Calling `step_wait` without any prior call " "to `step_async`.", 409 | AsyncState.WAITING_STEP.value, 410 | ) 411 | 412 | if not self._poll(timeout): 413 | self._state = AsyncState.DEFAULT 414 | raise mp.TimeoutError( 415 | f"The call to `step_wait` has timed out after {timeout} second(s)." 416 | ) 417 | 418 | results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) 419 | self._raise_if_errors(successes) 420 | self._state = AsyncState.DEFAULT 421 | observations_list, rewards, dones, infos = zip(*results) 422 | 423 | if not self.shared_memory: 424 | self.observations = concatenate( 425 | self.single_observation_space, 426 | observations_list, 427 | self.observations, 428 | ) 429 | 430 | return ( 431 | deepcopy(self.observations) if self.copy else self.observations, 432 | np.array(rewards), 433 | np.array(dones, dtype=np.bool_), 434 | infos, 435 | ) 436 | 437 | def call_parse(self, name, *args, **kwargs): 438 | self.call_async_parse(name, *args, **kwargs) 439 | return self.call_wait() 440 | 441 | def call_async_parse(self, name, *args, **kwargs): 442 | self._assert_is_running() 443 | if self._state != AsyncState.DEFAULT: 444 | raise AlreadyPendingCallError( 445 | "Calling `call_async` while waiting " 446 | f"for a pending call to `{self._state.value}` to complete.", 447 | self._state.value, 448 | ) 449 | kwargs_list = [dict(zip(kwargs, t)) for t in zip(*kwargs.values())] 450 | 451 | for pipe, kwarg in zip(self.parent_pipes, kwargs_list): 452 | pipe.send(("_call", (name, args, kwarg))) 453 | self._state = AsyncState.WAITING_CALL 454 | 455 | def call_async(self, name, *args, **kwargs): 456 | """ 457 | Parameters 458 | ---------- 459 | name : string 460 | Name of the method or property to call. 461 | 462 | *args 463 | Arguments to apply to the method call. 464 | 465 | **kwargs 466 | Keywoard arguments to apply to the method call. 467 | """ 468 | self._assert_is_running() 469 | if self._state != AsyncState.DEFAULT: 470 | raise AlreadyPendingCallError( 471 | "Calling `call_async` while waiting " 472 | f"for a pending call to `{self._state.value}` to complete.", 473 | self._state.value, 474 | ) 475 | 476 | for pipe in self.parent_pipes: 477 | pipe.send(("_call", (name, args, kwargs))) 478 | self._state = AsyncState.WAITING_CALL 479 | 480 | def call_wait(self, timeout=None): 481 | """ 482 | Parameters 483 | ---------- 484 | timeout : int or float, optional 485 | Number of seconds before the call to `step_wait` times out. If 486 | `None` (default), the call to `step_wait` never times out. 487 | 488 | Returns 489 | ------- 490 | results : list 491 | List of the results of the individual calls to the method or 492 | property for each environment. 493 | """ 494 | self._assert_is_running() 495 | if self._state != AsyncState.WAITING_CALL: 496 | raise NoAsyncCallError( 497 | "Calling `call_wait` without any prior call to `call_async`.", 498 | AsyncState.WAITING_CALL.value, 499 | ) 500 | 501 | if not self._poll(timeout): 502 | self._state = AsyncState.DEFAULT 503 | raise mp.TimeoutError( 504 | f"The call to `call_wait` has timed out after {timeout} second(s)." 505 | ) 506 | 507 | results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) 508 | self._raise_if_errors(successes) 509 | self._state = AsyncState.DEFAULT 510 | 511 | return results 512 | 513 | def set_attr(self, name, values): 514 | """ 515 | Parameters 516 | ---------- 517 | name : string 518 | Name of the property to be set in each individual environment. 519 | 520 | values : list, tuple, or object 521 | Values of the property to be set to. If `values` is a list or 522 | tuple, then it corresponds to the values for each individual 523 | environment, otherwise a single value is set for all environments. 524 | """ 525 | self._assert_is_running() 526 | if not isinstance(values, (list, tuple)): 527 | values = [values for _ in range(self.num_envs)] 528 | if len(values) != self.num_envs: 529 | raise ValueError( 530 | "Values must be a list or tuple with length equal to the " 531 | f"number of environments. Got `{len(values)}` values for " 532 | f"{self.num_envs} environments." 533 | ) 534 | 535 | if self._state != AsyncState.DEFAULT: 536 | raise AlreadyPendingCallError( 537 | "Calling `set_attr` while waiting " 538 | f"for a pending call to `{self._state.value}` to complete.", 539 | self._state.value, 540 | ) 541 | 542 | for pipe, value in zip(self.parent_pipes, values): 543 | pipe.send(("_setattr", (name, value))) 544 | _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) 545 | self._raise_if_errors(successes) 546 | 547 | def close_extras(self, timeout=None, terminate=False): 548 | """Close the environments & clean up the extra resources 549 | (processes and pipes). 550 | 551 | Parameters 552 | ---------- 553 | timeout : int or float, optional 554 | Number of seconds before the call to :meth:`close` times out. If ``None``, 555 | the call to :meth:`close` never times out. If the call to :meth:`close` 556 | times out, then all processes are terminated. 557 | 558 | terminate : bool 559 | If ``True``, then the :meth:`close` operation is forced and all processes 560 | are terminated. 561 | 562 | Raises 563 | ------ 564 | TimeoutError 565 | If :meth:`close` timed out. 566 | """ 567 | timeout = 0 if terminate else timeout 568 | try: 569 | if self._state != AsyncState.DEFAULT: 570 | logger.warn( 571 | f"Calling `close` while waiting for a pending call to `{self._state.value}` to complete." 572 | ) 573 | function = getattr(self, f"{self._state.value}_wait") 574 | function(timeout) 575 | except mp.TimeoutError: 576 | terminate = True 577 | 578 | if terminate: 579 | for process in self.processes: 580 | if process.is_alive(): 581 | process.terminate() 582 | else: 583 | for pipe in self.parent_pipes: 584 | if (pipe is not None) and (not pipe.closed): 585 | pipe.send(("close", None)) 586 | for pipe in self.parent_pipes: 587 | if (pipe is not None) and (not pipe.closed): 588 | pipe.recv() 589 | 590 | for pipe in self.parent_pipes: 591 | if pipe is not None: 592 | pipe.close() 593 | for process in self.processes: 594 | process.join() 595 | 596 | def _poll(self, timeout=None): 597 | self._assert_is_running() 598 | if timeout is None: 599 | return True 600 | end_time = time.perf_counter() + timeout 601 | delta = None 602 | for pipe in self.parent_pipes: 603 | delta = max(end_time - time.perf_counter(), 0) 604 | if pipe is None: 605 | return False 606 | if pipe.closed or (not pipe.poll(delta)): 607 | return False 608 | return True 609 | 610 | def _check_spaces(self): 611 | self._assert_is_running() 612 | spaces = (self.single_observation_space, self.single_action_space) 613 | for pipe in self.parent_pipes: 614 | pipe.send(("_check_spaces", spaces)) 615 | results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) 616 | self._raise_if_errors(successes) 617 | same_observation_spaces, same_action_spaces = zip(*results) 618 | if not all(same_observation_spaces): 619 | raise RuntimeError( 620 | "Some environments have an observation space different from " 621 | f"`{self.single_observation_space}`. In order to batch observations, " 622 | "the observation spaces from all environments must be equal." 623 | ) 624 | if not all(same_action_spaces): 625 | raise RuntimeError( 626 | "Some environments have an action space different from " 627 | f"`{self.single_action_space}`. In order to batch actions, the " 628 | "action spaces from all environments must be equal." 629 | ) 630 | 631 | def _assert_is_running(self): 632 | if self.closed: 633 | raise ClosedEnvironmentError( 634 | f"Trying to operate on `{type(self).__name__}`, after a call to `close()`." 635 | ) 636 | 637 | def _raise_if_errors(self, successes): 638 | if all(successes): 639 | return 640 | 641 | num_errors = self.num_envs - sum(successes) 642 | assert num_errors > 0 643 | for _ in range(num_errors): 644 | index, exctype, value = self.error_queue.get() 645 | logger.error( 646 | f"Received the following error from Worker-{index}: {exctype.__name__}: {value}" 647 | ) 648 | logger.error(f"Shutting down Worker-{index}.") 649 | self.parent_pipes[index].close() 650 | self.parent_pipes[index] = None 651 | 652 | logger.error("Raising the last exception back to the main process.") 653 | raise exctype(value) 654 | 655 | def __del__(self): 656 | if not getattr(self, "closed", True): 657 | self.close(terminate=True) 658 | 659 | 660 | def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): 661 | assert shared_memory is None 662 | env = env_fn() 663 | parent_pipe.close() 664 | try: 665 | while True: 666 | command, data = pipe.recv() 667 | if command == "reset": 668 | if "return_info" in data and data["return_info"] == True: 669 | observation, info = env.reset(**data) 670 | pipe.send(((observation, info), True)) 671 | else: 672 | observation = env.reset(**data) 673 | pipe.send((observation, True)) 674 | 675 | elif command == "step": 676 | observation, reward, done, info = env.step(data) 677 | if done: 678 | info["terminal_observation"] = observation 679 | observation = env.reset() 680 | pipe.send(((observation, reward, done, info), True)) 681 | elif command == "seed": 682 | env.seed(data) 683 | pipe.send((None, True)) 684 | elif command == "close": 685 | pipe.send((None, True)) 686 | break 687 | elif command == "_call": 688 | name, args, kwargs = data 689 | if name in ["reset", "step", "seed", "close"]: 690 | raise ValueError( 691 | f"Trying to call function `{name}` with " 692 | f"`_call`. Use `{name}` directly instead." 693 | ) 694 | function = getattr(env, name) 695 | if callable(function): 696 | pipe.send((function(*args, **kwargs), True)) 697 | else: 698 | pipe.send((function, True)) 699 | elif command == "_setattr": 700 | name, value = data 701 | setattr(env, name, value) 702 | pipe.send((None, True)) 703 | elif command == "_check_spaces": 704 | pipe.send( 705 | ( 706 | (data[0] == env.observation_space, data[1] == env.action_space), 707 | True, 708 | ) 709 | ) 710 | else: 711 | raise RuntimeError( 712 | f"Received unknown command `{command}`. Must " 713 | "be one of {`reset`, `step`, `seed`, `close`, `_call`, " 714 | "`_setattr`, `_check_spaces`}." 715 | ) 716 | except (KeyboardInterrupt, Exception): 717 | error_queue.put((index,) + sys.exc_info()[:2]) 718 | pipe.send((None, False)) 719 | finally: 720 | env.close() 721 | 722 | 723 | def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): 724 | assert shared_memory is not None 725 | env = env_fn() 726 | observation_space = env.observation_space 727 | parent_pipe.close() 728 | try: 729 | while True: 730 | command, data = pipe.recv() 731 | if command == "reset": 732 | if "return_info" in data and data["return_info"] == True: 733 | observation, info = env.reset(**data) 734 | write_to_shared_memory( 735 | observation_space, index, observation, shared_memory 736 | ) 737 | pipe.send(((None, info), True)) 738 | else: 739 | observation = env.reset(**data) 740 | write_to_shared_memory( 741 | observation_space, index, observation, shared_memory 742 | ) 743 | pipe.send((None, True)) 744 | elif command == "step": 745 | observation, reward, done, info = env.step(data) 746 | if done: 747 | info["terminal_observation"] = observation 748 | observation = env.reset() 749 | write_to_shared_memory( 750 | observation_space, index, observation, shared_memory 751 | ) 752 | pipe.send(((None, reward, done, info), True)) 753 | elif command == "seed": 754 | env.seed(data) 755 | pipe.send((None, True)) 756 | elif command == "close": 757 | pipe.send((None, True)) 758 | break 759 | elif command == "_call": 760 | name, args, kwargs = data 761 | if name in ["reset", "step", "seed", "close"]: 762 | raise ValueError( 763 | f"Trying to call function `{name}` with " 764 | f"`_call`. Use `{name}` directly instead." 765 | ) 766 | function = getattr(env, name) 767 | if callable(function): 768 | pipe.send((function(*args, **kwargs), True)) 769 | else: 770 | pipe.send((function, True)) 771 | elif command == "_setattr": 772 | name, value = data 773 | setattr(env, name, value) 774 | pipe.send((None, True)) 775 | elif command == "_check_spaces": 776 | pipe.send( 777 | ((data[0] == observation_space, data[1] == env.action_space), True) 778 | ) 779 | else: 780 | raise RuntimeError( 781 | f"Received unknown command `{command}`. Must " 782 | "be one of {`reset`, `step`, `seed`, `close`, `_call`, " 783 | "`_setattr`, `_check_spaces`}." 784 | ) 785 | except (KeyboardInterrupt, Exception): 786 | error_queue.put((index,) + sys.exc_info()[:2]) 787 | pipe.send((None, False)) 788 | finally: 789 | env.close() 790 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | """ 12 | VM: cpu, cpu, mem, mem, cpu % 16, cpu % 16 (0 is full, 1 is empty) 13 | PM: cpu, cpu, mem, mem, fragment_rate, cpu % 16, fragment_rate, cpu % 16 14 | cpu % 16 = round(normalized_cpu * 88) % 16 / 16 15 | fragment_rate = round(normalized_cpu * 88) % 16 / round(normalized_cpu * 88) 16 | To rescale memory, mem * 368776 17 | """ 18 | 19 | import argparse 20 | import os 21 | import random 22 | import time 23 | from distutils.util import strtobool 24 | 25 | import pandas as pd 26 | import wandb 27 | 28 | import gym 29 | import numpy as np 30 | import torch 31 | import torch.nn as nn 32 | import torch.optim as optim 33 | from torch.distributions.categorical import Categorical 34 | from tqdm import trange 35 | 36 | import gym_reschdule_combination.envs.vm_rescheduler_env 37 | 38 | import models 39 | import utils 40 | from env_patch import AsyncVectorEnv_Patch 41 | from main import make_env 42 | 43 | 44 | def parse_args(): 45 | # fmt: off 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--model", type=str, default="attn", help="model architecture") 48 | parser.add_argument("--restore-name", type=str, required=True, help="restore experiment name") 49 | parser.add_argument("--restore-file-name", type=str, required=True, help="restore file name") 50 | parser.add_argument("--pretrain", action='store_true', 51 | help="if toggled, we will restore pretrained weights for vm selection") 52 | parser.add_argument("--gym-id", type=str, default="generalizer-v1", 53 | help="the id of the gym environment") 54 | parser.add_argument("--vm-data-size", type=str, default="M", choices=["M", "L"], 55 | help="size of the dataset") 56 | parser.add_argument("--max-steps", type=int, default=50, help="maximum number of redeploy steps") 57 | parser.add_argument("--learning-rate", type=float, default=2.5e-4, 58 | help="the learning rate of the optimizer") 59 | parser.add_argument("--seed", type=int, default=1, 60 | help="seed of the experiment") 61 | parser.add_argument("--total-timesteps", type=int, default=2000000, 62 | help="total timesteps of the experiments") 63 | parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 64 | help="if toggled, `torch.backends.cudnn.deterministic=False`") 65 | parser.add_argument("--normalize", action='store_true', 66 | help="if toggled, we will normalize the input features") 67 | parser.add_argument("--track", action='store_true', 68 | help="if toggled, this experiment will be tracked with Weights and Biases") 69 | parser.add_argument("--debug", action='store_true', 70 | help="if toggled, this experiment will save run details") 71 | 72 | # Algorithm specific arguments 73 | parser.add_argument("--num-envs", type=int, default=8, 74 | help="the number of parallel game environments") 75 | parser.add_argument("--num-steps", type=int, default=256, 76 | help="the number of steps to run in each environment per policy rollout") 77 | parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 78 | help="Toggle learning rate annealing for policy and value networks") 79 | parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 80 | help="Use GAE for advantage computation") 81 | parser.add_argument("--gamma", type=float, default=0.99, 82 | help="the discount factor gamma") 83 | parser.add_argument("--gae-lambda", type=float, default=0.95, 84 | help="the lambda for the general advantage estimation") 85 | parser.add_argument("--num-minibatches", type=int, default=4, 86 | help="the number of mini-batches") 87 | parser.add_argument("--accum-iter", type=int, default=4, 88 | help="the number of mini-batches") 89 | parser.add_argument("--update-epochs", type=int, default=4, 90 | help="the K epochs to update the policy") 91 | parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 92 | help="Toggles advantages normalization") 93 | parser.add_argument("--clip-coef", type=float, default=0.1, 94 | help="the surrogate clipping coefficient") 95 | parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 96 | help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") 97 | parser.add_argument("--ent-coef", type=float, default=0.005, # 0.01 98 | help="coefficient of the entropy") 99 | parser.add_argument("--vf-coef", type=float, default=1e-2, # 1e-4 100 | help="coefficient of the value function") 101 | parser.add_argument("--max-grad-norm", type=float, default=0.5, 102 | help="the maximum norm for the gradient clipping") 103 | parser.add_argument("--target-kl", type=float, default=None, 104 | help="the target KL divergence threshold") 105 | args = parser.parse_args() 106 | args.batch_size = int(args.num_envs * args.num_steps) 107 | args.minibatch_size = int(args.batch_size // (args.num_minibatches * args.accum_iter)) 108 | return args 109 | 110 | 111 | class CategoricalMasked(Categorical): 112 | def __init__(self, logits=None, probs=None, masks=None): 113 | if masks is None or torch.sum(masks) == 0: 114 | self.masks = None 115 | super(CategoricalMasked, self).__init__(logits=logits) 116 | else: 117 | self.masks = masks 118 | if logits is not None: 119 | logits = torch.where(self.masks, torch.tensor(-1e8, device=logits.device), logits) 120 | super(CategoricalMasked, self).__init__(logits=logits) 121 | else: 122 | probs = torch.where(self.masks, torch.tensor(0.0, device=probs.device), probs) 123 | small_val_mask = torch.sum(probs, dim=1) < 1e-4 124 | probs[small_val_mask] = torch.where(self.masks[small_val_mask], torch.tensor(0.0, device=probs.device), 125 | torch.tensor(1.0, device=probs.device)) 126 | super(CategoricalMasked, self).__init__(probs=probs) 127 | 128 | def entropy(self): 129 | if self.masks is None: 130 | return super(CategoricalMasked, self).entropy() 131 | p_log_p = self.logits * self.probs 132 | p_log_p = torch.where(self.masks, torch.tensor(0.0, device=p_log_p.device), p_log_p) 133 | return -p_log_p.sum(-1) 134 | 135 | 136 | """ 137 | class Agent(nn.Module): 138 | def __init__(self, vm_net, params, args_model): 139 | super(Agent, self).__init__() 140 | 141 | self.vm_net = vm_net 142 | self.device = params.device 143 | self.model = args_model 144 | 145 | def get_value(self, obs_info_pm, obs_info_all_vm, obs_info_num_steps, obs_info_num_vms): 146 | num_vms_mask = torch.arange(obs_info_all_vm.shape[1], 147 | device=obs_info_all_vm.device)[None, :] >= obs_info_num_vms[:, None] 148 | return self.vm_net(obs_info_all_vm, obs_info_num_steps, obs_info_pm, num_vms_mask)[1] 149 | 150 | def get_action_and_value(self, envs, obs_info_pm, obs_info_all_vm, obs_info_num_steps, obs_info_num_vms, 151 | pm_mask=None, selected_vm=None, selected_pm=None): 152 | if pm_mask is None: 153 | assert selected_vm is None and selected_pm is None, \ 154 | 'action must be None when action_mask is not given!' 155 | else: 156 | assert selected_vm is not None and selected_pm is not None, \ 157 | 'action must be given when action_mask is given!' 158 | 159 | num_vms_mask = torch.arange(obs_info_all_vm.shape[1], device=self.device)[None, :] >= obs_info_num_vms[:, None] 160 | 161 | b_sz = obs_info_pm.shape[0] 162 | # obs_info_all_vm: torch.Size([8, 2089, 14]) 163 | # obs_info_pm: torch.Size([8, 279, 8]) 164 | if self.model == "attn": 165 | vm_logits, critic_score, attn_score = self.vm_net(obs_info_all_vm, obs_info_num_steps, obs_info_pm, 166 | num_vms_mask, return_attns=True) 167 | else: 168 | raise ValueError(f'self.model={self.model} is not implemented') 169 | # vm_pred: torch.Size([8, 2089]) 170 | # critic_score: torch.Size([8, 1]) 171 | vm_cat = CategoricalMasked(logits=vm_logits, masks=num_vms_mask) 172 | if selected_vm is None: 173 | selected_vm = vm_cat.sample() 174 | vm_log_prob = vm_cat.log_prob(selected_vm) 175 | # selected_vm: torch.Size([8]) 176 | # vm_log_prob: torch.Size([8]) 177 | # entropy: torch.Size([8]) 178 | 179 | if pm_mask is None: 180 | pm_mask = torch.tensor(np.array(envs.call_parse('get_pm_mask', vm_id=selected_vm.cpu().tolist())), 181 | dtype=torch.bool, device=self.device) # pm_mask: torch.Size([8, 279]) 182 | 183 | # obs_info_all_vm: torch.Size([8, 2089, 14]) 184 | pm_probs = attn_score[-1][torch.arange(b_sz, device=self.device), selected_vm][:, 1:] 185 | # pm_logits: torch.Size([8, 279]) 186 | pm_cat = CategoricalMasked(probs=pm_probs, masks=pm_mask) 187 | if selected_pm is None: 188 | pm_probs = torch.where(pm_mask, torch.tensor(0.0, device=pm_probs.device), pm_probs) 189 | selected_pm = torch.argmax(pm_probs, dim=1) 190 | # print('torch max: ', torch.amax(pm_probs, dim=1)) 191 | # selected_pm = pm_cat.sample() 192 | # selected_pm: torch.Size([8]) 193 | pm_log_prob = pm_cat.log_prob(selected_pm) 194 | # pm_log_prob: torch.Size([8]) 195 | log_prob = vm_log_prob + pm_log_prob 196 | entropy = vm_cat.entropy() + pm_cat.entropy() 197 | 198 | return selected_vm, selected_pm, log_prob, entropy, critic_score, pm_mask 199 | """ 200 | 201 | 202 | class Agent(nn.Module): 203 | def __init__(self, vm_net, pm_net, params, args_model): 204 | super(Agent, self).__init__() 205 | 206 | self.vm_net = vm_net 207 | self.pm_net = pm_net 208 | self.device = params.device 209 | self.model = args_model 210 | self.num_vm = params.num_vm 211 | 212 | def get_value(self, obs_info_pm, obs_info_all_vm, obs_info_num_steps, obs_info_num_vms): 213 | num_vms_mask = torch.arange(self.num_vm, device=obs_info_all_vm.device)[None, :] >= obs_info_num_vms[:, None] 214 | if self.model == "attn": 215 | return self.vm_net(obs_info_all_vm, obs_info_num_steps, obs_info_pm, num_vms_mask)[1] 216 | elif self.model == "mlp": 217 | return self.vm_net(obs_info_all_vm, obs_info_pm)[1] 218 | 219 | def get_action_and_value(self, envs, obs_info_pm, obs_info_all_vm, obs_info_num_steps, obs_info_num_vms, 220 | pm_mask=None, selected_vm=None, selected_pm=None): 221 | if pm_mask is None: 222 | assert selected_vm is None and selected_pm is None, \ 223 | 'action must be None when action_mask is not given!' 224 | else: 225 | assert selected_vm is not None and selected_pm is not None, \ 226 | 'action must be given when action_mask is given!' 227 | num_vms_mask = torch.arange(self.num_vm, device=self.device)[None, :] >= obs_info_num_vms[:, None] 228 | 229 | b_sz = obs_info_pm.shape[0] 230 | # obs_info_all_vm: torch.Size([8, 2089, 14]) 231 | # obs_info_pm: torch.Size([8, 279, 8]) 232 | if self.model == "attn": 233 | vm_logits, critic_score = self.vm_net(obs_info_all_vm, obs_info_num_steps, obs_info_pm, num_vms_mask) 234 | elif self.model == "mlp": 235 | vm_logits, critic_score = self.vm_net(obs_info_all_vm, obs_info_pm) 236 | else: 237 | raise ValueError(f'self.model={self.model} is not implemented') 238 | # vm_pred: torch.Size([8, 2089]) 239 | # critic_score: torch.Size([8, 1]) 240 | vm_cat = CategoricalMasked(logits=vm_logits, masks=num_vms_mask) 241 | if selected_vm is None: 242 | selected_vm = vm_cat.sample() 243 | vm_log_prob = vm_cat.log_prob(selected_vm) 244 | # selected_vm: torch.Size([8]) 245 | # vm_log_prob: torch.Size([8]) 246 | # entropy: torch.Size([8]) 247 | 248 | if pm_mask is None: 249 | pm_mask = torch.tensor(np.array(envs.call_parse('get_pm_mask', vm_id=selected_vm.cpu().tolist())), 250 | dtype=torch.bool, device=self.device) # pm_mask: torch.Size([8, 279]) 251 | 252 | # obs_info_all_vm: torch.Size([8, 2089, 14]) 253 | if self.model == "attn": 254 | pm_logits = self.pm_net(obs_info_all_vm[torch.arange(b_sz), selected_vm].unsqueeze(1), obs_info_num_steps, 255 | obs_info_pm) # b_sz 256 | elif self.model == "mlp": 257 | pm_logits = self.pm_net(obs_info_all_vm[torch.arange(b_sz), selected_vm].unsqueeze(1), obs_info_pm) # b_sz 258 | else: 259 | raise ValueError(f'self.model={self.model} is not implemented') 260 | # pm_logits: torch.Size([8, 279]) 261 | pm_cat = CategoricalMasked(logits=pm_logits, masks=pm_mask) 262 | # print('pm max prob: ', torch.amax(pm_cat.probs, dim=1)) 263 | if selected_pm is None: 264 | selected_pm = pm_cat.sample() # selected_pm: torch.Size([8]) 265 | pm_log_prob = pm_cat.log_prob(selected_pm) # pm_log_prob: torch.Size([8]) 266 | log_prob = vm_log_prob + pm_log_prob 267 | entropy = vm_cat.entropy() + pm_cat.entropy() 268 | 269 | return selected_vm, selected_pm, log_prob, entropy, critic_score, pm_mask 270 | 271 | 272 | if __name__ == "__main__": 273 | args = parse_args() 274 | num_train = 6000 275 | num_dev = 200 276 | num_test = 200 277 | num_envs = args.num_envs 278 | num_steps = args.num_steps 279 | run_name = f'{args.restore_name}' 280 | np.set_printoptions(precision=4) 281 | np.set_printoptions(suppress=True) 282 | torch.backends.cudnn.benchmark = True 283 | torch.set_default_dtype(torch.float32) 284 | print('vf_coef: ', args.vf_coef) 285 | 286 | # TRY NOT TO MODIFY: seeding 287 | random.seed(args.seed) 288 | np.random.seed(args.seed) 289 | torch.manual_seed(args.seed) 290 | torch.backends.cudnn.deterministic = args.torch_deterministic 291 | 292 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 293 | 294 | # env setup 295 | envs = AsyncVectorEnv_Patch( 296 | [make_env(args.gym_id, args.seed + i, args.vm_data_size, args.max_steps, 297 | args.normalize) for i in range(num_envs)] 298 | ) 299 | 300 | # assert isinstance(envs.single_action_space, gym.spaces.MultiDiscrete), \ 301 | # "only MultiDiscrete action space is supported" 302 | 303 | params = utils.Params(f'./experiments/pretrain/{args.model}/params.json') 304 | params.update('./data/params.json') 305 | params.device = device 306 | params.batch_size = args.num_envs 307 | params.accum_iter = args.accum_iter 308 | 309 | print('clip_vloss: ', args.clip_vloss) 310 | 311 | # input the vm candidate model 312 | if args.model == 'attn': 313 | # vm_cand_model = models.VM_Attn_Wrapper(params, args.pretrain).model 314 | vm_cand_model = models.VM_Attn_Wrapper(params, args.pretrain).model 315 | pm_cand_model = models.PM_Attn_Wrapper(params).model 316 | elif args.model == 'mlp': 317 | vm_cand_model = models.VM_MLP_Wrapper(params, args.pretrain).model 318 | pm_cand_model = models.PM_MLP_Wrapper(params).model 319 | else: 320 | raise ValueError(f'args.model = {args.model} is not defined!') 321 | 322 | agent = Agent(vm_cand_model, pm_cand_model, params, args.model) 323 | optim = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) 324 | agent.eval() 325 | global_step = utils.load_checkpoint(args.restore_name, args.restore_file_name, agent) 326 | print(f"- Restored file (global step {global_step}) " 327 | f"from {os.path.join(args.restore_name, args.restore_file_name + '.pth.tar')}") 328 | 329 | if args.track: 330 | wandb.watch(agent, log_freq=100) 331 | 332 | # ALGO Logic: Storage setup 333 | obs_vm = torch.zeros(num_steps, args.num_envs, params.num_vm, params.vm_cov, device=device) 334 | obs_pm = torch.zeros(num_steps, args.num_envs, params.num_pm, params.pm_cov, device=device) 335 | obs_num_steps = torch.zeros(num_steps, args.num_envs, 1, 1, device=device) 336 | obs_num_vms = torch.zeros(num_steps, args.num_envs, dtype=torch.int32, device=device) 337 | vm_actions = torch.zeros(num_steps, args.num_envs, device=device) 338 | pm_actions = torch.zeros(num_steps, args.num_envs, device=device) 339 | logprobs = torch.zeros(num_steps, args.num_envs, device=device) 340 | rewards = torch.zeros(num_steps, args.num_envs, device=device) 341 | dones = torch.zeros(num_steps, args.num_envs, device=device) 342 | values = torch.zeros(num_steps, args.num_envs, device=device) 343 | # envs.single_action_space.nvec: [2089, 279] (#vm, #pm) 344 | action_masks = torch.zeros(num_steps, args.num_envs, envs.single_action_space.nvec[1], dtype=torch.bool, 345 | device=device) 346 | 347 | # TRY NOT TO MODIFY: start the game 348 | if args.debug: 349 | col_names = ['step'] 350 | for i in range(params.num_vm): 351 | for j in range(params.vm_cov): 352 | col_names.append(f'vm_{i}_cov_{j}') 353 | 354 | for i in range(params.num_pm): 355 | for j in range(params.pm_cov): 356 | col_names.append(f'pm_{i}_cov_{j}') 357 | 358 | col_names += ['num_steps', 'num_vms', 'vm_action', 'pm_action', 'logprob', 'rewards', 'done'] 359 | col_names += ['values', 'ep_return', 'fragment_rate'] 360 | plot_step = np.tile(np.expand_dims(np.arange(num_steps), -1), 3).reshape((num_steps, 3, 1)) 361 | 362 | num_updates = args.total_timesteps // args.batch_size 363 | 364 | with torch.no_grad(): 365 | envs.call('set_mode', mode='dev') 366 | 367 | dev_all_frag_rate = np.ones((num_dev, num_steps)) 368 | dev_all_min_frag_rate = np.ones((num_dev, num_steps)) 369 | dev_pbar = trange(0, num_dev, num_envs, desc='Dev') 370 | for file_index in dev_pbar: 371 | file_ids = [num_train + file_index + env_id for env_id in range(num_envs)] 372 | envs.call_parse('set_current_env', env_id=file_ids) 373 | 374 | current_ep_info = np.zeros((num_steps, args.num_envs, 2)) - 1000 # return, len, fr 375 | next_obs_dict = envs.reset() 376 | next_obs_pm = torch.tensor(next_obs_dict['pm_info'], device=device) # torch.Size([8, 279, 8]) 377 | next_obs_vm = torch.tensor(next_obs_dict['vm_info'], device=device) # torch.Size([8, 279, 14]) 378 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 379 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 380 | next_done = torch.zeros(args.num_envs, device=device) 381 | 382 | for step in range(0, num_steps): 383 | obs_pm[step] = next_obs_pm 384 | obs_vm[step] = next_obs_vm 385 | obs_num_steps[step] = next_obs_num_steps 386 | obs_num_vms[step] = next_obs_num_vms 387 | dones[step] = next_done 388 | 389 | vm_action, pm_action, logprob, _, value, action_mask \ 390 | = agent.get_action_and_value(envs, next_obs_pm, next_obs_vm, next_obs_num_steps, next_obs_num_vms) 391 | values[step] = value.flatten() # value: torch.Size([8, 1]) 392 | action_masks[step] = action_mask 393 | vm_actions[step] = vm_action 394 | pm_actions[step] = pm_action 395 | logprobs[step] = logprob 396 | 397 | # TRY NOT TO MODIFY: execute the game and log data. 398 | # print(f'vm_action: {vm_action.cpu().numpy()}, pm_action: {pm_action.cpu().numpy()}') 399 | next_obs_dict, reward, done, info = envs.step(torch.stack([vm_action, pm_action], 400 | dim=-1).cpu().numpy()) 401 | next_obs_pm = torch.Tensor(next_obs_dict['pm_info']).to(device) 402 | next_obs_vm = torch.Tensor(next_obs_dict['vm_info']).to(device) 403 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 404 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 405 | rewards[step] = torch.tensor(reward, device=device).view(-1) 406 | next_done = torch.Tensor(done).to(device) 407 | 408 | for env_id, item in enumerate(info): 409 | dev_all_frag_rate[file_index + env_id, step] = item['fragment_rate'] 410 | current_ep_info[step, env_id, 1] = item['fragment_rate'] 411 | if "episode" in item.keys(): 412 | current_ep_info[step, env_id, 0] = item["episode"]["r"] 413 | dev_all_min_frag_rate[file_index + env_id, step] = item['fragment_rate'] 414 | 415 | if args.debug: 416 | plot_obs_vm = obs_vm[:, :3].cpu().data.numpy().reshape(num_steps, 3, -1) 417 | plot_obs_pm = obs_pm[:, :3].cpu().data.numpy().reshape(num_steps, 3, -1) 418 | plot_obs_num_steps = obs_num_steps[:, :3].cpu().data.numpy().reshape(num_steps, 3, -1) 419 | plot_obs_num_vms = obs_num_vms[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 420 | plot_vm_actions = vm_actions[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 421 | plot_pm_actions = pm_actions[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 422 | plot_logprobs = logprobs[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 423 | plot_rewards = rewards[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 424 | plot_dones = dones[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 425 | plot_values = values[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 426 | plot_ep_info = current_ep_info[:, :3] 427 | plot_update_all = np.swapaxes(np.concatenate([plot_step, plot_obs_vm, plot_obs_pm, plot_obs_num_steps, 428 | plot_obs_num_vms, plot_vm_actions, plot_pm_actions, 429 | plot_logprobs, plot_rewards, plot_dones, 430 | plot_values, plot_ep_info], axies=-1), axis1=1, axis2=0) 431 | plot_update_all = plot_update_all.reshape((num_steps * 3, -1)) 432 | episode_df = pd.DataFrame(plot_update_all, columns=col_names) 433 | plot_fr_mean = np.mean(plot_ep_info[:, :, 2][plot_ep_info[:, :, 2] != -1000]) 434 | episode_df.to_pickle(f'runs/{run_name}/dev_{num_train + file_index}' 435 | f'-{num_train + file_index + 2}.pkl') 436 | 437 | for i in range(num_dev): 438 | print(f'dev {i}: {dev_all_min_frag_rate[i][dev_all_min_frag_rate[i] != 1]}') 439 | current_dev_frag_rate = np.mean(np.amin(dev_all_min_frag_rate, axis=1)) 440 | np.save(os.path.join('runs', args.restore_name, 'dev_all_frag_rate.npy'), dev_all_frag_rate) 441 | print(f'Dev fragment rate: {current_dev_frag_rate:.4f}') 442 | 443 | envs.call('set_mode', mode='test') 444 | 445 | test_all_min_frag_rate = np.ones((num_test, num_steps)) 446 | test_pbar = trange(0, num_test, num_envs, desc='Test') 447 | for file_index in test_pbar: 448 | file_ids = [num_train + num_dev + file_index + env_id for env_id in range(num_envs)] 449 | envs.call_parse('set_current_env', env_id=file_ids) 450 | 451 | current_ep_info = np.zeros((num_steps, args.num_envs, 2)) - 1000 # return, len, fr 452 | next_obs_dict = envs.reset() 453 | next_obs_pm = torch.tensor(next_obs_dict['pm_info'], device=device) # torch.Size([8, 279, 8]) 454 | next_obs_vm = torch.tensor(next_obs_dict['vm_info'], device=device) # torch.Size([8, 279, 14]) 455 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 456 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 457 | next_done = torch.zeros(args.num_envs, device=device) 458 | 459 | for step in range(0, num_steps): 460 | obs_pm[step] = next_obs_pm 461 | obs_vm[step] = next_obs_vm 462 | obs_num_steps[step] = next_obs_num_steps 463 | obs_num_vms[step] = next_obs_num_vms 464 | dones[step] = next_done 465 | 466 | vm_action, pm_action, logprob, _, value, action_mask \ 467 | = agent.get_action_and_value(envs, next_obs_pm, next_obs_vm, next_obs_num_steps, next_obs_num_vms) 468 | values[step] = value.flatten() # value: torch.Size([8, 1]) 469 | action_masks[step] = action_mask 470 | vm_actions[step] = vm_action 471 | pm_actions[step] = pm_action 472 | logprobs[step] = logprob 473 | 474 | # TRY NOT TO MODIFY: execute the game and log data. 475 | # print(f'vm_action: {vm_action.cpu().numpy()}, pm_action: {pm_action.cpu().numpy()}') 476 | next_obs_dict, reward, done, info = envs.step(torch.stack([vm_action, pm_action], 477 | dim=-1).cpu().numpy()) 478 | next_obs_pm = torch.Tensor(next_obs_dict['pm_info']).to(device) 479 | next_obs_vm = torch.Tensor(next_obs_dict['vm_info']).to(device) 480 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 481 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 482 | rewards[step] = torch.tensor(reward, device=device).view(-1) 483 | next_done = torch.Tensor(done).to(device) 484 | 485 | for env_id, item in enumerate(info): 486 | current_ep_info[step, env_id, 1] = item['fragment_rate'] 487 | if "episode" in item.keys(): 488 | current_ep_info[step, env_id, 0] = item["episode"]["r"] 489 | test_all_min_frag_rate[file_index + env_id, step] = item['fragment_rate'] 490 | 491 | if args.debug: 492 | plot_obs_vm = obs_vm[:, :3].cpu().data.numpy().reshape(num_steps, 3, -1) 493 | plot_obs_pm = obs_pm[:, :3].cpu().data.numpy().reshape(num_steps, 3, -1) 494 | plot_obs_num_steps = obs_num_steps[:, :3].cpu().data.numpy().reshape(num_steps, 3, -1) 495 | plot_obs_num_vms = obs_num_vms[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 496 | plot_vm_actions = vm_actions[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 497 | plot_pm_actions = pm_actions[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 498 | plot_logprobs = logprobs[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 499 | plot_rewards = rewards[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 500 | plot_dones = dones[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 501 | plot_values = values[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 502 | plot_ep_info = current_ep_info[:, :3] 503 | plot_update_all = np.swapaxes(np.concatenate([plot_step, plot_obs_vm, plot_obs_pm, plot_obs_num_steps, 504 | plot_obs_num_vms, plot_vm_actions, plot_pm_actions, 505 | plot_logprobs, plot_rewards, plot_dones, 506 | plot_values, plot_ep_info], axies=-1), axis1=1, axis2=0) 507 | plot_update_all = plot_update_all.reshape((num_steps * 3, -1)) 508 | episode_df = pd.DataFrame(plot_update_all, columns=col_names) 509 | plot_fr_mean = np.mean(plot_ep_info[:, :, 2][plot_ep_info[:, :, 2] != -1000]) 510 | episode_df.to_pickle(f'runs/{run_name}/' 511 | f'test_{num_train + num_dev + file_index}' 512 | f'-{num_train + num_dev + file_index + 2}.pkl') 513 | 514 | current_test_frag_rate = np.mean(np.amin(test_all_min_frag_rate, axis=1)) 515 | print(f'Test fragment rate: {current_test_frag_rate:.4f}') 516 | 517 | np.save(f"runs/{run_name}/{args.restore_file_name}_dev_all_min_frag_rate.npy", dev_all_min_frag_rate) 518 | np.save(f"runs/{run_name}/{args.restore_file_name}_test_all_min_frag_rate.npy", test_all_min_frag_rate) 519 | 520 | envs.close() 521 | -------------------------------------------------------------------------------- /experiments/pretrain/attn/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 1e-3, 3 | "warmup_portion": 0.1, 4 | "max_grad_norm": -1, 5 | "batch_size": 256, 6 | "d_hidden": 8, 7 | "num_head": 2, 8 | "transformer_blocks": 2, 9 | "d_ff": 10, 10 | "num_epochs": 100, 11 | "dropout": 0.1, 12 | "predict_batch": 128, 13 | "num_loss": 1, 14 | "weighted_sampler": false 15 | } 16 | -------------------------------------------------------------------------------- /experiments/pretrain/mlp/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 1e-3, 3 | "warmup_portion": 0.1, 4 | "max_grad_norm": -1, 5 | "batch_size": 256, 6 | "d_hidden": 10, 7 | "num_epochs": 100, 8 | "dropout": 0.1, 9 | "predict_batch": 128, 10 | "num_loss": 1, 11 | "weighted_sampler": false 12 | } 13 | -------------------------------------------------------------------------------- /find_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | """ 12 | VM: cpu, cpu, mem, mem, cpu % 16, cpu % 16 (0 is full, 1 is empty) 13 | PM: cpu, cpu, mem, mem, fragment_rate, cpu % 16, fragment_rate, cpu % 16 14 | cpu % 16 = round(normalized_cpu * 88) % 16 / 16 15 | fragment_rate = round(normalized_cpu * 88) % 16 / round(normalized_cpu * 88) 16 | To rescale memory, mem * 368776 17 | """ 18 | 19 | import argparse 20 | import os 21 | import random 22 | import time 23 | from distutils.util import strtobool 24 | 25 | import pandas as pd 26 | import wandb 27 | 28 | import gym 29 | import numpy as np 30 | import torch 31 | import torch.nn as nn 32 | import torch.optim as optim 33 | from torch.distributions.categorical import Categorical 34 | from torch.utils.tensorboard import SummaryWriter 35 | from tqdm import trange 36 | 37 | import gym_reschdule_combination.envs.vm_rescheduler_env 38 | 39 | import models 40 | import utils 41 | from env_patch import AsyncVectorEnv_Patch 42 | 43 | 44 | def parse_args(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--model", type=str, default="attn", help="model architecture") 47 | parser.add_argument("--pretrain", action='store_true', 48 | help="if toggled, we will restore pretrained weights for vm selection") 49 | parser.add_argument("--gym-id", type=str, default="generalizer-v1", 50 | help="the id of the gym environment") 51 | parser.add_argument("--vm-data-size", type=str, default="M", choices=["M", "L"], 52 | help="size of the dataset") 53 | parser.add_argument("--max-steps", type=int, default=50, help="maximum number of redeploy steps") 54 | parser.add_argument("--learning-rate", type=float, default=2.5e-4, 55 | help="the learning rate of the optimizer") 56 | parser.add_argument("--seed", type=int, default=1, 57 | help="seed of the experiment") 58 | parser.add_argument("--total-timesteps", type=int, default=2000000, 59 | help="total timesteps of the experiments") 60 | parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 61 | help="if toggled, `torch.backends.cudnn.deterministic=False`") 62 | parser.add_argument("--normalize", action='store_true', 63 | help="if toggled, we will normalize the input features") 64 | parser.add_argument("--track", action='store_true', 65 | help="if toggled, this experiment will be tracked with Weights and Biases") 66 | parser.add_argument("--debug", action='store_true', 67 | help="if toggled, this experiment will save run details") 68 | 69 | # Algorithm specific arguments 70 | parser.add_argument("--num-envs", type=int, default=8, 71 | help="the number of parallel game environments") 72 | parser.add_argument("--num-steps", type=int, default=128, 73 | help="the number of steps to run in each environment per policy rollout") 74 | parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 75 | help="Toggle learning rate annealing for policy and value networks") 76 | parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 77 | help="Use GAE for advantage computation") 78 | parser.add_argument("--gamma", type=float, default=0.99, 79 | help="the discount factor gamma") 80 | parser.add_argument("--gae-lambda", type=float, default=0.95, 81 | help="the lambda for the general advantage estimation") 82 | parser.add_argument("--num-minibatches", type=int, default=4, 83 | help="the number of mini-batches") 84 | parser.add_argument("--accum-iter", type=int, default=4, 85 | help="number of iterations where gradient is accumulated before the weights are updated;" 86 | " used to increase the effective batch size") 87 | parser.add_argument("--update-epochs", type=int, default=4, 88 | help="the K epochs to update the policy") 89 | parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 90 | help="Toggles advantages normalization") 91 | parser.add_argument("--clip-coef", type=float, default=0.1, 92 | help="the surrogate clipping coefficient") 93 | parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 94 | help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") 95 | parser.add_argument("--ent-coef", type=float, default=0.005, # 0.01 96 | help="coefficient of the entropy") 97 | parser.add_argument("--vf-coef", type=float, default=1e-2, # 1e-4 98 | help="coefficient of the value function") 99 | parser.add_argument("--max-grad-norm", type=float, default=0.5, 100 | help="the maximum norm for the gradient clipping") 101 | parser.add_argument("--target-kl", type=float, default=None, 102 | help="the target KL divergence threshold") 103 | args = parser.parse_args() 104 | args.batch_size = int(args.num_envs * args.num_steps) 105 | args.minibatch_size = int(args.batch_size // (args.num_minibatches * args.accum_iter)) 106 | return args 107 | 108 | 109 | def make_env(gym_id, seed, vm_data_size, max_steps, normalize): 110 | def thunk(): 111 | env = gym.make(gym_id, seed=seed, vm_data_size=vm_data_size, max_steps=max_steps, normalize=normalize) 112 | env = gym.wrappers.RecordEpisodeStatistics(env) 113 | env.seed(seed) 114 | env.action_space.seed(seed) 115 | env.observation_space.seed(seed) 116 | return env 117 | 118 | return thunk 119 | 120 | 121 | if __name__ == "__main__": 122 | args = parse_args() 123 | num_train = 6000 124 | num_dev = 200 125 | num_test = 200 126 | num_envs = args.num_envs 127 | np.set_printoptions(precision=4) 128 | np.set_printoptions(suppress=True) 129 | torch.backends.cudnn.benchmark = True 130 | torch.set_default_dtype(torch.float32) 131 | 132 | # TRY NOT TO MODIFY: seeding 133 | random.seed(args.seed) 134 | np.random.seed(args.seed) 135 | torch.manual_seed(args.seed) 136 | torch.backends.cudnn.deterministic = args.torch_deterministic 137 | 138 | # env setup 139 | envs = AsyncVectorEnv_Patch( 140 | [make_env(args.gym_id, args.seed + i, args.vm_data_size, args.max_steps, 141 | args.normalize) for i in range(num_envs)] 142 | ) 143 | 144 | params = utils.Params(f'./experiments/pretrain/{args.model}/params.json') 145 | params.update('./data/params.json') 146 | 147 | envs.call('set_mode', mode='dev') 148 | pm_mean = np.zeros((num_dev, 8)) 149 | pm_std = np.zeros((num_dev, 8)) 150 | vm_mean = np.zeros((num_dev, 6)) 151 | vm_std = np.zeros((num_dev, 6)) 152 | 153 | dev_pbar = trange(0, num_dev, num_envs, desc='Dev') 154 | for i, file_index in enumerate(dev_pbar): 155 | file_ids = [num_train + file_index + env_id for env_id in range(num_envs)] 156 | envs.call_parse('set_current_env', env_id=file_ids) 157 | 158 | info = envs.reset() 159 | vm_mean[i * num_envs: (i+1) * num_envs] = info['vm_mean'][:, 0] 160 | vm_std[i * num_envs: (i+1) * num_envs] = info['vm_std'][:, 0] 161 | pm_mean[i * num_envs: (i+1) * num_envs] = info['pm_mean'][:, 0] 162 | pm_std[i * num_envs: (i+1) * num_envs] = info['pm_std'][:, 0] 163 | 164 | vm_all_mean = np.mean(vm_mean, axis=0) 165 | vm_all_std = np.mean(vm_std, axis=0) 166 | pm_all_mean = np.mean(pm_mean, axis=0) 167 | pm_all_std = np.mean(pm_std, axis=0) 168 | print(f'dev: vm_mean = {vm_all_mean}, vm_std = {vm_all_std}, pm_mean = {pm_all_mean}, pm_std = {pm_all_std}') 169 | print(f'dev frag: {pm_mean[:, 4]}') 170 | print(f'dev frag: {pm_std[:, 4]}') 171 | 172 | envs.call('set_mode', mode='test') 173 | pm_mean = np.zeros((num_test, 8)) 174 | pm_std = np.zeros((num_test, 8)) 175 | vm_mean = np.zeros((num_test, 6)) 176 | vm_std = np.zeros((num_test, 6)) 177 | 178 | test_pbar = trange(0, num_test, num_envs, desc='Test') 179 | for i, file_index in enumerate(test_pbar): 180 | file_ids = [num_train + num_dev + file_index + env_id for env_id in range(num_envs)] 181 | envs.call_parse('set_current_env', env_id=file_ids) 182 | 183 | info = envs.reset() 184 | vm_mean[i * num_envs: (i+1) * num_envs] = info['vm_mean'][:, 0] 185 | vm_std[i * num_envs: (i+1) * num_envs] = info['vm_std'][:, 0] 186 | pm_mean[i * num_envs: (i+1) * num_envs] = info['pm_mean'][:, 0] 187 | pm_std[i * num_envs: (i+1) * num_envs] = info['pm_std'][:, 0] 188 | 189 | vm_all_mean = np.mean(vm_mean, axis=0) 190 | vm_all_std = np.mean(vm_std, axis=0) 191 | pm_all_mean = np.mean(pm_mean, axis=0) 192 | pm_all_std = np.mean(pm_std, axis=0) 193 | print(f'test: vm_mean = {vm_all_mean}, vm_std = {vm_all_std}, pm_mean = {pm_all_mean}, pm_std = {pm_all_std}') 194 | print(f'test frag: {pm_mean[:, 4]}') 195 | print(f'test frag: {pm_std[:, 4]}') 196 | 197 | envs.call('set_mode', mode='train') 198 | pm_mean = np.zeros((4000, 8)) 199 | pm_std = np.zeros((4000, 8)) 200 | vm_mean = np.zeros((4000, 6)) 201 | vm_std = np.zeros((4000, 6)) 202 | 203 | train_pbar = trange(0, 4000, num_envs, desc='Train') 204 | for i, file_index in enumerate(train_pbar): 205 | file_ids = [file_index + env_id for env_id in range(num_envs)] 206 | envs.call_parse('set_current_env', env_id=file_ids) 207 | 208 | info = envs.reset() 209 | vm_mean[i * num_envs: (i+1) * num_envs] = info['vm_mean'][:, 0] 210 | vm_std[i * num_envs: (i+1) * num_envs] = info['vm_std'][:, 0] 211 | pm_mean[i * num_envs: (i+1) * num_envs] = info['pm_mean'][:, 0] 212 | pm_std[i * num_envs: (i+1) * num_envs] = info['pm_std'][:, 0] 213 | 214 | vm_all_mean = np.mean(vm_mean, axis=0) 215 | vm_all_std = np.mean(vm_std, axis=0) 216 | pm_all_mean = np.mean(pm_mean, axis=0) 217 | pm_all_std = np.mean(pm_std, axis=0) 218 | print(f'train: vm_mean = {vm_all_mean}, vm_std = {vm_all_std}, pm_mean = {pm_all_mean}, pm_std = {pm_all_std}') 219 | 220 | vm_all_mean1 = np.mean(vm_mean[-200:], axis=0) 221 | vm_all_std1 = np.mean(vm_std[-200:], axis=0) 222 | pm_all_mean1 = np.mean(pm_mean[-200:], axis=0) 223 | pm_all_std1 = np.mean(pm_std[-200:], axis=0) 224 | print(f'train1: vm_mean = {vm_all_mean1}, vm_std = {vm_all_std1}, pm_mean = {pm_all_mean1}, pm_std = {pm_all_std1}') 225 | print(f'train1 frag: {pm_mean[-200:, 4]}') 226 | print(f'train1 frag: {pm_std[-200:, 4]}') 227 | 228 | vm_all_mean2 = np.mean(vm_mean[-400:-200], axis=0) 229 | vm_all_std2 = np.mean(vm_std[-400:-200], axis=0) 230 | pm_all_mean2 = np.mean(pm_mean[-400:-200], axis=0) 231 | pm_all_std2 = np.mean(pm_std[-400:-200], axis=0) 232 | print(f'train2: vm_mean = {vm_all_mean2}, vm_std = {vm_all_std2}, pm_mean = {pm_all_mean2}, pm_std = {pm_all_std2}') 233 | print(f'train2 frag: {pm_mean[-400:-200, 4]}') 234 | print(f'train2 frag: {pm_std[-400:-200, 4]}') 235 | 236 | envs.close() 237 | -------------------------------------------------------------------------------- /gym-reschdule_combination/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DRL-based-VM-Rescheduling/a1232df1bce7851c991229fd7a34871f0685f5f0/gym-reschdule_combination/.gitkeep -------------------------------------------------------------------------------- /gym-reschdule_combination/gym_reschdule_combination.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: gym-reschdule-combination 3 | Version: 1.0.0 4 | Summary: UNKNOWN 5 | License: UNKNOWN 6 | Platform: UNKNOWN 7 | 8 | UNKNOWN 9 | 10 | -------------------------------------------------------------------------------- /gym-reschdule_combination/gym_reschdule_combination.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | gym_reschdule_combination/__init__.py 3 | gym_reschdule_combination.egg-info/PKG-INFO 4 | gym_reschdule_combination.egg-info/SOURCES.txt 5 | gym_reschdule_combination.egg-info/dependency_links.txt 6 | gym_reschdule_combination.egg-info/requires.txt 7 | gym_reschdule_combination.egg-info/top_level.txt 8 | gym_reschdule_combination/envs/__init__.py 9 | gym_reschdule_combination/envs/vm_rescheduler_env.py 10 | gym_reschdule_combination/envs/vm_rescheduler_env_heuristic.py 11 | gym_reschdule_combination/envs/vm_rescheduler_env_static.py -------------------------------------------------------------------------------- /gym-reschdule_combination/gym_reschdule_combination.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /gym-reschdule_combination/gym_reschdule_combination.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | gym 2 | -------------------------------------------------------------------------------- /gym-reschdule_combination/gym_reschdule_combination.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | gym_reschdule_combination 2 | -------------------------------------------------------------------------------- /gym-reschdule_combination/gym_reschdule_combination/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DRL-based-VM-Rescheduling/a1232df1bce7851c991229fd7a34871f0685f5f0/gym-reschdule_combination/gym_reschdule_combination/.gitkeep -------------------------------------------------------------------------------- /gym-reschdule_combination/gym_reschdule_combination/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | from gym.envs.registration import register 12 | 13 | 14 | register( 15 | id="generalizer-v0", 16 | entry_point="gym_reschdule_combination.envs:VM_generlizer_v0", 17 | ) 18 | register( 19 | id="generalizer-v1", 20 | entry_point="gym_reschdule_combination.envs:VM_generlizer_v1", 21 | ) 22 | register( 23 | id="graph-v1", 24 | entry_point="gym_reschdule_combination.envs:VM_graph_v1", 25 | ) 26 | register( 27 | id="graph-v2", 28 | entry_point="gym_reschdule_combination.envs:VM_graph_v2", 29 | ) 30 | -------------------------------------------------------------------------------- /gym-reschdule_combination/gym_reschdule_combination/envs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DRL-based-VM-Rescheduling/a1232df1bce7851c991229fd7a34871f0685f5f0/gym-reschdule_combination/gym_reschdule_combination/envs/.gitkeep -------------------------------------------------------------------------------- /gym-reschdule_combination/gym_reschdule_combination/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | from gym_reschdule_combination.envs.vm_rescheduler_env import VM_generlizer_v0 12 | from gym_reschdule_combination.envs.vm_rescheduler_env import VM_generlizer_v1 13 | from gym_reschdule_combination.envs.vm_rescheduler_env import VM_graph_v1 14 | from gym_reschdule_combination.envs.vm_rescheduler_env import VM_graph_v2 15 | -------------------------------------------------------------------------------- /gym-reschdule_combination/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | from setuptools import setup 12 | setup(name="gym_reschdule_combination", 13 | version="1.0.0", 14 | install_requires=["gym"] 15 | ) -------------------------------------------------------------------------------- /main-onehead.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | """ 12 | VM: cpu, cpu, mem, mem, cpu % 16, cpu % 16 (0 is full, 1 is empty) 13 | PM: cpu, cpu, mem, mem, fragment_rate, cpu % 16, fragment_rate, cpu % 16 14 | cpu % 16 = round(normalized_cpu * 88) % 16 / 16 15 | fragment_rate = round(normalized_cpu * 88) % 16 / round(normalized_cpu * 88) 16 | To rescale memory, mem * 368776 17 | """ 18 | 19 | import argparse 20 | import os 21 | import random 22 | import time 23 | from distutils.util import strtobool 24 | 25 | import pandas as pd 26 | import wandb 27 | 28 | import gym 29 | import numpy as np 30 | import torch 31 | import torch.nn as nn 32 | import torch.optim as optim 33 | from torch.distributions.categorical import Categorical 34 | from torch.utils.tensorboard import SummaryWriter 35 | from tqdm import trange 36 | 37 | import gym_reschdule_combination.envs.vm_rescheduler_env 38 | 39 | import models 40 | import utils 41 | from env_patch import AsyncVectorEnv_Patch 42 | from main import make_env, CategoricalMasked 43 | 44 | 45 | def parse_args(): 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--model", type=str, default="attn", help="model architecture") 48 | parser.add_argument("--pretrain", action='store_true', 49 | help="if toggled, we will restore pretrained weights for vm selection") 50 | parser.add_argument("--gym-id", type=str, default="generalizer-v1", 51 | help="the id of the gym environment") 52 | parser.add_argument("--vm-data-size", type=str, default="M", choices=["M", "L"], 53 | help="size of the dataset") 54 | parser.add_argument("--max-steps", type=int, default=50, help="maximum number of redeploy steps") 55 | parser.add_argument("--learning-rate", type=float, default=2.5e-4, 56 | help="the learning rate of the optimizer") 57 | parser.add_argument("--seed", type=int, default=1, 58 | help="seed of the experiment") 59 | parser.add_argument("--total-timesteps", type=int, default=2000000, 60 | help="total timesteps of the experiments") 61 | parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 62 | help="if toggled, `torch.backends.cudnn.deterministic=False`") 63 | parser.add_argument("--normalize", action='store_true', 64 | help="if toggled, we will normalize the input features") 65 | parser.add_argument("--track", action='store_true', 66 | help="if toggled, this experiment will be tracked with Weights and Biases") 67 | parser.add_argument("--debug", action='store_true', 68 | help="if toggled, this experiment will save run details") 69 | 70 | # Algorithm specific arguments 71 | parser.add_argument("--num-envs", type=int, default=8, 72 | help="the number of parallel game environments") 73 | parser.add_argument("--num-steps", type=int, default=128, 74 | help="the number of steps to run in each environment per policy rollout") 75 | parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 76 | help="Toggle learning rate annealing for policy and value networks") 77 | parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 78 | help="Use GAE for advantage computation") 79 | parser.add_argument("--gamma", type=float, default=0.99, 80 | help="the discount factor gamma") 81 | parser.add_argument("--gae-lambda", type=float, default=0.95, 82 | help="the lambda for the general advantage estimation") 83 | parser.add_argument("--num-minibatches", type=int, default=4, 84 | help="the number of mini-batches") 85 | parser.add_argument("--accum-iter", type=int, default=4, 86 | help="number of iterations where gradient is accumulated before the weights are updated;" 87 | " used to increase the effective batch size") 88 | parser.add_argument("--update-epochs", type=int, default=4, 89 | help="the K epochs to update the policy") 90 | parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 91 | help="Toggles advantages normalization") 92 | parser.add_argument("--clip-coef", type=float, default=0.1, 93 | help="the surrogate clipping coefficient") 94 | parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 95 | help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") 96 | parser.add_argument("--ent-coef", type=float, default=0.005, # 0.01 97 | help="coefficient of the entropy") 98 | parser.add_argument("--vf-coef", type=float, default=1e-2, # 1e-4 99 | help="coefficient of the value function") 100 | parser.add_argument("--max-grad-norm", type=float, default=0.5, 101 | help="the maximum norm for the gradient clipping") 102 | parser.add_argument("--target-kl", type=float, default=None, 103 | help="the target KL divergence threshold") 104 | args = parser.parse_args() 105 | args.batch_size = int(args.num_envs * args.num_steps) 106 | args.minibatch_size = int(args.batch_size // (args.num_minibatches * args.accum_iter)) 107 | return args 108 | 109 | 110 | class Agent(nn.Module): 111 | def __init__(self, vm_net, params, args_model): 112 | super(Agent, self).__init__() 113 | 114 | self.vm_net = vm_net 115 | self.device = params.device 116 | self.model = args_model 117 | 118 | def get_value(self, obs_info_pm, obs_info_all_vm, obs_info_num_steps, obs_info_num_vms): 119 | num_vms_mask = torch.arange(obs_info_all_vm.shape[1], 120 | device=obs_info_all_vm.device)[None, :] >= obs_info_num_vms[:, None] 121 | return self.vm_net(obs_info_all_vm, obs_info_num_steps, obs_info_pm, num_vms_mask)[1] 122 | 123 | def get_action_and_value(self, envs, obs_info_pm, obs_info_all_vm, obs_info_num_steps, obs_info_num_vms, 124 | pm_mask=None, selected_vm=None, selected_pm=None): 125 | if pm_mask is None: 126 | assert selected_vm is None and selected_pm is None, \ 127 | 'action must be None when action_mask is not given!' 128 | else: 129 | assert selected_vm is not None and selected_pm is not None, \ 130 | 'action must be given when action_mask is given!' 131 | 132 | num_vms_mask = torch.arange(obs_info_all_vm.shape[1], device=self.device)[None, :] >= obs_info_num_vms[:, None] 133 | 134 | b_sz = obs_info_pm.shape[0] 135 | # obs_info_all_vm: torch.Size([8, 2089, 14]) 136 | # obs_info_pm: torch.Size([8, 279, 8]) 137 | if self.model == "attn": 138 | vm_logits, critic_score, attn_score = self.vm_net(obs_info_all_vm, obs_info_num_steps, obs_info_pm, 139 | num_vms_mask, return_attns=True) 140 | else: 141 | raise ValueError(f'self.model={self.model} is not implemented') 142 | # vm_pred: torch.Size([8, 2089]) 143 | # critic_score: torch.Size([8, 1]) 144 | vm_cat = CategoricalMasked(logits=vm_logits, masks=num_vms_mask) 145 | if selected_vm is None: 146 | selected_vm = vm_cat.sample() 147 | vm_log_prob = vm_cat.log_prob(selected_vm) 148 | # selected_vm: torch.Size([8]) 149 | # vm_log_prob: torch.Size([8]) 150 | # entropy: torch.Size([8]) 151 | 152 | if pm_mask is None: 153 | pm_mask = torch.tensor(np.array(envs.call_parse('get_pm_mask', vm_id=selected_vm.cpu().tolist())), 154 | dtype=torch.bool, device=self.device) # pm_mask: torch.Size([8, 279]) 155 | 156 | # obs_info_all_vm: torch.Size([8, 2089, 14]) 157 | pm_probs = attn_score[-1][torch.arange(b_sz, device=self.device), selected_vm][:, 1:] 158 | # pm_logits: torch.Size([8, 279]) 159 | pm_cat = CategoricalMasked(probs=pm_probs, masks=pm_mask) 160 | if selected_pm is None: 161 | selected_pm = pm_cat.sample() 162 | # selected_pm: torch.Size([8]) 163 | pm_log_prob = pm_cat.log_prob(selected_pm) 164 | # pm_log_prob: torch.Size([8]) 165 | log_prob = vm_log_prob + pm_log_prob 166 | entropy = vm_cat.entropy() + pm_cat.entropy() 167 | 168 | return selected_vm, selected_pm, log_prob, entropy, critic_score, pm_mask 169 | 170 | 171 | if __name__ == "__main__": 172 | args = parse_args() 173 | save_every_step = 50 174 | plot_every_step = 20 175 | test_every_step = 30 176 | num_train = 6000 177 | num_dev = 200 178 | num_test = 200 179 | num_envs = args.num_envs 180 | num_steps = args.num_steps 181 | num_test_steps = min(256, int(num_steps * args.max_steps / 10)) 182 | run_name = f"{args.vm_data_size}{args.max_steps}_{args.gym_id}_{args.model}_{args.seed}" \ 183 | f"_{utils.name_with_datetime()}" 184 | np.set_printoptions(precision=4) 185 | np.set_printoptions(suppress=True) 186 | torch.backends.cudnn.benchmark = True 187 | torch.set_default_dtype(torch.float32) 188 | if args.track: 189 | wandb.init(entity="zhykoties", 190 | project="vm_scheduling", 191 | name=run_name, 192 | sync_tensorboard=True, 193 | monitor_gym=True, config={ 194 | 'model': args.model, 195 | 'ent_coef': args.ent_coef, 196 | 'vf_coef': args.vf_coef, 197 | 'eff_b_sz': args.accum_iter * args.num_minibatches}, 198 | # notes="", 199 | tags=[args.model, args.gym_id, 'onehead', 'norm'] if args.normalize else [args.model, args.gym_id, 'onehead'], 200 | save_code=True 201 | ) 202 | writer = SummaryWriter(f"runs/{run_name}") 203 | writer.add_text( 204 | "hyperparameters", 205 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 206 | ) 207 | print('vf_coef: ', args.vf_coef) 208 | 209 | # TRY NOT TO MODIFY: seeding 210 | random.seed(args.seed) 211 | np.random.seed(args.seed) 212 | torch.manual_seed(args.seed) 213 | torch.backends.cudnn.deterministic = args.torch_deterministic 214 | 215 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 216 | 217 | # env setup 218 | envs = AsyncVectorEnv_Patch( 219 | [make_env(args.gym_id, args.seed + i, args.vm_data_size, args.max_steps, 220 | args.normalize) for i in range(num_envs)] 221 | ) 222 | 223 | # assert isinstance(envs.single_action_space, gym.spaces.MultiDiscrete), \ 224 | # "only MultiDiscrete action space is supported" 225 | 226 | params = utils.Params(f'./experiments/pretrain/{args.model}/params.json') 227 | params.update('./data/params.json') 228 | params.device = device 229 | params.batch_size = args.num_envs 230 | params.accum_iter = args.accum_iter 231 | 232 | print('clip_vloss: ', args.clip_vloss) 233 | 234 | # input the vm candidate model 235 | if args.model == 'attn': 236 | vm_cand_model = models.VM_Attn_Wrapper(params, args.pretrain).model 237 | elif args.model == 'mlp': 238 | vm_cand_model = models.VM_MLP_Wrapper(params, args.pretrain).model 239 | else: 240 | raise ValueError(f'args.model = {args.model} is not defined!') 241 | 242 | agent = Agent(vm_cand_model, params, args.model) 243 | vm_optim = optim.Adam(vm_cand_model.parameters(), lr=args.learning_rate, eps=1e-5) 244 | 245 | if args.track: 246 | wandb.watch(agent, log_freq=100) 247 | 248 | # ALGO Logic: Storage setup 249 | obs_vm = torch.zeros(num_steps, args.num_envs, params.num_vm, params.vm_cov, device=device) 250 | obs_pm = torch.zeros(num_steps, args.num_envs, params.num_pm, params.pm_cov, device=device) 251 | obs_num_steps = torch.zeros(num_steps, args.num_envs, 1, 1, device=device) 252 | obs_num_vms = torch.zeros(num_steps, args.num_envs, dtype=torch.int32, device=device) 253 | vm_actions = torch.zeros(num_steps, args.num_envs, device=device) 254 | pm_actions = torch.zeros(num_steps, args.num_envs, device=device) 255 | logprobs = torch.zeros(num_steps, args.num_envs, device=device) 256 | rewards = torch.zeros(num_steps, args.num_envs, device=device) 257 | dones = torch.zeros(num_steps, args.num_envs, device=device) 258 | values = torch.zeros(num_steps, args.num_envs, device=device) 259 | # envs.single_action_space.nvec: [2089, 279] (#vm, #pm) 260 | action_masks = torch.zeros(num_steps, args.num_envs, envs.single_action_space.nvec[1], dtype=torch.bool, 261 | device=device) 262 | 263 | # TRY NOT TO MODIFY: start the game 264 | global_step = 0 265 | dev_best_frag_rate = 1 266 | best_frag_rate_step = 0 267 | test_best_frag_rate = 1 268 | if args.debug: 269 | col_names = ['step'] 270 | for i in range(params.num_vm): 271 | for j in range(params.vm_cov): 272 | col_names.append(f'vm_{i}_cov_{j}') 273 | 274 | for i in range(params.num_pm): 275 | for j in range(params.pm_cov): 276 | col_names.append(f'pm_{i}_cov_{j}') 277 | 278 | col_names += ['num_steps', 'num_vms', 'vm_action', 'pm_action', 'logprob', 'rewards', 'done'] 279 | col_names += ['values', 'ep_return', 'fragment_rate'] 280 | plot_step = np.tile(np.expand_dims(np.arange(num_steps), -1), 3).reshape((num_steps, 3, 1)) 281 | 282 | num_updates = args.total_timesteps // args.batch_size 283 | pbar = trange(1, num_updates + 1) 284 | for update in pbar: 285 | # Annealing the rate if instructed to do so. 286 | if args.anneal_lr: 287 | frac = 1.0 - (update - 1.0) / num_updates 288 | lrnow = frac * args.learning_rate 289 | vm_optim.param_groups[0]["lr"] = lrnow 290 | 291 | current_ep_info = np.zeros((num_steps, args.num_envs, 2)) - 1000 # return, len, fr 292 | 293 | next_obs_dict = envs.reset() 294 | next_obs_pm = torch.Tensor(next_obs_dict['pm_info']).to(device) 295 | next_obs_vm = torch.Tensor(next_obs_dict['vm_info']).to(device) 296 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 297 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 298 | next_done = torch.zeros(args.num_envs, device=device) 299 | 300 | for step in range(0, num_steps): 301 | global_step += 1 * args.num_envs 302 | 303 | obs_pm[step] = next_obs_pm 304 | obs_vm[step] = next_obs_vm 305 | obs_num_steps[step] = next_obs_num_steps 306 | obs_num_vms[step] = next_obs_num_vms 307 | dones[step] = next_done 308 | 309 | with torch.no_grad(): 310 | vm_action, pm_action, logprob, _, value, action_mask \ 311 | = agent.get_action_and_value(envs, next_obs_pm, next_obs_vm, next_obs_num_steps, next_obs_num_vms) 312 | values[step] = value.flatten() # value: torch.Size([8, 1]) 313 | action_masks[step] = action_mask 314 | vm_actions[step] = vm_action 315 | pm_actions[step] = pm_action 316 | logprobs[step] = logprob 317 | 318 | # TRY NOT TO MODIFY: execute the game and log data. 319 | # print(f'vm_action: {vm_action.cpu().numpy()}, pm_action: {pm_action.cpu().numpy()}') 320 | next_obs_dict, reward, done, info = envs.step(torch.stack([vm_action, pm_action], 321 | dim=-1).cpu().numpy()) 322 | next_obs_pm = torch.Tensor(next_obs_dict['pm_info']).to(device) 323 | next_obs_vm = torch.Tensor(next_obs_dict['vm_info']).to(device) 324 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 325 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 326 | rewards[step] = torch.tensor(reward, device=device).view(-1) 327 | next_done = torch.Tensor(done).to(device) 328 | 329 | for env_id, item in enumerate(info): 330 | if "episode" in item.keys(): 331 | current_ep_info[step, env_id, 0] = item["episode"]["r"] 332 | current_ep_info[step, env_id, 1] = item['fragment_rate'] 333 | 334 | no_end_mask = current_ep_info[:, :, 0] != -1000 335 | current_ep_return = current_ep_info[:, :, 0][no_end_mask] 336 | current_ep_fr = current_ep_info[:, :, 1][no_end_mask] 337 | if args.track: 338 | writer.add_scalar("episode_details/episodic_return", np.mean(current_ep_return), global_step) 339 | writer.add_scalar("episode_details/fragment_rate", np.mean(current_ep_fr), global_step) 340 | writer.add_scalar("episode_details/min_fragment_rate", np.amin(current_ep_fr), global_step) 341 | pbar.set_description(f'Train frag rate: {np.amin(current_ep_fr):.4f}') 342 | # if args.track: 343 | # table = wandb.Table(data=np.stack([current_ep_return, current_ep_fr], axis=-1), 344 | # columns=["return", "fragment rate"]) 345 | # wandb.log({"episode_details/return_vs_FR": wandb.plot.scatter(table, "return", "fragment rate")}) 346 | if args.debug: 347 | print(f'========= global_step: {global_step} ========= ' 348 | f'\n{np.stack([current_ep_return, current_ep_fr], axis=-1)}') 349 | 350 | if args.debug and (update + 1) % plot_every_step == 0: 351 | plot_obs_vm = obs_vm[:, :3].cpu().data.numpy().reshape(num_steps, 3, -1) 352 | plot_obs_pm = obs_pm[:, :3].cpu().data.numpy().reshape(num_steps, 3, -1) 353 | plot_obs_num_steps = obs_num_steps[:, :3].cpu().data.numpy().reshape(num_steps, 3, -1) 354 | plot_obs_num_vms = obs_num_vms[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 355 | plot_vm_actions = vm_actions[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 356 | plot_pm_actions = pm_actions[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 357 | plot_logprobs = logprobs[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 358 | plot_rewards = rewards[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 359 | plot_dones = dones[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 360 | plot_values = values[:, :3].cpu().data.numpy().reshape(num_steps, 3, 1) 361 | plot_ep_info = current_ep_info[:, :3] 362 | plot_update_all = np.swapaxes(np.concatenate([plot_step, plot_obs_vm, plot_obs_pm, plot_obs_num_steps, 363 | plot_obs_num_vms, plot_vm_actions, plot_pm_actions, 364 | plot_logprobs, plot_rewards, plot_dones, 365 | plot_values, plot_ep_info], axies=-1), axis1=1, axis2=0) 366 | plot_update_all = plot_update_all.reshape((num_steps * 3, -1)) 367 | episode_df = pd.DataFrame(plot_update_all, columns=col_names) 368 | plot_fr_mean = np.mean(plot_ep_info[:, :, 2][plot_ep_info[:, :, 2] != -1000]) 369 | episode_df.to_pickle(f'runs/{run_name}/u_{update}_{plot_fr_mean}.pkl') 370 | 371 | # bootstrap value if not done 372 | with torch.no_grad(): 373 | next_value = agent.get_value(next_obs_pm, next_obs_vm, next_obs_num_steps, next_obs_num_vms).reshape(1, -1) 374 | if args.gae: 375 | advantages = torch.zeros_like(rewards, device=device) 376 | lastgaelam = 0 377 | for t in reversed(range(num_steps)): 378 | if t == num_steps - 1: 379 | nextnonterminal = 1.0 - next_done 380 | nextvalues = next_value 381 | else: 382 | nextnonterminal = 1.0 - dones[t + 1] 383 | nextvalues = values[t + 1] 384 | delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] 385 | advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam 386 | returns = advantages + values 387 | else: 388 | returns = torch.zeros_like(rewards, device=device) 389 | for t in reversed(range(num_steps)): 390 | if t == num_steps - 1: 391 | nextnonterminal = 1.0 - next_done 392 | next_return = next_value 393 | else: 394 | nextnonterminal = 1.0 - dones[t + 1] 395 | next_return = returns[t + 1] 396 | returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return 397 | advantages = returns - values 398 | 399 | # flatten the batch 400 | b_obs_vm = obs_vm.reshape(-1, params.num_vm, params.vm_cov) 401 | b_obs_pm = obs_pm.reshape(-1, params.num_pm, params.pm_cov) 402 | b_obs_num_steps = obs_num_steps.reshape(-1, 1, 1) 403 | b_obs_num_vms = obs_num_vms.reshape(-1) 404 | b_vm_actions = vm_actions.reshape(-1) 405 | b_logprobs = logprobs.reshape(-1) 406 | b_pm_actions = pm_actions.reshape(-1) 407 | b_advantages = advantages.reshape(-1) 408 | b_returns = returns.reshape(-1) 409 | b_values = values.reshape(-1) 410 | b_action_masks = action_masks.reshape(-1, envs.single_action_space.nvec[1]) 411 | 412 | if args.debug: 413 | print('CRITIC CHECK - returns (pred vs real):\n', 414 | torch.stack([b_values, b_returns], dim=-1).cpu().data.numpy()[:50]) 415 | 416 | # Optimizing the policy and value network 417 | b_inds = np.arange(args.batch_size) 418 | clipfracs = [] 419 | for epoch in range(args.update_epochs): 420 | np.random.shuffle(b_inds) 421 | for index, start in enumerate(range(0, args.batch_size, args.minibatch_size)): 422 | end = start + args.minibatch_size 423 | mb_inds = b_inds[start:end] 424 | _, _, logprob, entropy, newvalue, _ = agent.get_action_and_value( 425 | envs, 426 | b_obs_pm[mb_inds], 427 | b_obs_vm[mb_inds], 428 | b_obs_num_steps[mb_inds], 429 | b_obs_num_vms[mb_inds], 430 | pm_mask=b_action_masks[mb_inds], 431 | selected_vm=b_vm_actions.long()[mb_inds], 432 | selected_pm=b_pm_actions.long()[mb_inds] 433 | ) 434 | 435 | logratio = logprob - b_logprobs[mb_inds] 436 | ratio = logratio.exp() 437 | # if epoch == 0 and start == 0: 438 | # print(f'pm_ratio: {pm_ratio}, vm_ratio: {vm_ratio}') 439 | 440 | with torch.no_grad(): 441 | old_approx_kl = (-logratio).mean() 442 | approx_kl = ((ratio - 1) - logratio).mean() 443 | clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] 444 | 445 | mb_advantages = b_advantages[mb_inds] 446 | if args.norm_adv: 447 | mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) 448 | 449 | # Policy loss 450 | pg_loss1 = -mb_advantages.detach() * ratio 451 | pg_loss2 = -mb_advantages.detach() * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 452 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 453 | 454 | # Value loss 455 | newvalue = newvalue.view(-1) 456 | if args.clip_vloss: 457 | v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 458 | v_clipped = b_values[mb_inds] + torch.clamp( 459 | newvalue - b_values[mb_inds], 460 | -args.clip_coef, 461 | args.clip_coef, 462 | ) 463 | v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 464 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 465 | v_loss = 0.5 * v_loss_max.mean() 466 | else: 467 | v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() 468 | 469 | entropy_loss = entropy.mean() 470 | loss = (pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef) / params.accum_iter 471 | # print(f"vm loss is {vm_loss}, pm loss is {pm_loss}") 472 | # print(f"VM p: {vm_pg_loss}, e: {-args.ent_coef * vm_entropy_loss}, v:{v_loss * args.vf_coef}") 473 | # print(f"PM p: {pm_pg_loss}, e: {-args.ent_coef * pm_entropy_loss}") 474 | loss.backward() 475 | if ((index + 1) % params.accum_iter == 0) or (start + args.minibatch_size > args.batch_size): 476 | nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 477 | vm_optim.step() 478 | vm_optim.zero_grad(set_to_none=True) 479 | 480 | if args.target_kl is not None: 481 | if approx_kl > args.target_kl: 482 | break 483 | 484 | y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() 485 | var_y = np.var(y_true) 486 | explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 487 | """ 488 | https://github.com/DLR-RM/stable-baselines3/blob/d5d1a02c15cdce868c72bbc94913e66fdd2efd3a/stable_baselines3/common/utils.py#L46 489 | Computes fraction of variance that ypred explains about y. 490 | Returns 1 - Var[y-ypred] / Var[y] 491 | interpretation: 492 | ev=0 => might as well have predicted zero 493 | ev=1 => perfect prediction 494 | ev<0 => worse than just predicting zero 495 | """ 496 | if (update + 1) % test_every_step == 0: 497 | 498 | agent.eval() 499 | with torch.no_grad(): 500 | envs.call('set_mode', mode='dev') 501 | 502 | dev_all_min_frag_rate = np.ones((num_dev, num_test_steps)) 503 | dev_pbar = trange(0, num_dev, num_envs, desc='Dev') 504 | for file_index in dev_pbar: 505 | file_ids = [num_train + file_index + env_id for env_id in range(num_envs)] 506 | envs.call_parse('set_current_env', env_id=file_ids) 507 | 508 | next_obs_dict = envs.reset() 509 | next_obs_pm = torch.Tensor(next_obs_dict['pm_info']).to(device) 510 | next_obs_vm = torch.Tensor(next_obs_dict['vm_info']).to(device) 511 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 512 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 513 | 514 | for step in range(0, num_test_steps): 515 | vm_action, pm_action, logprob, _, value, action_mask \ 516 | = agent.get_action_and_value(envs, next_obs_pm, next_obs_vm, next_obs_num_steps, 517 | next_obs_num_vms) 518 | 519 | next_obs_dict, reward, done, info = envs.step(torch.stack([vm_action, pm_action], 520 | dim=-1).cpu().numpy()) 521 | next_obs_pm = torch.Tensor(next_obs_dict['pm_info']).to(device) 522 | next_obs_vm = torch.Tensor(next_obs_dict['vm_info']).to(device) 523 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 524 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 525 | next_done = torch.Tensor(done).to(device) 526 | 527 | for env_id, item in enumerate(info): 528 | if "episode" in item.keys(): 529 | dev_all_min_frag_rate[file_index + env_id, step] = item['fragment_rate'] 530 | 531 | current_dev_frag_rate = np.mean(np.amin(dev_all_min_frag_rate, axis=1)) 532 | 533 | envs.call('set_mode', mode='test') 534 | 535 | test_all_min_frag_rate = np.ones((num_test, num_test_steps)) 536 | test_pbar = trange(0, num_test, num_envs, desc='Test') 537 | for file_index in test_pbar: 538 | file_ids = [num_train + num_dev + file_index + env_id for env_id in range(num_envs)] 539 | envs.call_parse('set_current_env', env_id=file_ids) 540 | 541 | next_obs_dict = envs.reset() 542 | next_obs_pm = torch.Tensor(next_obs_dict['pm_info']).to(device) 543 | next_obs_vm = torch.Tensor(next_obs_dict['vm_info']).to(device) 544 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 545 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 546 | 547 | for step in range(0, num_test_steps): 548 | vm_action, pm_action, logprob, _, value, action_mask \ 549 | = agent.get_action_and_value(envs, next_obs_pm, next_obs_vm, next_obs_num_steps, 550 | next_obs_num_vms) 551 | 552 | next_obs_dict, reward, done, info = envs.step(torch.stack([vm_action, pm_action], 553 | dim=-1).cpu().numpy()) 554 | next_obs_pm = torch.Tensor(next_obs_dict['pm_info']).to(device) 555 | next_obs_vm = torch.Tensor(next_obs_dict['vm_info']).to(device) 556 | next_obs_num_steps = torch.Tensor(next_obs_dict['num_steps']).to(device) 557 | next_obs_num_vms = torch.tensor(next_obs_dict['num_vms'], dtype=torch.int32, device=device) 558 | next_done = torch.Tensor(done).to(device) 559 | 560 | for env_id, item in enumerate(info): 561 | if "episode" in item.keys(): 562 | test_all_min_frag_rate[file_index + env_id, step] = item['fragment_rate'] 563 | 564 | current_test_frag_rate = np.mean(np.amin(test_all_min_frag_rate, axis=1)) 565 | 566 | if current_dev_frag_rate < dev_best_frag_rate: 567 | best_frag_rate_step = update 568 | dev_best_frag_rate = current_dev_frag_rate 569 | test_best_frag_rate = current_test_frag_rate 570 | if args.track: 571 | np.save(f"runs/{run_name}/dev_all_min_frag_rate.npy", dev_all_min_frag_rate) 572 | np.save(f"runs/{run_name}/test_all_min_frag_rate.npy", test_all_min_frag_rate) 573 | utils.save_checkpoint({'global_step': global_step, 574 | 'state_dict': agent.state_dict(), 575 | 'vm_optim_dict': vm_optim.state_dict()}, 576 | global_step=global_step, 577 | checkpoint=f"runs/{run_name}", 578 | is_best=True) 579 | 580 | if args.track: 581 | writer.add_scalar("Eval/dev_frag_rate", current_dev_frag_rate, global_step) 582 | writer.add_scalar("Eval/test_frag_rate", current_test_frag_rate, global_step) 583 | 584 | envs.call('set_mode', mode='train') 585 | agent.train() 586 | 587 | if args.track: 588 | # TRY NOT TO MODIFY: record rewards for plotting purposes 589 | writer.add_scalar("Charts/vm_learning_rate", vm_optim.param_groups[0]["lr"], global_step) 590 | writer.add_scalar("losses/value_loss", v_loss.item(), global_step) 591 | writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) 592 | writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) 593 | writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) 594 | writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) 595 | writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) 596 | writer.add_scalar("losses/explained_variance", explained_var, global_step) 597 | 598 | if (update + 1) % save_every_step == 0: 599 | utils.save_checkpoint({'global_step': global_step, 600 | 'state_dict': agent.state_dict(), 601 | 'optim_dict': vm_optim.state_dict()}, 602 | global_step=global_step, 603 | checkpoint=f"runs/{run_name}") 604 | 605 | envs.close() 606 | if args.track: 607 | writer.close() 608 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | from models.vm_mlp import VM_MLP_Wrapper 12 | from models.pm_mlp import PM_MLP_Wrapper 13 | from models.vm_attn import VM_Attn_Wrapper 14 | from models.pm_attn import PM_Attn_Wrapper 15 | from models.vm_sparse_attn import VM_Sparse_Attn_Wrapper 16 | from models.vm_lite_sparse_attn import VM_Lite_Sparse_Attn_Wrapper 17 | from models.pm_detail_attn import PM_Detail_Attn_Wrapper 18 | from models.vm_attn_graph import VM_Attn_Graph_Wrapper 19 | from models.pm_attn_graph import PM_Attn_Graph_Wrapper 20 | from models.gcn_embed import GCN_Wrapper 21 | -------------------------------------------------------------------------------- /models/components/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import logging 4 | 5 | logger = logging.getLogger('VM.Helper') 6 | 7 | 8 | def accuracy_MAE(predict_all: np.ndarray, gt_all: np.ndarray, missing_value=0): 9 | zero_index = (gt_all != missing_value) 10 | diff = np.mean(np.abs(gt_all[zero_index] - predict_all[zero_index])) 11 | return diff 12 | 13 | 14 | def accuracy_MSE(predict_all: np.ndarray, gt_all: np.ndarray, missing_value=0): 15 | zero_index = (gt_all != missing_value) 16 | diff = np.mean((gt_all[zero_index] - predict_all[zero_index]) ** 2) 17 | return diff 18 | 19 | 20 | # for plots 21 | def accuracy_MAPE_(mu: torch.Tensor, labels: torch.Tensor, missing_value=0): 22 | mu = mu.cpu().detach().numpy() 23 | labels = labels.cpu().detach().numpy() 24 | 25 | mask = (labels == missing_value) 26 | mu[mask] = 1 27 | labels[mask] = 1 28 | result = np.average(np.abs(mu - labels) / np.abs(labels), axis=1) 29 | return result 30 | 31 | 32 | # for plots 33 | def accuracy_AGG_(mu: torch.Tensor, labels: torch.Tensor, missing_value=0): 34 | batch_size = mu.shape[0] 35 | time_step = mu.shape[1] // 30 * 30 36 | mu = mu[:, :time_step] 37 | labels = labels[:, :time_step] 38 | mu[labels == missing_value] = missing_value 39 | mu = mu.view(batch_size, -1, 30).sum(dim=2) 40 | labels = labels.view(batch_size, -1, 30).sum(dim=2) 41 | 42 | mu = mu.cpu().detach().numpy() 43 | labels = labels.cpu().detach().numpy() 44 | 45 | mask = (labels == 0) 46 | mu[mask] = 1 47 | labels[mask] = 1 48 | result = np.average(np.abs((mu - labels) / labels), axis=1) 49 | return result 50 | 51 | 52 | # for plots 53 | def accuracy_ROU_(rou: float, rou_pred: torch.Tensor, labels: torch.Tensor, missing_value=0): 54 | rou_pred = rou_pred.cpu().detach().numpy() 55 | labels = labels.cpu().detach().numpy() 56 | 57 | mask = labels == missing_value 58 | rou_pred[mask] = 0. 59 | labels[mask] = 0. 60 | 61 | abs_diff = np.abs(labels - rou_pred) 62 | abs_diff_1 = abs_diff.copy() 63 | abs_diff_1[labels < rou_pred] = 0. 64 | abs_diff_2 = abs_diff.copy() 65 | abs_diff_2[labels >= rou_pred] = 0. 66 | 67 | numerator = 2 * (rou * np.sum(abs_diff_1, axis=1) + (1 - rou) * np.sum(abs_diff_2, axis=1)) 68 | denominator = np.sum(np.abs(labels), axis=1) 69 | 70 | mask2 = (denominator == 0) 71 | denominator[mask2] = 1 72 | result = numerator / denominator 73 | result[mask2] = -1 74 | return result 75 | -------------------------------------------------------------------------------- /models/components/multihead.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple 3 | import torch 4 | from torch.overrides import ( 5 | has_torch_function, handle_torch_function) 6 | from torch.nn.functional import linear, softmax, dropout, pad 7 | import warnings 8 | 9 | Tensor = torch.Tensor 10 | 11 | 12 | # 13 | # multihead attention 14 | # 15 | 16 | def _in_projection_packed( 17 | q: Tensor, 18 | k: Tensor, 19 | v: Tensor, 20 | w: Tensor, 21 | b: Optional[Tensor] = None, 22 | ) -> List[Tensor]: 23 | r""" 24 | Performs the in-projection step of the attention operation, using packed weights. 25 | Output is a triple containing projection tensors for query, key and value. 26 | 27 | Args: 28 | q, k, v: query, key and value tensors to be projected. For self-attention, 29 | these are typically the same tensor; for encoder-decoder attention, 30 | k and v are typically the same tensor. (We take advantage of these 31 | identities for performance if they are present.) Regardless, q, k and v 32 | must share a common embedding dimension; otherwise their shapes may vary. 33 | w: projection weights for q, k and v, packed into a single tensor. Weights 34 | are packed along dimension 0, in q, k, v order. 35 | b: optional projection biases for q, k and v, packed into a single tensor 36 | in q, k, v order. 37 | 38 | Shape: 39 | Inputs: 40 | - q: :math:`(..., E)` where E is the embedding dimension 41 | - k: :math:`(..., E)` where E is the embedding dimension 42 | - v: :math:`(..., E)` where E is the embedding dimension 43 | - w: :math:`(E * 3, E)` where E is the embedding dimension 44 | - b: :math:`E * 3` where E is the embedding dimension 45 | 46 | Output: 47 | - in output list :math:`[q', k', v']`, each output tensor will have the 48 | same shape as the corresponding input tensor. 49 | """ 50 | E = q.size(-1) 51 | if k is v: 52 | if q is k: 53 | # self-attention 54 | return linear(q, w, b).chunk(3, dim=-1) 55 | else: 56 | # encoder-decoder attention 57 | w_q, w_kv = w.split([E, E * 2]) 58 | if b is None: 59 | b_q = b_kv = None 60 | else: 61 | b_q, b_kv = b.split([E, E * 2]) 62 | return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1) 63 | else: 64 | w_q, w_k, w_v = w.chunk(3) 65 | if b is None: 66 | b_q = b_k = b_v = None 67 | else: 68 | b_q, b_k, b_v = b.chunk(3) 69 | return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) 70 | 71 | 72 | def _in_projection( 73 | q: Tensor, 74 | k: Tensor, 75 | v: Tensor, 76 | w_q: Tensor, 77 | w_k: Tensor, 78 | w_v: Tensor, 79 | b_q: Optional[Tensor] = None, 80 | b_k: Optional[Tensor] = None, 81 | b_v: Optional[Tensor] = None, 82 | ) -> Tuple[Tensor, Tensor, Tensor]: 83 | r""" 84 | Performs the in-projection step of the attention operation. This is simply 85 | a triple of linear projections, with shape constraints on the weights which 86 | ensure embedding dimension uniformity in the projected outputs. 87 | Output is a triple containing projection tensors for query, key and value. 88 | 89 | Args: 90 | q, k, v: query, key and value tensors to be projected. 91 | w_q, w_k, w_v: weights for q, k and v, respectively. 92 | b_q, b_k, b_v: optional biases for q, k and v, respectively. 93 | 94 | Shape: 95 | Inputs: 96 | - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any 97 | number of leading dimensions. 98 | - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any 99 | number of leading dimensions. 100 | - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any 101 | number of leading dimensions. 102 | - w_q: :math:`(Eq, Eq)` 103 | - w_k: :math:`(Eq, Ek)` 104 | - w_v: :math:`(Eq, Ev)` 105 | - b_q: :math:`(Eq)` 106 | - b_k: :math:`(Eq)` 107 | - b_v: :math:`(Eq)` 108 | 109 | Output: in output triple :math:`(q', k', v')`, 110 | - q': :math:`[Qdims..., Eq]` 111 | - k': :math:`[Kdims..., Eq]` 112 | - v': :math:`[Vdims..., Eq]` 113 | 114 | """ 115 | Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) 116 | assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" 117 | assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" 118 | assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" 119 | assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" 120 | assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" 121 | assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" 122 | return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) 123 | 124 | 125 | def _scaled_dot_product_attention( 126 | q: Tensor, 127 | k: Tensor, 128 | v: Tensor, 129 | attn_mask: Optional[Tensor] = None, 130 | dropout_p: float = 0.0, 131 | ) -> Tuple[Tensor, Tensor]: 132 | r""" 133 | Computes scaled dot product attention on query, key and value tensors, using 134 | an optional attention mask if passed, and applying dropout if a probability 135 | greater than 0.0 is specified. 136 | Returns a tensor pair containing attended values and attention weights. 137 | 138 | Args: 139 | q, k, v: query, key and value tensors. See Shape section for shape details. 140 | attn_mask: optional tensor containing mask values to be added to calculated 141 | attention. May be 2D or 3D; see Shape section for details. 142 | dropout_p: dropout probability. If greater than 0.0, dropout is applied. 143 | 144 | Shape: 145 | - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length, 146 | and E is embedding dimension. 147 | - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, 148 | and E is embedding dimension. 149 | - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, 150 | and E is embedding dimension. 151 | - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of 152 | shape :math:`(Nt, Ns)`. 153 | 154 | - Output: attention values have shape :math:`(B, Nt, E)`; attention weights 155 | have shape :math:`(B, Nt, Ns)` 156 | """ 157 | B, Nt, E = q.shape 158 | q = q / math.sqrt(E) 159 | # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) 160 | attn = torch.bmm(q, k.transpose(-2, -1)) 161 | if attn_mask is not None: 162 | attn += attn_mask 163 | attn = softmax(attn, dim=-1) 164 | if dropout_p > 0.0: 165 | attn = dropout(attn, p=dropout_p) 166 | # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) 167 | output = torch.bmm(attn, v) 168 | return output, attn 169 | 170 | 171 | def multi_head_attention_forward( 172 | query: Tensor, 173 | key: Tensor, 174 | value: Tensor, 175 | split_point: int, 176 | embed_dim_to_check: int, 177 | num_heads: int, 178 | in_proj_weight1: Tensor, 179 | in_proj_bias1: Optional[Tensor], 180 | bias_k1: Optional[Tensor], 181 | bias_v1: Optional[Tensor], 182 | in_proj_weight2: Tensor, 183 | in_proj_bias2: Optional[Tensor], 184 | bias_k2: Optional[Tensor], 185 | bias_v2: Optional[Tensor], 186 | add_zero_attn: bool, 187 | dropout_p: float, 188 | out_proj_weight1: Tensor, 189 | out_proj_bias1: Optional[Tensor], 190 | out_proj_weight2: Tensor, 191 | out_proj_bias2: Optional[Tensor], 192 | training: bool = True, 193 | key_padding_mask: Optional[Tensor] = None, 194 | need_weights: bool = True, 195 | attn_mask: Optional[Tensor] = None, 196 | use_separate_proj_weight: bool = False, 197 | q_proj_weight1: Optional[Tensor] = None, 198 | k_proj_weight1: Optional[Tensor] = None, 199 | v_proj_weight1: Optional[Tensor] = None, 200 | q_proj_weight2: Optional[Tensor] = None, 201 | k_proj_weight2: Optional[Tensor] = None, 202 | v_proj_weight2: Optional[Tensor] = None, 203 | static_k: Optional[Tensor] = None, 204 | static_v: Optional[Tensor] = None, 205 | ) -> Tuple[Tensor, Optional[Tensor]]: 206 | r""" 207 | Args: 208 | query, key, value: map a query and a set of key-value pairs to an output. 209 | See "Attention Is All You Need" for more details. 210 | embed_dim_to_check: total dimension of the model. 211 | num_heads: parallel attention heads. 212 | in_proj_weight, in_proj_bias: input projection weight and bias. 213 | bias_k, bias_v: bias of the key and value sequences to be added at dim=0. 214 | add_zero_attn: add a new batch of zeros to the key and 215 | value sequences at dim=1. 216 | dropout_p: probability of an element to be zeroed. 217 | out_proj_weight, out_proj_bias: the output projection weight and bias. 218 | training: apply dropout if is ``True``. 219 | key_padding_mask: if provided, specified padding elements in the key will 220 | be ignored by the attention. This is an binary mask. When the value is True, 221 | the corresponding value on the attention layer will be filled with -inf. 222 | need_weights: output attn_output_weights. 223 | attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all 224 | the batches while a 3D mask allows to specify a different mask for the entries of each batch. 225 | use_separate_proj_weight: the function accept the proj. weights for query, key, 226 | and value in different forms. If false, in_proj_weight will be used, which is 227 | a combination of q_proj_weight, k_proj_weight, v_proj_weight. 228 | q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 229 | static_k, static_v: static key and value used for attention operators. 230 | 231 | 232 | Shape: 233 | Inputs: 234 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 235 | the embedding dimension. 236 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 237 | the embedding dimension. 238 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 239 | the embedding dimension. 240 | - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. 241 | If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions 242 | will be unchanged. If a BoolTensor is provided, the positions with the 243 | value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 244 | - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 245 | 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, 246 | S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked 247 | positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend 248 | while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` 249 | are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 250 | is provided, it will be added to the attention weight. 251 | - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 252 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 253 | - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 254 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 255 | 256 | Outputs: 257 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 258 | E is the embedding dimension. 259 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 260 | L is the target sequence length, S is the source sequence length. 261 | """ 262 | # set up shape vars 263 | tgt_len, bsz, embed_dim = query.shape 264 | src_len, _, _ = key.shape 265 | assert embed_dim == embed_dim_to_check, \ 266 | f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" 267 | if isinstance(embed_dim, torch.Tensor): 268 | # embed_dim can be a tensor when JIT tracing 269 | head_dim = embed_dim.div(num_heads, rounding_mode='trunc') 270 | else: 271 | head_dim = embed_dim // num_heads 272 | assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" 273 | if use_separate_proj_weight: 274 | # allow MHA to have different embedding dimensions when separate projection weights are used 275 | assert key.shape[:2] == value.shape[:2], \ 276 | f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" 277 | else: 278 | assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" 279 | 280 | # 281 | # compute in-projection 282 | # 283 | if not use_separate_proj_weight: 284 | q1, k1, v1 = _in_projection_packed(query[:split_point], key[:split_point], value[:split_point], 285 | in_proj_weight1, in_proj_bias1) 286 | q2, k2, v2 = _in_projection_packed(query[split_point:], key[split_point:], value[split_point:], 287 | in_proj_weight2, in_proj_bias2) 288 | else: 289 | assert q_proj_weight1 is not None and q_proj_weight2 is not None, \ 290 | "use_separate_proj_weight is True but q_proj_weight is None" 291 | assert k_proj_weight1 is not None and k_proj_weight2 is not None, \ 292 | "use_separate_proj_weight is True but k_proj_weight is None" 293 | assert v_proj_weight1 is not None and v_proj_weight2 is not None, \ 294 | "use_separate_proj_weight is True but v_proj_weight is None" 295 | if in_proj_bias1 is None: 296 | b_q1 = b_k1 = b_v1 = None 297 | else: 298 | b_q1, b_k1, b_v1 = in_proj_bias1.chunk(3) 299 | if in_proj_bias2 is None: 300 | b_q2 = b_k2 = b_v2 = None 301 | else: 302 | b_q2, b_k2, b_v2 = in_proj_bias2.chunk(3) 303 | q1, k1, v1 = _in_projection(query[:split_point], key[:split_point], value[:split_point], 304 | q_proj_weight1, k_proj_weight1, v_proj_weight1, b_q1, b_k1, b_v1) 305 | q2, k2, v2 = _in_projection(query[split_point:], key[split_point:], value[split_point:], 306 | q_proj_weight2, k_proj_weight2, v_proj_weight2, b_q2, b_k2, b_v2) 307 | q = torch.cat([q1, q2], dim=0) 308 | k = torch.cat([k1, k2], dim=0) 309 | v = torch.cat([v1, v2], dim=0) 310 | 311 | # prep attention mask 312 | if attn_mask is not None: 313 | if attn_mask.dtype == torch.uint8: 314 | warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 315 | attn_mask = attn_mask.to(torch.bool) 316 | else: 317 | assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ 318 | f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" 319 | # ensure attn_mask's dim is 3 320 | if attn_mask.dim() == 2: 321 | correct_2d_size = (tgt_len, src_len) 322 | if attn_mask.shape != correct_2d_size: 323 | raise RuntimeError( 324 | f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") 325 | attn_mask = attn_mask.unsqueeze(0) 326 | elif attn_mask.dim() == 3: 327 | correct_3d_size = (bsz * num_heads, tgt_len, src_len) 328 | if attn_mask.shape != correct_3d_size: 329 | raise RuntimeError( 330 | f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") 331 | else: 332 | raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") 333 | 334 | # prep key padding mask 335 | if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: 336 | warnings.warn( 337 | "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 338 | key_padding_mask = key_padding_mask.to(torch.bool) 339 | 340 | # add bias along batch dimension (currently second) 341 | if bias_k1 is not None and bias_v1 is not None and bias_k2 is not None and bias_v2 is not None: 342 | assert static_k is None, "bias cannot be added to static key." 343 | assert static_v is None, "bias cannot be added to static value." 344 | k = torch.cat([torch.cat([k[:split_point], bias_k1.repeat(1, bsz, 1)]), 345 | torch.cat([k[split_point:], bias_k2.repeat(1, bsz, 1)])], dim=0) 346 | v = torch.cat([torch.cat([v[:split_point], bias_v1.repeat(1, bsz, 1)]), 347 | torch.cat([v[split_point:], bias_v2.repeat(1, bsz, 1)])], dim=0) 348 | if attn_mask is not None: 349 | attn_mask = pad(attn_mask, (0, 1)) 350 | if key_padding_mask is not None: 351 | key_padding_mask = pad(key_padding_mask, (0, 1)) 352 | else: 353 | assert bias_k1 is None and bias_k2 is None 354 | assert bias_v1 is None and bias_v2 is None 355 | 356 | # 357 | # reshape q, k, v for multihead attention and make em batch first 358 | # 359 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 360 | if static_k is None: 361 | k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 362 | else: 363 | # TODO finish disentangling control flow so we don't do in-projections when statics are passed 364 | assert static_k.size(0) == bsz * num_heads, \ 365 | f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" 366 | assert static_k.size(2) == head_dim, \ 367 | f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" 368 | k = static_k 369 | if static_v is None: 370 | v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 371 | else: 372 | # TODO finish disentangling control flow so we don't do in-projections when statics are passed 373 | assert static_v.size(0) == bsz * num_heads, \ 374 | f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" 375 | assert static_v.size(2) == head_dim, \ 376 | f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" 377 | v = static_v 378 | 379 | # add zero attention along batch dimension (now first) 380 | if add_zero_attn: 381 | zero_attn_shape = (bsz * num_heads, 1, head_dim) 382 | k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) 383 | v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) 384 | if attn_mask is not None: 385 | attn_mask = pad(attn_mask, (0, 1)) 386 | if key_padding_mask is not None: 387 | key_padding_mask = pad(key_padding_mask, (0, 1)) 388 | 389 | # update source sequence length after adjustments 390 | src_len = k.size(1) 391 | 392 | # merge key padding and attention masks 393 | if key_padding_mask is not None: 394 | assert key_padding_mask.shape == (bsz, src_len), \ 395 | f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" 396 | key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ 397 | expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) 398 | if attn_mask is None: 399 | attn_mask = key_padding_mask 400 | elif attn_mask.dtype == torch.bool: 401 | attn_mask = attn_mask.logical_or(key_padding_mask) 402 | else: 403 | attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) 404 | 405 | # convert mask to float 406 | if attn_mask is not None and attn_mask.dtype == torch.bool: 407 | new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float) 408 | new_attn_mask.masked_fill_(attn_mask, float("-inf")) 409 | attn_mask = new_attn_mask 410 | 411 | # adjust dropout probability 412 | if not training: 413 | dropout_p = 0.0 414 | 415 | # 416 | # (deep breath) calculate attention and out projection 417 | # 418 | attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p) 419 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 420 | attn_output = torch.cat([linear(attn_output[:split_point], out_proj_weight1, out_proj_bias1), 421 | linear(attn_output[split_point:], out_proj_weight2, out_proj_bias2)], dim=0) 422 | 423 | if need_weights: 424 | # average attention weights over heads 425 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 426 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 427 | else: 428 | return attn_output, None 429 | -------------------------------------------------------------------------------- /models/components/multihead_activation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.modules.linear import NonDynamicallyQuantizableLinear 6 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ 7 | from torch.nn.parameter import Parameter 8 | from torch.nn.modules.module import Module 9 | 10 | from models.components.multihead import multi_head_attention_forward 11 | 12 | 13 | class MultiheadAttention_Split(Module): 14 | r"""Allows the model to jointly attend to information 15 | from different representation subspaces. 16 | See `Attention Is All You Need `_. 17 | 18 | .. math:: 19 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 20 | 21 | where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. 22 | 23 | Args: 24 | embed_dim: Total dimension of the model. 25 | num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split 26 | across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). 27 | dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). 28 | bias: If specified, adds bias to input / output projection layers. Default: ``True``. 29 | add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. 30 | add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. 31 | Default: ``False``. 32 | kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). 33 | vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). 34 | batch_first: If ``True``, then the input and output tensors are provided 35 | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 36 | 37 | Examples:: 38 | 39 | >>> multihead_attn = nn.MultiheadAttentionSplit(embed_dim, num_heads, split_point) 40 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 41 | """ 42 | __constants__ = ['batch_first'] 43 | bias_k: Optional[torch.Tensor] 44 | bias_v: Optional[torch.Tensor] 45 | 46 | def __init__(self, embed_dim, num_heads, split_point, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, 47 | kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None: 48 | factory_kwargs = {'device': device, 'dtype': dtype} 49 | super(MultiheadAttention_Split, self).__init__() 50 | self.embed_dim = embed_dim 51 | self.kdim = kdim if kdim is not None else embed_dim 52 | self.vdim = vdim if vdim is not None else embed_dim 53 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 54 | 55 | self.num_heads = num_heads 56 | self.split_point = split_point 57 | self.dropout = dropout 58 | self.batch_first = batch_first 59 | self.head_dim = embed_dim // num_heads 60 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 61 | 62 | if self._qkv_same_embed_dim is False: 63 | self.q_proj_weight1 = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) 64 | self.k_proj_weight1 = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) 65 | self.v_proj_weight1 = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) 66 | self.q_proj_weight2 = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) 67 | self.k_proj_weight2 = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) 68 | self.v_proj_weight2 = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) 69 | self.register_parameter('in_proj_weight1', None) 70 | self.register_parameter('in_proj_weight2', None) 71 | else: 72 | self.in_proj_weight1 = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) 73 | self.in_proj_weight2 = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) 74 | self.register_parameter('q_proj_weight1', None) 75 | self.register_parameter('k_proj_weight1', None) 76 | self.register_parameter('v_proj_weight1', None) 77 | self.register_parameter('q_proj_weight2', None) 78 | self.register_parameter('k_proj_weight2', None) 79 | self.register_parameter('v_proj_weight2', None) 80 | 81 | if bias: 82 | self.in_proj_bias1 = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) 83 | self.in_proj_bias2 = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) 84 | else: 85 | self.register_parameter('in_proj_bias1', None) 86 | self.register_parameter('in_proj_bias2', None) 87 | self.out_proj1 = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs) 88 | self.out_proj2 = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs) 89 | 90 | if add_bias_kv: 91 | self.bias_k1 = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 92 | self.bias_v1 = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 93 | self.bias_k2 = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 94 | self.bias_v2 = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 95 | else: 96 | self.bias_k1 = self.bias_v1 = self.bias_k2 = self.bias_v2 = None 97 | 98 | self.add_zero_attn = add_zero_attn 99 | 100 | self._reset_parameters() 101 | 102 | def _reset_parameters(self): 103 | if self._qkv_same_embed_dim: 104 | xavier_uniform_(self.in_proj_weight1) 105 | xavier_uniform_(self.in_proj_weight2) 106 | else: 107 | xavier_uniform_(self.q_proj_weight1) 108 | xavier_uniform_(self.k_proj_weight1) 109 | xavier_uniform_(self.v_proj_weight1) 110 | xavier_uniform_(self.q_proj_weight2) 111 | xavier_uniform_(self.k_proj_weight2) 112 | xavier_uniform_(self.v_proj_weight2) 113 | 114 | if self.in_proj_bias1 is not None: 115 | constant_(self.in_proj_bias1, 0.) 116 | constant_(self.out_proj1.bias, 0.) 117 | constant_(self.in_proj_bias2, 0.) 118 | constant_(self.out_proj2.bias, 0.) 119 | if self.bias_k1 is not None: 120 | xavier_normal_(self.bias_k1) 121 | xavier_normal_(self.bias_k2) 122 | if self.bias_v1 is not None: 123 | xavier_normal_(self.bias_v1) 124 | xavier_normal_(self.bias_v2) 125 | 126 | def __setstate__(self, state): 127 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 128 | if '_qkv_same_embed_dim' not in state: 129 | state['_qkv_same_embed_dim'] = True 130 | 131 | super(MultiheadAttention_Split, self).__setstate__(state) 132 | 133 | def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, 134 | need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: 135 | r""" 136 | Args: 137 | query: Query embeddings of shape :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` 138 | when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, 139 | and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against 140 | key-value pairs to produce the output. See "Attention Is All You Need" for more details. 141 | key: Key embeddings of shape :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when 142 | ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and 143 | :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. 144 | value: Value embeddings of shape :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when 145 | ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and 146 | :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. 147 | key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` 148 | to ignore for the purpose of attention (i.e. treat as "padding"). Binary and byte masks are supported. 149 | For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for 150 | the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key`` 151 | value will be ignored. 152 | need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. 153 | Default: ``True``. 154 | attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape 155 | :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, 156 | :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be 157 | broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. 158 | Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the 159 | corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the 160 | corresponding position is not allowed to attend. For a float mask, the mask values will be added to 161 | the attention weight. 162 | 163 | Outputs: 164 | - **attn_output** - Attention outputs of shape :math:`(L, N, E)` when ``batch_first=False`` or 165 | :math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is 166 | the batch size, and :math:`E` is the embedding dimension ``embed_dim``. 167 | - **attn_output_weights** - Attention output weights of shape :math:`(N, L, S)`, where :math:`N` is the batch 168 | size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. Only returned 169 | when ``need_weights=True``. 170 | """ 171 | if self.batch_first: 172 | query, key, value = [x.transpose(1, 0) for x in (query, key, value)] 173 | 174 | if not self._qkv_same_embed_dim: 175 | attn_output, attn_output_weights = multi_head_attention_forward( 176 | query, key, value, self.split_point, self.embed_dim, self.num_heads, 177 | self.in_proj_weight1, self.in_proj_bias1, 178 | self.bias_k1, self.bias_v1, self.in_proj_weight2, self.in_proj_bias2, 179 | self.bias_k2, self.bias_v2, self.add_zero_attn, 180 | self.dropout, self.out_proj1.weight, self.out_proj1.bias, 181 | self.out_proj2.weight, self.out_proj2.bias, 182 | training=self.training, 183 | key_padding_mask=key_padding_mask, need_weights=need_weights, 184 | attn_mask=attn_mask, use_separate_proj_weight=True, 185 | q_proj_weight1=self.q_proj_weight1, k_proj_weight1=self.k_proj_weight1, 186 | v_proj_weight1=self.v_proj_weight1, q_proj_weight2=self.q_proj_weight2, 187 | k_proj_weight2=self.k_proj_weight2, v_proj_weight2=self.v_proj_weight2) 188 | else: 189 | attn_output, attn_output_weights = multi_head_attention_forward( 190 | query, key, value, self.split_point, self.embed_dim, self.num_heads, 191 | self.in_proj_weight1, self.in_proj_bias1, 192 | self.bias_k1, self.bias_v1, self.in_proj_weight2, self.in_proj_bias2, 193 | self.bias_k2, self.bias_v2, self.add_zero_attn, 194 | self.dropout, self.out_proj1.weight, self.out_proj1.bias, 195 | self.out_proj2.weight, self.out_proj2.bias, 196 | training=self.training, 197 | key_padding_mask=key_padding_mask, need_weights=need_weights, 198 | attn_mask=attn_mask) 199 | if self.batch_first: 200 | return attn_output.transpose(1, 0), attn_output_weights 201 | else: 202 | return attn_output, attn_output_weights 203 | -------------------------------------------------------------------------------- /models/gcn_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from dgl.nn.pytorch import GATv2Conv 16 | 17 | 18 | logger = logging.getLogger('VM.gcn_embed') 19 | 20 | 21 | class GCN_Embedder(nn.Module): 22 | def __init__(self, params): 23 | super(GCN_Embedder, self).__init__() 24 | self.device = params.device 25 | self.batch_size = params.batch_size 26 | self.d_hidden = params.d_hidden # 8 27 | self.num_pm = params.num_pm 28 | self.pm_cov = params.pm_cov # 6 29 | self.output_dim = self.d_hidden # 8 30 | 31 | self.conv1 = GATv2Conv(params.pm_cov + 1, params.d_hidden, num_heads=2) 32 | self.conv2 = GATv2Conv(params.d_hidden * 2, self.output_dim, num_heads=1, feat_drop=0.1, attn_drop=0.1) 33 | 34 | def forward(self, g, in_feat, b_sz): 35 | h = self.conv1(g, in_feat) 36 | h = F.elu(h).reshape(-1, self.d_hidden * 2) 37 | h = self.conv2(g, h).reshape(b_sz, -1, self.output_dim) 38 | return h[:, :self.num_pm], h[:, self.num_pm:] 39 | 40 | 41 | class GCN_Wrapper(nn.Module): 42 | def __init__(self, params): 43 | super(GCN_Wrapper, self).__init__() 44 | self.model = GCN_Embedder(params).to(params.device) 45 | -------------------------------------------------------------------------------- /models/pm_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import torch 13 | import torch.nn as nn 14 | from .components.pt_transformer import Transformer 15 | 16 | 17 | logger = logging.getLogger('VM.attn') 18 | 19 | 20 | class PM_candidate_model(nn.Module): 21 | def __init__(self, params): 22 | super(PM_candidate_model, self).__init__() 23 | self.device = params.device 24 | self.batch_size = params.batch_size 25 | self.d_hidden = params.d_hidden 26 | 27 | self.pm_encode = nn.Linear(params.pm_cov, params.d_hidden) 28 | self.vm_encode = nn.Linear(params.vm_cov, params.d_hidden) 29 | 30 | self.transformer = Transformer(d_model=params.d_hidden, nhead=params.num_head, 31 | num_encoder_layers=params.transformer_blocks, 32 | num_decoder_layers=params.transformer_blocks, dim_feedforward=params.d_ff, 33 | activation='gelu', batch_first=True, dropout=params.dropout, 34 | need_attn_weights=True, device=params.device) 35 | 36 | self.output_layer = nn.Linear(params.d_hidden, 1) 37 | 38 | def forward(self, chosen_vm_state, num_step_states, pm_states, return_attns=False): 39 | # chosen_vm_state: torch.Size([8, 1, 14]) 40 | transformer_output = self.transformer(src=torch.cat([num_step_states.repeat(1, 1, self.d_hidden), 41 | self.vm_encode(chosen_vm_state)], dim=1), 42 | tgt=self.pm_encode(pm_states)) 43 | score = torch.squeeze(self.output_layer(transformer_output[0])) 44 | if return_attns: 45 | return score, transformer_output[1] 46 | else: 47 | return score 48 | 49 | 50 | class PM_Attn_Wrapper(nn.Module): 51 | def __init__(self, params): 52 | super(PM_Attn_Wrapper, self).__init__() 53 | self.model = PM_candidate_model(params).to(params.device) 54 | -------------------------------------------------------------------------------- /models/pm_attn_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import os 13 | import torch 14 | import torch.nn as nn 15 | from .components.pt_transformer import Transformer 16 | 17 | 18 | logger = logging.getLogger('VM.attn') 19 | 20 | 21 | class PM_candidate_model(nn.Module): 22 | def __init__(self, params): 23 | super(PM_candidate_model, self).__init__() 24 | self.device = params.device 25 | self.batch_size = params.batch_size 26 | # self.seq_length = params.num_pm 27 | self.d_hidden = params.d_hidden 28 | self.pm_encode = nn.Linear(params.pm_cov + self.d_hidden, params.d_hidden) 29 | self.vm_encode = nn.Linear(params.vm_cov + self.d_hidden, params.d_hidden) 30 | 31 | self.transformer = Transformer(d_model=params.d_hidden, nhead=params.num_head, 32 | num_encoder_layers=params.transformer_blocks, 33 | num_decoder_layers=params.transformer_blocks, dim_feedforward=params.d_ff, 34 | activation='gelu', batch_first=True, dropout=params.dropout, 35 | need_attn_weights=True, device=params.device) 36 | 37 | self.output_layer = nn.Linear(params.d_hidden, 1) 38 | 39 | def forward(self, chosen_vm_state, num_step_states, pm_states, return_attns=False): 40 | # chosen_vm_state: torch.Size([8, 1, 14]) 41 | transformer_output = self.transformer(src=torch.cat([num_step_states.repeat(1, 1, self.d_hidden), 42 | self.vm_encode(chosen_vm_state)], dim=1), 43 | tgt=self.pm_encode(pm_states)) 44 | score = torch.squeeze(self.output_layer(transformer_output[0])) 45 | if return_attns: 46 | return score, transformer_output[1] 47 | else: 48 | return score 49 | 50 | 51 | class PM_Attn_Graph_Wrapper(nn.Module): 52 | def __init__(self, params): 53 | super(PM_Attn_Graph_Wrapper, self).__init__() 54 | self.model = PM_candidate_model(params).to(params.device) 55 | -------------------------------------------------------------------------------- /models/pm_detail_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import torch 13 | import torch.nn as nn 14 | from .components.pt_transformer import Transformer 15 | 16 | 17 | logger = logging.getLogger('VM.attn') 18 | 19 | 20 | class PM_candidate_model(nn.Module): 21 | def __init__(self, params): 22 | super(PM_candidate_model, self).__init__() 23 | self.device = params.device 24 | self.batch_size = params.batch_size 25 | # self.seq_length = params.num_pm 26 | self.d_hidden = params.d_hidden 27 | 28 | self.pm_encode = nn.Linear(params.pm_cov + 1, params.d_hidden) 29 | self.vm_encode = nn.Linear(params.vm_cov, params.d_hidden) 30 | 31 | self.transformer = Transformer(d_model=params.d_hidden, nhead=params.num_head, 32 | num_encoder_layers=params.transformer_blocks, 33 | num_decoder_layers=params.transformer_blocks, dim_feedforward=params.d_ff, 34 | activation='gelu', batch_first=True, dropout=params.dropout, 35 | need_attn_weights=True, device=params.device) 36 | 37 | self.output_layer = nn.Linear(params.d_hidden, 1) 38 | 39 | def forward(self, chosen_vm_embed, chosen_vm_state, num_step_states, pm_states, pm_embed, pm_attn, 40 | return_attns=False): 41 | # chosen_vm_state: torch.Size([8, 1, 14]) 42 | transformer_output = self.transformer(src=torch.cat([num_step_states.repeat(1, 1, self.d_hidden), 43 | self.vm_encode(chosen_vm_state) + chosen_vm_embed], dim=1), 44 | tgt=self.pm_encode(torch.cat([pm_states, 45 | pm_attn[:, :, None]], dim=-1)) + pm_embed) 46 | score = torch.squeeze(self.output_layer(transformer_output[0])) 47 | if return_attns: 48 | return score, transformer_output[1] 49 | else: 50 | return score 51 | 52 | 53 | class PM_Detail_Attn_Wrapper(nn.Module): 54 | def __init__(self, params): 55 | super(PM_Detail_Attn_Wrapper, self).__init__() 56 | self.model = PM_candidate_model(params).to(params.device) 57 | -------------------------------------------------------------------------------- /models/pm_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | logger = logging.getLogger('VM.mlp') 17 | 18 | 19 | class PM_candidate_model(nn.Module): 20 | def __init__(self, params): 21 | super(PM_candidate_model, self).__init__() 22 | self.device = params.device 23 | self.batch_size = params.batch_size 24 | self.input_size = params.num_pm * params.pm_cov + params.vm_cov #2246 25 | self.output_size = params.num_pm 26 | 27 | self.layers = nn.Sequential( 28 | nn.Linear(self.input_size, 128), 29 | nn.Tanh(), 30 | nn.Linear(128, 128), 31 | nn.Tanh(), 32 | ) 33 | self.pm_head = nn.Linear(128, self.output_size) 34 | 35 | self.loss_fn = nn.L1Loss() 36 | 37 | def forward(self, vm_states, pm_states): 38 | b_sz = vm_states.shape[0] 39 | x = torch.cat([pm_states.reshape(b_sz, -1), vm_states.reshape(b_sz, -1)], dim=-1) 40 | hidden = self.layers(x) 41 | x = self.pm_head(hidden) 42 | return x 43 | 44 | 45 | class PM_MLP_Wrapper(nn.Module): 46 | def __init__(self, params): 47 | super(PM_MLP_Wrapper, self).__init__() 48 | self.model = PM_candidate_model(params).to(params.device) 49 | -------------------------------------------------------------------------------- /models/vm_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import os 13 | import torch 14 | import torch.nn as nn 15 | from .components.pt_transformer import Transformer 16 | 17 | 18 | logger = logging.getLogger('VM.attn') 19 | 20 | 21 | class VM_candidate_model(nn.Module): 22 | def __init__(self, params): 23 | super(VM_candidate_model, self).__init__() 24 | self.device = params.device 25 | self.batch_size = params.batch_size 26 | # self.seq_length = params.num_pm 27 | self.d_hidden = params.d_hidden 28 | 29 | self.pm_encode = nn.Linear(params.pm_cov, params.d_hidden) 30 | self.vm_encode = nn.Linear(params.vm_cov, params.d_hidden) 31 | 32 | self.transformer = Transformer(d_model=params.d_hidden, nhead=params.num_head, 33 | num_encoder_layers=params.transformer_blocks, 34 | num_decoder_layers=params.transformer_blocks, dim_feedforward=params.d_ff, 35 | activation='gelu', batch_first=True, dropout=params.dropout, 36 | need_attn_weights=True, device=params.device) 37 | 38 | self.output_layer = nn.Linear(params.d_hidden, 1) 39 | self.critic_layer = nn.Linear(params.d_hidden, 1) 40 | self.critic_token = -torch.ones(1, 1, params.d_hidden).to(self.device) 41 | 42 | def forward(self, vm_states, num_step_states, pm_states, num_vms_mask=None, return_attns=False): 43 | b_sz = vm_states.shape[0] 44 | transformer_output = self.transformer(src=torch.cat([num_step_states.repeat(1, 1, self.d_hidden), 45 | self.pm_encode(pm_states)], dim=1), 46 | tgt_key_padding_mask=torch.cat([num_vms_mask, 47 | torch.zeros(b_sz, 1, dtype=torch.bool, 48 | device=self.device)], dim=1), 49 | tgt=torch.cat([self.vm_encode(vm_states), 50 | self.critic_token.repeat(b_sz, 1, 1).detach()], dim=1)) 51 | score = torch.squeeze(self.output_layer(transformer_output[0][:, :-1])) 52 | critic_score = self.critic_layer(transformer_output[0][:, -1]) 53 | if return_attns: 54 | return score, critic_score, transformer_output[1] 55 | else: 56 | return score, critic_score 57 | 58 | 59 | class VM_Attn_Wrapper(nn.Module): 60 | def __init__(self, params, pretrain=False): 61 | super(VM_Attn_Wrapper, self).__init__() 62 | self.model = VM_candidate_model(params).to(params.device) 63 | if pretrain: 64 | model_save_path = './saved_model_weights/attn.ckpt' 65 | assert os.path.isfile(model_save_path) 66 | self.model.load_state_dict(torch.load(model_save_path)) 67 | -------------------------------------------------------------------------------- /models/vm_attn_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import os 13 | import torch 14 | import torch.nn as nn 15 | from .components.pt_transformer import Transformer 16 | 17 | 18 | logger = logging.getLogger('VM.attn') 19 | 20 | 21 | class VM_candidate_model(nn.Module): 22 | def __init__(self, params): 23 | super(VM_candidate_model, self).__init__() 24 | self.device = params.device 25 | self.batch_size = params.batch_size 26 | # self.seq_length = params.num_pm 27 | self.d_hidden = params.d_hidden 28 | self.pm_encode = nn.Linear(params.pm_cov + self.d_hidden, params.d_hidden) 29 | self.vm_encode = nn.Linear(params.vm_cov + self.d_hidden, params.d_hidden) 30 | 31 | self.transformer = Transformer(d_model=params.d_hidden, nhead=params.num_head, 32 | num_encoder_layers=params.transformer_blocks, 33 | num_decoder_layers=params.transformer_blocks, dim_feedforward=params.d_ff, 34 | activation='gelu', batch_first=True, dropout=params.dropout, 35 | need_attn_weights=True, device=params.device) 36 | 37 | self.output_layer = nn.Linear(params.d_hidden, 1) 38 | self.critic_layer = nn.Linear(params.d_hidden, 1) 39 | self.critic_token = -torch.ones(1, 1, params.d_hidden).to(self.device) 40 | 41 | def forward(self, vm_states, num_step_states, pm_states, num_vms_mask=None, return_attns=False): 42 | b_sz = vm_states.shape[0] 43 | transformer_output = self.transformer(src=torch.cat([num_step_states.repeat(1, 1, self.d_hidden), 44 | self.pm_encode(pm_states)], dim=1), 45 | tgt_key_padding_mask=torch.cat([num_vms_mask, 46 | torch.zeros(b_sz, 1, dtype=torch.bool, 47 | device=self.device)], dim=1), 48 | tgt=torch.cat([self.vm_encode(vm_states), 49 | self.critic_token.repeat(b_sz, 1, 1).detach()], dim=1)) 50 | score = torch.squeeze(self.output_layer(transformer_output[0][:, :-1])) 51 | critic_score = self.critic_layer(transformer_output[0][:, -1]) 52 | if return_attns: 53 | return score, critic_score, transformer_output[1] 54 | else: 55 | return score, critic_score 56 | 57 | 58 | class VM_Attn_Graph_Wrapper(nn.Module): 59 | def __init__(self, params, pretrain=False): 60 | super(VM_Attn_Graph_Wrapper, self).__init__() 61 | self.model = VM_candidate_model(params).to(params.device) 62 | if pretrain: 63 | model_save_path = './saved_model_weights/attn.ckpt' 64 | assert os.path.isfile(model_save_path) 65 | self.model.load_state_dict(torch.load(model_save_path)) 66 | -------------------------------------------------------------------------------- /models/vm_lite_sparse_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import os 13 | import torch 14 | import torch.nn as nn 15 | from .components.pt_transformer import TransformerSparseDecoder, TransformerSparseDecoderLayer 16 | 17 | logger = logging.getLogger('VM.attn') 18 | 19 | 20 | class VM_candidate_model(nn.Module): 21 | def __init__(self, params): 22 | super(VM_candidate_model, self).__init__() 23 | self.device = params.device 24 | self.batch_size = params.batch_size 25 | self.num_pm = params.num_pm 26 | self.num_vm = params.num_vm 27 | self.num_head = params.num_head 28 | 29 | self.d_hidden = params.d_hidden 30 | 31 | self.pm_encode = nn.Linear(params.pm_cov, params.d_hidden) 32 | self.vm_encode = nn.Linear(params.vm_cov, params.d_hidden) 33 | 34 | decoder_layer = TransformerSparseDecoderLayer(d_model=params.d_hidden, nhead=params.num_head, 35 | # split_point=self.num_pm+1, 36 | dim_feedforward=params.d_ff, dropout=params.dropout, 37 | activation='gelu', batch_first=True, norm_first=True, 38 | need_attn_weights=True, device=params.device) 39 | self.transformer = TransformerSparseDecoder(decoder_layer=decoder_layer, num_layers=params.transformer_blocks) 40 | 41 | self.output_layer = nn.Linear(params.d_hidden, 1) 42 | self.critic_layer = nn.Linear(params.d_hidden, 1) 43 | self.critic_token = -torch.ones(1, 1, params.d_hidden).to(self.device) 44 | 45 | def forward(self, vm_states, num_step_states, pm_states, vm_pm_relation, num_vms_mask=None, return_attns=False): 46 | b_sz = vm_states.shape[0] 47 | local_mask = torch.zeros(b_sz, self.num_pm + self.num_vm + 2, self.num_pm + self.num_vm + 2, 48 | dtype=torch.bool, device=self.device) 49 | local_mask[:, 1:-1, 1:-1] = vm_pm_relation != vm_pm_relation[:, None, :, 0] 50 | tgt_key_pad_mask = torch.zeros(b_sz, 2 + self.num_pm + self.num_vm, dtype=torch.bool, device=self.device) 51 | tgt_key_pad_mask[:, 1 + self.num_pm:-1] = num_vms_mask 52 | transformer_output = self.transformer(tgt=torch.cat([num_step_states.repeat(1, 1, self.d_hidden), 53 | self.pm_encode(pm_states), self.vm_encode(vm_states), 54 | self.critic_token.repeat(b_sz, 1, 1).detach()], dim=1), 55 | local_mask=torch.repeat_interleave(local_mask, self.num_head, dim=0), 56 | tgt_key_padding_mask=tgt_key_pad_mask) 57 | score = torch.squeeze(self.output_layer(transformer_output[0][:, 1 + self.num_pm:-1])) 58 | critic_score = self.critic_layer(transformer_output[0][:, -1]) 59 | if return_attns: 60 | return score, critic_score, transformer_output[1] 61 | else: 62 | return score, critic_score 63 | 64 | 65 | class VM_Lite_Sparse_Attn_Wrapper(nn.Module): 66 | def __init__(self, params, pretrain=False): 67 | super(VM_Lite_Sparse_Attn_Wrapper, self).__init__() 68 | self.model = VM_candidate_model(params).to(params.device) 69 | if pretrain: 70 | model_save_path = './saved_model_weights/attn.ckpt' 71 | assert os.path.isfile(model_save_path) 72 | self.model.load_state_dict(torch.load(model_save_path)) 73 | -------------------------------------------------------------------------------- /models/vm_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import os 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | logger = logging.getLogger('VM.mlp') 18 | 19 | 20 | class VM_candidate_model(nn.Module): 21 | def __init__(self, params): 22 | super(VM_candidate_model, self).__init__() 23 | self.device = params.device 24 | self.batch_size = params.batch_size 25 | self.input_size = params.num_pm * params.pm_cov + params.vm_cov*params.num_vm # 2246 26 | self.output_size = params.num_vm 27 | 28 | self.layers = nn.Sequential( 29 | nn.Linear(self.input_size, 128), 30 | nn.Tanh(), 31 | nn.Linear(128, 128), 32 | nn.Tanh(), 33 | ) 34 | self.vm_head = nn.Linear(128, self.output_size) 35 | self.critic = nn.Linear(128, 1) 36 | 37 | def forward(self, vm_states, pm_states): 38 | 39 | b_sz = vm_states.shape[0] 40 | x = torch.cat([pm_states.reshape(b_sz, -1), vm_states.reshape(b_sz, -1)], dim=-1) 41 | hidden = self.layers(x) 42 | return self.vm_head(hidden), self.critic(hidden) 43 | 44 | 45 | class VM_MLP_Wrapper(nn.Module): 46 | def __init__(self, params, pretrain=False): 47 | super(VM_MLP_Wrapper, self).__init__() 48 | self.model = VM_candidate_model(params).to(params.device) 49 | if pretrain: 50 | model_save_path1 = './saved_model_weights/model_network.ckpt' 51 | model_save_path2 = './saved_model_weights/model_vm_head.ckpt' 52 | assert os.path.isfile(model_save_path1) 53 | assert os.path.isfile(model_save_path2) 54 | self.model.layers.load_state_dict(torch.load(model_save_path1)) 55 | self.model.vm_head.load_state_dict(torch.load(model_save_path2)) 56 | -------------------------------------------------------------------------------- /models/vm_sparse_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | import logging 12 | import os 13 | import torch 14 | import torch.nn as nn 15 | from .components.pt_transformer import TransformerSparseDecoder, TransformerSparseDecoderLayer 16 | 17 | 18 | logger = logging.getLogger('VM.attn') 19 | 20 | 21 | class VM_candidate_model(nn.Module): 22 | def __init__(self, params): 23 | super(VM_candidate_model, self).__init__() 24 | self.device = params.device 25 | self.batch_size = params.batch_size 26 | self.num_pm = params.num_pm 27 | self.num_vm = params.num_vm 28 | self.num_head = params.num_head 29 | 30 | self.d_hidden = params.d_hidden 31 | 32 | self.pm_encode = nn.Linear(params.pm_cov, params.d_hidden) 33 | self.vm_encode = nn.Linear(params.vm_cov, params.d_hidden) 34 | 35 | decoder_layer = TransformerSparseDecoderLayer(d_model=params.d_hidden, nhead=params.num_head, 36 | dim_feedforward=params.d_ff, dropout=params.dropout, 37 | activation='gelu', batch_first=True, norm_first=True, 38 | need_attn_weights=True, device=params.device) 39 | self.transformer = TransformerSparseDecoder(decoder_layer=decoder_layer, num_layers=params.transformer_blocks) 40 | 41 | self.output_layer = nn.Linear(params.d_hidden, 1) 42 | self.critic_layer = nn.Linear(params.d_hidden, 1) 43 | self.critic_token = -torch.ones(1, 1, params.d_hidden).to(self.device) 44 | 45 | def forward(self, vm_states, num_step_states, pm_states, vm_pm_relation, num_vms_mask=None, return_attns=False): 46 | b_sz = vm_states.shape[0] 47 | local_mask = torch.zeros(b_sz, self.num_pm + self.num_vm + 2, self.num_pm + self.num_vm + 2, 48 | dtype=torch.bool, device=self.device) 49 | local_mask[:, 1:-1, 1:-1] = vm_pm_relation != vm_pm_relation[:, None, :, 0] 50 | tgt_key_pad_mask = torch.zeros(b_sz, 2 + self.num_pm + self.num_vm, dtype=torch.bool, device=self.device) 51 | tgt_key_pad_mask[:, 1 + self.num_pm:-1] = num_vms_mask 52 | transformer_output = self.transformer(tgt=torch.cat([num_step_states.repeat(1, 1, self.d_hidden), 53 | self.pm_encode(pm_states), self.vm_encode(vm_states), 54 | self.critic_token.repeat(b_sz, 1, 1).detach()], dim=1), 55 | local_mask=local_mask.repeat(self.num_head, 1, 1), 56 | tgt_key_padding_mask=tgt_key_pad_mask) 57 | pm_embed = transformer_output[0][:, 1:1 + self.num_pm] 58 | vm_embed = transformer_output[0][:, 1 + self.num_pm:-1] 59 | score = torch.squeeze(self.output_layer(vm_embed)) 60 | critic_score = self.critic_layer(transformer_output[0][:, -1]) 61 | if return_attns: 62 | return pm_embed, vm_embed, score, critic_score, transformer_output[1] 63 | else: 64 | return pm_embed, vm_embed, score, critic_score 65 | 66 | 67 | class VM_Sparse_Attn_Wrapper(nn.Module): 68 | def __init__(self, params, pretrain=False): 69 | super(VM_Sparse_Attn_Wrapper, self).__init__() 70 | self.model = VM_candidate_model(params).to(params.device) 71 | if pretrain: 72 | model_save_path = './saved_model_weights/attn.ckpt' 73 | assert os.path.isfile(model_save_path) 74 | self.model.load_state_dict(torch.load(model_save_path)) 75 | -------------------------------------------------------------------------------- /saved_model_weights/mlp.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DRL-based-VM-Rescheduling/a1232df1bce7851c991229fd7a34871f0685f5f0/saved_model_weights/mlp.ckpt -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022. ByteDance Co., Ltd. All rights reserved. 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the Apache-2.0 license. 5 | # 6 | # This program is distributed in the hope that it will be useful, but WITHOUT ANY 7 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 8 | # PARTICULAR PURPOSE. See the Apache-2.0 License for more details. 9 | 10 | 11 | from datetime import datetime 12 | import json 13 | import os 14 | import pytz 15 | import sys 16 | from tqdm import tqdm 17 | from models.components.helpers import * 18 | from scipy.ndimage import gaussian_filter1d 19 | import matplotlib 20 | import matplotlib.pyplot as plt 21 | 22 | matplotlib.use('Agg') 23 | matplotlib.rcParams['savefig.dpi'] = 300 # Uncomment for higher plot resolutions 24 | logger = logging.getLogger('TS.utils') 25 | 26 | 27 | class Params: 28 | """ 29 | Class that loads hyperparameters from a json file as a dictionary (also support nested dicts). 30 | Example: 31 | params = Params(json_path) 32 | 33 | # access key-value pairs 34 | params.learning_rate 35 | params['learning_rate'] 36 | 37 | # change the value of learning_rate in params 38 | params.learning_rate = 0.5 39 | params['learning_rate'] = 0.5 40 | 41 | # print params 42 | print(params) 43 | 44 | # combine two json files 45 | params.update(Params(json_path2)) 46 | """ 47 | 48 | def __init__(self, json_path=None): 49 | if json_path is not None and os.path.isfile(json_path): 50 | with open(json_path) as f: 51 | params = json.load(f) 52 | self.__dict__.update(params) 53 | else: 54 | self.__dict__ = {} 55 | 56 | def save(self, json_path): 57 | with open(json_path, 'w') as f: 58 | json.dump(self.__dict__, f, indent=4, ensure_ascii=False) 59 | 60 | def update(self, json_path=None, params=None): 61 | """Loads parameters from json file""" 62 | if json_path is not None: 63 | with open(json_path) as f: 64 | params = json.load(f) 65 | self.__dict__.update(params) 66 | elif params is not None: 67 | self.__dict__.update(vars(params)) 68 | else: 69 | raise Exception('One of json_path and params must be provided in Params.update()!') 70 | 71 | def __contains__(self, item): 72 | return item in self.__dict__ 73 | 74 | def __getitem__(self, key): 75 | return getattr(self, str(key)) 76 | 77 | def __setitem__(self, key, value): 78 | return setattr(self, key, value) 79 | 80 | def __str__(self): 81 | return json.dumps(self.__dict__, sort_keys=True, indent=4, ensure_ascii=False) 82 | 83 | 84 | def save_dict_to_json(d, json_path): 85 | """ 86 | Saves dict of floats in json file 87 | Args: 88 | d: (dict) of float-castable values (np.float, int, float, etc.) 89 | json_path: (string) path to json file 90 | """ 91 | with open(json_path, 'w') as f: 92 | # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) 93 | d = {k: float(v) for k, v in d.items()} 94 | json.dump(d, f, indent=4) 95 | 96 | 97 | def save_checkpoint(state, global_step, checkpoint, is_best=False): 98 | """ 99 | Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves 100 | checkpoint + 'best.pth.tar' 101 | Args: 102 | state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict 103 | global_step: (int) number of updates performed 104 | checkpoint: (string) folder where parameters are to be saved 105 | is_best: (boolean) 106 | """ 107 | if is_best: 108 | filepath = os.path.join(checkpoint, f'best.pth.tar') 109 | else: 110 | filepath = os.path.join(checkpoint, f'latest.pth.tar') 111 | if not os.path.exists(checkpoint): 112 | logger.info(f'Checkpoint Directory does not exist! Making directory {checkpoint}') 113 | os.mkdir(checkpoint) 114 | torch.save(state, filepath) 115 | logger.info(f'Checkpoint saved to {filepath}') 116 | 117 | 118 | def load_checkpoint(file_dir, restore_file, model, optimizer=None, loss=None): 119 | """ 120 | Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of 121 | optimizer assuming it is present in checkpoint. 122 | Args: 123 | checkpoint: (string) filename which needs to be loaded 124 | model: (torch.nn.Module) model for which the parameters are loaded 125 | optimizer: (torch.optim) optional: resume optimizer from checkpoint 126 | """ 127 | checkpoint = os.path.join(f'runs/{file_dir}', restore_file + '.pth.tar') 128 | if not os.path.exists(checkpoint): 129 | raise FileNotFoundError(f"File doesn't exist {checkpoint}") 130 | else: 131 | logger.info(f'Restoring parameters from {checkpoint}') 132 | if torch.cuda.is_available(): 133 | checkpoint = torch.load(checkpoint, map_location='cuda') 134 | else: 135 | checkpoint = torch.load(checkpoint, map_location='cpu') 136 | model.load_state_dict(checkpoint['state_dict']) 137 | 138 | if optimizer: 139 | optimizer.load_state_dict(checkpoint['optim_dict']) 140 | if loss: 141 | loss = np.load(os.path.join(file_dir, restore_file + '_loss.npy')) 142 | return checkpoint['global_step'], loss 143 | else: 144 | return checkpoint['global_step'] 145 | 146 | 147 | def plot_all_epoch(variable1, variable2, save_name, plot_title, location='./figures/', plot_start=0): 148 | num_samples = variable1.shape[0] 149 | if num_samples > plot_start: 150 | x = np.arange(start=plot_start, stop=num_samples) 151 | f = plt.figure() 152 | plt.title(plot_title) 153 | ax1 = plt.gca() 154 | line1, = ax1.plot(x, variable1[plot_start:num_samples]) 155 | ax2 = ax1.twinx() 156 | line2, = ax2.plot(x, variable2[plot_start:num_samples], c='r') 157 | plt.legend((line1, line2), ("Test metrics", "Validation metrics")) 158 | ax1.set_ylabel("Test") 159 | ax2.set_ylabel("Validation") 160 | f.savefig(os.path.join(location, save_name + '_summary.png')) 161 | plt.close() 162 | 163 | 164 | def name_with_datetime(): 165 | now = datetime.now(tz=pytz.utc) 166 | now = now.astimezone(pytz.timezone('US/Pacific')) 167 | return now.strftime("%Y-%m-%d_%H:%M:%S") 168 | 169 | 170 | def cum_by_axis1(input_x): 171 | cum_input = np.zeros(input_x.shape) 172 | for i in range(cum_input.shape[1]): 173 | cum_input[:, i] = np.sum(input_x[:, :(i + 1)], axis=1) 174 | return cum_input 175 | 176 | 177 | def plot_all_loss(loss_summary, save_name, plot_title, location='./figures/'): 178 | gaussian_window_size = 3 179 | loss_cum = cum_by_axis1(loss_summary) 180 | loss_cum = gaussian_filter1d(loss_cum, gaussian_window_size, axis=0) 181 | num_loss = loss_cum.shape[1] 182 | color_list = ['b', 'r', 'g', 'm', 'y'] 183 | loss_list = ['loss_1', 'loss_2', 'loss_3', 'loss_4', 'loss_5'] 184 | f = plt.figure() 185 | plt.title(plot_title) 186 | num_batches = loss_cum.shape[0] 187 | if num_batches > 10000: 188 | pack_size = num_batches // 10000 189 | x = np.arange(num_batches)[0:num_batches:pack_size] 190 | plt.fill_between(x, 0, loss_cum[0:num_batches:pack_size, 0], color='b', alpha=0.2, label='loss_1') 191 | for i in range(num_loss - 1): 192 | plt.fill_between(x, loss_cum[0:num_batches:pack_size, i], loss_cum[0:num_batches:pack_size, i + 1], 193 | color=color_list[i + 1], alpha=0.2, label=loss_list[i + 1]) 194 | else: 195 | x = np.arange(num_batches) 196 | plt.fill_between(x, 0, loss_cum[:, 0], color='b', alpha=0.2, label='loss_1') 197 | for i in range(num_loss - 1): 198 | plt.fill_between(x, loss_cum[:, i], loss_cum[:, i + 1], color=color_list[i + 1], alpha=0.2, 199 | label=loss_list[i + 1]) 200 | plt.yscale('log') 201 | plt.legend() 202 | f.savefig(os.path.join(location, save_name + '_summary.png')) 203 | plt.close() 204 | 205 | 206 | def calc_metrics_all(predict_all, gt_all, params): 207 | summary_metric = dict() 208 | summary_metric['mae'] = accuracy_MAE(predict_all, gt_all, params.missing_value) 209 | summary_metric['mse'] = accuracy_MSE(predict_all, gt_all, params.missing_value) 210 | return summary_metric 211 | 212 | 213 | # for plots 214 | def batch_metrics(sample_params, labels, missing_value=0): 215 | metric = dict() 216 | metric['p50'] = accuracy_ROU_(0.5, sample_params[:, :, 0], labels, missing_value) 217 | metric['p10'] = accuracy_ROU_(0.1, sample_params[:, :, 1], labels, missing_value) 218 | metric['p90'] = accuracy_ROU_(0.9, sample_params[:, :, 2], labels, missing_value) 219 | return metric 220 | 221 | 222 | def final_metrics_list_to_int(summary_metric): 223 | final_metric = dict() 224 | for metric, value_array in summary_metric.items(): 225 | for i in range(value_array.shape[0]): 226 | final_metric[metric + ' ' + str(i)] = value_array[i] 227 | return final_metric 228 | 229 | 230 | def model_list(): 231 | """ 232 | List all available models found under ./model. 233 | """ 234 | files = os.listdir('./model') 235 | files = [name.replace('.py', '') for name in files if name.endswith('.py')] 236 | return files 237 | --------------------------------------------------------------------------------