├── README.md └── slurm_manager.py /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow_slurm_manager 2 | working on a cluster manager for TF 3 | 4 | Since graduating from the university that provided a slurm cluster, I no longer have access to a slurm cluster. Therefore, I haven't been maintaining this code base since Sep 2016. 5 | 6 | Aditionally, I have not tested this on any other slurm clusters apart from https://marylou.byu.edu/. 7 | There may be environement variable differences that you may need to look into on your specific cluster. 8 | 9 | ----------- 10 | 11 | Long story short, this script should parse the environment variables that slurm sets up on each node it gives you, and build the ClusterSpec from those variables (the generated string needs to be consistant across all nodes). Furthermore, each computer is responsible for deciding what job/task and task_id, and hopefully that matches up with the cluster_spec string generated. Those were the guiding thoughts I was having while writing it. 12 | 13 | This script will attempt to only assign one parameter server (`ps`) per physical node. (I do not believe it will support more than one., but it's been a while since I wrote the code). From there on out, it will start adding worker nodes until there have been enough worker nodes assigned. 14 | 15 | -------------------------------------------------------------------------------- /slurm_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import hostlist 4 | import tensorflow as tf 5 | import re 6 | 7 | ## it may be useful to know that slurm_nodeid tells you which node you are one (in case there is more than one task on any given node...) 8 | ## Perhaps you could better assign parameter servers be distributed across all nodes before doubleing up on one. 9 | class SlurmClusterManager(): 10 | def __init__(self, num_param_servers=1, num_workers=None, starting_port=None): 11 | 12 | # Check Environment for all needed SLURM varialbes 13 | assert 'SLURM_JOB_NODELIST' in os.environ # SLURM_NODELIST for backwards compatability if needed. 14 | assert 'SLURM_TASKS_PER_NODE' in os.environ 15 | assert 'SLURM_PROCID' in os.environ 16 | assert 'SLURM_NPROCS' in os.environ 17 | assert 'SLURM_NNODES' in os.environ 18 | 19 | # Grab SLURM variables 20 | self.hostnames = hostlist.expand_hostlist(os.environ['SLURM_JOB_NODELIST']) # expands 'NAME1(x2),NAME2' -> 'NAME1,NAME1,NAME2' 21 | self.num_tasks_per_host = self._parse_slurm_tasks_per_node(os.environ['SLURM_TASKS_PER_NODE']) # expands '1,2(x2)' -> '1,2,2' 22 | self.my_proc_id = int(os.environ['SLURM_PROCID']) # index into hostnames/num_tasks_per_host lists 23 | self.num_processes = int(os.environ['SLURM_NPROCS']) 24 | self.nnodes = int(os.environ['SLURM_NNODES']) 25 | 26 | # Sanity check that everything has been parsed correctly 27 | assert len(self.hostnames) == len(self.num_tasks_per_host) 28 | assert len(self.hostnames) == self.nnodes 29 | assert self.num_processes == sum(self.num_tasks_per_host) 30 | 31 | # Numbber of PS/Workers 32 | # Note: I'm making the assumption that having more than one PS/node 33 | # doesn't add any benefit. It makes code simpler in self.build_cluster_spec() 34 | self.num_parameter_servers = min(num_param_servers, len(self.hostnames)) 35 | if num_workers is None: 36 | # Currently I'm not using num_workers' 37 | # TODO What happens to num_workers once I allocate less PS than they requested? 38 | self.num_workers = self.num_processes - self.num_parameter_servers # default to all other nodes doing something 39 | 40 | # Default port to use 41 | if starting_port is not None: 42 | self.starting_port = starting_port # use user specified port 43 | else: 44 | self.starting_port = 2222 45 | 46 | 47 | def build_cluster_spec(self): 48 | # tuples of (str(Hostname:Port), JobName, TaskID) for each process 49 | proc_info = [] 50 | for _ in range(self.num_processes): 51 | proc_info.append([None, None, None]) 52 | 53 | # Assign Port# to each process according to Hostname 54 | # Note: if there are multiple processes on the same hostname, 55 | # each one needs it's own port number, hence the variable name starting_port) 56 | pid = 0 57 | first_pid_per_host = {} # Reverse-Lookup map 58 | for cnt, hostname in zip(self.num_tasks_per_host, self.hostnames): 59 | first_pid_per_host[hostname] = pid 60 | for i in range(cnt): 61 | proc_info[pid][0] = "{}:{}".format(hostname, self.starting_port + i) 62 | pid += 1 63 | 64 | # Assign PSs to different physical hosts 65 | # TODO Maybe sorting by hostnames/task_per_hostname by tasks_per_hostname my increase performance? 66 | # NOTE: this code requires that the num_parameter_servers be less than or equalto the number of indificial physical nodes 67 | ps_strings = [] 68 | for ps_id in range(self.num_parameter_servers): 69 | pid = first_pid_per_host[self.hostnames[ps_id]] 70 | ps_strings.append(proc_info[pid][0]) 71 | proc_info[pid][1] = 'ps' 72 | proc_info[pid][2] = ps_id 73 | 74 | # Assign workers to the remaining open spots 75 | wk_id = 0 76 | wk_strings = [] 77 | for info in proc_info: 78 | if info[1] == None: # It's not a ps 79 | wk_strings.append(info[0]) 80 | info[1] = 'worker' 81 | info[2] = wk_id 82 | wk_id += 1 83 | 84 | # Each processor: Grab your Job/TaskID 85 | job = proc_info[self.my_proc_id][1] 86 | task_id = proc_info[self.my_proc_id][2] 87 | 88 | # Return it all! :D 89 | cluster_spec = tf.train.ClusterSpec({'worker': wk_strings, 'ps': ps_strings}) 90 | return cluster_spec, job, task_id 91 | 92 | def _parse_slurm_tasks_per_node(self, num_tasks_per_nodes): 93 | ''' 94 | SLURM_TASKS_PER_NODE Comes in compressed, so we need to uncompress it: 95 | e.g: if slurm gave us the following setup: 96 | Host 1: 1 process 97 | Host 2: 3 processes 98 | Host 3: 3 processes 99 | Host 4: 4 processes 100 | Then the environment variable SLURM_TASKS_PER_NODE = '1,3(x2),4' 101 | But we need it to become this => [1, 3, 3, 4] 102 | ''' 103 | final_list = [] 104 | num_tasks_per_nodes = num_tasks_per_nodes.split(',') 105 | 106 | for node in num_tasks_per_nodes: 107 | if 'x' in node: # "n(xN)"; n=tasks, N=repeats 108 | n_tasks, n_nodes = [int(n) for n in re.findall('\d+', node)] 109 | final_list += [n_tasks] * n_nodes 110 | else: 111 | final_list.append(int(node)) 112 | return final_list 113 | --------------------------------------------------------------------------------