├── README.md ├── LICENSE └── get_checkpoint.py /README.md: -------------------------------------------------------------------------------- 1 | # dl-tf-get-checkpoint 2 | Get checkpoint from checkpoints dir using the provided condition 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 floydhub 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /get_checkpoint.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import argparse 4 | import tensorflow as tf 5 | import shutil 6 | 7 | """ 8 | Get a checkpoint from the checkpoints dir using the provided 9 | condition. 10 | 11 | Currently only 'latest' checkpoint can be retrieved 12 | TODO: Add other choices 13 | """ 14 | 15 | # Parse command line args 16 | # ================================================== 17 | parser = argparse.ArgumentParser(description='Get particular checkpoint from dir') 18 | 19 | parser.add_argument('-i', '--checkpoints_dir', required=True, 20 | help='Checkpoints dir') 21 | parser.add_argument('-c', '--choice', required=True, default='latest', 22 | choices=['latest'], help='Method to choose checkpoint from dir') 23 | parser.add_argument('-o', '--checkpoint', required=True, 24 | help='Path to chosen checkpoint') 25 | 26 | args = parser.parse_args() 27 | 28 | # Convert args to dict 29 | vargs = vars(args) 30 | 31 | print("\nArguments:") 32 | for arg in vargs: 33 | print("{}={}".format(arg, getattr(args, arg))) 34 | 35 | if args.choice == 'latest': 36 | try: 37 | checkpoint_file = tf.train.latest_checkpoint(args.checkpoints_dir) 38 | shutil.copy2(checkpoint_file, args.checkpoint) 39 | except Exception as exc: 40 | raise Exception("Could not locate or copy latest checkpoint. Error: {}".format(exc)) 41 | else: 42 | raise NotImplementedError("Choice {} is not implemented yet".format(args.choice)) 43 | 44 | print("\n{} checkpoint written to {}".format(args.choice, args.checkpoint)) --------------------------------------------------------------------------------