├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data_generator ├── jspDatagen.py └── vrpDatagen.py ├── figs ├── expression_simplification.png ├── job_scheduling.png └── vehicle_routing.png └── src ├── .gitignore ├── Halide_search.py ├── arguments.py ├── jsp_nonNN_baselines.py ├── models ├── BaseModel.py ├── HalideModel.py ├── __init__.py ├── data_utils │ ├── Dag.py │ ├── Seq.py │ ├── Tree.py │ ├── __init__.py │ ├── data_utils.py │ ├── parser.py │ └── utils.py ├── jspModel.py ├── model_utils │ ├── __init__.py │ ├── logger.py │ └── supervisor.py ├── modules │ ├── HalideInputEncoder.py │ ├── __init__.py │ ├── jspInputEncoder.py │ ├── mlp.py │ └── vrpInputEncoder.py ├── rewriter │ ├── HalideRewriter.py │ ├── __init__.py │ ├── jspRewriter.py │ └── vrpRewriter.py └── vrpModel.py ├── run_Halide.py ├── run_jsp.py └── run_vrp.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.json 3 | *.csv 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to NeuralRewriter 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. 26 | 27 | ## License 28 | By contributing to NeuralRewriter, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Rewriter 2 | 3 | This repo provides the code to replicate the experiments in the paper 4 | 5 | > Xinyun Chen, Yuandong Tian, Learning to Perform Local Rewriting for Combinatorial Optimization, in NeurIPS 2019. 6 | 7 | Paper [[arXiv](https://arxiv.org/abs/1810.00337)] 8 | 9 | ## Prerequisites 10 | 11 | [PyTorch](https://pytorch.org) 12 | 13 | ## Tasks 14 | 15 | 16 | ### Expression Simplification 17 | 18 | For expression simplification, given an initial expression (in Halide for our evaluation), the goal is to find an equivalent expression that is simplified, e.g., with a shorter length. 19 | 20 | #### Performance 21 | 22 | ![Expression Simplification](./figs/expression_simplification.png) 23 | 24 | We compare our approach (NeuRewriter) with the following baselines: 25 | 26 | * Z3-simplify [1]: the tactic implemented in Z3, which performs rule-based rewriting. 27 | * Halide-rule [2]: the Halide rule-based rewriter. 28 | * Heuristic search: beam search to find the shortest rewritten expression using the Halide rule set. 29 | * Z3-ctx-solver-simplify [1]: the tactic implemented in Z3, which invokes a solver to find the simplified equivalent expression. 30 | 31 | In the figure, ``Average expression length reduction`` is the decrease of the length defined as the number of characters in the expression, and ``Average tree size reduction`` is the number of nodes decreased from the initial expression parse tree to the rewritten one. 32 | 33 | #### Dataset 34 | 35 | We generate expressions in [Halide](https://github.com/Halide/Halide) using a [random pipeline generator](2https://github.com/halide/Halide/tree/new_autoschedule_with_new_simplifier/apps/random_pipeline). We obtain rewriting traces using the Halide rule-based rewriter [here](https://github.com/halide/Halide/blob/rl_simplifier_rules/test/correctness/rewriter.cpp). 36 | 37 | #### Usage 38 | 39 | The code includes the implementation of following approaches: 40 | 41 | * Neural Rewriter (Ours): run ``run_Halide.py``. 42 | * Search: run ``Halide_search.py``. 43 | 44 | ### Job Scheduling 45 | 46 | For job scheduling, we have a machine with ``D`` types of resources, and a queue that can hold at most ``W=10`` pending jobs. Each job arrives in an online fashion, with a fixed resource demand and the duration. The goal is to minimize the average slowdown ``(Cj - Aj) / Tj``, where ``Cj`` is the completion time of job ``j``, ``Aj`` is the arrival time, and ``Tj`` is the job duration. 47 | 48 | #### Performance 49 | 50 | ![Job Scheduling](./figs/job_scheduling.png) 51 | 52 | We compare our approach (NeuRewriter) with the following baselines: 53 | 54 | * EJF: earliest job first, schedules each job in the increasing order of their arrival time. 55 | * SJF: shortest job first, schedules the shortest job in the pending job queue. 56 | * SJFS: shortest job first search, searches over the shortest jobs to schedule, then returns the optimal one. 57 | * OR-tools [3]: a generic toolbox for combinatorial optimization. 58 | * DeepRM [4]: a reinforcement learning policy to construct the schedule from scratch. 59 | * SJF-offline: applies the shortest job first heuristic, and assumes an unbounded length of the job queue. 60 | 61 | In the figure, ``D`` denotes the number of resource types. 62 | 63 | #### Dataset 64 | 65 | The dataset generator can be found under [this folder](./data_generator/). 66 | 67 | #### Usage 68 | 69 | The code includes the implementation of following approaches: 70 | 71 | * Neural Rewriter (Ours): run ``run_jsp.py``. 72 | * Baseline algorithms: run ``jsp_nonNN_baselines.py``, set `--alg` from `[SJF, EJF, random]`. 73 | 74 | ### Vehicle Routing 75 | 76 | For vehicle routing, we have a single vehicle with limited capacity to satisfy the resource demands of a set of customer nodes. To do so, we need to construct multiple routes starting and ending at the depot, so that the resources delivered in each route do not exceed the vehicle capacity, while the total route length is minimized. 77 | 78 | #### Performance 79 | 80 | ![Vehicle Routing](./figs/vehicle_routing.png) 81 | 82 | We compare our approach (NeuRewriter) with the following baselines: 83 | 84 | * Random Sweep [5]: a classic heuristic for vehicle routing. 85 | * Random CW [6]: Clarke-Wright savings heuristic for vehicle routing. 86 | * OR-tools [3]: a generic toolbox for combinatorial optimization. 87 | * Nazari et al. [7]: a reinforcement learning policy to construct the route from scratch. 88 | * AM [8]: a reinforcement learning policy to construct the route from scratch. 89 | 90 | In the figure, ``VRP X, CAP Y`` means that the number of customer nodes is ``X``, and the vehicle capacity is ``Y``. 91 | 92 | #### Dataset 93 | 94 | The dataset generator can be found under [this folder](./data_generator/). 95 | 96 | #### Usage 97 | 98 | Run ``run_vrp.py``. 99 | 100 | ## Run experiments 101 | 102 | In the following we list some important arguments for experiments using neural network models: 103 | * `--train_dataset`, `--val_dataset`, `--test_dataset`: path to the training/validation/test datasets respectively. 104 | * `--model_dir`: path to the directory that stores the models. 105 | * `--log_name`: name of the log file. 106 | * `--load_model`: path to the pretrained model (optional). 107 | * `--eval`: adding this command will enable the evaluation mode; otherwise, the model will be trained by default. 108 | * `--num_epochs`: number of training epochs. The default value is `10`, but usually 1 epoch is enough for a decent performance. 109 | * `--eval_every_n EVAL_EVERY_N`: evaluating the model and saving checkpoints every `EVAL_EVERY_N` steps. 110 | * `--max_eval_size MAX_EVAL_SIZE`: when the value is not `None`, when performing the validation during training, the model only evaluates the first `MAX_EVAL_SIZE` samples in the validation set. Setting it to a small value if the validation process takes long. 111 | 112 | More details can be found in ``arguments.py``. 113 | 114 | ## Citation 115 | 116 | If you use the code in this repo, please cite the following paper: 117 | 118 | ``` 119 | @inproceedings{chen2019learning, 120 | title={Learning to Perform Local Rewriting for Combinatorial Optimization}, 121 | author={Chen, Xinyun and Tian, Yuandong}, 122 | booktitle={Advances in Neural Information Processing Systems}, 123 | year={2019} 124 | } 125 | ``` 126 | ## License 127 | This repo is CC-BY-NC licensed, as found in the [LICENSE file](./LICENSE). 128 | 129 | ## References 130 | 131 | [1] [Z3](https://github.com/Z3Prover/z3) 132 | 133 | [2] [Halide](https://github.com/halide/Halide) 134 | 135 | [3] [OR-tools](https://developers.google.com/optimization) 136 | 137 | [4] Mao et al. Resource Management with Deep Reinforcement Learning. ACM HotNets 2016 138 | 139 | [5] Wren and Holliday. Computer scheduling of vehicles from one or more depots 140 | to a number of delivery points. Operational Research Quarterly, 1972. 141 | 142 | [6] Clarke and Wright. Scheduling of vehicles from a central depot to a number of 143 | delivery points. Operations research, 1964. 144 | 145 | [7] Nazari et al. Reinforcement Learning for Solving the Vehicle Routing Problem. NeurIPS 2018. 146 | 147 | [8] Kool et al. Attention, Learn to Solve Routing Problems! ICLR 2019. 148 | -------------------------------------------------------------------------------- /data_generator/jspDatagen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import sys 9 | import numpy as np 10 | import argparse 11 | import json 12 | 13 | argParser = argparse.ArgumentParser() 14 | argParser.add_argument('--res_file', type=str, default='jsp_r10.json') 15 | argParser.add_argument('--res_train_file', type=str, default='jsp_r10_train.json') 16 | argParser.add_argument('--res_val_file', type=str, default='jsp_r10_val.json') 17 | argParser.add_argument('--res_test_file', type=str, default='jsp_r10_test.json') 18 | 19 | argParser.add_argument('--num_samples', type=int, default=100000) 20 | argParser.add_argument('--seed', type=int, default=None) 21 | argParser.add_argument('--num_res', type=int, default=10) 22 | argParser.add_argument('--max_resource_size', type=int, default=10) 23 | argParser.add_argument('--time_horizon', type=int, default=50) 24 | argParser.add_argument('--job_horizon', type=int, default=10) 25 | argParser.add_argument('--job_small_chance', type=float, default=0.8) 26 | argParser.add_argument('--new_job_rate', type=float, default=0.7) 27 | argParser.add_argument('--job_len_big_lower', type=int, default=10) 28 | argParser.add_argument('--job_len_big_upper', type=int, default=15) 29 | argParser.add_argument('--job_len_small_lower', type=int, default=1) 30 | argParser.add_argument('--job_len_small_upper', type=int, default=3) 31 | argParser.add_argument('--dominant_res_lower', type=int, default=5) 32 | argParser.add_argument('--dominant_res_upper', type=int, default=10) 33 | argParser.add_argument('--other_res_lower', type=int, default=1) 34 | argParser.add_argument('--other_res_upper', type=int, default=2) 35 | 36 | argParser.add_argument('--uniform_short', action='store_true') 37 | argParser.add_argument('--uniform_long', action='store_true') 38 | argParser.add_argument('--uniform_resource', action='store_true') 39 | argParser.add_argument('--dynamic_new_job_rate', action='store_true') 40 | 41 | args = argParser.parse_args() 42 | 43 | 44 | def sample_job(): 45 | if np.random.rand() < args.job_small_chance: 46 | cur_job_len = np.random.randint(args.job_len_small_lower, args.job_len_small_upper + 1) 47 | else: 48 | cur_job_len = np.random.randint(args.job_len_big_lower, args.job_len_big_upper + 1) 49 | cur_resource_size = np.zeros(args.num_res) 50 | if args.uniform_resource: 51 | if np.random.rand() < 0.5: 52 | dominant_res = [] 53 | else: 54 | dominant_res = range(args.num_res) 55 | else: 56 | dominant_res = np.random.randint(low=0, high=args.num_res, size=args.num_res // 2) 57 | for i in range(args.num_res): 58 | if i in dominant_res: 59 | cur_resource_size[i] = np.random.randint(args.dominant_res_lower, args.dominant_res_upper + 1) 60 | else: 61 | cur_resource_size[i] = np.random.randint(args.other_res_lower, args.other_res_upper + 1) 62 | return cur_job_len, cur_resource_size 63 | 64 | 65 | def main(): 66 | np.random.seed(args.seed) 67 | samples = [] 68 | for _ in range(args.num_samples): 69 | cur_sample = [] 70 | if args.uniform_short: 71 | args.job_small_chance = 1.0 72 | elif args.uniform_long: 73 | args.job_small_chance = 0.0 74 | while len(cur_sample) == 0: 75 | for i in range(args.time_horizon): 76 | if args.dynamic_new_job_rate: 77 | args.new_job_rate = np.random.rand() 78 | if np.random.rand() < args.new_job_rate: 79 | cur_job_len, cur_resource_size = sample_job() 80 | cur_sample.append({'start_time': i, 'job_len': cur_job_len, 'resource_size': list(cur_resource_size)}) 81 | samples.append(cur_sample) 82 | 83 | path = '../data/jsp/' 84 | if not os.path.exists(path): 85 | os.makedirs(path) 86 | 87 | data_size = len(samples) 88 | print(data_size) 89 | fout_res = open(path+args.res_file, 'w') 90 | json.dump(samples, fout_res) 91 | 92 | fout_train = open(path+args.res_train_file, 'w') 93 | train_data_size = int(data_size * 0.8) 94 | json.dump(samples[:train_data_size], fout_train) 95 | 96 | fout_val = open(path+args.res_val_file, 'w') 97 | val_data_size = int(data_size * 0.9) - train_data_size 98 | json.dump(samples[train_data_size: train_data_size + val_data_size], fout_val) 99 | 100 | fout_test = open(path+args.res_test_file, 'w') 101 | test_data_size = data_size - train_data_size - val_data_size 102 | json.dump(samples[train_data_size + val_data_size:], fout_test) 103 | 104 | 105 | main() 106 | -------------------------------------------------------------------------------- /data_generator/vrpDatagen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import sys 9 | import numpy as np 10 | import argparse 11 | import json 12 | 13 | argParser = argparse.ArgumentParser() 14 | argParser.add_argument('--res_file', type=str, default='vrp_20_30.json') 15 | argParser.add_argument('--res_train_file', type=str, default='vrp_20_30_train.json') 16 | argParser.add_argument('--res_val_file', type=str, default='vrp_20_30_val.json') 17 | argParser.add_argument('--res_test_file', type=str, default='vrp_20_30_test.json') 18 | argParser.add_argument('--num_samples', type=int, default=100000) 19 | argParser.add_argument('--seed', type=int, default=None) 20 | argParser.add_argument('--num_customers', type=int, default=20) 21 | argParser.add_argument('--max_demand', type=int, default=9) 22 | argParser.add_argument('--position_range', type=float, default=1.0) 23 | argParser.add_argument('--capacity', type=int, default=30, choices=[20, 30, 40, 50]) 24 | 25 | args = argParser.parse_args() 26 | 27 | 28 | def sample_pos(): 29 | return np.random.rand(), np.random.rand() 30 | 31 | 32 | def main(): 33 | np.random.seed(args.seed) 34 | samples = [] 35 | for _ in range(args.num_samples): 36 | cur_sample = {} 37 | cur_sample['customers'] = [] 38 | cur_sample['capacity'] = args.capacity 39 | dx, dy = sample_pos() 40 | cur_sample['depot'] = (dx, dy) 41 | for i in range(args.num_customers): 42 | cx, cy = sample_pos() 43 | demand = np.random.randint(1, args.max_demand + 1) 44 | cur_sample['customers'].append({'position': (cx, cy), 'demand': demand}) 45 | samples.append(cur_sample) 46 | 47 | path = '../data/vrp/' 48 | if not os.path.exists(path): 49 | os.makedirs(path) 50 | 51 | data_size = len(samples) 52 | print(data_size) 53 | fout_res = open(path+args.res_file, 'w') 54 | json.dump(samples, fout_res) 55 | 56 | fout_train = open(path+args.res_train_file, 'w') 57 | train_data_size = int(data_size * 0.8) 58 | json.dump(samples[:train_data_size], fout_train) 59 | 60 | fout_val = open(path+args.res_val_file, 'w') 61 | val_data_size = int(data_size * 0.9) - train_data_size 62 | json.dump(samples[train_data_size: train_data_size + val_data_size], fout_val) 63 | 64 | fout_test = open(path+args.res_test_file, 'w') 65 | test_data_size = data_size - train_data_size - val_data_size 66 | json.dump(samples[train_data_size + val_data_size:], fout_test) 67 | 68 | 69 | main() 70 | -------------------------------------------------------------------------------- /figs/expression_simplification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/neural-rewriter/356c468a6ed54ec2ef8a007cc3c4bbdf6ab9b96b/figs/expression_simplification.png -------------------------------------------------------------------------------- /figs/job_scheduling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/neural-rewriter/356c468a6ed54ec2ef8a007cc3c4bbdf6ab9b96b/figs/job_scheduling.png -------------------------------------------------------------------------------- /figs/vehicle_routing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/neural-rewriter/356c468a6ed54ec2ef8a007cc3c4bbdf6ab9b96b/figs/vehicle_routing.png -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /src/Halide_search.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import math 10 | import random 11 | import sys 12 | import os 13 | import json 14 | import numpy as np 15 | import time 16 | import torch 17 | 18 | import arguments as arguments 19 | import models as models 20 | import models.data_utils.data_utils as data_utils 21 | from models.rewriter import HalideRewriter 22 | 23 | argParser = arguments.get_arg_parser("Halide") 24 | args = argParser.parse_args() 25 | 26 | DataProcessor = data_utils.HalideDataProcessor() 27 | term_vocab, term_vocab_list = DataProcessor.load_term_vocab() 28 | op_vocab, op_vocab_list = DataProcessor.load_ops() 29 | args.term_vocab_size = len(term_vocab) 30 | args.op_vocab_size = len(op_vocab) 31 | rewriter = HalideRewriter(args, term_vocab, term_vocab_list, op_vocab, op_vocab_list) 32 | 33 | expr_rec = {} 34 | 35 | 36 | def get_nonterm_idxes(tm, cur_idx=None): 37 | if cur_idx is None: 38 | cur_idx = tm.root 39 | cur_tree = tm.get_tree(cur_idx) 40 | if len(cur_tree.children) == 0: 41 | return [] 42 | nonterm_idxes = [] 43 | nonterm_idxes += [cur_idx] 44 | for child in cur_tree.children: 45 | child_tree = tm.get_tree(child) 46 | if child_tree.parent != cur_idx: 47 | raise ValueError('invalid edge: ' + str(cur_idx) + ' ' + cur_tree.root + ' ' + str(cur_tree.children) + ' ' + str(child) + ' ' + str(child_tree.parent)) 48 | nonterm_idxes += get_nonterm_idxes(tm, child) 49 | return nonterm_idxes 50 | 51 | 52 | def rewrite(tm, init_expr, len_tm, num_nodes_tm, depth): 53 | expr_rec[init_expr] = 1 54 | min_len = len_tm 55 | min_num_nodes = num_nodes_tm 56 | res_tm = tm 57 | if depth >= args.max_reduce_steps: 58 | return res_tm, min_len, min_num_nodes 59 | nonterm_idxes = get_nonterm_idxes(tm) 60 | candidate_tm = [] 61 | for i in nonterm_idxes: 62 | cur_tree = tm.get_tree(i) 63 | for j in range(args.num_actions): 64 | op_list = rewriter.get_rewrite_seq(j) 65 | op = rewriter.get_rewrite_op(op_list[0]) 66 | new_tm, update_tree_idxes = op(tm, i) 67 | if len(update_tree_idxes) == 0: 68 | continue 69 | new_expr = new_tm.to_string(new_tm.root) 70 | new_len = len(new_expr) 71 | new_num_nodes = new_tm.num_valid_nodes() 72 | if (new_expr in expr_rec) or len(expr_rec) >= args.num_sample_rewrite_pos and new_len >= min_len: 73 | continue 74 | q_idx = len(candidate_tm) - 1 75 | while q_idx >= 0 and new_len < candidate_tm[q_idx][0]: 76 | q_idx -= 1 77 | candidate_tm = candidate_tm[:q_idx + 1] + [(new_len, new_num_nodes, new_expr, new_tm)] + candidate_tm[q_idx + 1: args.num_sample_rewrite_pos] 78 | 79 | for i in range(len(candidate_tm)): 80 | new_tm, new_len, new_num_nodes = rewrite(candidate_tm[i][3], candidate_tm[i][2], candidate_tm[i][0], candidate_tm[i][1], depth + 1) 81 | if new_len < min_len: 82 | res_tm = new_tm 83 | min_len = new_len 84 | min_num_nodes = new_num_nodes 85 | return res_tm, min_len, min_num_nodes 86 | 87 | 88 | def evaluate(args): 89 | print('Search:') 90 | 91 | test_data = data_utils.load_dataset(args.test_dataset, args) 92 | if args.test_min_len is not None: 93 | test_data = DataProcessor.prune_dataset(test_data, min_len=args.test_min_len) 94 | DataProcessor.calc_data_stat(test_data) 95 | data_size = len(test_data) 96 | test_data = test_data[:data_size] 97 | 98 | cum_expr_reward = 0 99 | cum_gt_reward = 0 100 | cum_tree_reward = 0 101 | 102 | for batch_idx in range(0, data_size, args.batch_size): 103 | batch_data = DataProcessor.get_batch(test_data, args.batch_size, batch_idx) 104 | for i, sample in enumerate(batch_data): 105 | gt_trace, tm = sample 106 | global expr_rec 107 | expr_rec = {} 108 | init_expr = tm.to_string(tm.root) 109 | len_tm = len(init_expr) 110 | num_nodes_tm = tm.num_trees 111 | res_tm, res_len, res_num_nodes = rewrite(tm, init_expr, len_tm, num_nodes_tm, 0) 112 | cur_expr_reward = len(init_expr) - res_len 113 | cur_tree_reward = num_nodes_tm - res_num_nodes 114 | cur_gt_reward = len(gt_trace[0]) - len(gt_trace[-1]) 115 | cum_expr_reward += cur_expr_reward 116 | cum_tree_reward += cur_tree_reward 117 | cum_gt_reward += cur_gt_reward 118 | print('sample %d cur expr reward: %.4f cur tree reward: %.4f gt reward: %.4f avg expr reward: %.4f avg tree reward: %.4f avg gt reward: %.4f' \ 119 | % (batch_idx + i, cur_expr_reward, cur_tree_reward, cur_gt_reward, cum_expr_reward * 1.0 / (batch_idx + i + 1), cum_tree_reward * 1.0 / (batch_idx + i + 1), cum_gt_reward * 1.0 / (batch_idx + i + 1))) 120 | cum_expr_reward = cum_expr_reward * 1.0 / data_size 121 | cum_tree_reward = cum_tree_reward * 1.0 / data_size 122 | cum_gt_reward = cum_gt_reward * 1.0 / data_size 123 | print('avg search expr reward: %.4f tree reward: %.4f avg gt reward: %.4f' % (cum_expr_reward, cum_tree_reward, cum_gt_reward)) 124 | 125 | 126 | if __name__ == "__main__": 127 | evaluate(args) -------------------------------------------------------------------------------- /src/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | 10 | def get_arg_parser(title): 11 | parser = argparse.ArgumentParser(description=title) 12 | parser.add_argument('--cpu', action='store_true', default=False) 13 | parser.add_argument('--eval', action='store_true') 14 | parser.add_argument('--model_dir', type=str, default='../checkpoints/model_0') 15 | parser.add_argument('--input_format', type=str, default='DAG', choices=['seq', 'DAG']) 16 | parser.add_argument('--max_eval_size', type=int, default=1000) 17 | parser.add_argument('--load_model', type=str, default=None) 18 | parser.add_argument('--resume', type=int, default=0) 19 | parser.add_argument('--processes', type=int, default=1) 20 | parser.add_argument('--train_proportion', type=float, default=1.0) 21 | 22 | parser.add_argument('--LSTM_hidden_size', type=int, default=512) 23 | parser.add_argument('--MLP_hidden_size', type=int, default=256) 24 | parser.add_argument('--param_init', type=float, default=0.1) 25 | parser.add_argument('--num_LSTM_layers', type=int, default=1) 26 | parser.add_argument('--seed', type=int, default=None) 27 | parser.add_argument('--num_sample_rewrite_pos', type=int, default=10) 28 | parser.add_argument('--num_sample_rewrite_op', type=int, default=10) 29 | parser.add_argument('--max_reduce_steps', type=int, default=50) 30 | parser.add_argument('--cont_prob', type=float, default=0.5) 31 | 32 | parser.add_argument('--keep_last_n', type=int, default=None) 33 | parser.add_argument('--eval_every_n', type=int, default=100) 34 | parser.add_argument('--log_interval', type=int, default=100) 35 | parser.add_argument('--log_name', type=str, default='model_0.csv') 36 | 37 | data_group = parser.add_argument_group('data') 38 | if title == 'Halide': 39 | data_group.add_argument('--train_dataset', type=str, default='../data/Halide/rewritten_exprs_train.json') 40 | data_group.add_argument('--val_dataset', type=str, default='../data/Halide/rewritten_exprs_val.json') 41 | data_group.add_argument('--test_dataset', type=str, default='../data/Halide/rewritten_exprs_test.json') 42 | data_group.add_argument('--term_vocab_size', type=int, default=None) 43 | data_group.add_argument('--op_vocab_size', type=int, default=None) 44 | data_group.add_argument('--num_actions', type=int, default=19) 45 | data_group.add_argument('--embedding_size', type=int, default=128) 46 | data_group.add_argument('--value_loss_coef', type=float, default=10.0) 47 | data_group.add_argument('--gamma', type=float, default=0.9) 48 | data_group.add_argument('--lr', type=float, default=1e-4) 49 | data_group.add_argument('--batch_size', type=int, default=128) 50 | data_group.add_argument('--num_MLP_layers', type=int, default=1) 51 | data_group.add_argument('--train_max_len', type=int, default=None) 52 | data_group.add_argument('--test_min_len', type=int, default=None) 53 | elif title == 'jsp': 54 | data_group.add_argument('--train_dataset', type=str, default='../data/jsp/jsp_r10_train.json') 55 | data_group.add_argument('--val_dataset', type=str, default='../data/jsp/jsp_r10_val.json') 56 | data_group.add_argument('--test_dataset', type=str, default='../data/jsp/jsp_r10_test.json') 57 | data_group.add_argument('--max_resource_size', type=int, default=10) 58 | data_group.add_argument('--job_horizon', type=int, default=10) 59 | data_group.add_argument('--num_res', type=int, default=10) 60 | data_group.add_argument('--max_time_horizon', type=int, default=1000) 61 | data_group.add_argument('--max_job_len', type=int, default=15) 62 | data_group.add_argument('--lr', type=float, default=5e-5) 63 | data_group.add_argument('--batch_size', type=int, default=64) 64 | data_group.add_argument('--num_MLP_layers', type=int, default=1) 65 | data_group.add_argument('--base_alg', type=str, default='EJF', choices=['EJF', 'SJF', 'random']) 66 | data_group.add_argument('--value_loss_coef', type=float, default=50.0) 67 | data_group.add_argument('--gamma', type=float, default=0.0) 68 | elif title == 'vrp': 69 | data_group.add_argument('--train_dataset', type=str, default='../data/vrp/vrp_20_30_train.json') 70 | data_group.add_argument('--val_dataset', type=str, default='../data/vrp/vrp_20_30_val.json') 71 | data_group.add_argument('--test_dataset', type=str, default='../data/vrp/vrp_20_30_test.json') 72 | data_group.add_argument('--lr', type=float, default=5e-5) 73 | data_group.add_argument('--value_loss_coef', type=float, default=0.01) 74 | data_group.add_argument('--gamma', type=float, default=0.9) 75 | data_group.add_argument('--batch_size', type=int, default=64) 76 | data_group.add_argument('--num_MLP_layers', type=int, default=2) 77 | data_group.add_argument('--embedding_size', type=int, default=7) 78 | data_group.add_argument('--attention_size', type=int, default=16) 79 | 80 | output_trace_group = parser.add_argument_group('output_trace_option') 81 | output_trace_group.add_argument('--output_trace_flag', type=str, default='nop', choices=['succeed', 'fail', 'complete', 'nop']) 82 | output_trace_group.add_argument('--output_trace_option', type=str, default='both', choices=['pred', 'both']) 83 | output_trace_group.add_argument('--output_trace_file', type=str, default=None) 84 | 85 | train_group = parser.add_argument_group('train') 86 | train_group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd', 'rmsprop']) 87 | train_group.add_argument('--lr_decay_steps', type=int, default=500) 88 | train_group.add_argument('--lr_decay_rate', type=float, default=0.9) 89 | train_group.add_argument('--gradient_clip', type=float, default=5.0) 90 | train_group.add_argument('--num_epochs', type=int, default=10) 91 | train_group.add_argument('--dropout_rate', type=float, default=0.0) 92 | 93 | return parser 94 | -------------------------------------------------------------------------------- /src/jsp_nonNN_baselines.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import sys 10 | import numpy as np 11 | import argparse 12 | import json 13 | import time 14 | 15 | argParser = argparse.ArgumentParser() 16 | argParser.add_argument('--input_file', type=str, default='../data/jsp/jsp_r10_test.json') 17 | argParser.add_argument('--alg', type=str, default='SJF', choices=['random', 'EJF', 'SJF', 'offline']) 18 | argParser.add_argument('--num_res', type=int, default=10) 19 | argParser.add_argument('--job_horizon', type=int, default=10) 20 | argParser.add_argument('--max_resource_size', type=int, default=10) 21 | argParser.add_argument('--max_time_horizon', type=int, default=1000) 22 | 23 | args = argParser.parse_args() 24 | 25 | 26 | def add_job(cur_job, scheduled_job, used_resources): 27 | scheduled_job.append(cur_job) 28 | cur_resources = cur_job['resource_size'] 29 | st_time = cur_job['schedule_time'] 30 | ed_time = st_time + cur_job['job_len'] 31 | for t in range(st_time, ed_time): 32 | for j in range(args.num_res): 33 | used_resources[t][j] += cur_resources[j] 34 | return scheduled_job, used_resources 35 | 36 | 37 | def calc_min_schedule_time(used_resources, cur_job): 38 | tmp_used_resources = used_resources.copy() 39 | min_schedule_time = cur_job['schedule_time'] 40 | cur_resources = cur_job['resource_size'] 41 | runnable = True 42 | cur_time_horizon = min_schedule_time 43 | while not runnable or cur_time_horizon <= min_schedule_time + cur_job['job_len'] - 1: 44 | runnable = True 45 | for j in range(args.num_res): 46 | if tmp_used_resources[cur_time_horizon][j] + cur_resources[j] > args.max_resource_size: 47 | runnable = False 48 | break 49 | if not runnable: 50 | min_schedule_time = cur_time_horizon + 1 51 | cur_time_horizon += 1 52 | 53 | return min_schedule_time 54 | 55 | 56 | def random_schedule(job_seq): 57 | used_resources = np.zeros((args.max_time_horizon, args.num_res)) 58 | scheduled_job = [] 59 | pending_job = [] 60 | for job_idx in range(len(job_seq) + 1): 61 | if job_idx < len(job_seq): 62 | cur_job = job_seq[job_idx].copy() 63 | st_time = cur_job['start_time'] 64 | job_len = cur_job['job_len'] 65 | cur_resources = cur_job['resource_size'] 66 | cur_job['schedule_time'] = st_time 67 | else: 68 | st_time = -1 69 | schedule_time = -1 70 | if job_idx == len(job_seq): 71 | pending_job_cap = 1 72 | else: 73 | pending_job_cap = args.job_horizon 74 | while len(pending_job) >= pending_job_cap: 75 | schedule_idx = np.random.choice(len(pending_job)) 76 | schedule_time = calc_min_schedule_time(used_resources, pending_job[schedule_idx]) 77 | pending_job[schedule_idx]['schedule_time'] = schedule_time 78 | scheduled_job, used_resources = add_job(pending_job[schedule_idx], scheduled_job, used_resources) 79 | pending_job = pending_job[:schedule_idx] + pending_job[schedule_idx + 1:] 80 | 81 | if job_idx == len(job_seq): 82 | break 83 | pending_job.append(cur_job) 84 | 85 | return scheduled_job 86 | 87 | 88 | def ejf(job_seq): 89 | used_resources = np.zeros((args.max_time_horizon, args.num_res)) 90 | scheduled_job = [] 91 | for job_idx in range(len(job_seq)): 92 | cur_job = job_seq[job_idx].copy() 93 | st_time = cur_job['start_time'] 94 | job_len = cur_job['job_len'] 95 | cur_resources = cur_job['resource_size'] 96 | cur_job['schedule_time'] = st_time 97 | min_schedule_time = st_time 98 | cur_job['schedule_time'] = max(cur_job['schedule_time'], min_schedule_time) 99 | cur_job['schedule_time'] = calc_min_schedule_time(used_resources, cur_job) 100 | cur_completion_time = cur_job['schedule_time'] + job_len - st_time 101 | scheduled_job, used_resources = add_job(cur_job, scheduled_job, used_resources) 102 | return scheduled_job 103 | 104 | 105 | def sjf(job_seq): 106 | used_resources = np.zeros((args.max_time_horizon, args.num_res)) 107 | scheduled_job = [] 108 | pending_job = [] 109 | for job_idx in range(len(job_seq) + 1): 110 | if job_idx < len(job_seq): 111 | cur_job = job_seq[job_idx].copy() 112 | st_time = cur_job['start_time'] 113 | job_len = cur_job['job_len'] 114 | cur_resources = cur_job['resource_size'] 115 | cur_job['schedule_time'] = st_time 116 | else: 117 | st_time = -1 118 | schedule_time = -1 119 | if job_idx == len(job_seq): 120 | pending_job_cap = 1 121 | else: 122 | pending_job_cap = args.job_horizon 123 | while len(pending_job) >= pending_job_cap: 124 | schedule_idx = -1 125 | schedule_time = -1 126 | for i in range(len(pending_job)): 127 | cur_min_schedule_time = calc_min_schedule_time(used_resources, pending_job[i]) 128 | if schedule_idx == -1 or cur_min_schedule_time < schedule_time or cur_min_schedule_time == schedule_time and pending_job[i]['job_len'] < pending_job[schedule_idx]['job_len']: 129 | schedule_idx = i 130 | schedule_time = cur_min_schedule_time 131 | pending_job[schedule_idx]['schedule_time'] = schedule_time 132 | scheduled_job, used_resources = add_job(pending_job[schedule_idx], scheduled_job, used_resources) 133 | pending_job = pending_job[:schedule_idx] + pending_job[schedule_idx + 1:] 134 | 135 | if job_idx == len(job_seq): 136 | break 137 | pending_job.append(cur_job) 138 | 139 | return scheduled_job 140 | 141 | 142 | def offline(job_seq): 143 | used_resources = np.zeros((args.max_time_horizon, args.num_res)) 144 | scheduled_job = [] 145 | pending_job = [] 146 | for job_idx in range(len(job_seq) + 1): 147 | if job_idx < len(job_seq): 148 | cur_job = job_seq[job_idx].copy() 149 | st_time = cur_job['start_time'] 150 | job_len = cur_job['job_len'] 151 | cur_resources = cur_job['resource_size'] 152 | cur_job['schedule_time'] = st_time 153 | else: 154 | st_time = -1 155 | schedule_time = -1 156 | if job_idx == len(job_seq): 157 | pending_job_cap = 1 158 | else: 159 | pending_job_cap = len(job_seq) 160 | while len(pending_job) >= pending_job_cap: 161 | schedule_idx = -1 162 | schedule_time = -1 163 | for i in range(len(pending_job)): 164 | cur_min_schedule_time = calc_min_schedule_time(used_resources, pending_job[i]) 165 | if schedule_idx == -1 or cur_min_schedule_time < schedule_time or cur_min_schedule_time == schedule_time and pending_job[i]['job_len'] < pending_job[schedule_idx]['job_len']: 166 | schedule_idx = i 167 | schedule_time = cur_min_schedule_time 168 | pending_job[schedule_idx]['schedule_time'] = schedule_time 169 | scheduled_job, used_resources = add_job(pending_job[schedule_idx], scheduled_job, used_resources) 170 | pending_job = pending_job[:schedule_idx] + pending_job[schedule_idx + 1:] 171 | 172 | if job_idx == len(job_seq): 173 | break 174 | pending_job.append(cur_job) 175 | 176 | return scheduled_job 177 | 178 | 179 | def calc_reward(res): 180 | avg_slow_down = 0.0 181 | avg_completion_time = 0.0 182 | for cur_job in res: 183 | st_time = cur_job['start_time'] 184 | job_len = cur_job['job_len'] 185 | cur_completion_time = cur_job['schedule_time'] + job_len - st_time 186 | avg_slow_down += cur_completion_time * 1.0 / job_len 187 | avg_completion_time += cur_completion_time 188 | avg_slow_down /= len(res) 189 | avg_completion_time /= len(res) 190 | return avg_slow_down, avg_completion_time 191 | 192 | 193 | if __name__ == "__main__": 194 | with open(args.input_file, 'r') as fin: 195 | samples = json.load(fin) 196 | avg_slow_down = 0.0 197 | avg_completion_time = 0.0 198 | for i, cur_sample in enumerate(samples): 199 | if args.alg == 'random': 200 | res = random_schedule(cur_sample) 201 | if args.alg == 'EJF': 202 | res = ejf(cur_sample) 203 | elif args.alg == 'SJF': 204 | res = sjf(cur_sample) 205 | elif args.alg == 'offline': 206 | res = offline(cur_sample) 207 | cur_avg_slow_down, cur_avg_completion_time = calc_reward(res) 208 | avg_slow_down += cur_avg_slow_down 209 | avg_completion_time += cur_avg_completion_time 210 | print('sample %d slow down: %.4f completion time: %.4f' % (i, cur_avg_slow_down, cur_avg_completion_time)) 211 | 212 | avg_slow_down /= len(samples) 213 | avg_completion_time /= len(samples) 214 | print('average slow down: %.4f average completion time: %.4f' % (avg_slow_down, avg_completion_time)) 215 | -------------------------------------------------------------------------------- /src/models/BaseModel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import operator 10 | import random 11 | import time 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch import cuda 17 | from torch.autograd import Variable 18 | from torch.nn.utils import clip_grad_norm 19 | import torch.nn.functional as F 20 | 21 | 22 | class BaseModel(nn.Module): 23 | """ 24 | Base neural rewriter model. The concrete architectures for different applications are derived from it. 25 | """ 26 | def __init__(self, args): 27 | super(BaseModel, self).__init__() 28 | self.processes = args.processes 29 | self.batch_size = args.batch_size 30 | self.LSTM_hidden_size = args.LSTM_hidden_size 31 | self.MLP_hidden_size = args.MLP_hidden_size 32 | self.num_MLP_layers = args.num_MLP_layers 33 | self.gradient_clip = args.gradient_clip 34 | if args.lr_decay_steps and args.resume: 35 | self.lr = args.lr * args.lr_decay_rate ** ((args.resume - 1) // args.lr_decay_steps) 36 | else: 37 | self.lr = args.lr 38 | print('Current learning rate is {}.'.format(self.lr)) 39 | self.dropout_rate = args.dropout_rate 40 | self.max_reduce_steps = args.max_reduce_steps 41 | self.num_sample_rewrite_pos = args.num_sample_rewrite_pos 42 | self.num_sample_rewrite_op = args.num_sample_rewrite_op 43 | self.value_loss_coef = args.value_loss_coef 44 | self.gamma = args.gamma 45 | self.cont_prob = args.cont_prob 46 | self.cuda_flag = args.cuda 47 | 48 | 49 | def init_weights(self, param_init): 50 | for param in self.parameters(): 51 | param.data.uniform_(-param_init, param_init) 52 | 53 | 54 | def lr_decay(self, lr_decay_rate): 55 | self.lr *= lr_decay_rate 56 | print('Current learning rate is {}.'.format(self.lr)) 57 | for param_group in self.optimizer.param_groups: 58 | param_group['lr'] = self.lr 59 | 60 | 61 | def train(self): 62 | if self.gradient_clip > 0: 63 | clip_grad_norm(self.parameters(), self.gradient_clip) 64 | self.optimizer.step() 65 | -------------------------------------------------------------------------------- /src/models/HalideModel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import operator 10 | import random 11 | import time 12 | from multiprocessing.pool import ThreadPool 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from torch import cuda 18 | from torch.autograd import Variable 19 | from torch.nn.utils import clip_grad_norm 20 | import torch.nn.functional as F 21 | from torch.distributions.categorical import Categorical 22 | 23 | from .data_utils import data_utils 24 | from .modules import HalideInputEncoder, mlp 25 | from .rewriter import HalideRewriter 26 | from .BaseModel import BaseModel 27 | 28 | eps = 1e-3 29 | log_eps = np.log(eps) 30 | 31 | 32 | class HalideModel(BaseModel): 33 | """ 34 | Model for expression simplification. 35 | """ 36 | def __init__(self, args, term_vocab, term_vocab_list, op_vocab, op_vocab_list): 37 | super(HalideModel, self).__init__(args) 38 | self.term_vocab = term_vocab 39 | self.term_vocab_list = term_vocab_list 40 | self.op_vocab = op_vocab 41 | self.op_vocab_list = op_vocab_list 42 | self.term_vocab_size = args.term_vocab_size 43 | self.op_vocab_size = args.op_vocab_size 44 | self.embedding_size = args.embedding_size 45 | self.num_actions = args.num_actions 46 | self.reward_thres = -0.05 47 | self.rewriter = HalideRewriter(args, term_vocab, term_vocab_list, op_vocab, op_vocab_list) 48 | self.input_encoder = HalideInputEncoder.TreeLSTM(args, term_vocab, term_vocab_list, op_vocab, op_vocab_list) 49 | self.policy = mlp.MLPModel(self.num_MLP_layers, self.LSTM_hidden_size * 2, self.MLP_hidden_size, self.num_actions, self.cuda_flag, self.dropout_rate) 50 | self.value_estimator = mlp.MLPModel(self.num_MLP_layers, self.LSTM_hidden_size * 2, self.MLP_hidden_size, 1, self.cuda_flag, self.dropout_rate) 51 | 52 | if args.optimizer == 'adam': 53 | self.optimizer = optim.Adam(self.parameters(), lr=self.lr) 54 | elif args.optimizer == 'sgd': 55 | self.optimizer = optim.SGD(self.parameters(), lr=self.lr) 56 | elif args.optimizer == 'rmsprop': 57 | self.optimizer = optim.RMSprop(self.parameters(), lr=self.lr) 58 | else: 59 | raise ValueError('optimizer undefined: ', args.optimizer) 60 | 61 | 62 | def rewrite(self, tm, ac_logprobs, trace_rec, expr_rec, candidate_rewrite_pos, pending_actions, eval_flag, max_search_pos, reward_thres=None): 63 | if len(candidate_rewrite_pos) == 0: 64 | return [], [], [], [], [], [] 65 | 66 | candidate_rewrite_pos.sort(reverse=True, key=operator.itemgetter(0)) 67 | if not eval_flag: 68 | sample_exp_reward_tensor = [] 69 | for idx, (cur_pred_reward, cur_pred_reward_tensor, cur_ac_prob, rewrite_pos, tensor_idx) in enumerate(candidate_rewrite_pos): 70 | sample_exp_reward_tensor.append(cur_pred_reward_tensor) 71 | sample_exp_reward_tensor = torch.cat(sample_exp_reward_tensor, 0) 72 | sample_exp_reward_tensor = torch.exp(sample_exp_reward_tensor * 10) 73 | sample_exp_reward = sample_exp_reward_tensor.data.cpu().numpy() 74 | 75 | expr = expr_rec[-1] 76 | extra_reward_rec = [] 77 | extra_action_rec = [] 78 | candidate_tree_managers = [] 79 | candidate_update_tree_idxes = [] 80 | candidate_rewrite_rec = [] 81 | candidate_expr_rec = [] 82 | candidate_pending_actions = [] 83 | 84 | if len(pending_actions) > 0: 85 | for idx, (pred_reward, cur_pred_reward_tensor, cur_ac_prob, rewrite_pos, tensor_idx) in enumerate(candidate_rewrite_pos): 86 | if len(candidate_tree_managers) > 0 and idx >= max_search_pos: 87 | break 88 | if reward_thres is not None and pred_reward < reward_thres: 89 | if eval_flag: 90 | break 91 | elif np.random.random() > self.cont_prob: 92 | continue 93 | init_expr = tm.to_string(rewrite_pos) 94 | op_idx = pending_actions[0] 95 | op_list = self.rewriter.get_rewrite_seq(op_idx) 96 | op = self.rewriter.get_rewrite_op(op_list[0]) 97 | new_tm, cur_update_tree_idxes = op(tm, rewrite_pos) 98 | if len(cur_update_tree_idxes) == 0: 99 | extra_action_rec.append((ac_logprobs[tensor_idx], op_idx)) 100 | continue 101 | cur_expr = str(new_tm) 102 | if cur_expr in candidate_expr_rec: 103 | continue 104 | candidate_expr_rec.append(cur_expr) 105 | candidate_update_tree_idxes.append(cur_update_tree_idxes) 106 | candidate_tree_managers.append(new_tm) 107 | candidate_rewrite_rec.append((ac_logprobs[tensor_idx], pred_reward, cur_pred_reward_tensor, rewrite_pos, init_expr, int(op_idx))) 108 | candidate_pending_actions.append(pending_actions[1:]) 109 | if len(candidate_tree_managers) >= max_search_pos: 110 | break 111 | if len(candidate_tree_managers) > 0: 112 | return candidate_tree_managers, candidate_update_tree_idxes, candidate_rewrite_rec, candidate_pending_actions, extra_reward_rec, extra_action_rec 113 | 114 | if not eval_flag: 115 | sample_rewrite_pos_dist = Categorical(sample_exp_reward_tensor) 116 | sample_rewrite_pos = sample_rewrite_pos_dist.sample(sample_shape=[len(candidate_rewrite_pos)]) 117 | #sample_rewrite_pos = torch.multinomial(sample_exp_reward_tensor, len(candidate_rewrite_pos)) 118 | sample_rewrite_pos = sample_rewrite_pos.data.cpu().numpy() 119 | indexes = np.unique(sample_rewrite_pos, return_index=True)[1] 120 | sample_rewrite_pos = [sample_rewrite_pos[i] for i in sorted(indexes)] 121 | sample_rewrite_pos = sample_rewrite_pos[:self.num_sample_rewrite_pos] 122 | sample_exp_reward = [sample_exp_reward[i] for i in sample_rewrite_pos] 123 | sample_rewrite_pos = [candidate_rewrite_pos[i] for i in sample_rewrite_pos] 124 | else: 125 | sample_rewrite_pos = candidate_rewrite_pos.copy() 126 | 127 | for idx, (pred_reward, cur_pred_reward_tensor, cur_ac_prob, rewrite_pos, tensor_idx) in enumerate(sample_rewrite_pos): 128 | if len(candidate_tree_managers) > 0 and idx >= max_search_pos: 129 | break 130 | if reward_thres is not None and pred_reward < reward_thres: 131 | if eval_flag: 132 | break 133 | elif np.random.random() > self.cont_prob: 134 | continue 135 | init_expr = tm.to_string(rewrite_pos) 136 | if eval_flag: 137 | _, candidate_acs = torch.sort(cur_ac_prob) 138 | candidate_acs = candidate_acs.data.cpu().numpy() 139 | candidate_acs = candidate_acs[::-1] 140 | else: 141 | candidate_acs_dist = Categorical(cur_ac_prob) 142 | candidate_acs = candidate_acs_dist.sample(sample_shape=[self.num_actions]) 143 | #candidate_acs = torch.multinomial(cur_ac_prob, self.num_actions) 144 | candidate_acs = candidate_acs.data.cpu().numpy() 145 | indexes = np.unique(candidate_acs, return_index=True)[1] 146 | candidate_acs = [candidate_acs[i] for i in sorted(indexes)] 147 | cur_active = False 148 | cur_ac_prob = cur_ac_prob.data.cpu().numpy() 149 | for i, op_idx in enumerate(candidate_acs): 150 | if (expr, init_expr, op_idx) in trace_rec: 151 | continue 152 | op_list = self.rewriter.get_rewrite_seq(op_idx) 153 | op = self.rewriter.get_rewrite_op(op_list[0]) 154 | new_tm, cur_update_tree_idxes = op(tm, rewrite_pos) 155 | if len(cur_update_tree_idxes) == 0: 156 | extra_action_rec.append((ac_logprobs[tensor_idx], op_idx)) 157 | continue 158 | cur_expr = str(new_tm) 159 | if cur_expr in candidate_expr_rec: 160 | continue 161 | candidate_expr_rec.append(cur_expr) 162 | candidate_update_tree_idxes.append(cur_update_tree_idxes) 163 | candidate_tree_managers.append(new_tm) 164 | candidate_rewrite_rec.append((ac_logprobs[tensor_idx], pred_reward, cur_pred_reward_tensor, rewrite_pos, init_expr, int(op_list[0]))) 165 | candidate_pending_actions.append(op_list[1:]) 166 | cur_active = True 167 | if len(candidate_tree_managers) >= max_search_pos: 168 | break 169 | if not cur_active: 170 | extra_reward_rec.append(cur_pred_reward_tensor) 171 | return candidate_tree_managers, candidate_update_tree_idxes, candidate_rewrite_rec, candidate_pending_actions, extra_reward_rec, extra_action_rec 172 | 173 | 174 | def batch_rewrite(self, tree_managers, ac_logprobs, trace_rec, expr_rec, candidate_rewrite_pos, pending_actions, eval_flag, max_search_pos, reward_thres): 175 | candidate_tree_managers = [] 176 | candidate_update_tree_idxes = [] 177 | candidate_rewrite_rec = [] 178 | candidate_pending_actions = [] 179 | extra_reward_rec = [] 180 | extra_action_rec = [] 181 | for i in range(len(tree_managers)): 182 | cur_candidate_tree_managers, cur_candidate_update_tree_idxes, cur_candidate_rewrite_rec, cur_candidate_pending_actions, cur_extra_reward_rec, cur_extra_action_rec = self.rewrite(tree_managers[i], ac_logprobs, trace_rec[i], expr_rec[i], candidate_rewrite_pos[i], pending_actions[i], eval_flag, max_search_pos, reward_thres) 183 | candidate_tree_managers.append(cur_candidate_tree_managers) 184 | candidate_update_tree_idxes.append(cur_candidate_update_tree_idxes) 185 | candidate_rewrite_rec.append(cur_candidate_rewrite_rec) 186 | candidate_pending_actions.append(cur_candidate_pending_actions) 187 | extra_reward_rec = extra_reward_rec + cur_extra_reward_rec 188 | extra_action_rec = extra_action_rec + cur_extra_action_rec 189 | return candidate_tree_managers, candidate_update_tree_idxes, candidate_rewrite_rec, candidate_pending_actions, extra_reward_rec, extra_action_rec 190 | 191 | 192 | def calc_dependency(self, tm, cur_idx=None): 193 | if cur_idx is None: 194 | cur_idx = tm.root 195 | tm.trees[cur_idx].depth = 0 196 | tm.trees[cur_idx].dependency_parent = cur_idx 197 | cur_tree = tm.get_tree(cur_idx) 198 | if len(cur_tree.children) == 0: 199 | return [] 200 | nonterm_idxes = [] 201 | nonterm_idxes += [cur_idx] 202 | for child in cur_tree.children: 203 | tm.trees[child].depth = cur_tree.depth + 1 204 | child_tree = tm.get_tree(child) 205 | if child_tree.parent != cur_idx: 206 | raise ValueError('invalid edge: ' + str(cur_idx) + ' ' + cur_tree.root + ' ' + str(cur_tree.children) + ' ' + str(child) + ' ' + str(child_tree.parent)) 207 | nonterm_idxes += self.calc_dependency(tm, child) 208 | tm.trees[cur_idx].dependency_parent = tm.root 209 | return nonterm_idxes 210 | 211 | 212 | def forward(self, batch_data, eval_flag=False): 213 | tree_managers = [] 214 | batch_size = len(batch_data) 215 | for trace, tm in batch_data: 216 | tree_managers.append(tm) 217 | tree_managers = self.input_encoder.calc_embedding(tree_managers, eval_flag) 218 | 219 | active = True 220 | reduce_steps = 0 221 | 222 | trace_rec = [[] for _ in range(batch_size)] 223 | rewrite_rec = [[] for _ in range(batch_size)] 224 | tm_rec = [[] for _ in range(batch_size)] 225 | expr_rec = [[] for _ in range(batch_size)] 226 | extra_reward_rec = [] 227 | extra_action_rec = [] 228 | 229 | for idx in range(batch_size): 230 | expr_rec[idx].append(str(tree_managers[idx])) 231 | trace_rec[idx].append((expr_rec[idx][-1], '', -1)) 232 | tm_rec[idx].append(tree_managers[idx]) 233 | 234 | pending_actions = [[] for _ in range(batch_size)] 235 | while active and ((self.max_reduce_steps is None) or reduce_steps < self.max_reduce_steps): 236 | active = False 237 | reduce_steps += 1 238 | nonterm_idxes = [] 239 | tree_embeddings = [] 240 | root_embeddings = [] 241 | for tm_idx in range(batch_size): 242 | tm = tree_managers[tm_idx] 243 | cur_nonterm_idxes = self.calc_dependency(tm) 244 | if len(cur_nonterm_idxes) == 0: 245 | continue 246 | for tree_idx in cur_nonterm_idxes: 247 | cur_tree = tm.get_tree(tree_idx) 248 | nonterm_idxes.append((tm_idx, tree_idx)) 249 | tree_embeddings.append(cur_tree.state[0]) 250 | root_embedding = tm.get_tree(cur_tree.dependency_parent).state[0] 251 | root_embeddings.append(root_embedding) 252 | if len(nonterm_idxes) == 0: 253 | break 254 | ac_logits = [] 255 | pred_rewards = [] 256 | for st in range(0, len(nonterm_idxes), self.batch_size): 257 | cur_tree_embeddings = tree_embeddings[st: st + self.batch_size] 258 | cur_tree_embeddings = torch.cat(cur_tree_embeddings, 0) 259 | cur_root_embeddings = root_embeddings[st: st + self.batch_size] 260 | cur_root_embeddings = torch.cat(cur_root_embeddings, 0) 261 | cur_inputs = torch.cat([cur_root_embeddings, cur_tree_embeddings], 1) 262 | cur_ac_logits = self.policy(cur_inputs) 263 | cur_pred_rewards = self.value_estimator(cur_inputs) 264 | ac_logits.append(cur_ac_logits) 265 | pred_rewards.append(cur_pred_rewards) 266 | ac_logits = torch.cat(ac_logits, 0) 267 | ac_logprobs = nn.LogSoftmax()(ac_logits) 268 | ac_probs = nn.Softmax()(ac_logits) 269 | pred_rewards = torch.cat(pred_rewards, 0) 270 | candidate_rewrite_pos = [[] for _ in range(batch_size)] 271 | for idx, (tm_idx, tree_idx) in enumerate(nonterm_idxes): 272 | candidate_rewrite_pos[tm_idx].append((pred_rewards[idx].data[0], pred_rewards[idx], ac_probs[idx], tree_idx, idx)) 273 | 274 | update_tree_idxes = [[] for _ in range(batch_size)] 275 | candidate_tree_managers, candidate_update_tree_idxes, candidate_rewrite_rec, candidate_pending_actions, cur_extra_reward_rec, cur_extra_action_rec = self.batch_rewrite(tree_managers, ac_logprobs, trace_rec, expr_rec, candidate_rewrite_pos, pending_actions, eval_flag, max_search_pos=1, reward_thres=self.reward_thres) 276 | for tm_idx in range(batch_size): 277 | cur_candidate_tree_managers = candidate_tree_managers[tm_idx] 278 | cur_candidate_update_tree_idxes = candidate_update_tree_idxes[tm_idx] 279 | cur_candidate_rewrite_rec = candidate_rewrite_rec[tm_idx] 280 | cur_candidate_pending_actions = candidate_pending_actions[tm_idx] 281 | if len(cur_candidate_tree_managers) > 0: 282 | active = True 283 | cur_tree_manager = cur_candidate_tree_managers[0] 284 | cur_update_tree_idxes = cur_candidate_update_tree_idxes[0] 285 | cur_rewrite_rec = cur_candidate_rewrite_rec[0] 286 | cur_pending_actions = cur_candidate_pending_actions[0] 287 | tree_managers[tm_idx] = cur_tree_manager 288 | update_tree_idxes[tm_idx] = cur_update_tree_idxes 289 | ac_logprob, pred_reward, cur_pred_reward_tensor, rewrite_pos, init_expr, applied_op = cur_rewrite_rec 290 | trace_rec[tm_idx][-1] = (expr_rec[tm_idx][-1], init_expr, applied_op) 291 | rewrite_rec[tm_idx].append(cur_rewrite_rec) 292 | pending_actions[tm_idx] = cur_pending_actions 293 | if cur_pending_actions[0] < 0: 294 | ac_logprob_st, pred_reward_st, cur_pred_reward_tensor_st, rewrite_pos_st, init_expr_st, applied_op_st = rewrite_rec[tm_idx][cur_pending_actions[0]] 295 | expr_st, init_expr_st, applied_op_st = trace_rec[tm_idx][cur_pending_actions[0]] 296 | rewrite_rec[tm_idx][cur_pending_actions[0]] = (ac_logprob_st, pred_reward_st, cur_pred_reward_tensor_st, rewrite_pos_st, init_expr_st, cur_pending_actions[1]) 297 | trace_rec[tm_idx][cur_pending_actions[0]] = (expr_st, init_expr_st, cur_pending_actions[1]) 298 | if cur_pending_actions[0] < -1: 299 | rewrite_rec[tm_idx] = rewrite_rec[tm_idx][:cur_pending_actions[0] + 1] 300 | trace_rec[tm_idx] = trace_rec[tm_idx][:cur_pending_actions[0] + 1] 301 | expr_rec[tm_idx] = expr_rec[tm_idx][:cur_pending_actions[0] + 1] 302 | tm_rec[tm_idx] = tm_rec[tm_idx][:cur_pending_actions[0] + 1] 303 | pending_actions[tm_idx] = [] 304 | extra_reward_rec = extra_reward_rec + cur_extra_reward_rec 305 | extra_action_rec = extra_action_rec + cur_extra_action_rec 306 | if not active: 307 | break 308 | updated_tm = self.input_encoder.update_embedding(tree_managers, update_tree_idxes, eval_flag) 309 | for i in range(batch_size): 310 | tree_managers[i] = updated_tm[i] 311 | if len(update_tree_idxes[i]) > 0: 312 | expr_rec[i].append(str(updated_tm[i])) 313 | trace_rec[i].append((expr_rec[i][-1], '', -1)) 314 | tm_rec[i].append(updated_tm[i]) 315 | 316 | total_policy_loss = data_utils.np_to_tensor(np.zeros(1), 'float', self.cuda_flag) 317 | total_value_loss = data_utils.np_to_tensor(np.zeros(1), 'float', self.cuda_flag) 318 | 319 | pred_actions_rec = [] 320 | pred_actions_logprob_rec = [] 321 | pred_value_rec = [] 322 | value_target_rec = [] 323 | pred_dependency_rec = [] 324 | dependency_target_rec = [] 325 | total_reward = 0 326 | for tm_idx, cur_trace_rec in enumerate(trace_rec): 327 | pred_trace_len = [] 328 | for i, (expr, init_expr, op_idx) in enumerate(cur_trace_rec): 329 | pred_trace_len.append(len(expr)) 330 | max_reward = 0 331 | for idx, (ac_logprob, pred_reward, cur_pred_reward_tensor, rewrite_pos, init_expr, applied_op) in enumerate(rewrite_rec[tm_idx]): 332 | cur_reward = pred_trace_len[idx] - pred_trace_len[idx + 1] - 1 333 | max_reward = max(max_reward, pred_trace_len[0] - pred_trace_len[idx + 1]) 334 | decay_coef = 1.0 335 | num_rollout_steps = len(pred_trace_len) - idx - 1 336 | for i in range(idx + 1, idx + 1 + num_rollout_steps): 337 | cur_reward = max(decay_coef * (min(pred_trace_len[idx] - pred_trace_len[i] - (i - idx), len(init_expr))), cur_reward) 338 | decay_coef *= self.gamma 339 | cur_reward = cur_reward * 1.0 / len(init_expr) 340 | cur_reward_tensor = data_utils.np_to_tensor(np.array([cur_reward], dtype=np.float32), 'float', self.cuda_flag, eval_flag) 341 | if ac_logprob.data.cpu().numpy()[0] > log_eps or cur_reward - pred_reward > 0: 342 | ac_mask = np.zeros(self.num_actions) 343 | ac_mask[applied_op] = cur_reward - pred_reward 344 | ac_mask = data_utils.np_to_tensor(ac_mask, 'float', self.cuda_flag, eval_flag) 345 | ac_mask = ac_mask.unsqueeze(0) 346 | pred_actions_rec.append(ac_mask) 347 | pred_actions_logprob_rec.append(ac_logprob.unsqueeze(0)) 348 | pred_value_rec.append(cur_pred_reward_tensor) 349 | value_target_rec.append(cur_reward_tensor) 350 | total_reward += max_reward 351 | 352 | for cur_pred_reward in extra_reward_rec: 353 | pred_value_rec.append(cur_pred_reward) 354 | value_target = data_utils.np_to_tensor(np.zeros(1), 'float', self.cuda_flag) 355 | value_target_rec.append(value_target) 356 | 357 | if len(pred_actions_rec) > 0: 358 | pred_actions_rec = torch.cat(pred_actions_rec, 0) 359 | pred_actions_logprob_rec = torch.cat(pred_actions_logprob_rec, 0) 360 | pred_value_rec = torch.cat(pred_value_rec, 0) 361 | value_target_rec = torch.cat(value_target_rec, 0) 362 | pred_value_rec = pred_value_rec.unsqueeze(1) 363 | value_target_rec = value_target_rec.unsqueeze(1) 364 | total_policy_loss = -torch.sum(pred_actions_logprob_rec * pred_actions_rec) 365 | total_value_loss = F.smooth_l1_loss(pred_value_rec, value_target_rec, size_average=False) 366 | total_policy_loss /= batch_size 367 | total_value_loss /= batch_size 368 | total_loss = total_policy_loss + total_value_loss * self.value_loss_coef 369 | total_reward = total_reward * 1.0 / batch_size 370 | return total_loss, total_reward, trace_rec, tm_rec 371 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/neural-rewriter/356c468a6ed54ec2ef8a007cc3c4bbdf6ab9b96b/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/data_utils/Dag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | from .utils import * 12 | 13 | 14 | class Node(object): 15 | """ 16 | Class to represent each node in a directed acyclic graph (DAG). It can be used for job scheduling. 17 | """ 18 | def __init__(self, state=None): 19 | self.children = [] 20 | self.parents = [] 21 | if state is None: 22 | self.state = None 23 | else: 24 | self.state = state[0].clone(), state[1].clone() 25 | 26 | 27 | def add_child(self, child): 28 | self.children += [child] 29 | 30 | 31 | def add_parent(self, parent): 32 | self.parents += [parent] 33 | 34 | 35 | def del_child(self, child): 36 | self.children.remove(child) 37 | 38 | 39 | def del_parent(self, parent): 40 | self.parents.remove(parent) 41 | 42 | 43 | class JobNode(Node): 44 | """ 45 | Class to represent each job node for job scheduling. 46 | """ 47 | def __init__(self, resource_size, start_time, job_len, schedule_time=None, embedding=None, state=None): 48 | super(JobNode, self).__init__(state) 49 | self.resource_size = resource_size 50 | self.st_time = start_time 51 | self.job_len = job_len 52 | if schedule_time is None: 53 | self.schedule_time = None 54 | self.ed_time = None 55 | self.completion_time = None 56 | self.slow_down = None 57 | else: 58 | self.update_schedule_time(schedule_time) 59 | if embedding is None: 60 | self.embedding = None 61 | else: 62 | self.embedding = embedding.copy() 63 | 64 | 65 | def update_schedule_time(self, t): 66 | self.schedule_time = t 67 | self.ed_time = self.schedule_time + self.job_len 68 | self.completion_time = self.ed_time - self.st_time 69 | if self.job_len > 0: 70 | self.slow_down = self.completion_time * 1.0 / self.job_len 71 | else: 72 | self.slow_down = 0 73 | 74 | 75 | def update_embedding(self, embedding): 76 | self.embedding = embedding 77 | 78 | 79 | class DagManager(object): 80 | """ 81 | Class to maintain the state for problems with DAG-structured data. Can be used for job scheduling. 82 | """ 83 | def __init__(self): 84 | self.nodes = [] 85 | self.num_nodes = 0 86 | self.root = 0 87 | 88 | 89 | def get_node(self, idx): 90 | return self.nodes[idx] 91 | 92 | 93 | def add_edge(self, x, y): 94 | self.nodes[x].add_child(y) 95 | self.nodes[y].add_parent(x) 96 | 97 | 98 | def del_edge(self, x, y): 99 | self.nodes[x].del_child(y) 100 | self.nodes[y].del_parent(x) 101 | 102 | 103 | def clear_states(self): 104 | for idx in range(self.num_nodes): 105 | self.nodes[idx].state = None 106 | self.nodes[idx].rev_state = None 107 | 108 | 109 | class JobScheduleManager(DagManager): 110 | """ 111 | Class to maintain the state for job scheduling problems. 112 | """ 113 | def __init__(self, num_res, max_time_horizon, max_job_len, max_resource_size): 114 | super(JobScheduleManager, self).__init__() 115 | self.num_res = num_res 116 | self.max_time_horizon = max_time_horizon 117 | self.max_job_len = max_job_len 118 | self.max_resource_size = max_resource_size 119 | self.max_schedule_time = 0 120 | self.max_ed_time = 0 121 | self.embedding_size = (self.max_job_len + 1) * self.num_res + 1 122 | self.nodes.append(JobNode(resource_size=0, start_time=0, job_len=0, schedule_time=0, embedding=[0.0 for _ in range(self.embedding_size)])) 123 | self.num_jobs = 0 124 | self.resource_map = np.zeros((self.max_time_horizon, self.num_res)) 125 | self.schedule = [[] for _ in range(self.max_time_horizon)] 126 | self.terminate = [[] for _ in range(self.max_time_horizon)] 127 | 128 | 129 | def clone(self): 130 | res = JobScheduleManager(self.num_res, self.max_time_horizon, self.max_job_len, self.max_resource_size) 131 | res.root = self.root 132 | res.nodes = [] 133 | for i, node in enumerate(self.nodes): 134 | res.nodes.append(JobNode(resource_size=node.resource_size, start_time=node.st_time, job_len=node.job_len, schedule_time=node.schedule_time, embedding=node.embedding, state=node.state)) 135 | if i != 0: 136 | res.schedule[node.schedule_time].append(i) 137 | res.terminate[node.ed_time].append(i) 138 | for child in node.children: 139 | res.nodes[i].add_child(child) 140 | for parent in node.parents: 141 | res.nodes[i].add_parent(parent) 142 | res.num_nodes = self.num_nodes 143 | res.num_jobs = self.num_jobs 144 | res.resource_map = self.resource_map.copy() 145 | res.avg_slow_down = self.avg_slow_down 146 | res.avg_completion_time = self.avg_completion_time 147 | res.max_schedule_time = self.max_schedule_time 148 | res.max_ed_time = self.max_ed_time 149 | return res 150 | 151 | 152 | def add_job(self, node_idx, cur_time): 153 | job = self.nodes[node_idx] 154 | ed_time = cur_time + job.job_len 155 | for t in range(cur_time, ed_time): 156 | self.resource_map[t] += job.resource_size 157 | if job.schedule_time == cur_time: 158 | return 159 | if job.schedule_time is not None: 160 | self.schedule[job.schedule_time].remove(node_idx) 161 | self.terminate[job.ed_time].remove(node_idx) 162 | self.schedule[cur_time].append(node_idx) 163 | self.terminate[ed_time].append(node_idx) 164 | self.nodes[node_idx].update_schedule_time(cur_time) 165 | 166 | 167 | def update_embedding(self, node_idx): 168 | job = self.nodes[node_idx] 169 | embedding = [] 170 | embedding.append(job.slow_down) 171 | embedding += [job.resource_size[i] * 1.0 / self.max_resource_size for i in range(self.num_res)] 172 | for t in range(job.schedule_time, job.ed_time): 173 | embedding += [self.resource_map[t][i] * 1.0 / self.max_resource_size for i in range(self.num_res)] 174 | if len(embedding) < self.embedding_size: 175 | embedding += [0.0 for _ in range(self.embedding_size - len(embedding))] 176 | self.nodes[node_idx].update_embedding(embedding) 177 | 178 | 179 | def update_stat(self): 180 | self.avg_completion_time = 0.0 181 | self.avg_slow_down = 0.0 182 | self.max_schedule_time = 0 183 | self.max_ed_time = 0 184 | for node in self.nodes: 185 | self.avg_completion_time += node.completion_time 186 | self.avg_slow_down += node.slow_down 187 | self.max_schedule_time = max(self.max_schedule_time, node.schedule_time) 188 | self.max_ed_time = max(self.max_ed_time, node.ed_time) 189 | self.avg_slow_down = self.avg_slow_down * 1.0 / self.num_jobs 190 | self.avg_completion_time = self.avg_completion_time * 1.0 / self.num_jobs 191 | 192 | 193 | def get_parent_idxes(self, st, job_horizon): 194 | res = [st] 195 | st_job = self.get_node(st) 196 | idx = 0 197 | scheduled_time = [] 198 | scheduled_time.append(st_job.ed_time) 199 | while len(res) < job_horizon + 1 and idx < len(res): 200 | cur_job = self.get_node(res[idx]) 201 | for parent in cur_job.parents: 202 | if not (parent in res): 203 | res.append(parent) 204 | idx += 1 205 | return res[1:job_horizon + 1] 206 | 207 | 208 | def get_children_idxes(self, st, job_horizon): 209 | res = [st] 210 | st_job = self.get_node(st) 211 | idx = 0 212 | while len(res) < job_horizon + 1 and idx < len(res): 213 | cur_job = self.get_node(res[idx]) 214 | for child in cur_job.children: 215 | if not (child in res): 216 | res.append(child) 217 | idx += 1 218 | return res[1:job_horizon + 1] 219 | 220 | 221 | def runnable(self, cur_job, schedule_time): 222 | for t in range(cur_job.job_len): 223 | for j in range(self.num_res): 224 | if self.resource_map[t + schedule_time][j] + cur_job.resource_size[j] > self.max_resource_size: 225 | return False 226 | return True 227 | 228 | 229 | def calc_min_schedule_time(self, min_schedule_time, cur_job_idx): 230 | cur_job = self.get_node(cur_job_idx) 231 | cur_time_horizon = min_schedule_time 232 | if self.runnable(cur_job, min_schedule_time): 233 | return min_schedule_time 234 | new_schedule_time = self.max_time_horizon 235 | for i, node in enumerate(self.nodes): 236 | if i == cur_job_idx: 237 | continue 238 | if node.ed_time is None: 239 | continue 240 | if node.ed_time <= min_schedule_time or node.ed_time >= new_schedule_time: 241 | continue 242 | if self.runnable(cur_job, node.ed_time): 243 | new_schedule_time = node.ed_time 244 | return new_schedule_time -------------------------------------------------------------------------------- /src/models/data_utils/Seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | from .utils import * 12 | 13 | 14 | class VrpNode(object): 15 | """ 16 | Class to represent each node for vehicle routing. 17 | """ 18 | def __init__(self, x, y, demand, px, py, capacity, dis, embedding=None): 19 | self.x = x 20 | self.y = y 21 | self.demand = demand 22 | self.px = px 23 | self.py = py 24 | self.capacity = capacity 25 | self.dis = dis 26 | if embedding is None: 27 | self.embedding = None 28 | else: 29 | self.embedding = embedding.copy() 30 | 31 | class SeqManager(object): 32 | """ 33 | Base class for sequential input data. Can be used for vehicle routing. 34 | """ 35 | def __init__(self): 36 | self.nodes = [] 37 | self.num_nodes = 0 38 | 39 | 40 | def get_node(self, idx): 41 | return self.nodes[idx] 42 | 43 | 44 | class VrpManager(SeqManager): 45 | """ 46 | The class to maintain the state for vehicle routing. 47 | """ 48 | def __init__(self, capacity): 49 | super(VrpManager, self).__init__() 50 | self.capacity = capacity 51 | self.route = [] 52 | self.vehicle_state = [] 53 | self.tot_dis = [] 54 | self.encoder_outputs = None 55 | 56 | 57 | def clone(self): 58 | res = VrpManager(self.capacity) 59 | res.nodes = [] 60 | for i, node in enumerate(self.nodes): 61 | res.nodes.append(VrpNode(x=node.x, y=node.y, demand=node.demand, px=node.px, py=node.py, capacity=node.capacity, dis=node.dis, embedding=node.embedding)) 62 | res.num_nodes = self.num_nodes 63 | res.route = self.route[:] 64 | res.vehicle_state = self.vehicle_state[:] 65 | res.tot_dis = self.tot_dis[:] 66 | res.encoder_outputs = self.encoder_outputs.clone() 67 | return res 68 | 69 | 70 | def get_dis(self, node_1, node_2): 71 | return np.sqrt((node_1.x - node_2.x) ** 2 + (node_1.y - node_2.y) ** 2) 72 | 73 | 74 | def get_neighbor_idxes(self, route_idx): 75 | neighbor_idxes = [] 76 | route_node_idx = self.vehicle_state[route_idx][0] 77 | pre_node_idx, pre_capacity = self.vehicle_state[route_idx - 1] 78 | for i in range(1, len(self.vehicle_state) - 1): 79 | cur_node_idx = self.vehicle_state[i][0] 80 | if route_node_idx == cur_node_idx: 81 | continue 82 | if pre_node_idx == 0 and cur_node_idx == 0: 83 | continue 84 | cur_node = self.get_node(cur_node_idx) 85 | if route_node_idx == 0 and i > route_idx and cur_node.demand > pre_capacity: 86 | continue 87 | neighbor_idxes.append(i) 88 | return neighbor_idxes 89 | 90 | 91 | def add_route_node(self, node_idx): 92 | node = self.get_node(node_idx) 93 | if len(self.vehicle_state) == 0: 94 | pre_node_idx = 0 95 | pre_capacity = self.capacity 96 | else: 97 | pre_node_idx, pre_capacity = self.vehicle_state[-1] 98 | pre_node = self.get_node(pre_node_idx) 99 | if node_idx > 0: 100 | self.vehicle_state.append((node_idx, pre_capacity - self.nodes[node_idx].demand)) 101 | else: 102 | self.vehicle_state.append((node_idx, self.capacity)) 103 | cur_dis = self.get_dis(node, pre_node) 104 | if len(self.tot_dis) == 0: 105 | self.tot_dis.append(cur_dis) 106 | else: 107 | self.tot_dis.append(self.tot_dis[-1] + cur_dis) 108 | new_node = VrpNode(x=node.x, y=node.y, demand=node.demand, px=pre_node.x, py=pre_node.y, capacity=pre_capacity, dis=cur_dis) 109 | if new_node.capacity == 0: 110 | new_node.embedding = [new_node.x, new_node.y, new_node.demand * 1.0 / self.capacity, new_node.px, new_node.py, 0.0, new_node.dis] 111 | else: 112 | new_node.embedding = [new_node.x, new_node.y, new_node.demand * 1.0 / self.capacity, new_node.px, new_node.py, new_node.demand * 1.0 / new_node.capacity, new_node.dis] 113 | self.nodes[node_idx] = new_node 114 | self.route.append(new_node.embedding[:]) 115 | -------------------------------------------------------------------------------- /src/models/data_utils/Tree.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from .utils import * 12 | 13 | eps = 1e-6 14 | 15 | 16 | class Tree(object): 17 | """ 18 | The class to represent a single tree. It can be used for expression simplification. 19 | """ 20 | def __init__(self, root, parent, depth, state=None): 21 | self.root = root 22 | self.is_const = False 23 | self.update_tok_type() 24 | self.children = [] 25 | self.parent = parent 26 | self.depth = depth 27 | if state is None: 28 | self.state = None 29 | else: 30 | self.state = state[0].clone(), state[1].clone() 31 | 32 | 33 | def add_child(self, child, child_idx=-1): 34 | while len(self.children) <= child_idx: 35 | self.children += [child] 36 | if child_idx == -1: 37 | self.children += [child] 38 | else: 39 | self.children[child_idx] = child 40 | 41 | 42 | def update_tok_type(self): 43 | if is_var(self.root): 44 | self.tok_type = 'var' 45 | elif is_int(self.root): 46 | self.tok_type = 'const' 47 | self.is_const = True 48 | elif self.root == '!': 49 | self.tok_type = 'unary_op' 50 | elif self.root in ['min', 'max', 'select']: 51 | self.tok_type = 'select' 52 | else: 53 | self.tok_type = 'binary_op' 54 | 55 | 56 | def num_tokens(self): 57 | res = 1 58 | for c in self.children: 59 | res += c.num_tokens() 60 | return res 61 | 62 | 63 | class TreeManager(object): 64 | """ 65 | The class to maintain tree-structured input data. It is used for expression simplification. 66 | """ 67 | def __init__(self): 68 | self.clear() 69 | 70 | 71 | def clear(self): 72 | self.trees = [] 73 | self.num_trees = 0 74 | self.root = 0 75 | 76 | 77 | def create_tree(self, root, parent, depth, state=None): 78 | self.trees.append(Tree(root, parent, depth, state)) 79 | self.num_trees += 1 80 | return self.num_trees - 1 81 | 82 | 83 | def get_tree(self, idx): 84 | return self.trees[idx] 85 | 86 | 87 | def num_valid_nodes(self, cur_idx=None): 88 | if cur_idx is None: 89 | cur_idx = self.root 90 | tot = 1 91 | cur_tree = self.get_tree(cur_idx) 92 | if len(cur_tree.children) == 0: 93 | return tot 94 | for child in cur_tree.children: 95 | tot += self.num_valid_nodes(child) 96 | return tot 97 | 98 | 99 | def clone(self): 100 | res = TreeManager() 101 | res.root = self.root 102 | res.num_trees = 0 103 | for i, tree in enumerate(self.trees): 104 | res.create_tree(tree.root, tree.parent, tree.depth, tree.state) 105 | for child in tree.children: 106 | res.trees[i].children.append(child) 107 | return res 108 | 109 | 110 | def clone_tree(self, tm, ref_tree_idx=-1): 111 | if ref_tree_idx == -1: 112 | ref_tree_idx = tm.root 113 | ref_tree = tm.get_tree(ref_tree_idx) 114 | root = self.create_tree(ref_tree.root, -1, 0, ref_tree.state) 115 | for child in ref_tree.children: 116 | c = self.clone_tree(tm, child) 117 | self.update_edge(root, c) 118 | return root 119 | 120 | 121 | def equal_tree(self, t1, t2): 122 | if t1.root != t2.root: 123 | return False 124 | if t1.state is not None and t2.state is not None: 125 | dis_vec = torch.abs(t1.state[0] - t2.state[0]) + torch.abs(t1.state[1] - t2.state[1]) 126 | dis, _ = dis_vec.max(1) 127 | if dis.data[0] > eps: 128 | return False 129 | for child_idx in range(len(t1.children)): 130 | t1_c = self.get_tree(t1.children[child_idx]) 131 | t2_c = self.get_tree(t2.children[child_idx]) 132 | if not self.equal_tree(t1_c, t2_c): 133 | return False 134 | return True 135 | 136 | 137 | def update_edge(self, parent, child, child_idx=-1): 138 | if parent == -1: 139 | self.root = child 140 | else: 141 | self.trees[parent].add_child(child, child_idx) 142 | self.trees[child].parent = parent 143 | 144 | 145 | def find_child_idx(self, parent, child): 146 | if parent == -1: 147 | return -1 148 | parent_tree = self.get_tree(parent) 149 | for child_idx in range(len(parent_tree.children)): 150 | if parent_tree.children[child_idx] == child: 151 | return child_idx 152 | raise ValueError('invalid edge: ' + str(parent) + ' ' + parent_tree.root + ' ' + str(child) + ' ' + str(parent_tree.children)) 153 | return -1 154 | 155 | 156 | def find_subtree(self, tree_idx, subtree_idx, op): 157 | cur_tree = self.get_tree(tree_idx) 158 | subtree = self.get_tree(subtree_idx) 159 | if self.equal_tree(cur_tree, subtree): 160 | return tree_idx 161 | 162 | if op == '-': 163 | if cur_tree.root == '-': 164 | lchild = self.get_tree(cur_tree.children[0]) 165 | if not self.equal_tree(lchild, subtree): 166 | res = self.find_subtree(cur_tree.children[0], subtree_idx, op) 167 | if res != -1: 168 | return res 169 | res = self.find_subtree(cur_tree.children[1], subtree_idx, '+') 170 | if res != -1: 171 | return res 172 | if cur_tree.root == '+': 173 | for child in cur_tree.children: 174 | child_tree = self.get_tree(child) 175 | if not self.equal_tree(child_tree, subtree): 176 | res = self.find_subtree(child, subtree_idx, op) 177 | if res != -1: 178 | return res 179 | return -1 180 | 181 | if op == '+' and cur_tree.root == '-': 182 | res = self.find_subtree(cur_tree.children[0], subtree_idx, op) 183 | if res != -1: 184 | return res 185 | rchild = self.get_tree(cur_tree.children[1]) 186 | if not self.equal_tree(rchild, subtree): 187 | res = self.find_subtree(cur_tree.children[1], subtree_idx, '-') 188 | if res != -1: 189 | return res 190 | 191 | if cur_tree.root != op: 192 | return -1 193 | for child in cur_tree.children: 194 | res = self.find_subtree(child, subtree_idx, op) 195 | if res != -1: 196 | return res 197 | return -1 198 | 199 | 200 | def find_alg_const(self, tree_idx, op): 201 | cur_tree = self.get_tree(tree_idx) 202 | if cur_tree.is_const: 203 | return tree_idx 204 | if op == '+': 205 | if cur_tree.root == '+': 206 | for child in cur_tree.children: 207 | cur_idx = self.find_alg_const(child, op) 208 | if cur_idx is not None: 209 | return cur_idx 210 | elif cur_tree.root == '-': 211 | if self.get_tree(cur_tree.children[1]).is_const: 212 | return cur_tree.children[1] 213 | cur_idx = self.find_alg_const(cur_tree.children[0], op) 214 | return cur_idx 215 | if op == '*': 216 | if cur_tree.root == '*': 217 | for child in cur_tree.children: 218 | cur_idx = self.find_alg_const(child, op) 219 | if cur_idx is not None: 220 | return cur_idx 221 | return None 222 | 223 | 224 | def find_minmax_const(self, tree_idx, op): 225 | cur_tree = self.get_tree(tree_idx) 226 | if cur_tree.is_const: 227 | return tree_idx 228 | if cur_tree.root != op: 229 | return None 230 | res_idx = None 231 | res = None 232 | for child in cur_tree.children: 233 | cur_idx = self.find_minmax_const(child, op) 234 | if cur_idx is not None: 235 | if res_idx is None: 236 | res_idx = cur_idx 237 | res_tree = self.get_tree(res_idx) 238 | res = int(res_tree.root) 239 | else: 240 | cur_tree = self.get_tree(cur_idx) 241 | t = int(cur_tree.root) 242 | if op == 'max' and t > res or op == 'min' and t < res: 243 | res = t 244 | res_idx = cur_idx 245 | return res_idx 246 | 247 | 248 | def find_muldiv_term(self, tree_idx): 249 | cur_tree = self.get_tree(tree_idx) 250 | if not cur_tree.root in ['+', '-', '*']: 251 | return [] 252 | ltree_idx = cur_tree.children[0] 253 | ltree = self.get_tree(ltree_idx) 254 | rtree_idx = cur_tree.children[1] 255 | rtree = self.get_tree(rtree_idx) 256 | if cur_tree.root == '*' and rtree.is_const and int(rtree.root) > 0 and ltree.root == '/': 257 | ltree_ltree_idx = ltree.children[0] 258 | ltree_ltree = self.get_tree(ltree_ltree_idx) 259 | ltree_rtree_idx = ltree.children[1] 260 | ltree_rtree = self.get_tree(ltree_rtree_idx) 261 | if ltree_rtree.root == rtree.root: 262 | return [tree_idx] 263 | res = [] 264 | if cur_tree.root == '+': 265 | for child in cur_tree.children: 266 | cur_res = self.find_muldiv_term(child) 267 | res = res + cur_res 268 | elif cur_tree.root == '-': 269 | res = self.find_muldiv_term(cur_tree.children[0]) 270 | return res 271 | 272 | 273 | def is_times(self, tree_idx, c): 274 | cur_tree = self.get_tree(tree_idx) 275 | if cur_tree.is_const and int(cur_tree.root) % c == 0: 276 | return True 277 | if cur_tree.root != '*': 278 | return False 279 | ltree_idx = cur_tree.children[0] 280 | ltree = self.get_tree(ltree_idx) 281 | rtree_idx = cur_tree.children[1] 282 | rtree = self.get_tree(rtree_idx) 283 | if rtree.is_const and int(rtree.root) % c == 0: 284 | return True 285 | return False 286 | 287 | 288 | def find_times_term(self, tree_idx, c): 289 | if self.is_times(tree_idx, c): 290 | return tree_idx 291 | cur_tree = self.get_tree(tree_idx) 292 | if not cur_tree.root in ['+', '-']: 293 | return -1 294 | ltree_idx = cur_tree.children[0] 295 | ltree = self.get_tree(ltree_idx) 296 | rtree_idx = cur_tree.children[1] 297 | rtree = self.get_tree(rtree_idx) 298 | res = self.find_times_term(ltree_idx, c) 299 | if res != -1: 300 | return res 301 | res = self.find_times_term(rtree_idx, c) 302 | return res 303 | 304 | 305 | def to_string(self, tree_idx, tok_map=None, log_out=None): 306 | self.trees[tree_idx].update_tok_type() 307 | r = '' 308 | if self.trees[tree_idx].tok_type == 'var' or self.trees[tree_idx].tok_type == 'const': 309 | if tok_map and (self.trees[tree_idx].root in tok_map): 310 | r += tok_map(self.trees[tree_idx].root) 311 | else: 312 | r += str(self.trees[tree_idx].root) 313 | return r 314 | if self.trees[tree_idx].tok_type == 'binary_op': 315 | r += '(' 316 | r += self.to_string(self.trees[tree_idx].children[0], tok_map, log_out) 317 | if not self.trees[tree_idx].root in ['*', '/']: 318 | r += ' ' 319 | if tok_map and (self.trees[tree_idx].root in tok_map): 320 | r += tok_map(self.trees[tree_idx].root) 321 | else: 322 | r += str(self.trees[tree_idx].root) 323 | if not self.trees[tree_idx].root in ['*', '/']: 324 | r += ' ' 325 | r += self.to_string(self.trees[tree_idx].children[1], tok_map, log_out) 326 | r += ')' 327 | return r 328 | if self.trees[tree_idx].tok_type == 'unary_op': 329 | r = '' 330 | if tok_map and (self.trees[tree_idx].root in tok_map): 331 | r += tok_map(self.trees[tree_idx].root) 332 | else: 333 | r += str(self.trees[tree_idx].root) 334 | r += '(' 335 | r += self.to_string(self.trees[tree_idx].children[0], tok_map, log_out) 336 | r += ')' 337 | return r 338 | if self.trees[tree_idx].tok_type == 'select': 339 | if tok_map and (self.trees[tree_idx].root in tok_map): 340 | r += tok_map(self.trees[tree_idx].root) 341 | else: 342 | r += str(self.trees[tree_idx].root) 343 | r += '(' 344 | r += self.to_string(self.trees[tree_idx].children[0], tok_map, log_out) 345 | for c in self.trees[tree_idx].children[1:]: 346 | r += ', ' + self.to_string(c, tok_map, log_out) 347 | r += ')' 348 | return r 349 | 350 | 351 | def __str__(self): 352 | return self.to_string(self.root) -------------------------------------------------------------------------------- /src/models/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/neural-rewriter/356c468a6ed54ec2ef8a007cc3c4bbdf6ab9b96b/src/models/data_utils/__init__.py -------------------------------------------------------------------------------- /src/models/data_utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import collections 10 | import json 11 | import os 12 | import random 13 | import sys 14 | import time 15 | import six 16 | import numpy as np 17 | import copy 18 | import pickle 19 | 20 | import torch 21 | from torch.autograd import Variable 22 | from torch.utils.data import Dataset, DataLoader 23 | 24 | from .parser import * 25 | 26 | _PAD = b"_PAD" 27 | 28 | PAD_ID = 0 29 | START_VOCAB_SIZE = 1 30 | max_token_len = 5 31 | 32 | def np_to_tensor(inp, output_type, cuda_flag, volatile_flag=False): 33 | if output_type == 'float': 34 | inp_tensor = Variable(torch.FloatTensor(inp), volatile=volatile_flag) 35 | elif output_type == 'int': 36 | inp_tensor = Variable(torch.LongTensor(inp), volatile=volatile_flag) 37 | else: 38 | print('undefined tensor type') 39 | if cuda_flag: 40 | inp_tensor = inp_tensor.cuda() 41 | return inp_tensor 42 | 43 | def load_dataset(filename, args): 44 | with open(filename, 'r') as f: 45 | samples = json.load(f) 46 | print('Number of data samples in ' + filename + ': ', len(samples)) 47 | return samples 48 | 49 | class HalideDataProcessor(object): 50 | def __init__(self): 51 | self.parser = HalideParser() 52 | self.tokenizer = self.parser.tokenizer 53 | def load_term_vocab(self): 54 | vocab = {} 55 | vocab_list = [] 56 | vocab[_PAD] = PAD_ID 57 | vocab_list.append(_PAD) 58 | for i in range(10): 59 | vocab[str(i)] = len(vocab) 60 | vocab_list.append(str(i)) 61 | vocab['v'] = len(vocab) 62 | vocab_list.append('v') 63 | vocab['-'] = len(vocab) 64 | vocab_list.append('-') 65 | return vocab, vocab_list 66 | 67 | def load_ops(self): 68 | ops_list = self.tokenizer.ops + self.tokenizer.keywords 69 | ops = {} 70 | for op in ops_list: 71 | ops[op] = len(ops) 72 | return ops, ops_list 73 | 74 | def token_to_ids(self, token, vocab): 75 | token_ids = [vocab.get(c) for c in token] 76 | token_ids = [PAD_ID for _ in range(max_token_len - len(token_ids))] + token_ids 77 | return token_ids 78 | 79 | def prune_dataset(self, init_data, min_len=None, max_len=None): 80 | data = [] 81 | for trace in init_data: 82 | expr_len = len(trace[0]) 83 | if min_len is not None and expr_len < min_len: 84 | continue 85 | if max_len is not None and expr_len > max_len: 86 | continue 87 | data.append(trace) 88 | return data 89 | 90 | def get_batch(self, data, batch_size, start_idx=None): 91 | data_size = len(data) 92 | if start_idx is not None: 93 | batch_idxes = [i for i in range(start_idx, min(data_size, start_idx + batch_size))] 94 | else: 95 | batch_idxes = np.random.choice(len(data), batch_size) 96 | batch_data = [] 97 | for idx in batch_idxes: 98 | trace = data[idx] 99 | tm = self.parser.parse(trace[0]) 100 | batch_data.append((trace, tm)) 101 | return batch_data 102 | 103 | def print_gt_trace(self, trace): 104 | print('ground truth: ') 105 | for trace_step in trace: 106 | print(trace_step) 107 | print('') 108 | 109 | def print_pred_trace(self, trace_rec): 110 | print('prediction: ') 111 | for trace_step in trace_rec: 112 | print(trace_step) 113 | print('') 114 | 115 | class jspDataProcessor(object): 116 | def __init__(self, args): 117 | self.parser = jspDependencyParser(args) 118 | 119 | def get_batch(self, data, batch_size, start_idx=None): 120 | data_size = len(data) 121 | if start_idx is not None: 122 | batch_idxes = [i for i in range(start_idx, min(data_size, start_idx + batch_size))] 123 | else: 124 | batch_idxes = np.random.choice(len(data), batch_size) 125 | batch_data = [] 126 | for idx in batch_idxes: 127 | job_seq = data[idx] 128 | dm = self.parser.parse(job_seq) 129 | batch_data.append(dm) 130 | return batch_data 131 | 132 | 133 | class vrpDataProcessor(object): 134 | def __init__(self): 135 | self.parser = vrpParser() 136 | 137 | def get_batch(self, data, batch_size, start_idx=None): 138 | data_size = len(data) 139 | if start_idx is not None: 140 | batch_idxes = [i for i in range(start_idx, min(data_size, start_idx + batch_size))] 141 | else: 142 | batch_idxes = np.random.choice(len(data), batch_size) 143 | batch_data = [] 144 | for idx in batch_idxes: 145 | problem = data[idx] 146 | dm = self.parser.parse(problem) 147 | batch_data.append(dm) 148 | return batch_data -------------------------------------------------------------------------------- /src/models/data_utils/parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import os 10 | import sys 11 | import argparse 12 | import pyparsing as pyp 13 | from .Seq import * 14 | from .Tree import * 15 | from .Dag import * 16 | from .utils import * 17 | 18 | ''' 19 | Halide Grammar: 20 | ::= v0 | v1 | v2 | v3 | v4 | v5 | v6 | v7 | v8 | v9 | v10 | v11 | v12 21 | ::= | 22 | ::= | | | 23 | ::= ( BinaryOp ) 24 | ::= ! 25 | ::= max(, ) | min(, ) | select(, , ) 26 | ::= | | | 27 | ''' 28 | 29 | class HalideTokenizer(object): 30 | 31 | def __init__(self): 32 | self.vars = ['v'] 33 | self.num = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 34 | self.alg_ops = ['+', '-', '*', '/', '%'] 35 | self.cmp_ops = ['<', '<=', '=='] 36 | self.bool_ops = ['&&', '||'] 37 | self.rel_ops = self.cmp_ops + self.bool_ops + ['!'] 38 | self.ops = self.alg_ops + self.rel_ops 39 | self.keywords = ['max', 'min', 'select'] 40 | self.chars = ['(', ')', ','] 41 | 42 | def preprocess(self, raw_inp): 43 | Var = pyp.Combine('v' + pyp.Word(pyp.nums)) 44 | Const = pyp.Combine(pyp.Optional(pyp.Literal('-')) + pyp.Word(pyp.nums)) 45 | Term = Var | Const 46 | BinaryOp = pyp.oneOf('+ - * / % < <= == && ||') 47 | UnaryOp = pyp.oneOf('!') 48 | Expr = pyp.Forward() 49 | BinaryExpr = pyp.Group(pyp.Literal('(') + Expr + BinaryOp + Expr + pyp.Literal(')')) 50 | UnaryExpr = pyp.Group(UnaryOp + Expr) 51 | SelectExpr = pyp.Group(pyp.oneOf('min max') + pyp.Literal('(') + Expr + pyp.Literal(',') + Expr + pyp.Literal(')')) \ 52 | | pyp.Group(pyp.Literal('select') + pyp.Literal('(') + Expr + pyp.Literal(',') + Expr + pyp.Literal(',') + Expr + pyp.Literal(')')) 53 | Expr << (BinaryExpr | UnaryExpr | SelectExpr | Term) 54 | return Expr.parseString(raw_inp) 55 | 56 | class HalideParser(object): 57 | def __init__(self): 58 | self.tokenizer = HalideTokenizer() 59 | self.EOF = '' 60 | 61 | def debug(self, msg): 62 | if self.is_debug: 63 | print(msg) 64 | 65 | def parse(self, inp, debug=False): 66 | self.is_debug = debug 67 | inp = self.tokenizer.preprocess(inp) 68 | tm = TreeManager() 69 | tm, root_idx = self.parseExpr(tm, inp, -1, 0) 70 | return tm 71 | 72 | def parseExpr(self, tm, inp, parent, depth): 73 | if type(inp) is str: 74 | e1 = tm.create_tree(inp, parent, depth) 75 | return tm, e1 76 | elif len(inp) == 1: 77 | return self.parseExpr(tm, inp[0], parent, depth) 78 | elif (inp[0] == '!'): 79 | op = inp[0] 80 | op_node_idx = tm.create_tree(op, parent, depth) 81 | tm, e1 = self.parseExpr(tm, inp[1], op_node_idx, depth + 1) 82 | tm.update_edge(op_node_idx, e1) 83 | elif inp[0] in self.tokenizer.keywords: 84 | op = inp[0] 85 | op_node_idx = tm.create_tree(op, parent, depth) 86 | for expr in inp[1:]: 87 | if expr in self.tokenizer.chars: 88 | continue 89 | tm, e = self.parseExpr(tm, expr, op_node_idx, depth + 1) 90 | tm.update_edge(op_node_idx, e) 91 | else: 92 | op = inp[2] 93 | op_node_idx = tm.create_tree(op, parent, depth) 94 | tm, e1 = self.parseExpr(tm, inp[1], op_node_idx, depth + 1) 95 | tm, e2 = self.parseExpr(tm, inp[3], op_node_idx, depth + 1) 96 | tm.update_edge(op_node_idx, e1) 97 | tm.update_edge(op_node_idx, e2) 98 | return tm, op_node_idx 99 | 100 | class jspDependencyParser(object): 101 | def __init__(self, args): 102 | self.num_res = args.num_res 103 | self.max_time_horizon = args.max_time_horizon 104 | self.max_job_len = args.max_job_len 105 | self.max_resource_size = args.max_resource_size 106 | self.job_horizon = args.job_horizon 107 | self.base_alg = args.base_alg 108 | 109 | def ejf(self, dm): 110 | for job_idx in range(1, dm.num_nodes): 111 | job = dm.get_node(job_idx) 112 | min_schedule_time = dm.calc_min_schedule_time(job.st_time, job_idx) 113 | dm.add_job(job_idx, min_schedule_time) 114 | return dm 115 | 116 | def sjf(self, dm): 117 | pending_job = [] 118 | for job_idx in range(1, dm.num_nodes + 1): 119 | if job_idx == dm.num_nodes: 120 | pending_job_cap = 1 121 | else: 122 | pending_job_cap = self.job_horizon 123 | 124 | while len(pending_job) >= pending_job_cap: 125 | schedule_idx = -1 126 | schedule_time = -1 127 | min_job_len = -1 128 | for pending_job_idx in pending_job: 129 | cur_pending_job = dm.get_node(pending_job_idx) 130 | cur_min_schedule_time = dm.calc_min_schedule_time(cur_pending_job.st_time, pending_job_idx) 131 | if schedule_idx == -1 or cur_min_schedule_time < schedule_time or cur_min_schedule_time == schedule_time and cur_pending_job.job_len < min_job_len: 132 | schedule_idx = pending_job_idx 133 | schedule_time = cur_min_schedule_time 134 | min_job_len = cur_pending_job.job_len 135 | dm.add_job(schedule_idx, schedule_time) 136 | dm.max_ed_time = max(dm.max_ed_time, dm.get_node(schedule_idx).ed_time) 137 | pending_job.remove(schedule_idx) 138 | 139 | if job_idx == dm.num_nodes: 140 | break 141 | pending_job.append(job_idx) 142 | 143 | return dm 144 | 145 | def random_schedule(self, dm): 146 | pending_job = [] 147 | for job_idx in range(1, dm.num_nodes + 1): 148 | if job_idx == dm.num_nodes: 149 | pending_job_cap = 1 150 | else: 151 | pending_job_cap = self.job_horizon 152 | 153 | while len(pending_job) >= pending_job_cap: 154 | schedule_idx = np.random.choice(pending_job) 155 | cur_schedule_job = dm.get_node(schedule_idx) 156 | schedule_time = dm.calc_min_schedule_time(cur_schedule_job.st_time, schedule_idx) 157 | dm.add_job(schedule_idx, schedule_time) 158 | dm.max_ed_time = max(dm.max_ed_time, dm.get_node(schedule_idx).ed_time) 159 | pending_job.remove(schedule_idx) 160 | 161 | if job_idx == dm.num_nodes: 162 | break 163 | pending_job.append(job_idx) 164 | 165 | return dm 166 | 167 | def parse(self, job_seq, debug=False): 168 | self.is_debug = debug 169 | dm = JobScheduleManager(self.num_res, self.max_time_horizon, self.max_job_len, self.max_resource_size) 170 | for inp in job_seq: 171 | dm.nodes.append(JobNode(resource_size=inp['resource_size'], start_time=inp['start_time'], job_len=inp['job_len'])) 172 | dm.num_nodes = len(dm.nodes) 173 | dm.num_jobs = len(job_seq) 174 | if self.base_alg == 'EJF': 175 | dm = self.ejf(dm) 176 | elif self.base_alg == 'SJF': 177 | dm = self.sjf(dm) 178 | else: 179 | dm = self.random_schedule(dm) 180 | 181 | for job_idx in range(1, dm.num_nodes): 182 | job = dm.get_node(job_idx) 183 | dm.update_embedding(job_idx) 184 | if job.schedule_time == job.st_time: 185 | dm.add_edge(dm.root, job_idx) 186 | else: 187 | for i in dm.terminate[job.schedule_time]: 188 | dm.add_edge(i, job_idx) 189 | dm.update_stat() 190 | return dm 191 | 192 | 193 | class vrpParser(object): 194 | def parse(self, problem, debug=False): 195 | self.is_debug = debug 196 | dm = VrpManager(problem['capacity']) 197 | dm.nodes.append(VrpNode(x=problem['depot'][0], y=problem['depot'][1], demand=0, px=problem['depot'][0], py=problem['depot'][1], capacity=problem['capacity'], dis=0.0)) 198 | for customer in problem['customers']: 199 | dm.nodes.append(VrpNode(x=customer['position'][0], y=customer['position'][1], demand=customer['demand'], px=customer['position'][0], py=customer['position'][1], capacity=problem['capacity'], dis=0.0)) 200 | dm.num_nodes = len(dm.nodes) 201 | cur_capacity = problem['capacity'] 202 | pending_nodes = [i for i in range(0, dm.num_nodes)] 203 | dm.add_route_node(0) 204 | cur_capacity = dm.vehicle_state[-1][1] 205 | while len(pending_nodes) > 1: 206 | dis = [] 207 | demands = [] 208 | pre_node_idx = dm.vehicle_state[-1][0] 209 | pre_node = dm.get_node(pre_node_idx) 210 | for i in pending_nodes: 211 | cur_node = dm.get_node(i) 212 | dis.append(dm.get_dis(pre_node, cur_node)) 213 | demands.append(cur_node.demand) 214 | for i in range(len(pending_nodes)): 215 | for j in range(i + 1, len(pending_nodes)): 216 | if dis[i] > dis[j] or dis[i] == dis[j] and demands[i] > demands[j]: 217 | pending_nodes[i], pending_nodes[j] = pending_nodes[j], pending_nodes[i] 218 | dis[i], dis[j] = dis[j], dis[i] 219 | demands[i], demands[j] = demands[j], demands[i] 220 | for i in pending_nodes: 221 | if i == 0: 222 | if cur_capacity == problem['capacity']: 223 | continue 224 | dm.add_route_node(0) 225 | break 226 | else: 227 | cur_node = dm.get_node(i) 228 | if cur_node.demand > cur_capacity: 229 | continue 230 | dm.add_route_node(i) 231 | pending_nodes.remove(i) 232 | break 233 | cur_capacity = dm.vehicle_state[-1][1] 234 | dm.add_route_node(0) 235 | return dm 236 | -------------------------------------------------------------------------------- /src/models/data_utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | def is_int(s): 10 | return s.lstrip('+-').isdigit() 11 | 12 | 13 | def is_var(s): 14 | return s.startswith('v') 15 | 16 | 17 | def gcd(x, y): 18 | while y != 0: 19 | t = y 20 | y = x % y 21 | x = t 22 | return x 23 | 24 | 25 | def calc(op, c1, c2): 26 | c1 = int(c1) 27 | c2 = int(c2) 28 | if op == '+': 29 | return str(c1 + c2) 30 | elif op == '-': 31 | return str(c1 - c2) 32 | elif op == '*': 33 | return str(c1 * c2) 34 | elif op == '/': 35 | return str(c1 // c2) 36 | elif op == '%': 37 | return str(c1 % c2) 38 | elif op == '<': 39 | if c1 < c2: 40 | return str(1) 41 | else: 42 | return str(0) 43 | elif op == '<=': 44 | if c1 <= c2: 45 | return str(1) 46 | else: 47 | return str(0) 48 | elif op == '==': 49 | if c1 == c2: 50 | return str(1) 51 | else: 52 | return str(0) 53 | elif op == '&&': 54 | if c1 and c2: 55 | return str(1) 56 | else: 57 | return str(0) 58 | elif op == '||': 59 | if c1 or c2: 60 | return str(1) 61 | else: 62 | return str(0) 63 | else: 64 | raise ValueError('undefined op: ' + op) 65 | -------------------------------------------------------------------------------- /src/models/jspModel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import operator 10 | import random 11 | import time 12 | from multiprocessing.pool import ThreadPool 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from torch import cuda 18 | from torch.autograd import Variable 19 | from torch.nn.utils import clip_grad_norm 20 | import torch.nn.functional as F 21 | from torch.distributions.categorical import Categorical 22 | 23 | from .data_utils import data_utils 24 | from .modules import jspInputEncoder, mlp 25 | from .rewriter import jspRewriter 26 | from .BaseModel import BaseModel 27 | 28 | eps = 1e-3 29 | log_eps = np.log(eps) 30 | 31 | 32 | class jspModel(BaseModel): 33 | """ 34 | Model for job scheduling. 35 | """ 36 | def __init__(self, args): 37 | super(jspModel, self).__init__(args) 38 | self.input_format = args.input_format 39 | self.max_resource_size = args.max_resource_size 40 | self.job_horizon = args.job_horizon 41 | self.num_res = args.num_res 42 | self.max_time_horizon = args.max_time_horizon 43 | self.max_job_len = args.max_job_len 44 | self.embedding_size = args.embedding_size 45 | self.num_actions = self.job_horizon * 2 46 | self.reward_thres = -0.01 47 | if self.input_format == 'seq': 48 | self.input_encoder = jspInputEncoder.SeqLSTM(args) 49 | else: 50 | self.input_encoder = jspInputEncoder.DagLSTM(args) 51 | self.policy_embedding = mlp.MLPModel(self.num_MLP_layers, self.LSTM_hidden_size * 2, self.MLP_hidden_size, self.LSTM_hidden_size, self.cuda_flag, self.dropout_rate) 52 | self.policy = mlp.MLPModel(self.num_MLP_layers, self.LSTM_hidden_size * (self.job_horizon * 2), self.MLP_hidden_size, self.num_actions, self.cuda_flag, self.dropout_rate) 53 | self.value_estimator = mlp.MLPModel(self.num_MLP_layers, self.LSTM_hidden_size, self.MLP_hidden_size, 1, self.cuda_flag, self.dropout_rate) 54 | self.rewriter = jspRewriter() 55 | 56 | if args.optimizer == 'adam': 57 | self.optimizer = optim.Adam(self.parameters(), lr=self.lr) 58 | elif args.optimizer == 'sgd': 59 | self.optimizer = optim.SGD(self.parameters(), lr=self.lr) 60 | elif args.optimizer == 'rmsprop': 61 | self.optimizer = optim.RMSprop(self.parameters(), lr=self.lr) 62 | else: 63 | raise ValueError('optimizer undefined: ', args.optimizer) 64 | 65 | 66 | def rewrite(self, dm, trace_rec, candidate_rewrite_pos, eval_flag, max_search_pos, reward_thres=None): 67 | 68 | candidate_rewrite_pos.sort(reverse=True, key=operator.itemgetter(0)) 69 | if not eval_flag: 70 | sample_exp_reward_tensor = [] 71 | for idx, (cur_pred_reward, cur_pred_reward_tensor, rewrite_pos) in enumerate(candidate_rewrite_pos): 72 | sample_exp_reward_tensor.append(cur_pred_reward_tensor) 73 | sample_exp_reward_tensor = torch.cat(sample_exp_reward_tensor, 0) 74 | sample_exp_reward_tensor = torch.exp(sample_exp_reward_tensor * 10) 75 | sample_exp_reward = sample_exp_reward_tensor.data.cpu().numpy() 76 | 77 | candidate_dag_managers = [] 78 | candidate_update_node_idxes = [] 79 | candidate_rewrite_rec = [] 80 | extra_reward_rec = [] 81 | 82 | if not eval_flag: 83 | sample_rewrite_pos_dist = Categorical(sample_exp_reward_tensor) 84 | sample_rewrite_pos = sample_rewrite_pos_dist.sample(sample_shape=[len(candidate_rewrite_pos)]) 85 | #sample_rewrite_pos = torch.multinomial(sample_exp_reward_tensor, len(candidate_rewrite_pos)) 86 | sample_rewrite_pos = sample_rewrite_pos.data.cpu().numpy() 87 | indexes = np.unique(sample_rewrite_pos, return_index=True)[1] 88 | sample_rewrite_pos = [sample_rewrite_pos[i] for i in sorted(indexes)] 89 | sample_rewrite_pos = sample_rewrite_pos[:self.num_sample_rewrite_pos] 90 | sample_exp_reward = [sample_exp_reward[i] for i in sample_rewrite_pos] 91 | sample_rewrite_pos = [candidate_rewrite_pos[i] for i in sample_rewrite_pos] 92 | else: 93 | sample_rewrite_pos = candidate_rewrite_pos.copy() 94 | 95 | for idx, (pred_reward, cur_pred_reward_tensor, rewrite_pos) in enumerate(sample_rewrite_pos): 96 | if len(candidate_dag_managers) > 0 and idx >= max_search_pos: 97 | break 98 | if reward_thres is not None and pred_reward < reward_thres: 99 | if eval_flag: 100 | break 101 | elif np.random.random() > self.cont_prob: 102 | continue 103 | parent_idxes = dm.get_parent_idxes(rewrite_pos, self.job_horizon) 104 | children_idxes = dm.get_children_idxes(rewrite_pos, self.job_horizon) 105 | policy_embedding_inputs = [] 106 | cur_input = dm.get_node(rewrite_pos).state[0] 107 | cur_inputs = [] 108 | for i in parent_idxes: 109 | policy_embedding_inputs.append(dm.get_node(i).state[0]) 110 | cur_inputs.append(cur_input.clone()) 111 | while len(policy_embedding_inputs) < self.job_horizon: 112 | zero_state = Variable(torch.zeros(1, self.LSTM_hidden_size)) 113 | if self.cuda_flag: 114 | zero_state = zero_state.cuda() 115 | policy_embedding_inputs.append(zero_state) 116 | cur_inputs.append(zero_state.clone()) 117 | for i in children_idxes: 118 | policy_embedding_inputs.append(dm.get_node(i).state[0]) 119 | cur_inputs.append(cur_input.clone()) 120 | while len(policy_embedding_inputs) < self.job_horizon * 2: 121 | zero_state = Variable(torch.zeros(1, self.LSTM_hidden_size)) 122 | if self.cuda_flag: 123 | zero_state = zero_state.cuda() 124 | policy_embedding_inputs.append(zero_state) 125 | cur_inputs.append(zero_state.clone()) 126 | policy_embedding_inputs = torch.cat(policy_embedding_inputs, 0) 127 | cur_inputs = torch.cat(cur_inputs, 0) 128 | policy_embedding_inputs = torch.cat([cur_inputs, policy_embedding_inputs], 1) 129 | policy_inputs = self.policy_embedding(policy_embedding_inputs) 130 | policy_inputs = policy_inputs.view(1, self.LSTM_hidden_size * (self.job_horizon * 2)) 131 | ac_logits = self.policy(policy_inputs) 132 | ac_logprobs = nn.LogSoftmax()(ac_logits) 133 | ac_probs = nn.Softmax()(ac_logits) 134 | ac_logits = ac_logits.squeeze(0) 135 | ac_logprobs = ac_logprobs.squeeze(0) 136 | ac_probs = ac_probs.squeeze(0) 137 | if eval_flag: 138 | _, candidate_acs = torch.sort(ac_logprobs, descending=True) 139 | candidate_acs = candidate_acs.data.cpu().numpy() 140 | else: 141 | candidate_acs_dist = Categorical(ac_probs) 142 | candidate_acs = candidate_acs_dist.sample(sample_shape=[ac_probs.size()[0]]) 143 | #candidate_acs = torch.multinomial(ac_probs, ac_probs.size()[0]) 144 | candidate_acs = candidate_acs.data.cpu().numpy() 145 | indexes = np.unique(candidate_acs, return_index=True)[1] 146 | candidate_acs = [candidate_acs[i] for i in sorted(indexes)] 147 | cur_active = False 148 | for i, op_idx in enumerate(candidate_acs): 149 | if op_idx < self.job_horizon: 150 | if op_idx >= len(parent_idxes): 151 | continue 152 | neighbor_idx = parent_idxes[op_idx] 153 | else: 154 | if op_idx - self.job_horizon >= len(children_idxes): 155 | continue 156 | neighbor_idx = children_idxes[op_idx - self.job_horizon] 157 | if (rewrite_pos, neighbor_idx) in trace_rec or (neighbor_idx, rewrite_pos) in trace_rec: 158 | continue 159 | new_dm, cur_update_node_idxes = self.rewriter.move(dm, rewrite_pos, neighbor_idx) 160 | if len(cur_update_node_idxes) == 0: 161 | continue 162 | candidate_update_node_idxes.append(cur_update_node_idxes) 163 | candidate_dag_managers.append(new_dm) 164 | candidate_rewrite_rec.append((ac_logprobs, pred_reward, cur_pred_reward_tensor, rewrite_pos, op_idx, neighbor_idx)) 165 | cur_active = True 166 | if len(candidate_dag_managers) >= max_search_pos: 167 | break 168 | if not cur_active: 169 | extra_reward_rec.append(cur_pred_reward_tensor) 170 | return candidate_dag_managers, candidate_update_node_idxes, candidate_rewrite_rec, extra_reward_rec 171 | 172 | 173 | def batch_rewrite(self, dag_managers, trace_rec, candidate_rewrite_pos, eval_flag, max_search_pos, reward_thres): 174 | candidate_dag_managers = [] 175 | candidate_update_node_idxes = [] 176 | candidate_rewrite_rec = [] 177 | extra_reward_rec = [] 178 | for i in range(len(dag_managers)): 179 | cur_candidate_dag_managers, cur_candidate_update_node_idxes, cur_candidate_rewrite_rec, cur_extra_reward_rec = self.rewrite(dag_managers[i], trace_rec[i], candidate_rewrite_pos[i], eval_flag, max_search_pos, reward_thres) 180 | candidate_dag_managers.append(cur_candidate_dag_managers) 181 | candidate_update_node_idxes.append(cur_candidate_update_node_idxes) 182 | candidate_rewrite_rec.append(cur_candidate_rewrite_rec) 183 | extra_reward_rec = extra_reward_rec + cur_extra_reward_rec 184 | return candidate_dag_managers, candidate_update_node_idxes, candidate_rewrite_rec, extra_reward_rec 185 | 186 | 187 | def forward(self, batch_data, eval_flag=False): 188 | dag_managers = [] 189 | batch_size = len(batch_data) 190 | for dm in batch_data: 191 | dag_managers.append(dm) 192 | dag_managers = self.input_encoder.calc_embedding(dag_managers, eval_flag) 193 | 194 | active = True 195 | reduce_steps = 0 196 | 197 | trace_rec = [[] for _ in range(batch_size)] 198 | rewrite_rec = [[] for _ in range(batch_size)] 199 | dm_rec = [[] for _ in range(batch_size)] 200 | extra_reward_rec = [] 201 | 202 | for idx in range(batch_size): 203 | dm_rec[idx].append(dag_managers[idx]) 204 | 205 | while active and ((self.max_reduce_steps is None) or reduce_steps < self.max_reduce_steps): 206 | active = False 207 | reduce_steps += 1 208 | node_idxes = [] 209 | node_embeddings = [] 210 | root_embeddings = [] 211 | for dm_idx in range(batch_size): 212 | dm = dag_managers[dm_idx] 213 | root_embedding = dm.get_node(0).state[0] 214 | for i in range(1, dm.num_nodes): 215 | cur_node = dm.get_node(i) 216 | node_idxes.append((dm_idx, i)) 217 | node_embeddings.append(cur_node.state[0]) 218 | root_embeddings.append(root_embedding.clone()) 219 | pred_rewards = [] 220 | for st in range(0, len(node_idxes), self.batch_size): 221 | cur_node_embeddings = node_embeddings[st: st + self.batch_size] 222 | cur_node_embeddings = torch.cat(cur_node_embeddings, 0) 223 | cur_pred_rewards = self.value_estimator(cur_node_embeddings) 224 | pred_rewards.append(cur_pred_rewards) 225 | pred_rewards = torch.cat(pred_rewards, 0) 226 | candidate_rewrite_pos = [[] for _ in range(batch_size)] 227 | for idx, (dm_idx, node_idx) in enumerate(node_idxes): 228 | candidate_rewrite_pos[dm_idx].append((pred_rewards[idx].data[0], pred_rewards[idx], node_idx)) 229 | 230 | update_node_idxes = [[] for _ in range(batch_size)] 231 | candidate_dag_managers, candidate_update_node_idxes, candidate_rewrite_rec, cur_extra_reward_rec = self.batch_rewrite(dag_managers, trace_rec, candidate_rewrite_pos, eval_flag, max_search_pos=1, reward_thres=self.reward_thres) 232 | for dm_idx in range(batch_size): 233 | cur_candidate_dag_managers = candidate_dag_managers[dm_idx] 234 | cur_candidate_update_node_idxes = candidate_update_node_idxes[dm_idx] 235 | cur_candidate_rewrite_rec = candidate_rewrite_rec[dm_idx] 236 | if len(cur_candidate_dag_managers) > 0: 237 | active = True 238 | cur_dag_manager = cur_candidate_dag_managers[0] 239 | cur_update_node_idxes = cur_candidate_update_node_idxes[0] 240 | cur_rewrite_rec = cur_candidate_rewrite_rec[0] 241 | dag_managers[dm_idx] = cur_dag_manager 242 | update_node_idxes[dm_idx] = cur_update_node_idxes 243 | ac_logprob, pred_reward, cur_pred_reward_tensor, rewrite_pos, applied_op, neighbor_idx = cur_rewrite_rec 244 | rewrite_rec[dm_idx].append(cur_rewrite_rec) 245 | trace_rec[dm_idx].append((rewrite_pos, neighbor_idx)) 246 | if not active: 247 | break 248 | 249 | updated_dm = self.input_encoder.calc_embedding(dag_managers, eval_flag) 250 | 251 | for i in range(batch_size): 252 | dag_managers[i] = updated_dm[i] 253 | if len(update_node_idxes[i]) > 0: 254 | dm_rec[i].append(updated_dm[i]) 255 | 256 | total_policy_loss = data_utils.np_to_tensor(np.zeros(1), 'float', self.cuda_flag) 257 | total_value_loss = data_utils.np_to_tensor(np.zeros(1), 'float', self.cuda_flag) 258 | 259 | pred_actions_rec = [] 260 | pred_actions_logprob_rec = [] 261 | pred_value_rec = [] 262 | value_target_rec = [] 263 | total_reward = 0 264 | total_completion_time = 0 265 | total_slow_down = 0 266 | for dm_idx, cur_dm_rec in enumerate(dm_rec): 267 | pred_avg_slow_down = [] 268 | pred_avg_completion_time = [] 269 | for dm in cur_dm_rec: 270 | pred_avg_slow_down.append(dm.avg_slow_down) 271 | pred_avg_completion_time.append(dm.avg_completion_time) 272 | min_slow_down = pred_avg_slow_down[0] 273 | min_completion_time = pred_avg_completion_time[0] 274 | best_reward = min_slow_down 275 | for idx, (ac_logprob, pred_reward, cur_pred_reward_tensor, rewrite_pos, applied_op, neighbor_idx) in enumerate(rewrite_rec[dm_idx]): 276 | cur_reward = pred_avg_slow_down[idx] - pred_avg_slow_down[idx + 1] - 0.01 277 | best_reward = min(best_reward, pred_avg_slow_down[idx + 1]) 278 | min_slow_down = min(min_slow_down, pred_avg_slow_down[idx + 1]) 279 | min_completion_time = min(min_completion_time, pred_avg_completion_time[idx + 1]) 280 | 281 | if self.gamma > 0.0: 282 | decay_coef = 1.0 283 | num_rollout_steps = len(cur_dm_rec) - idx - 1 284 | for i in range(idx + 1, idx + 1 + num_rollout_steps): 285 | cur_reward = max(decay_coef * (pred_avg_slow_down[idx] - pred_avg_slow_down[i] - (i - idx) * 0.01), cur_reward) 286 | decay_coef *= self.gamma 287 | 288 | cur_reward = cur_reward * 1.0 / pred_avg_slow_down[0] 289 | cur_reward_tensor = data_utils.np_to_tensor(np.array([cur_reward], dtype=np.float32), 'float', self.cuda_flag, eval_flag) 290 | if ac_logprob.data.cpu().numpy()[0] > log_eps or cur_reward - pred_reward > 0: 291 | ac_mask = np.zeros(self.num_actions) 292 | ac_mask[applied_op] = cur_reward - pred_reward 293 | ac_mask = data_utils.np_to_tensor(ac_mask, 'float', self.cuda_flag, eval_flag) 294 | ac_mask = ac_mask.unsqueeze(0) 295 | pred_actions_rec.append(ac_mask) 296 | pred_actions_logprob_rec.append(ac_logprob.unsqueeze(0)) 297 | pred_value_rec.append(cur_pred_reward_tensor) 298 | value_target_rec.append(cur_reward_tensor) 299 | total_reward += best_reward 300 | total_completion_time += min_completion_time 301 | total_slow_down += min_slow_down 302 | 303 | for cur_pred_reward in extra_reward_rec: 304 | pred_value_rec.append(cur_pred_reward) 305 | value_target = data_utils.np_to_tensor(np.zeros(1), 'float', self.cuda_flag) 306 | value_target_rec.append(value_target) 307 | 308 | if len(pred_actions_rec) > 0: 309 | pred_actions_rec = torch.cat(pred_actions_rec, 0) 310 | pred_actions_logprob_rec = torch.cat(pred_actions_logprob_rec, 0) 311 | pred_value_rec = torch.cat(pred_value_rec, 0) 312 | value_target_rec = torch.cat(value_target_rec, 0) 313 | pred_value_rec = pred_value_rec.unsqueeze(1) 314 | value_target_rec = value_target_rec.unsqueeze(1) 315 | total_policy_loss = -torch.sum(pred_actions_logprob_rec * pred_actions_rec) 316 | total_value_loss = F.smooth_l1_loss(pred_value_rec, value_target_rec, size_average=False) 317 | total_policy_loss /= batch_size 318 | total_value_loss /= batch_size 319 | total_loss = total_policy_loss + total_value_loss * self.value_loss_coef 320 | total_reward = total_reward * 1.0 / batch_size 321 | total_completion_time = total_completion_time * 1.0 / batch_size 322 | total_slow_down = total_slow_down * 1.0 / batch_size 323 | return total_loss, total_reward, total_completion_time, dm_rec 324 | -------------------------------------------------------------------------------- /src/models/model_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .supervisor import * -------------------------------------------------------------------------------- /src/models/model_utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import argparse 10 | import sys 11 | import os 12 | import re 13 | import json 14 | import pandas as pd 15 | 16 | 17 | class Logger(object): 18 | """ 19 | The class for recording the training process. 20 | """ 21 | def __init__(self, args): 22 | self.log_interval = args.log_interval 23 | self.log_name = "../logs/" + args.log_name 24 | self.best_reward = 0 25 | self.records = [] 26 | if not os.path.exists("../logs/"): 27 | os.makedirs("../logs/") 28 | 29 | 30 | def write_summary(self, summary): 31 | print("global-step: %(global_step)d, avg-reward: %(avg_reward).3f" % summary) 32 | self.records.append(summary) 33 | df = pd.DataFrame(self.records) 34 | df.to_csv(self.log_name, index=False) 35 | self.best_reward = max(self.best_reward, summary['avg_reward']) 36 | -------------------------------------------------------------------------------- /src/models/model_utils/supervisor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import argparse 10 | import sys 11 | import os 12 | import torch 13 | import re 14 | import json 15 | import time 16 | import torch.multiprocessing as mp 17 | 18 | from torch.nn.utils import clip_grad_norm 19 | 20 | from ..data_utils import data_utils 21 | from ..data_utils.parser import * 22 | 23 | CKPT_PATTERN = re.compile('^ckpt-(\d+)$') 24 | 25 | 26 | class Supervisor(object): 27 | """ 28 | The base class to manage the high-level model execution processes. The concrete classes for different applications are derived from it. 29 | """ 30 | def __init__(self, model, args): 31 | self.processes = args.processes 32 | self.model = model 33 | self.keep_last_n = args.keep_last_n 34 | self.dropout_rate = args.dropout_rate 35 | self.global_step = args.resume 36 | self.batch_size = args.batch_size 37 | self.model_dir = args.model_dir 38 | 39 | 40 | def load_pretrained(self, load_model): 41 | print("Read model parameters from %s." % load_model) 42 | checkpoint = torch.load(load_model) 43 | self.model.load_state_dict(checkpoint) 44 | 45 | 46 | def save_model(self): 47 | if not os.path.exists(self.model_dir): 48 | os.makedirs(self.model_dir) 49 | global_step_padded = format(self.global_step, '08d') 50 | ckpt_name = 'ckpt-' + global_step_padded 51 | path = os.path.join(self.model_dir, ckpt_name) 52 | ckpt = self.model.state_dict() 53 | torch.save(ckpt, path) 54 | 55 | if self.keep_last_n is not None: 56 | ckpts = [] 57 | for file_name in os.listdir(self.model_dir): 58 | matched_name = CKPT_PATTERN.match(file_name) 59 | if matched_name is None or matched_name == ckpt_name: 60 | continue 61 | step = int(matched_name.group(1)) 62 | ckpts.append((step, file_name)) 63 | if len(ckpts) > self.keep_last_n: 64 | ckpts.sort() 65 | os.unlink(os.path.join(self.model_dir, ckpts[0][1])) 66 | 67 | 68 | class HalideSupervisor(Supervisor): 69 | """ 70 | Management class for expression simplification. 71 | """ 72 | def __init__(self, model, args, term_vocab, term_vocab_list, op_vocab, op_vocab_list): 73 | super(HalideSupervisor, self).__init__(model, args) 74 | self.DataProcessor = data_utils.HalideDataProcessor() 75 | self.parser = HalideParser() 76 | self.term_vocab = term_vocab 77 | self.term_vocab_list = term_vocab_list 78 | self.op_vocab = op_vocab 79 | self.op_vocab_list = op_vocab_list 80 | 81 | 82 | def train(self, batch_data): 83 | self.model.dropout_rate = self.dropout_rate 84 | self.model.optimizer.zero_grad() 85 | avg_loss, avg_reward, trace_rec, tm_rec = self.model(batch_data) 86 | self.global_step += 1 87 | if avg_reward != 0: 88 | avg_loss.backward() 89 | self.model.train() 90 | return avg_loss.item(), avg_reward 91 | 92 | 93 | def batch_eval(self, eval_data, output_trace_flag, output_trace_option, process_idx): 94 | cum_loss = 0 95 | cum_expr_reward = 0 96 | cum_gt_expr_reward = 0 97 | cum_tree_reward = 0 98 | cum_gt_tree_reward = 0 99 | data_size = len(eval_data) 100 | trace_rec = [] 101 | for batch_idx in range(0, data_size, self.batch_size): 102 | batch_data = self.DataProcessor.get_batch(eval_data, self.batch_size, batch_idx) 103 | cur_avg_loss, cur_avg_expr_reward, cur_trace_rec, cur_tm_rec = self.model(batch_data, eval_flag=True) 104 | cum_loss += cur_avg_loss.item() * len(batch_data) 105 | cum_expr_reward += cur_avg_expr_reward * len(batch_data) 106 | cur_gt_expr_reward = 0 107 | cur_gt_tree_reward = 0 108 | for idx, (trace, tm) in enumerate(batch_data): 109 | gt = len(trace[0]) - len(trace[-1]) 110 | cur_gt_expr_reward += gt 111 | num_nodes_0 = tm.num_trees 112 | final_tm = self.parser.parse(trace[-1]) 113 | num_nodes_1 = final_tm.num_trees 114 | cur_gt_tree_reward += num_nodes_0 - num_nodes_1 115 | init_expr = cur_trace_rec[idx][0][0] 116 | pred_expr = cur_trace_rec[idx][-1][0] 117 | pred_reward = len(init_expr) - len(pred_expr) 118 | if output_trace_flag == 'complete' or output_trace_flag == 'fail' and len(trace[-1]) < len(pred_expr) \ 119 | or output_trace_flag == 'succeed' and len(pred_expr) < len(trace[-1]): 120 | if output_trace_option != 'pred': 121 | self.DataProcessor.print_gt_trace(trace) 122 | self.DataProcessor.print_pred_trace(cur_trace_rec[idx]) 123 | print('end of a sample') 124 | print('') 125 | 126 | cur_cum_tree_reward = 0 127 | for tm_rec in cur_tm_rec: 128 | cur_tree_reward = 0 129 | num_nodes_0 = tm_rec[0].num_trees 130 | for final_tm in tm_rec[1:]: 131 | num_nodes_1 = final_tm.num_valid_nodes() 132 | if num_nodes_0 - num_nodes_1 > cur_tree_reward: 133 | cur_tree_reward = num_nodes_0 - num_nodes_1 134 | cur_cum_tree_reward += cur_tree_reward 135 | 136 | trace_rec = trace_rec + cur_trace_rec 137 | cum_tree_reward += cur_cum_tree_reward 138 | cum_gt_expr_reward += cur_gt_expr_reward 139 | cum_gt_tree_reward += cur_gt_tree_reward 140 | print('process start idx: %d batch idx: %d pred expr reward: %.4f pred tree reward: %.4f gt expr reward: %.4f gt tree reward: %.4f' % \ 141 | (process_idx, batch_idx, cur_avg_expr_reward, cur_cum_tree_reward * 1.0 / len(batch_data), cur_gt_expr_reward * 1.0 / len(batch_data), cur_gt_tree_reward * 1.0 / len(batch_data))) 142 | return cum_loss, cum_expr_reward, cum_tree_reward, cum_gt_expr_reward, cum_gt_tree_reward, trace_rec 143 | 144 | 145 | def eval(self, data, output_trace_flag, output_trace_option, output_trace_file, max_eval_size=None): 146 | data_size = len(data) 147 | if max_eval_size is not None: 148 | data_size = min(data_size, max_eval_size) 149 | eval_data = data[:data_size] 150 | if self.processes == 1: 151 | cum_loss, cum_expr_reward, cum_tree_reward, cum_gt_expr_reward, cum_gt_tree_reward, trace_rec = self.batch_eval(eval_data, output_trace_flag, output_trace_option, 0) 152 | else: 153 | cum_loss = 0 154 | cum_expr_reward = 0 155 | cum_tree_reward = 0 156 | cum_gt_expr_reward = 0 157 | cum_gt_tree_reward = 0 158 | trace_rec = [] 159 | try: 160 | mp.set_start_method('spawn') 161 | except RuntimeError: 162 | pass 163 | pool = mp.Pool(processes=self.processes) 164 | res = [] 165 | batch_per_process = data_size // self.processes 166 | if data_size % self.processes > 0: 167 | batch_per_process += 1 168 | for st in range(0, data_size, batch_per_process): 169 | res += [pool.apply_async(self.batch_eval, (eval_data[st: st + batch_per_process], output_trace_flag, output_trace_option, st))] 170 | for i in range(len(res)): 171 | cur_cum_loss, cur_cum_expr_reward, cur_cum_tree_reward, cur_cum_gt_expr_reward, cur_cum_gt_tree_reward, cur_trace_rec = res[i].get() 172 | cum_loss += cur_cum_loss 173 | cum_expr_reward += cur_cum_expr_reward 174 | cum_tree_reward += cur_cum_tree_reward 175 | cum_gt_expr_reward += cur_cum_gt_expr_reward 176 | cum_gt_tree_reward += cur_cum_gt_tree_reward 177 | trace_rec = trace_rec + cur_trace_rec 178 | 179 | avg_loss = cum_loss / data_size 180 | avg_expr_reward = cum_expr_reward * 1.0 / data_size 181 | avg_tree_reward = cum_tree_reward * 1.0 / data_size 182 | gt_expr_reward = cum_gt_expr_reward * 1.0 / data_size 183 | gt_tree_reward = cum_gt_tree_reward * 1.0 / data_size 184 | print('average: pred expr reward: %.4f pred tree reward: %.4f gt expr reward: %.4f gt tree reward: %.4f' % (avg_expr_reward, avg_tree_reward, gt_expr_reward, gt_tree_reward)) 185 | if output_trace_file is not None: 186 | fout_res = open(output_trace_file, 'w') 187 | json.dump(trace_rec, fout_res) 188 | return avg_loss, avg_expr_reward 189 | 190 | 191 | class jspSupervisor(Supervisor): 192 | """ 193 | Management class for job scheduling. 194 | """ 195 | def __init__(self, model, args): 196 | super(jspSupervisor, self).__init__(model, args) 197 | self.DataProcessor = data_utils.jspDataProcessor(args) 198 | 199 | 200 | def train(self, batch_data): 201 | self.model.dropout_rate = self.dropout_rate 202 | self.model.optimizer.zero_grad() 203 | avg_loss, avg_reward, avg_completion_time, dm_rec = self.model(batch_data) 204 | self.global_step += 1 205 | if avg_reward != 0: 206 | avg_loss.backward() 207 | self.model.train() 208 | return avg_loss.item(), avg_reward 209 | 210 | 211 | def batch_eval(self, eval_data, output_trace_flag, process_idx): 212 | cum_loss = 0 213 | cum_reward = 0 214 | cum_completion_time = 0 215 | cum_gt_reward = 0 216 | data_size = len(eval_data) 217 | 218 | for batch_idx in range(0, data_size, self.batch_size): 219 | batch_data = self.DataProcessor.get_batch(eval_data, self.batch_size, batch_idx) 220 | cur_avg_loss, cur_avg_reward, cur_avg_completion_time, dm_rec = self.model(batch_data, eval_flag=True) 221 | cum_loss += cur_avg_loss.item() * len(batch_data) 222 | cum_reward += cur_avg_reward * len(batch_data) 223 | cum_completion_time += cur_avg_completion_time * len(batch_data) 224 | if output_trace_flag == 'complete': 225 | for cur_dm_rec in dm_rec: 226 | for i, job in enumerate(cur_dm_rec[-1].nodes[1:]): 227 | print(i) 228 | print(job.st_time) 229 | print(job.job_len) 230 | print(job.resource_size) 231 | print(job.schedule_time) 232 | print('') 233 | print('process start idx: %d batch idx: %d pred reward: %.4f pred completion time: %.4f' \ 234 | % (process_idx, batch_idx, cur_avg_reward, cur_avg_completion_time)) 235 | return cum_loss, cum_reward, cum_completion_time 236 | 237 | 238 | def eval(self, data, output_trace_flag, max_eval_size=None): 239 | data_size = len(data) 240 | if max_eval_size is not None: 241 | data_size = min(data_size, max_eval_size) 242 | eval_data = data[:data_size] 243 | if self.processes == 1: 244 | cum_loss, cum_reward, cum_completion_time = self.batch_eval(eval_data, output_trace_flag, 0) 245 | else: 246 | cum_loss = 0 247 | cum_reward = 0 248 | cum_completion_time = 0 249 | try: 250 | mp.set_start_method('spawn') 251 | except RuntimeError: 252 | pass 253 | pool = mp.Pool(processes=self.processes) 254 | res = [] 255 | batch_per_process = data_size // self.processes 256 | if data_size % batch_per_process > 0: 257 | batch_per_process += 1 258 | for st in range(0, data_size, batch_per_process): 259 | res += [pool.apply_async(self.batch_eval, (eval_data[st: st + batch_per_process], output_trace_flag, st))] 260 | for i in range(len(res)): 261 | cur_cum_loss, cur_cum_reward, cur_cum_completion_time = res[i].get() 262 | cum_loss += cur_cum_loss 263 | cum_reward += cur_cum_reward 264 | cum_completion_time += cur_cum_completion_time 265 | 266 | avg_loss = cum_loss / data_size 267 | avg_reward = cum_reward / data_size 268 | avg_completion_time = cum_completion_time * 1.0 / data_size 269 | print('average pred reward: %.4f' % avg_reward) 270 | print('average completion time: %.4f' % avg_completion_time) 271 | return avg_loss, avg_reward 272 | 273 | 274 | class vrpSupervisor(Supervisor): 275 | """ 276 | Management class for vehicle routing. 277 | """ 278 | def __init__(self, model, args): 279 | super(vrpSupervisor, self).__init__(model, args) 280 | self.DataProcessor = data_utils.vrpDataProcessor() 281 | 282 | 283 | def train(self, batch_data): 284 | self.model.dropout_rate = self.dropout_rate 285 | self.model.optimizer.zero_grad() 286 | avg_loss, avg_reward, dm_rec = self.model(batch_data) 287 | self.global_step += 1 288 | if type(avg_loss) != float: 289 | avg_loss.backward() 290 | self.model.train() 291 | return avg_loss.item(), avg_reward 292 | 293 | 294 | def batch_eval(self, eval_data, output_trace_flag, process_idx): 295 | cum_loss = 0 296 | cum_reward = 0 297 | data_size = len(eval_data) 298 | 299 | for batch_idx in range(0, data_size, self.batch_size): 300 | batch_data = self.DataProcessor.get_batch(eval_data, self.batch_size, batch_idx) 301 | cur_avg_loss, cur_avg_reward, dm_rec = self.model(batch_data, eval_flag=True) 302 | cum_loss += cur_avg_loss.item() * len(batch_data) 303 | cum_reward += cur_avg_reward * len(batch_data) 304 | if output_trace_flag == 'complete': 305 | for cur_dm_rec in dm_rec: 306 | for i in range(len(cur_dm_rec)): 307 | print('step ' + str(i)) 308 | dm = cur_dm_rec[i] 309 | print(dm.tot_dis[-1]) 310 | for j in range(len(dm.vehicle_state)): 311 | cur_pos, cur_capacity = dm.vehicle_state[j] 312 | cur_node = dm.get_node(cur_pos) 313 | print(cur_node.x, cur_node.y, cur_node.demand, cur_capacity, dm.tot_dis[j]) 314 | print('') 315 | print('process start idx: %d batch idx: %d pred reward: %.4f' \ 316 | % (process_idx, batch_idx, cur_avg_reward)) 317 | return cum_loss, cum_reward 318 | 319 | 320 | def eval(self, data, output_trace_flag, max_eval_size=None): 321 | data_size = len(data) 322 | if max_eval_size is not None: 323 | data_size = min(data_size, max_eval_size) 324 | eval_data = data[:data_size] 325 | if self.processes == 1: 326 | cum_loss, cum_reward = self.batch_eval(eval_data, output_trace_flag, 0) 327 | else: 328 | cum_loss = 0 329 | cum_reward = 0 330 | try: 331 | mp.set_start_method('spawn') 332 | except RuntimeError: 333 | pass 334 | pool = mp.Pool(processes=self.processes) 335 | res = [] 336 | batch_per_process = data_size // self.processes 337 | if data_size % batch_per_process > 0: 338 | batch_per_process += 1 339 | for st in range(0, data_size, batch_per_process): 340 | res += [pool.apply_async(self.batch_eval, (eval_data[st: st + batch_per_process], output_trace_flag, st))] 341 | for i in range(len(res)): 342 | cur_cum_loss, cur_cum_reward = res[i].get() 343 | cum_loss += cur_cum_loss 344 | cum_reward += cur_cum_reward 345 | 346 | avg_loss = cum_loss / data_size 347 | avg_reward = cum_reward / data_size 348 | print('average pred reward: %.4f' % avg_reward) 349 | return avg_loss, avg_reward 350 | -------------------------------------------------------------------------------- /src/models/modules/HalideInputEncoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import random 9 | import numpy as np 10 | import time 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | 17 | from ..data_utils import data_utils 18 | 19 | 20 | class InputEmbedding(nn.Module): 21 | """ 22 | Component to compute the embedding of each terminal node. 23 | """ 24 | def __init__(self, args, term_vocab, term_vocab_list): 25 | super(InputEmbedding, self).__init__() 26 | self.dataProcessor = data_utils.HalideDataProcessor() 27 | self.term_vocab = term_vocab 28 | self.term_vocab_list = term_vocab_list 29 | self.term_vocab_size = args.term_vocab_size 30 | self.embedding_size = args.embedding_size 31 | self.hidden_size = args.LSTM_hidden_size 32 | self.cuda_flag = args.cuda 33 | 34 | self.char_embedding = nn.Embedding(self.term_vocab_size, self.embedding_size) 35 | self.token_embedding = nn.LSTM(input_size=self.embedding_size, 36 | hidden_size=self.hidden_size, 37 | num_layers=1, 38 | batch_first=True) 39 | 40 | def forward(self, raw_input_tokens, eval_mode=False): 41 | input_tokens = [] 42 | for raw_inp in raw_input_tokens: 43 | input_tokens.append(self.dataProcessor.token_to_ids(raw_inp, self.term_vocab)) 44 | input_tokens = np.array(input_tokens) 45 | input_tokens = data_utils.np_to_tensor(input_tokens, 'int', self.cuda_flag, eval_mode) 46 | if len(input_tokens.size()) < 2: 47 | input_tokens = input_tokens.unsqueeze(0) 48 | init_embedding = self.char_embedding(input_tokens) 49 | batch_size = input_tokens.size()[0] 50 | init_h = Variable(torch.zeros(1, batch_size, self.hidden_size)) 51 | init_c = Variable(torch.zeros(1, batch_size, self.hidden_size)) 52 | if self.cuda_flag: 53 | init_h = init_h.cuda() 54 | init_c = init_c.cuda() 55 | init_state = (init_h, init_c) 56 | embedding_outputs, embedding_states = self.token_embedding(init_embedding, init_state) 57 | return embedding_states 58 | 59 | 60 | class TreeLSTM(nn.Module): 61 | """ 62 | Tree LSTM to embed each node in the tree. It is used for expression simplification. 63 | """ 64 | def __init__(self, args, term_vocab, term_vocab_list, op_vocab, op_vocab_list): 65 | super(TreeLSTM, self).__init__() 66 | self.batch_size = args.batch_size 67 | self.hidden_size = args.LSTM_hidden_size 68 | self.embedding_size = args.embedding_size 69 | self.dropout_rate = args.dropout_rate 70 | self.cuda_flag = args.cuda 71 | self.term_vocab_size = args.term_vocab_size 72 | self.term_vocab = term_vocab 73 | self.term_vocab_list = term_vocab_list 74 | self.op_vocab_size = args.op_vocab_size 75 | self.op_vocab = op_vocab 76 | self.op_vocab_list = op_vocab_list 77 | self.input_embedding = InputEmbedding(args, term_vocab, term_vocab_list) 78 | 79 | self.ih = nn.ModuleList([nn.ModuleList([nn.Linear(self.hidden_size, self.hidden_size, bias=True) for _ in range(3)]) for _ in range(self.op_vocab_size)]) 80 | self.fh = nn.ModuleList([nn.ModuleList([nn.Linear(self.hidden_size, self.hidden_size, bias=True) for _ in range(3)]) for _ in range(self.op_vocab_size)]) 81 | self.oh = nn.ModuleList([nn.ModuleList([nn.Linear(self.hidden_size, self.hidden_size, bias=True) for _ in range(3)]) for _ in range(self.op_vocab_size)]) 82 | self.uh = nn.ModuleList([nn.ModuleList([nn.Linear(self.hidden_size, self.hidden_size, bias=True) for _ in range(3)]) for _ in range(self.op_vocab_size)]) 83 | 84 | 85 | def calc_root(self, ops, child_h, child_c): 86 | i = [] 87 | o = [] 88 | u = [] 89 | fc = [] 90 | h = [] 91 | c = [] 92 | for idx in range(len(ops)): 93 | op = self.op_vocab[ops[idx]] 94 | i_cur = Variable(torch.zeros(1, self.hidden_size)) 95 | o_cur = Variable(torch.zeros(1, self.hidden_size)) 96 | u_cur = Variable(torch.zeros(1, self.hidden_size)) 97 | f_cur = [] 98 | if self.cuda_flag: 99 | i_cur = i_cur.cuda() 100 | o_cur = o_cur.cuda() 101 | u_cur = u_cur.cuda() 102 | for child_idx in range(len(child_h[idx])): 103 | i_cur += self.ih[op][child_idx](child_h[idx][child_idx]) 104 | o_cur += self.oh[op][child_idx](child_h[idx][child_idx]) 105 | u_cur += self.uh[op][child_idx](child_h[idx][child_idx]) 106 | f_cur.append(self.fh[op][child_idx](child_h[idx][child_idx])) 107 | i_cur = F.sigmoid(i_cur) 108 | o_cur = F.sigmoid(o_cur) 109 | u_cur = F.tanh(u_cur) 110 | f_cur = torch.cat(f_cur, 0) 111 | f_cur = F.sigmoid(f_cur) 112 | fc_cur = F.torch.mul(f_cur, torch.cat(child_c[idx], 0)) 113 | fc_cur = F.torch.sum(fc_cur, 0) 114 | i.append(i_cur.unsqueeze(0)) 115 | o.append(o_cur.unsqueeze(0)) 116 | u.append(u_cur.unsqueeze(0)) 117 | fc.append(fc_cur.unsqueeze(0).unsqueeze(0)) 118 | i = torch.cat(i, 0) 119 | o = torch.cat(o, 0) 120 | u = torch.cat(u, 0) 121 | fc = torch.cat(fc, 0) 122 | c = F.torch.mul(i, u) + fc 123 | h = F.torch.mul(o, F.tanh(c)) 124 | return h, c 125 | 126 | 127 | def calc_embedding(self, tree_managers, eval_mode=False): 128 | queue_term = [] 129 | queue_nonterm = [] 130 | head_term = 0 131 | head_nonterm = 0 132 | max_num_trees = 0 133 | 134 | for tree_manager_idx in range(len(tree_managers)): 135 | tree_manager = tree_managers[tree_manager_idx] 136 | max_num_trees = max(max_num_trees, tree_manager.num_trees) 137 | for idx in range(tree_manager.num_trees): 138 | cur_tree = tree_manager.get_tree(idx) 139 | canCompute = True 140 | children_h = [] 141 | children_c = [] 142 | for child_idx in cur_tree.children: 143 | child = tree_manager.get_tree(child_idx) 144 | if child.state is None: 145 | canCompute = False 146 | break 147 | else: 148 | child_h, child_c = child.state 149 | children_h.append(child_h) 150 | children_c.append(child_c) 151 | if canCompute: 152 | if len(children_h) == 0: 153 | queue_term.append((tree_manager_idx, idx, cur_tree.root)) 154 | else: 155 | queue_nonterm.append((tree_manager_idx, idx, cur_tree.root, children_h, children_c)) 156 | 157 | while head_term < len(queue_term): 158 | encoder_inputs = [] 159 | tree_idxes = [] 160 | while head_term < len(queue_term): 161 | tree_manager_idx, idx, root = queue_term[head_term] 162 | tree_idxes.append((tree_manager_idx, idx)) 163 | encoder_inputs.append(root) 164 | head_term += 1 165 | if len(encoder_inputs) == 0: 166 | break 167 | encoder_outputs = self.input_embedding(encoder_inputs, eval_mode) 168 | for i in range(len(tree_idxes)): 169 | tree_manager_idx, cur_idx = tree_idxes[i] 170 | tree_manager = tree_managers[tree_manager_idx] 171 | child_h = encoder_outputs[0][:, i, :] 172 | child_c = encoder_outputs[1][:, i, :] 173 | tree_managers[tree_manager_idx].trees[cur_idx].state = child_h, child_c 174 | cur_tree = tree_manager.get_tree(cur_idx) 175 | if cur_tree.parent != -1: 176 | parent_tree = tree_manager.get_tree(cur_tree.parent) 177 | canCompute = True 178 | children_h = [] 179 | children_c = [] 180 | for child_idx in parent_tree.children: 181 | child = tree_manager.get_tree(child_idx) 182 | if child.state is None: 183 | canCompute = False 184 | break 185 | else: 186 | child_h, child_c = child.state 187 | children_h.append(child_h) 188 | children_c.append(child_c) 189 | if canCompute: 190 | queue_nonterm.append((tree_manager_idx, cur_tree.parent, parent_tree.root, children_h, children_c)) 191 | 192 | while head_nonterm < len(queue_nonterm): 193 | encoder_inputs = [] 194 | children_h = [] 195 | children_c = [] 196 | tree_idxes = [] 197 | while head_nonterm < len(queue_nonterm): 198 | tree_manager_idx, idx, root, child_h, child_c = queue_nonterm[head_nonterm] 199 | cur_tree = tree_managers[tree_manager_idx].get_tree(idx) 200 | if cur_tree.state is None: 201 | tree_idxes.append((tree_manager_idx, idx)) 202 | encoder_inputs.append(root) 203 | children_h.append(child_h) 204 | children_c.append(child_c) 205 | head_nonterm += 1 206 | if len(encoder_inputs) == 0: 207 | break 208 | encoder_outputs = self.calc_root(encoder_inputs, children_h, children_c) 209 | for i in range(len(tree_idxes)): 210 | tree_manager_idx, cur_idx = tree_idxes[i] 211 | tree_manager = tree_managers[tree_manager_idx] 212 | child_h = encoder_outputs[0][i] 213 | child_c = encoder_outputs[1][i] 214 | tree_managers[tree_manager_idx].trees[cur_idx].state = child_h, child_c 215 | cur_tree = tree_manager.get_tree(cur_idx) 216 | if cur_tree.parent != -1: 217 | parent_tree = tree_manager.get_tree(cur_tree.parent) 218 | canCompute = True 219 | children_h = [] 220 | children_c = [] 221 | for child_idx in parent_tree.children: 222 | child = tree_manager.get_tree(child_idx) 223 | if child.state is None: 224 | canCompute = False 225 | break 226 | else: 227 | child_h, child_c = child.state 228 | children_h.append(child_h) 229 | children_c.append(child_c) 230 | if canCompute: 231 | queue_nonterm.append((tree_manager_idx, cur_tree.parent, parent_tree.root, children_h, children_c)) 232 | return tree_managers 233 | 234 | 235 | def update_embedding(self, tree_managers, init_queues, eval_mode=False): 236 | queue_term = [] 237 | queue_nonterm = [] 238 | head_term = 0 239 | head_nonterm = 0 240 | 241 | for tree_manager_idx in range(len(tree_managers)): 242 | tree_manager = tree_managers[tree_manager_idx] 243 | init_queue = init_queues[tree_manager_idx] 244 | for idx in init_queue: 245 | if idx == -1: 246 | continue 247 | cur_tree = tree_manager.get_tree(idx) 248 | canCompute = True 249 | children_h = [] 250 | children_c = [] 251 | for child_idx in cur_tree.children: 252 | child = tree_manager.get_tree(child_idx) 253 | if child.state is None: 254 | canCompute = False 255 | break 256 | else: 257 | child_h, child_c = child.state 258 | children_h.append(child_h) 259 | children_c.append(child_c) 260 | if len(cur_tree.children) == 0: 261 | queue_term.append((tree_manager_idx, idx, cur_tree.root)) 262 | elif canCompute: 263 | queue_nonterm.append((tree_manager_idx, idx, cur_tree.root, children_h, children_c)) 264 | 265 | while head_term < len(queue_term): 266 | encoder_inputs = [] 267 | tree_idxes = [] 268 | while head_term < len(queue_term): 269 | tree_manager_idx, idx, root = queue_term[head_term] 270 | tree_idxes.append((tree_manager_idx, idx)) 271 | encoder_inputs.append(root) 272 | head_term += 1 273 | if len(encoder_inputs) == 0: 274 | break 275 | encoder_outputs = self.input_embedding(encoder_inputs, eval_mode) 276 | for i in range(len(tree_idxes)): 277 | tree_manager_idx, cur_idx = tree_idxes[i] 278 | tree_manager = tree_managers[tree_manager_idx] 279 | child_h = encoder_outputs[0][:, i, :] 280 | child_c = encoder_outputs[1][:, i, :] 281 | tree_managers[tree_manager_idx].trees[cur_idx].state = child_h, child_c 282 | cur_tree = tree_manager.get_tree(cur_idx) 283 | if cur_tree.parent != -1: 284 | parent_tree = tree_manager.get_tree(cur_tree.parent) 285 | canCompute = True 286 | children_h = [] 287 | children_c = [] 288 | for child_idx in parent_tree.children: 289 | child = tree_manager.get_tree(child_idx) 290 | if child.state is None: 291 | canCompute = False 292 | break 293 | else: 294 | child_h, child_c = child.state 295 | children_h.append(child_h) 296 | children_c.append(child_c) 297 | if canCompute: 298 | queue_nonterm.append((tree_manager_idx, cur_tree.parent, parent_tree.root, children_h, children_c)) 299 | 300 | while head_nonterm < len(queue_nonterm): 301 | encoder_inputs = [] 302 | children_h = [] 303 | children_c = [] 304 | tree_idxes = [] 305 | while head_nonterm < len(queue_nonterm): 306 | tree_manager_idx, idx, root, child_h, child_c = queue_nonterm[head_nonterm] 307 | cur_tree = tree_managers[tree_manager_idx].get_tree(idx) 308 | tree_idxes.append((tree_manager_idx, idx)) 309 | encoder_inputs.append(root) 310 | children_h.append(child_h) 311 | children_c.append(child_c) 312 | head_nonterm += 1 313 | if len(encoder_inputs) == 0: 314 | break 315 | encoder_outputs = self.calc_root(encoder_inputs, children_h, children_c) 316 | for i in range(len(tree_idxes)): 317 | tree_manager_idx, cur_idx = tree_idxes[i] 318 | tree_manager = tree_managers[tree_manager_idx] 319 | child_h = encoder_outputs[0][i] 320 | child_c = encoder_outputs[1][i] 321 | tree_managers[tree_manager_idx].trees[cur_idx].state = child_h, child_c 322 | cur_tree = tree_manager.get_tree(cur_idx) 323 | if cur_tree.parent != -1: 324 | parent_tree = tree_manager.get_tree(cur_tree.parent) 325 | canCompute = True 326 | children_h = [] 327 | children_c = [] 328 | for child_idx in parent_tree.children: 329 | child = tree_manager.get_tree(child_idx) 330 | if child.state is None: 331 | canCompute = False 332 | break 333 | else: 334 | child_h, child_c = child.state 335 | children_h.append(child_h) 336 | children_c.append(child_c) 337 | if canCompute: 338 | queue_nonterm.append((tree_manager_idx, cur_tree.parent, parent_tree.root, children_h, children_c)) 339 | 340 | return tree_managers 341 | 342 | 343 | -------------------------------------------------------------------------------- /src/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/neural-rewriter/356c468a6ed54ec2ef8a007cc3c4bbdf6ab9b96b/src/models/modules/__init__.py -------------------------------------------------------------------------------- /src/models/modules/jspInputEncoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import random 9 | import numpy as np 10 | import time 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | 17 | from ..data_utils import data_utils 18 | 19 | 20 | class SeqLSTM(nn.Module): 21 | """ 22 | LSTM to embed the input as a sequence. 23 | """ 24 | def __init__(self, args): 25 | super(SeqLSTM, self).__init__() 26 | self.batch_size = args.batch_size 27 | self.hidden_size = args.LSTM_hidden_size 28 | self.embedding_size = args.embedding_size 29 | self.num_layers = args.num_LSTM_layers 30 | self.dropout_rate = args.dropout_rate 31 | self.cuda_flag = args.cuda 32 | self.encoder = nn.LSTM(input_size=self.embedding_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True, dropout=self.dropout_rate) 33 | 34 | 35 | def calc_embedding(self, dag_managers, eval_mode=False): 36 | encoder_input = [] 37 | encoder_input_idx = [] 38 | max_node_cnt = 0 39 | batch_size = len(dag_managers) 40 | 41 | for dag_manager in dag_managers: 42 | cur_encoder_input = [] 43 | cur_encoder_input_idx = [] 44 | max_node_cnt = max(max_node_cnt, dag_manager.num_jobs) 45 | for st in range(dag_manager.max_schedule_time + 1): 46 | for idx in dag_manager.schedule[st]: 47 | cur_encoder_input_idx.append(idx) 48 | cur_encoder_input.append(dag_manager.get_node(idx).embedding) 49 | encoder_input.append(cur_encoder_input) 50 | encoder_input_idx.append(cur_encoder_input_idx) 51 | 52 | for i in range(batch_size): 53 | while len(encoder_input[i]) < max_node_cnt: 54 | encoder_input[i].append([0.0 for _ in range(self.embedding_size)]) 55 | 56 | encoder_input = np.array(encoder_input) 57 | encoder_input = data_utils.np_to_tensor(encoder_input, 'float', self.cuda_flag, eval_mode) 58 | init_h = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)) 59 | init_c = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)) 60 | if self.cuda_flag: 61 | init_h = init_h.cuda() 62 | init_c = init_c.cuda() 63 | init_state = (init_h, init_c) 64 | encoder_output, encoder_state = self.encoder(encoder_input, init_state) 65 | 66 | for dag_manager_idx, dag_manager in enumerate(dag_managers): 67 | for i, node_idx in enumerate(encoder_input_idx[dag_manager_idx]): 68 | dag_managers[dag_manager_idx].nodes[node_idx].state = (encoder_output[dag_manager_idx, i, :].unsqueeze(0), encoder_output[dag_manager_idx, i, :].unsqueeze(0)) 69 | init_h = Variable(torch.zeros(1, self.hidden_size)) 70 | init_c = Variable(torch.zeros(1, self.hidden_size)) 71 | if self.cuda_flag: 72 | init_h = init_h.cuda() 73 | init_c = init_c.cuda() 74 | init_state = (init_h, init_c) 75 | dag_managers[dag_manager_idx].nodes[0].state = init_state 76 | return dag_managers 77 | 78 | 79 | class DagLSTM(nn.Module): 80 | """ 81 | LSTM to embed the input as a DAG. 82 | """ 83 | def __init__(self, args): 84 | super(DagLSTM, self).__init__() 85 | self.batch_size = args.batch_size 86 | self.hidden_size = args.LSTM_hidden_size 87 | self.embedding_size = args.embedding_size 88 | self.dropout_rate = args.dropout_rate 89 | self.cuda_flag = args.cuda 90 | 91 | self.ix = nn.Linear(self.embedding_size, self.hidden_size, bias=True) 92 | self.ih = nn.Linear(self.hidden_size, self.hidden_size) 93 | self.fx = nn.Linear(self.embedding_size, self.hidden_size, bias=True) 94 | self.fh = nn.Linear(self.hidden_size, self.hidden_size) 95 | self.ox = nn.Linear(self.embedding_size, self.hidden_size, bias=True) 96 | self.oh = nn.Linear(self.hidden_size, self.hidden_size) 97 | self.ux = nn.Linear(self.embedding_size, self.hidden_size, bias=True) 98 | self.uh = nn.Linear(self.hidden_size, self.hidden_size) 99 | 100 | 101 | def calc_root(self, inputs, child_h, child_c): 102 | child_h_sum = torch.sum(child_h, 1) 103 | i = F.sigmoid(self.ix(inputs) + self.ih(child_h_sum)) 104 | o = F.sigmoid(self.ox(inputs) + self.oh(child_h_sum)) 105 | u = F.tanh(self.ux(inputs) + self.uh(child_h_sum)) 106 | 107 | fx = self.fx(inputs) 108 | fx = fx.unsqueeze(1) 109 | fx = fx.repeat(1, child_h.size()[1], 1) 110 | f = self.fh(child_h) 111 | f = f + fx 112 | f = F.sigmoid(f) 113 | fc = F.torch.mul(f, child_c) 114 | fc = torch.sum(fc, 1) 115 | c = F.torch.mul(i, u) + fc 116 | h = F.torch.mul(o, F.tanh(c)) 117 | return h, c 118 | 119 | 120 | def calc_embedding(self, dag_managers, eval_mode=False): 121 | queue = [] 122 | head = 0 123 | 124 | for dag_manager_idx in range(len(dag_managers)): 125 | for idx in range(dag_managers[dag_manager_idx].num_nodes): 126 | dag_managers[dag_manager_idx].nodes[idx].state = None 127 | 128 | for dag_manager_idx in range(len(dag_managers)): 129 | dag_manager = dag_managers[dag_manager_idx] 130 | root_node = dag_manager.get_node(0) 131 | children_h = [] 132 | children_c = [] 133 | queue.append((dag_manager_idx, 0, root_node.embedding, children_h, children_c)) 134 | 135 | while head < len(queue): 136 | encoder_inputs = [] 137 | children_h = [] 138 | children_c = [] 139 | dag_idxes = [] 140 | max_children_size = 1 141 | while head < len(queue): 142 | dag_manager_idx, idx, embedding, child_h, child_c = queue[head] 143 | cur_node = dag_managers[dag_manager_idx].get_node(idx) 144 | dag_idxes.append((dag_manager_idx, idx)) 145 | encoder_inputs.append(embedding) 146 | children_h.append(child_h) 147 | children_c.append(child_c) 148 | max_children_size = max(max_children_size, len(child_h)) 149 | head += 1 150 | if len(encoder_inputs) == 0: 151 | break 152 | encoder_inputs = np.array(encoder_inputs) 153 | encoder_inputs = data_utils.np_to_tensor(encoder_inputs, 'float', self.cuda_flag, eval_mode) 154 | 155 | for idx in range(len(children_h)): 156 | while len(children_h[idx]) < max_children_size: 157 | init_child_h = Variable(torch.zeros(1, self.hidden_size)) 158 | init_child_c = Variable(torch.zeros(1, self.hidden_size)) 159 | if self.cuda_flag: 160 | init_child_h = init_child_h.cuda() 161 | init_child_c = init_child_c.cuda() 162 | children_h[idx].append(init_child_h) 163 | children_c[idx].append(init_child_c) 164 | children_h[idx] = torch.cat(children_h[idx], 0).unsqueeze(0) 165 | children_c[idx] = torch.cat(children_c[idx], 0).unsqueeze(0) 166 | 167 | children_h = torch.cat(children_h, 0) 168 | children_c = torch.cat(children_c, 0) 169 | 170 | encoder_outputs = self.calc_root(encoder_inputs, children_h, children_c) 171 | for i in range(len(dag_idxes)): 172 | dag_manager_idx, cur_idx = dag_idxes[i] 173 | dag_manager = dag_managers[dag_manager_idx] 174 | child_h = encoder_outputs[0][i].unsqueeze(0) 175 | child_c = encoder_outputs[1][i].unsqueeze(0) 176 | dag_managers[dag_manager_idx].nodes[cur_idx].state = child_h, child_c 177 | cur_node = dag_manager.get_node(cur_idx) 178 | if len(cur_node.children) > 0: 179 | for child_idx in cur_node.children: 180 | child_node = dag_manager.get_node(child_idx) 181 | canCompute = True 182 | children_h = [] 183 | children_c = [] 184 | for parent_idx in child_node.parents: 185 | parent = dag_manager.get_node(parent_idx) 186 | if parent.state is None: 187 | canCompute = False 188 | break 189 | else: 190 | child_h, child_c = parent.state 191 | children_h.append(child_h) 192 | children_c.append(child_c) 193 | if canCompute: 194 | queue.append((dag_manager_idx, child_idx, child_node.embedding, children_h, children_c)) 195 | return dag_managers 196 | 197 | 198 | def update_embedding(self, dag_managers, init_queues, eval_mode=False): 199 | queue = [] 200 | head = 0 201 | for dag_manager_idx in range(len(dag_managers)): 202 | init_queue = init_queues[dag_manager_idx] 203 | for idx in init_queue: 204 | dag_managers[dag_manager_idx].nodes[idx].state = None 205 | for dag_manager_idx in range(len(dag_managers)): 206 | dag_manager = dag_managers[dag_manager_idx] 207 | init_queue = init_queues[dag_manager_idx] 208 | for idx in init_queue: 209 | cur_node = dag_manager.get_node(idx) 210 | canCompute = True 211 | children_h = [] 212 | children_c = [] 213 | for parent_idx in cur_node.parents: 214 | parent = dag_manager.get_node(parent_idx) 215 | if parent.state is None: 216 | canCompute = False 217 | break 218 | else: 219 | child_h, child_c = parent.state 220 | children_h.append(child_h) 221 | children_c.append(child_c) 222 | if canCompute: 223 | queue.append((dag_manager_idx, idx, cur_node.embedding, children_h, children_c)) 224 | 225 | while head < len(queue): 226 | encoder_inputs = [] 227 | children_h = [] 228 | children_c = [] 229 | dag_idxes = [] 230 | max_children_size = 1 231 | while head < len(queue): 232 | dag_manager_idx, idx, embedding, child_h, child_c = queue[head] 233 | cur_node = dag_managers[dag_manager_idx].get_node(idx) 234 | if cur_node.state is None: 235 | dag_idxes.append((dag_manager_idx, idx)) 236 | encoder_inputs.append(embedding) 237 | children_h.append(child_h) 238 | children_c.append(child_c) 239 | max_children_size = max(max_children_size, len(child_h)) 240 | head += 1 241 | if len(encoder_inputs) == 0: 242 | break 243 | encoder_inputs = np.array(encoder_inputs) 244 | encoder_inputs = data_utils.np_to_tensor(encoder_inputs, 'float', self.cuda_flag, eval_mode) 245 | 246 | for idx in range(len(children_h)): 247 | while len(children_h[idx]) < max_children_size: 248 | init_child_h = Variable(torch.zeros(1, self.hidden_size)) 249 | init_child_c = Variable(torch.zeros(1, self.hidden_size)) 250 | if self.cuda_flag: 251 | init_child_h = init_child_h.cuda() 252 | init_child_c = init_child_c.cuda() 253 | children_h[idx].append(init_child_h) 254 | children_c[idx].append(init_child_c) 255 | children_h[idx] = torch.cat(children_h[idx], 0).unsqueeze(0) 256 | children_c[idx] = torch.cat(children_c[idx], 0).unsqueeze(0) 257 | 258 | children_h = torch.cat(children_h, 0) 259 | children_c = torch.cat(children_c, 0) 260 | encoder_outputs = self.calc_root(encoder_inputs, children_h, children_c) 261 | for i in range(len(dag_idxes)): 262 | dag_manager_idx, cur_idx = dag_idxes[i] 263 | dag_manager = dag_managers[dag_manager_idx] 264 | child_h = encoder_outputs[0][i].unsqueeze(0) 265 | child_c = encoder_outputs[1][i].unsqueeze(0) 266 | dag_managers[dag_manager_idx].nodes[cur_idx].state = child_h, child_c 267 | cur_node = dag_manager.get_node(cur_idx) 268 | if len(cur_node.children) > 0: 269 | for child_idx in cur_node.children: 270 | child_node = dag_manager.get_node(child_idx) 271 | canCompute = True 272 | children_h = [] 273 | children_c = [] 274 | for parent_idx in child_node.parents: 275 | parent = dag_manager.get_node(parent_idx) 276 | if parent.state is None: 277 | canCompute = False 278 | break 279 | else: 280 | child_h, child_c = parent.state 281 | children_h.append(child_h) 282 | children_c.append(child_c) 283 | if canCompute: 284 | queue.append((dag_manager_idx, child_idx, child_node.embedding, children_h, children_c)) 285 | 286 | return dag_managers -------------------------------------------------------------------------------- /src/models/modules/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import random 9 | import numpy as np 10 | import time 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | 17 | 18 | class MLPModel(nn.Module): 19 | """ 20 | Multi-layer perceptron module. 21 | """ 22 | def __init__(self, num_layers, input_size, hidden_size, output_size, cuda_flag, dropout_rate=0.0, activation=None): 23 | super(MLPModel, self).__init__() 24 | self.num_layers = num_layers 25 | self.input_size = input_size 26 | self.hidden_size = hidden_size 27 | self.output_size = output_size 28 | self.dropout_rate = dropout_rate 29 | self.cuda_flag = cuda_flag 30 | self.dropout = nn.Dropout(p=self.dropout_rate) 31 | self.model = nn.Sequential( 32 | nn.Linear(self.input_size, self.hidden_size), 33 | nn.ReLU()) 34 | for _ in range(self.num_layers): 35 | self.model = nn.Sequential( 36 | self.model, 37 | nn.Linear(self.hidden_size, self.hidden_size), 38 | nn.ReLU()) 39 | self.model = nn.Sequential( 40 | self.model, 41 | nn.Linear(self.hidden_size, self.output_size)) 42 | if activation is not None: 43 | self.model = nn.Sequential( 44 | self.model, 45 | activation) 46 | 47 | 48 | def forward(self, inputs): 49 | return self.model(inputs) -------------------------------------------------------------------------------- /src/models/modules/vrpInputEncoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import random 9 | import numpy as np 10 | import time 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | 17 | from ..data_utils import data_utils 18 | 19 | class SeqLSTM(nn.Module): 20 | """ 21 | LSTM to embed the input as a sequence. 22 | """ 23 | def __init__(self, args): 24 | super(SeqLSTM, self).__init__() 25 | self.batch_size = args.batch_size 26 | self.hidden_size = args.LSTM_hidden_size 27 | self.embedding_size = args.embedding_size 28 | self.num_layers = args.num_LSTM_layers 29 | self.dropout_rate = args.dropout_rate 30 | self.cuda_flag = args.cuda 31 | self.encoder = nn.LSTM(input_size=self.embedding_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True, dropout=self.dropout_rate, bidirectional=True) 32 | 33 | 34 | def calc_embedding(self, seq_managers, eval_mode=False): 35 | encoder_input = [] 36 | max_node_cnt = 0 37 | batch_size = len(seq_managers) 38 | 39 | for seq_manager in seq_managers: 40 | encoder_input.append(seq_manager.route[:]) 41 | max_node_cnt = max(max_node_cnt, len(seq_manager.route)) 42 | 43 | for i in range(batch_size): 44 | while len(encoder_input[i]) < max_node_cnt: 45 | encoder_input[i].append([0.0 for _ in range(self.embedding_size)]) 46 | 47 | encoder_input = np.array(encoder_input) 48 | encoder_input = data_utils.np_to_tensor(encoder_input, 'float', self.cuda_flag, eval_mode) 49 | encoder_output, encoder_state = self.encoder(encoder_input) 50 | 51 | for seq_manager_idx, seq_manager in enumerate(seq_managers): 52 | seq_managers[seq_manager_idx].encoder_outputs = encoder_output[seq_manager_idx] 53 | 54 | return seq_managers -------------------------------------------------------------------------------- /src/models/rewriter/__init__.py: -------------------------------------------------------------------------------- 1 | from .HalideRewriter import HalideRewriter 2 | from .jspRewriter import jspRewriter 3 | from .vrpRewriter import vrpRewriter -------------------------------------------------------------------------------- /src/models/rewriter/jspRewriter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import operator 10 | import random 11 | import time 12 | import copy 13 | 14 | from ..data_utils import data_utils 15 | from ..data_utils import Dag 16 | 17 | class jspRewriter(object): 18 | """ 19 | Rewriter for job scheduling. 20 | """ 21 | def move(self, dm, cur_idx, neighbor_idx): 22 | cur_job = dm.get_node(cur_idx) 23 | neighbor_job = dm.get_node(neighbor_idx) 24 | new_schedule_time = max(neighbor_job.ed_time, cur_job.st_time) 25 | old_schedule_time = cur_job.schedule_time 26 | if new_schedule_time == old_schedule_time: 27 | return dm, [] 28 | min_stop_time = min(new_schedule_time, neighbor_job.ed_time) 29 | res = dm.clone() 30 | for t in range(cur_job.schedule_time, cur_job.ed_time): 31 | res.resource_map[t] -= cur_job.resource_size 32 | res.add_job(cur_idx, new_schedule_time) 33 | res.max_schedule_time = max(res.max_schedule_time, new_schedule_time) 34 | res.max_ed_time = max(res.max_ed_time, new_schedule_time + cur_job.job_len) 35 | 36 | updated_schedule_time = min(min_stop_time, old_schedule_time) 37 | scheduled_node_idxes = [cur_idx] 38 | 39 | time_step = updated_schedule_time 40 | while time_step <= res.max_schedule_time: 41 | temp_old_schedule = res.schedule[time_step].copy() 42 | for temp_job_idx in temp_old_schedule: 43 | temp_job = res.get_node(temp_job_idx) 44 | temp_old_schedule_time = temp_job.schedule_time 45 | new_schedule_time = temp_job.st_time 46 | for t in range(temp_job.schedule_time, temp_job.ed_time): 47 | res.resource_map[t] -= temp_job.resource_size 48 | new_schedule_time = res.calc_min_schedule_time(new_schedule_time, temp_job_idx) 49 | res.add_job(temp_job_idx, new_schedule_time) 50 | if not temp_job_idx in scheduled_node_idxes: 51 | scheduled_node_idxes.append(temp_job_idx) 52 | res.max_schedule_time = max(res.max_schedule_time, new_schedule_time) 53 | time_step += 1 54 | 55 | for idx in scheduled_node_idxes: 56 | job = res.get_node(idx) 57 | old_parents = job.parents.copy() 58 | old_children = job.children.copy() 59 | for parent_idx in old_parents: 60 | res.del_edge(parent_idx, idx) 61 | if job.schedule_time == job.st_time: 62 | res.add_edge(res.root, idx) 63 | else: 64 | schedule_idx = res.schedule[job.schedule_time].index(idx) 65 | if schedule_idx == 0: 66 | res.add_edge(res.terminate[job.schedule_time][-1], idx) 67 | else: 68 | res.add_edge(res.schedule[job.schedule_time][schedule_idx - 1], idx) 69 | res.update_embedding(idx) 70 | res.nodes[idx].parents.sort() 71 | dm.nodes[idx].parents.sort() 72 | if res.nodes[idx].parents != dm.nodes[idx].parents or res.nodes[idx].embedding != dm.nodes[idx].embedding: 73 | updated_schedule_time = min(updated_schedule_time, res.nodes[idx].schedule_time) 74 | res.update_stat() 75 | 76 | updated_node_idxes = [] 77 | for time_step in range(updated_schedule_time, res.max_schedule_time + 1): 78 | for idx in res.schedule[time_step]: 79 | updated_node_idxes.append(idx) 80 | return res, updated_node_idxes 81 | -------------------------------------------------------------------------------- /src/models/rewriter/vrpRewriter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import operator 10 | import random 11 | import time 12 | import copy 13 | 14 | from ..data_utils import data_utils 15 | 16 | 17 | class vrpRewriter(object): 18 | """ 19 | Rewriter for vehicle routing. 20 | """ 21 | def move(self, dm, cur_route_idx, neighbor_route_idx): 22 | min_update_idx = min(cur_route_idx, neighbor_route_idx) 23 | res = dm.clone() 24 | old_vehicle_state = res.vehicle_state[:] 25 | old_vehicle_state[cur_route_idx], old_vehicle_state[neighbor_route_idx] = old_vehicle_state[neighbor_route_idx], old_vehicle_state[cur_route_idx] 26 | if old_vehicle_state[neighbor_route_idx][0] == 0: 27 | del old_vehicle_state[neighbor_route_idx] 28 | res.vehicle_state = res.vehicle_state[:min_update_idx] 29 | res.route = res.route[:min_update_idx] 30 | res.tot_dis = res.tot_dis[:min_update_idx] 31 | cur_node_idx, cur_capacity = res.vehicle_state[-1] 32 | for t in range(min_update_idx, len(old_vehicle_state)): 33 | new_node_idx, new_capacity = old_vehicle_state[t] 34 | new_node = res.get_node(new_node_idx) 35 | if new_node_idx != 0 and cur_capacity < new_node.demand: 36 | res.add_route_node(0) 37 | res.add_route_node(new_node_idx) 38 | cur_capacity = res.vehicle_state[-1][1] 39 | return res 40 | -------------------------------------------------------------------------------- /src/models/vrpModel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import operator 10 | import random 11 | import time 12 | from multiprocessing.pool import ThreadPool 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from torch import cuda 18 | from torch.autograd import Variable 19 | from torch.nn.utils import clip_grad_norm 20 | import torch.nn.functional as F 21 | from torch.distributions.categorical import Categorical 22 | 23 | from .data_utils import data_utils 24 | from .modules import vrpInputEncoder, mlp 25 | from .rewriter import vrpRewriter 26 | from .BaseModel import BaseModel 27 | 28 | eps = 1e-3 29 | log_eps = np.log(eps) 30 | 31 | 32 | class vrpModel(BaseModel): 33 | """ 34 | Model architecture for vehicle routing. 35 | """ 36 | def __init__(self, args): 37 | super(vrpModel, self).__init__(args) 38 | self.input_format = args.input_format 39 | self.embedding_size = args.embedding_size 40 | self.attention_size = args.attention_size 41 | self.sqrt_attention_size = int(np.sqrt(self.attention_size)) 42 | self.reward_thres = -0.01 43 | self.input_encoder = vrpInputEncoder.SeqLSTM(args) 44 | self.policy_embedding = mlp.MLPModel(self.num_MLP_layers, self.LSTM_hidden_size * 6 + self.embedding_size * 2, self.MLP_hidden_size, self.attention_size, self.cuda_flag, self.dropout_rate) 45 | self.policy = mlp.MLPModel(self.num_MLP_layers, self.LSTM_hidden_size * 4, self.MLP_hidden_size, self.attention_size, self.cuda_flag, self.dropout_rate) 46 | self.value_estimator = mlp.MLPModel(self.num_MLP_layers, self.LSTM_hidden_size * 4, self.MLP_hidden_size, 1, self.cuda_flag, self.dropout_rate) 47 | self.rewriter = vrpRewriter() 48 | 49 | if args.optimizer == 'adam': 50 | self.optimizer = optim.Adam(self.parameters(), lr=self.lr) 51 | elif args.optimizer == 'sgd': 52 | self.optimizer = optim.SGD(self.parameters(), lr=self.lr) 53 | elif args.optimizer == 'rmsprop': 54 | self.optimizer = optim.RMSprop(self.parameters(), lr=self.lr) 55 | else: 56 | raise ValueError('optimizer undefined: ', args.optimizer) 57 | 58 | 59 | def rewrite(self, dm, trace_rec, candidate_rewrite_pos, eval_flag, max_search_pos, reward_thres=None): 60 | 61 | candidate_rewrite_pos.sort(reverse=True, key=operator.itemgetter(0)) 62 | if not eval_flag: 63 | sample_exp_reward_tensor = [] 64 | for idx, (cur_pred_reward, cur_pred_reward_tensor, rewrite_pos) in enumerate(candidate_rewrite_pos): 65 | sample_exp_reward_tensor.append(cur_pred_reward_tensor) 66 | sample_exp_reward_tensor = torch.cat(sample_exp_reward_tensor, 0) 67 | sample_exp_reward_tensor = torch.exp(sample_exp_reward_tensor * 10) 68 | sample_exp_reward = sample_exp_reward_tensor.data.cpu() 69 | 70 | candidate_dm = [] 71 | candidate_rewrite_rec = [] 72 | candidate_trace_rec = [] 73 | candidate_scores = [] 74 | 75 | if not eval_flag: 76 | sample_rewrite_pos_dist = Categorical(sample_exp_reward_tensor) 77 | sample_rewrite_pos = sample_rewrite_pos_dist.sample(sample_shape=[len(candidate_rewrite_pos)]) 78 | #sample_rewrite_pos = torch.multinomial(sample_exp_reward_tensor, len(candidate_rewrite_pos)) 79 | sample_rewrite_pos = sample_rewrite_pos.data.cpu().numpy() 80 | indexes = np.unique(sample_rewrite_pos, return_index=True)[1] 81 | sample_rewrite_pos = [sample_rewrite_pos[i] for i in sorted(indexes)] 82 | sample_rewrite_pos = sample_rewrite_pos[:self.num_sample_rewrite_pos] 83 | sample_exp_reward = [sample_exp_reward[i] for i in sample_rewrite_pos] 84 | sample_rewrite_pos = [candidate_rewrite_pos[i] for i in sample_rewrite_pos] 85 | else: 86 | sample_rewrite_pos = candidate_rewrite_pos.copy() 87 | 88 | for idx, (pred_reward, cur_pred_reward_tensor, rewrite_pos) in enumerate(sample_rewrite_pos): 89 | if len(candidate_dm) > 0 and idx >= max_search_pos: 90 | break 91 | if reward_thres is not None and pred_reward < reward_thres: 92 | if eval_flag: 93 | break 94 | elif np.random.random() > self.cont_prob: 95 | continue 96 | candidate_neighbor_idxes = dm.get_neighbor_idxes(rewrite_pos) 97 | cur_node_idx = dm.vehicle_state[rewrite_pos][0] 98 | cur_node = dm.get_node(cur_node_idx) 99 | pre_node_idx = dm.vehicle_state[rewrite_pos - 1][0] 100 | pre_node = dm.get_node(pre_node_idx) 101 | pre_capacity = dm.vehicle_state[rewrite_pos - 1][1] 102 | depot = dm.get_node(0) 103 | depot_state = dm.encoder_outputs[0].unsqueeze(0) 104 | cur_state = dm.encoder_outputs[rewrite_pos].unsqueeze(0) 105 | cur_states_0 = [] 106 | cur_states_1 = [] 107 | cur_states_2 = [] 108 | new_embeddings_0 = [] 109 | new_embeddings_1 = [] 110 | for i in candidate_neighbor_idxes: 111 | neighbor_idx = dm.vehicle_state[i][0] 112 | neighbor_node = dm.get_node(neighbor_idx) 113 | cur_states_0.append(depot_state.clone()) 114 | cur_states_1.append(cur_state.clone()) 115 | cur_states_2.append(dm.encoder_outputs[i].unsqueeze(0)) 116 | if pre_capacity >= neighbor_node.demand: 117 | new_embedding = [neighbor_node.x, neighbor_node.y, neighbor_node.demand * 1.0 / dm.capacity, pre_node.x, pre_node.y, neighbor_node.demand * 1.0 / pre_capacity, dm.get_dis(pre_node, neighbor_node)] 118 | else: 119 | new_embedding = [neighbor_node.x, neighbor_node.y, neighbor_node.demand * 1.0 / dm.capacity, pre_node.x, pre_node.y, neighbor_node.demand * 1.0 / dm.capacity, dm.get_dis(pre_node, depot) + dm.get_dis(depot, neighbor_node)] 120 | new_embeddings_0.append(new_embedding[:]) 121 | if pre_capacity >= neighbor_node.demand: 122 | new_embedding = [(neighbor_node.x - depot.x) * (pre_node.x - depot.x), (neighbor_node.y - depot.y) * (pre_node.y - depot.y), (neighbor_node.demand - cur_node.demand) * 1.0 / pre_capacity, pre_node.px, pre_node.py, \ 123 | (neighbor_node.demand - cur_node.demand) * 1.0 / dm.capacity, dm.get_dis(pre_node, depot) + dm.get_dis(depot, neighbor_node)] 124 | else: 125 | new_embedding = [(neighbor_node.x - depot.x) * (pre_node.x - depot.x), (neighbor_node.y - depot.y) * (pre_node.y - depot.y), (neighbor_node.demand - cur_node.demand) * 1.0 / dm.capacity, pre_node.px, pre_node.py, \ 126 | (neighbor_node.demand - cur_node.demand) * 1.0 / dm.capacity, dm.get_dis(pre_node, depot) + dm.get_dis(depot, neighbor_node)] 127 | new_embeddings_1.append(new_embedding[:]) 128 | cur_states_0 = torch.cat(cur_states_0, 0) 129 | cur_states_1 = torch.cat(cur_states_1, 0) 130 | cur_states_2 = torch.cat(cur_states_2, 0) 131 | new_embeddings_0 = data_utils.np_to_tensor(new_embeddings_0, 'float', self.cuda_flag) 132 | new_embeddings_1 = data_utils.np_to_tensor(new_embeddings_1, 'float', self.cuda_flag) 133 | policy_inputs = torch.cat([cur_states_0, cur_states_1, cur_states_2, new_embeddings_0, new_embeddings_1], 1) 134 | ctx_embeddings = self.policy_embedding(policy_inputs) 135 | cur_state_key = self.policy(torch.cat([cur_state, depot_state], dim=1)) 136 | ac_logits = torch.matmul(cur_state_key, torch.transpose(ctx_embeddings, 0, 1)) / self.sqrt_attention_size 137 | ac_logprobs = nn.LogSoftmax()(ac_logits) 138 | ac_probs = nn.Softmax()(ac_logits) 139 | ac_logits = ac_logits.squeeze(0) 140 | ac_logprobs = ac_logprobs.squeeze(0) 141 | ac_probs = ac_probs.squeeze(0) 142 | if eval_flag: 143 | _, candidate_acs = torch.sort(ac_logprobs, descending=True) 144 | candidate_acs = candidate_acs.data.cpu().numpy() 145 | else: 146 | candidate_acs_dist = Categorical(ac_probs) 147 | candidate_acs = candidate_acs_dist.sample(sample_shape=[ac_probs.size()[0]]) 148 | #candidate_acs = torch.multinomial(ac_probs, ac_probs.size()[0]) 149 | candidate_acs = candidate_acs.data.cpu().numpy() 150 | indexes = np.unique(candidate_acs, return_index=True)[1] 151 | candidate_acs = [candidate_acs[i] for i in sorted(indexes)] 152 | 153 | for i in candidate_acs: 154 | neighbor_idx = candidate_neighbor_idxes[i] 155 | new_dm = self.rewriter.move(dm, rewrite_pos, neighbor_idx) 156 | if new_dm.tot_dis[-1] in trace_rec: 157 | continue 158 | candidate_dm.append(new_dm) 159 | candidate_rewrite_rec.append((ac_logprobs, pred_reward, cur_pred_reward_tensor, rewrite_pos, i, new_dm.tot_dis[-1])) 160 | if len(candidate_dm) >= max_search_pos: 161 | break 162 | 163 | return candidate_dm, candidate_rewrite_rec 164 | 165 | 166 | def batch_rewrite(self, dm, trace_rec, candidate_rewrite_pos, eval_flag, max_search_pos, reward_thres): 167 | candidate_dm = [] 168 | candidate_rewrite_rec = [] 169 | for i in range(len(dm)): 170 | cur_candidate_dm, cur_candidate_rewrite_rec = self.rewrite(dm[i], trace_rec[i], candidate_rewrite_pos[i], eval_flag, max_search_pos, reward_thres) 171 | candidate_dm.append(cur_candidate_dm) 172 | candidate_rewrite_rec.append(cur_candidate_rewrite_rec) 173 | return candidate_dm, candidate_rewrite_rec 174 | 175 | 176 | def forward(self, batch_data, eval_flag=False): 177 | torch.set_grad_enabled(not eval_flag) 178 | dm_list = [] 179 | batch_size = len(batch_data) 180 | for dm in batch_data: 181 | dm_list.append(dm) 182 | dm_list = self.input_encoder.calc_embedding(dm_list, eval_flag) 183 | 184 | active = True 185 | reduce_steps = 0 186 | 187 | trace_rec = [{} for _ in range(batch_size)] 188 | rewrite_rec = [[] for _ in range(batch_size)] 189 | dm_rec = [[] for _ in range(batch_size)] 190 | 191 | for idx in range(batch_size): 192 | dm_rec[idx].append(dm_list[idx]) 193 | trace_rec[idx][dm_list[idx].tot_dis[-1]] = 0 194 | 195 | while active and (self.max_reduce_steps is None or reduce_steps < self.max_reduce_steps): 196 | active = False 197 | reduce_steps += 1 198 | node_idxes = [] 199 | node_states = [] 200 | depot_states = [] 201 | for dm_idx in range(batch_size): 202 | dm = dm_list[dm_idx] 203 | for i in range(1, len(dm.vehicle_state) - 1): 204 | cur_node_idx = dm.vehicle_state[i][0] 205 | cur_node = dm.get_node(cur_node_idx) 206 | node_idxes.append((dm_idx, i)) 207 | node_states.append(dm.encoder_outputs[i].unsqueeze(0)) 208 | depot_states.append(dm.encoder_outputs[0].clone().unsqueeze(0)) 209 | pred_rewards = [] 210 | for st in range(0, len(node_idxes), self.batch_size): 211 | cur_node_states = node_states[st: st + self.batch_size] 212 | cur_node_states = torch.cat(cur_node_states, 0) 213 | cur_depot_states = depot_states[st: st + self.batch_size] 214 | cur_depot_states = torch.cat(cur_depot_states, 0) 215 | cur_pred_rewards = self.value_estimator(torch.cat([cur_node_states, cur_depot_states], dim=1)) 216 | pred_rewards.append(cur_pred_rewards) 217 | pred_rewards = torch.cat(pred_rewards, 0) 218 | candidate_rewrite_pos = [[] for _ in range(batch_size)] 219 | for idx, (dm_idx, node_idx) in enumerate(node_idxes): 220 | candidate_rewrite_pos[dm_idx].append((pred_rewards[idx].data[0], pred_rewards[idx], node_idx)) 221 | 222 | candidate_dm, candidate_rewrite_rec = self.batch_rewrite(dm_list, trace_rec, candidate_rewrite_pos, eval_flag, max_search_pos=1, reward_thres=self.reward_thres) 223 | for dm_idx in range(batch_size): 224 | cur_candidate_dm = candidate_dm[dm_idx] 225 | cur_candidate_rewrite_rec = candidate_rewrite_rec[dm_idx] 226 | if len(cur_candidate_dm) > 0: 227 | active = True 228 | cur_dm = cur_candidate_dm[0] 229 | cur_rewrite_rec = cur_candidate_rewrite_rec[0] 230 | dm_list[dm_idx] = cur_dm 231 | rewrite_rec[dm_idx].append(cur_rewrite_rec) 232 | trace_rec[dm_idx][cur_dm.tot_dis[-1]] = 0 233 | if not active: 234 | break 235 | 236 | updated_dm = self.input_encoder.calc_embedding(dm_list, eval_flag) 237 | for i in range(batch_size): 238 | if updated_dm[i].tot_dis[-1] != dm_rec[i][-1].tot_dis[-1]: 239 | dm_rec[i].append(updated_dm[i]) 240 | 241 | total_policy_loss = data_utils.np_to_tensor(np.zeros(1), 'float', self.cuda_flag) 242 | total_value_loss = data_utils.np_to_tensor(np.zeros(1), 'float', self.cuda_flag) 243 | 244 | pred_value_rec = [] 245 | value_target_rec = [] 246 | total_reward = 0 247 | total_rewrite_steps = 0 248 | for dm_idx, cur_dm_rec in enumerate(dm_rec): 249 | pred_dis = [] 250 | for dm in cur_dm_rec: 251 | pred_dis.append(dm.tot_dis[-1]) 252 | best_reward = pred_dis[0] 253 | 254 | for idx, (ac_logprob, pred_reward, cur_pred_reward_tensor, rewrite_pos, applied_op, new_dis) in enumerate(rewrite_rec[dm_idx]): 255 | cur_reward = pred_dis[idx] - pred_dis[idx + 1] 256 | best_reward = min(best_reward, pred_dis[idx + 1]) 257 | 258 | if self.gamma > 0.0: 259 | decay_coef = 1.0 260 | num_rollout_steps = len(cur_dm_rec) - idx - 1 261 | for i in range(idx + 1, idx + 1 + num_rollout_steps): 262 | cur_reward = max(decay_coef * (pred_dis[idx] - pred_dis[i]), cur_reward) 263 | decay_coef *= self.gamma 264 | 265 | cur_reward_tensor = data_utils.np_to_tensor(np.array([cur_reward], dtype=np.float32), 'float', self.cuda_flag, volatile_flag=True) 266 | if ac_logprob.data.cpu().numpy()[0] > log_eps or cur_reward - pred_reward > 0: 267 | ac_mask = np.zeros(ac_logprob.size()[0]) 268 | ac_mask[applied_op] = cur_reward - pred_reward 269 | ac_mask = data_utils.np_to_tensor(ac_mask, 'float', self.cuda_flag, eval_flag) 270 | total_policy_loss -= ac_logprob[applied_op] * ac_mask[applied_op] 271 | pred_value_rec.append(cur_pred_reward_tensor) 272 | value_target_rec.append(cur_reward_tensor) 273 | 274 | total_reward += best_reward 275 | 276 | if len(pred_value_rec) > 0: 277 | pred_value_rec = torch.cat(pred_value_rec, 0) 278 | value_target_rec = torch.cat(value_target_rec, 0) 279 | pred_value_rec = pred_value_rec.unsqueeze(1) 280 | value_target_rec = value_target_rec.unsqueeze(1) 281 | total_value_loss = F.smooth_l1_loss(pred_value_rec, value_target_rec, size_average=False) 282 | total_policy_loss /= batch_size 283 | total_value_loss /= batch_size 284 | total_loss = total_policy_loss * self.value_loss_coef + total_value_loss 285 | total_reward = total_reward * 1.0 / batch_size 286 | 287 | return total_loss, total_reward, dm_rec 288 | 289 | 290 | -------------------------------------------------------------------------------- /src/run_Halide.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import random 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | import arguments 13 | import models.data_utils.data_utils as data_utils 14 | import models.model_utils as model_utils 15 | from models.HalideModel import HalideModel 16 | 17 | 18 | def create_model(args, term_vocab=None, term_vocab_list=None, op_vocab=None, op_vocab_list=None): 19 | model = HalideModel(args, term_vocab, term_vocab_list, op_vocab, op_vocab_list) 20 | if model.cuda_flag: 21 | model = model.cuda() 22 | model.share_memory() 23 | model_supervisor = model_utils.HalideSupervisor(model, args, term_vocab, term_vocab_list, op_vocab, op_vocab_list) 24 | if args.load_model: 25 | model_supervisor.load_pretrained(args.load_model) 26 | elif args.resume: 27 | pretrained = 'ckpt-' + str(args.resume).zfill(8) 28 | print('Resume from {} iterations.'.format(args.resume)) 29 | model_supervisor.load_pretrained(args.model_dir+'/'+pretrained) 30 | else: 31 | print('Created model with fresh parameters.') 32 | model_supervisor.model.init_weights(args.param_init) 33 | return model_supervisor 34 | 35 | 36 | def train(args): 37 | print('Training:') 38 | 39 | train_data = data_utils.load_dataset(args.train_dataset, args) 40 | eval_data = data_utils.load_dataset(args.val_dataset, args) 41 | 42 | DataProcessor = data_utils.HalideDataProcessor() 43 | 44 | if args.train_proportion < 1.0: 45 | random.shuffle(train_data) 46 | train_data_size = int(train_data_size * args.train_proportion) 47 | train_data = train_data[:train_data_size] 48 | 49 | if args.train_max_len is not None: 50 | train_data = DataProcessor.prune_dataset(train_data, max_len=args.train_max_len) 51 | 52 | train_data_size = len(train_data) 53 | term_vocab, term_vocab_list = DataProcessor.load_term_vocab() 54 | op_vocab, op_vocab_list = DataProcessor.load_ops() 55 | args.term_vocab_size = len(term_vocab) 56 | args.op_vocab_size = len(op_vocab) 57 | model_supervisor = create_model(args, term_vocab, term_vocab_list, op_vocab, op_vocab_list) 58 | 59 | if args.resume: 60 | resume_step = True 61 | else: 62 | resume_step = False 63 | resume_idx = args.resume * args.batch_size 64 | 65 | logger = model_utils.Logger(args) 66 | if args.resume: 67 | logs = pd.read_csv("../logs/" + args.log_name) 68 | for index, log in logs.iterrows(): 69 | val_summary = {'avg_reward': log['avg_reward'], 'global_step': log['global_step']} 70 | logger.write_summary(val_summary) 71 | 72 | for epoch in range(resume_idx//train_data_size, args.num_epochs): 73 | random.shuffle(train_data) 74 | for batch_idx in range(0+resume_step*resume_idx%train_data_size, train_data_size, args.batch_size): 75 | resume_step = False 76 | print(epoch, batch_idx) 77 | batch_data = DataProcessor.get_batch(train_data, args.batch_size, batch_idx) 78 | train_loss, train_reward = model_supervisor.train(batch_data) 79 | print('train loss: %.4f train reward: %.4f' % (train_loss, train_reward)) 80 | 81 | if model_supervisor.global_step % args.eval_every_n == 0: 82 | eval_loss, eval_reward = model_supervisor.eval(eval_data, args.output_trace_flag, args.max_eval_size) 83 | val_summary = {'avg_reward': eval_reward, 'global_step': model_supervisor.global_step} 84 | logger.write_summary(val_summary) 85 | model_supervisor.save_model() 86 | 87 | if args.lr_decay_steps and model_supervisor.global_step % args.lr_decay_steps == 0: 88 | model_supervisor.model.lr_decay(args.lr_decay_rate) 89 | if model_supervisor.model.cont_prob > 0.01: 90 | model_supervisor.model.cont_prob *= 0.5 91 | 92 | 93 | def evaluate(args): 94 | print('Evaluation:') 95 | 96 | test_data = data_utils.load_dataset(args.test_dataset, args) 97 | test_data_size = len(test_data) 98 | 99 | args.dropout_rate = 0.0 100 | 101 | DataProcessor = data_utils.HalideDataProcessor() 102 | 103 | if args.test_min_len is not None: 104 | test_data = DataProcessor.prune_dataset(test_data, min_len=args.test_min_len) 105 | 106 | term_vocab, term_vocab_list = DataProcessor.load_term_vocab() 107 | op_vocab, op_vocab_list = DataProcessor.load_ops() 108 | args.term_vocab_size = len(term_vocab) 109 | args.op_vocab_size = len(op_vocab) 110 | model_supervisor = create_model(args, term_vocab, term_vocab_list, op_vocab, op_vocab_list) 111 | test_loss, test_reward = model_supervisor.eval(test_data, args.output_trace_flag, args.output_trace_option, args.output_trace_file) 112 | 113 | print('test loss: %.4f test reward: %.4f' % (test_loss, test_reward)) 114 | 115 | 116 | if __name__ == "__main__": 117 | argParser = arguments.get_arg_parser("Halide") 118 | args = argParser.parse_args() 119 | args.cuda = not args.cpu and torch.cuda.is_available() 120 | random.seed(args.seed) 121 | np.random.seed(args.seed) 122 | if args.eval: 123 | evaluate(args) 124 | else: 125 | train(args) 126 | -------------------------------------------------------------------------------- /src/run_jsp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import random 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | import arguments 13 | import models.data_utils.data_utils as data_utils 14 | import models.model_utils as model_utils 15 | from models.jspModel import jspModel 16 | 17 | 18 | def create_model(args, term_vocab=None, term_vocab_list=None, op_vocab=None, op_vocab_list=None): 19 | args.embedding_size = args.num_res * (args.max_job_len + 1) + 1 20 | model = jspModel(args) 21 | 22 | if model.cuda_flag: 23 | model = model.cuda() 24 | model.share_memory() 25 | model_supervisor = model_utils.jspSupervisor(model, args) 26 | if args.load_model: 27 | model_supervisor.load_pretrained(args.load_model) 28 | elif args.resume: 29 | pretrained = 'ckpt-' + str(args.resume).zfill(8) 30 | print('Resume from {} iterations.'.format(args.resume)) 31 | model_supervisor.load_pretrained(args.model_dir+'/'+pretrained) 32 | else: 33 | print('Created model with fresh parameters.') 34 | model_supervisor.model.init_weights(args.param_init) 35 | return model_supervisor 36 | 37 | 38 | def train(args): 39 | print('Training:') 40 | 41 | train_data = data_utils.load_dataset(args.train_dataset, args) 42 | train_data_size = len(train_data) 43 | if args.train_proportion < 1.0: 44 | random.shuffle(train_data) 45 | train_data_size = int(train_data_size * args.train_proportion) 46 | train_data = train_data[:train_data_size] 47 | 48 | eval_data = data_utils.load_dataset(args.val_dataset, args) 49 | 50 | DataProcessor = data_utils.jspDataProcessor(args) 51 | model_supervisor = create_model(args) 52 | 53 | if args.resume: 54 | resume_step = True 55 | else: 56 | resume_step = False 57 | resume_idx = args.resume * args.batch_size 58 | 59 | logger = model_utils.Logger(args) 60 | if args.resume: 61 | logs = pd.read_csv("../logs/" + args.log_name) 62 | for index, log in logs.iterrows(): 63 | val_summary = {'avg_reward': log['avg_reward'], 'global_step': log['global_step']} 64 | logger.write_summary(val_summary) 65 | 66 | for epoch in range(resume_idx//train_data_size, args.num_epochs): 67 | random.shuffle(train_data) 68 | for batch_idx in range(0+resume_step*resume_idx%train_data_size, train_data_size, args.batch_size): 69 | resume_step = False 70 | print(epoch, batch_idx) 71 | batch_data = DataProcessor.get_batch(train_data, args.batch_size, batch_idx) 72 | train_loss, train_reward = model_supervisor.train(batch_data) 73 | print('train loss: %.4f train reward: %.4f' % (train_loss, train_reward)) 74 | 75 | if model_supervisor.global_step % args.eval_every_n == 0: 76 | eval_loss, eval_reward = model_supervisor.eval(eval_data, args.output_trace_flag, args.max_eval_size) 77 | val_summary = {'avg_reward': eval_reward, 'global_step': model_supervisor.global_step} 78 | logger.write_summary(val_summary) 79 | model_supervisor.save_model() 80 | 81 | if args.lr_decay_steps and model_supervisor.global_step % args.lr_decay_steps == 0: 82 | model_supervisor.model.lr_decay(args.lr_decay_rate) 83 | if model_supervisor.model.cont_prob > 0.01: 84 | model_supervisor.model.cont_prob *= 0.5 85 | 86 | 87 | def evaluate(args): 88 | print('Evaluation:') 89 | 90 | test_data = data_utils.load_dataset(args.test_dataset, args) 91 | test_data_size = len(test_data) 92 | args.dropout_rate = 0.0 93 | 94 | dataProcessor = data_utils.jspDataProcessor(args) 95 | model_supervisor = create_model(args) 96 | test_loss, test_reward = model_supervisor.eval(test_data, args.output_trace_flag) 97 | 98 | 99 | print('test loss: %.4f test reward: %.4f' % (test_loss, test_reward)) 100 | 101 | 102 | if __name__ == "__main__": 103 | argParser = arguments.get_arg_parser("jsp") 104 | args = argParser.parse_args() 105 | args.cuda = not args.cpu and torch.cuda.is_available() 106 | random.seed(args.seed) 107 | np.random.seed(args.seed) 108 | if args.eval: 109 | evaluate(args) 110 | else: 111 | train(args) 112 | -------------------------------------------------------------------------------- /src/run_vrp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import random 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | import arguments 13 | import models.data_utils.data_utils as data_utils 14 | import models.model_utils as model_utils 15 | from models.vrpModel import vrpModel 16 | 17 | 18 | def create_model(args): 19 | model = vrpModel(args) 20 | 21 | if model.cuda_flag: 22 | model = model.cuda() 23 | model.share_memory() 24 | model_supervisor = model_utils.vrpSupervisor(model, args) 25 | if args.load_model: 26 | model_supervisor.load_pretrained(args.load_model) 27 | elif args.resume: 28 | pretrained = 'ckpt-' + str(args.resume).zfill(8) 29 | print('Resume from {} iterations.'.format(args.resume)) 30 | model_supervisor.load_pretrained(args.model_dir+'/'+pretrained) 31 | else: 32 | print('Created model with fresh parameters.') 33 | model_supervisor.model.init_weights(args.param_init) 34 | return model_supervisor 35 | 36 | 37 | def train(args): 38 | print('Training:') 39 | 40 | train_data = data_utils.load_dataset(args.train_dataset, args) 41 | train_data_size = len(train_data) 42 | if args.train_proportion < 1.0: 43 | random.shuffle(train_data) 44 | train_data_size = int(train_data_size * args.train_proportion) 45 | train_data = train_data[:train_data_size] 46 | 47 | eval_data = data_utils.load_dataset(args.val_dataset, args) 48 | 49 | DataProcessor = data_utils.vrpDataProcessor() 50 | model_supervisor = create_model(args) 51 | 52 | if args.resume: 53 | resume_step = True 54 | else: 55 | resume_step = False 56 | resume_idx = args.resume * args.batch_size 57 | 58 | logger = model_utils.Logger(args) 59 | if args.resume: 60 | logs = pd.read_csv("../logs/" + args.log_name) 61 | for index, log in logs.iterrows(): 62 | val_summary = {'avg_reward': log['avg_reward'], 'global_step': log['global_step']} 63 | logger.write_summary(val_summary) 64 | 65 | for epoch in range(resume_idx//train_data_size, args.num_epochs): 66 | random.shuffle(train_data) 67 | for batch_idx in range(0+resume_step*resume_idx%train_data_size, train_data_size, args.batch_size): 68 | resume_step = False 69 | print(epoch, batch_idx) 70 | batch_data = DataProcessor.get_batch(train_data, args.batch_size, batch_idx) 71 | train_loss, train_reward = model_supervisor.train(batch_data) 72 | print('train loss: %.4f train reward: %.4f' % (train_loss, train_reward)) 73 | 74 | if model_supervisor.global_step % args.eval_every_n == 0: 75 | eval_loss, eval_reward = model_supervisor.eval(eval_data, args.output_trace_flag, args.max_eval_size) 76 | val_summary = {'avg_reward': eval_reward, 'global_step': model_supervisor.global_step} 77 | logger.write_summary(val_summary) 78 | model_supervisor.save_model() 79 | 80 | if args.lr_decay_steps and model_supervisor.global_step % args.lr_decay_steps == 0: 81 | model_supervisor.model.lr_decay(args.lr_decay_rate) 82 | if model_supervisor.model.cont_prob > 0.01: 83 | model_supervisor.model.cont_prob *= 0.5 84 | 85 | 86 | def evaluate(args): 87 | print('Evaluation:') 88 | 89 | test_data = data_utils.load_dataset(args.test_dataset, args) 90 | test_data_size = len(test_data) 91 | args.dropout_rate = 0.0 92 | 93 | dataProcessor = data_utils.vrpDataProcessor() 94 | model_supervisor = create_model(args) 95 | test_loss, test_reward = model_supervisor.eval(test_data, args.output_trace_flag) 96 | 97 | 98 | print('test loss: %.4f test reward: %.4f' % (test_loss, test_reward)) 99 | 100 | 101 | if __name__ == "__main__": 102 | argParser = arguments.get_arg_parser("vrp") 103 | args = argParser.parse_args() 104 | args.cuda = not args.cpu and torch.cuda.is_available() 105 | random.seed(args.seed) 106 | np.random.seed(args.seed) 107 | if args.eval: 108 | evaluate(args) 109 | else: 110 | train(args) 111 | --------------------------------------------------------------------------------