├── DLNest ├── __init__.py ├── Common │ ├── __init__.py │ ├── ModelBase.py │ ├── ModelBaseTorch.py │ ├── DatasetBase.py │ ├── RunningStatus.py │ ├── RunnerBase.py │ ├── RunnerBaseTorch.py │ └── LifeCycleBase.py ├── Executor │ ├── __init__.py │ ├── AnalyzeProcess.py │ ├── TrainProcess.py │ └── TaskProcess.py ├── Output │ ├── __init__.py │ ├── TrainStdout.py │ ├── DLNestBuffer.py │ ├── AnalyzerBuffer.py │ └── OutputLayerBase.py ├── Plugins │ ├── __init__.py │ ├── Utils │ │ ├── __init__.py │ │ ├── CheckPlugins.py │ │ └── SendMailTools.py │ ├── LogInit.py │ ├── DLNestPluginBase.py │ ├── AutoTensorboardScalar.py │ ├── SimpleCMDVisualize.py │ └── MailsNote.py ├── Scheduler │ ├── __init__.py │ ├── SchedulerStrategyBase.py │ ├── Scheduler.py │ └── DefaultStrategy.py ├── Information │ ├── __init__.py │ ├── test │ │ ├── __init__.py │ │ ├── FakeSubprocess.py │ │ ├── test_AnalyzeTask.py │ │ ├── test_TaskInformation.py │ │ └── test_DeviceInformation.py │ ├── CPUInformation.py │ ├── TaskInformation.py │ ├── GPUInformation.py │ ├── AnalyzeTask.py │ ├── DeviceInformation.py │ ├── TrainTask.py │ └── InfoCenter.py ├── Operations │ ├── __init__.py │ ├── SafeExit.py │ ├── ClearDLNestOutput.py │ ├── DelATask.py │ ├── ChangeDelay.py │ ├── GetTasksInformation.py │ ├── GetDevicesInformation.py │ ├── ChangeMaxTaskPerDevice.py │ ├── ChangeDevices.py │ ├── GetDLNestOutput.py │ ├── AnalyzeIndependent.py │ ├── Analyze.py │ ├── RunExp.py │ ├── RunIndependent.py │ ├── GetAnalyzeOutput.py │ ├── ContinueTrain.py │ ├── Run.py │ ├── GetPluginConfig.py │ ├── New.py │ └── UsePlugin.py ├── SavePackage │ ├── __init__.py │ └── SavePackage.py ├── ShellClient │ ├── __init__.py │ ├── Windows │ │ ├── __init__.py │ │ ├── Utils │ │ │ ├── __init__.py │ │ │ └── Completers.py │ │ ├── SCTextArea.py │ │ ├── CommandInput.py │ │ ├── DevicesInfoShower.py │ │ ├── TaskInfoShower.py │ │ └── ResultsOutput.py │ ├── Client.py │ └── Communicator.py ├── TornadoServer │ └── __init__.py ├── FactoryFiles │ ├── FactoryClean │ │ ├── freq_config.json │ │ ├── dataset_config.json │ │ ├── model_config.json │ │ ├── plugins_config.json │ │ ├── AnalyzeScripts │ │ │ └── test.py │ │ ├── Dataset │ │ │ └── Dataset.py │ │ ├── root_config.json │ │ ├── Model │ │ │ └── Runner.py │ │ └── LifeCycle.py │ └── FactoryMNIST │ │ ├── freq_config.json │ │ ├── AnalyzeScripts │ │ ├── showModel.py │ │ └── getAcc.py │ │ ├── model_config.json │ │ ├── common_config.json │ │ ├── dataset_config.json │ │ ├── root_config.json │ │ ├── plugins_config.json │ │ ├── Model │ │ ├── MNISTCNN.py │ │ └── Runner.py │ │ ├── LifeCycle.py │ │ └── Dataset │ │ └── Dataset.py ├── Client.py ├── pytest.ini ├── Server.py ├── Simple.py ├── Run.py ├── Analyze.py └── Local.py ├── manifest.in ├── setup.py └── .gitignore /DLNest/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/Common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/Executor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/Output/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/Plugins/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/Scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/Information/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/Operations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/Plugins/Utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/SavePackage/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/ShellClient/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/TornadoServer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /manifest.in: -------------------------------------------------------------------------------- 1 | global-exclude *.pyc -------------------------------------------------------------------------------- /DLNest/Information/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/ShellClient/Windows/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/ShellClient/Windows/Utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryClean/freq_config.json: -------------------------------------------------------------------------------- 1 | { 2 | } -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/freq_config.json: -------------------------------------------------------------------------------- 1 | { 2 | } -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryClean/dataset_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_config" : { 3 | } 4 | } -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryClean/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config" : { 3 | } 4 | } -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryClean/plugins_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "plugins" : [], 3 | "plugins_config" : {} 4 | } -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/AnalyzeScripts/showModel.py: -------------------------------------------------------------------------------- 1 | def experience(self): 2 | print(self.runner.model) -------------------------------------------------------------------------------- /DLNest/Client.py: -------------------------------------------------------------------------------- 1 | from DLNest.ShellClient.Client import startClient 2 | 3 | if __name__ == "__main__": 4 | startClient() -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryClean/AnalyzeScripts/test.py: -------------------------------------------------------------------------------- 1 | def experience(self): 2 | print(self.runner,self.dataset,self.args) -------------------------------------------------------------------------------- /DLNest/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | Information : 'Tests for Information module' 4 | 5 | addopts = --strict-markers -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config" : { 3 | "feats" : [1,64,128,1024] 4 | } 5 | } -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/common_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "other_file_paths":[ 3 | "./Model/MNISTCNN.py" 4 | ], 5 | "epochs" : 5 6 | } -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/dataset_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_config" : { 3 | "data_root" : "", 4 | "batch_size" : 64 5 | } 6 | } -------------------------------------------------------------------------------- /DLNest/Common/ModelBase.py: -------------------------------------------------------------------------------- 1 | from DLNest.Common.RunnerBase import RunnerBase 2 | 3 | print("ModelBase is deprecated, please use RunnerBase instead") 4 | ModelBase = RunnerBase -------------------------------------------------------------------------------- /DLNest/Operations/SafeExit.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.InfoCenter import InfoCenter 2 | 3 | def safeExit(): 4 | infoCenter = InfoCenter() 5 | infoCenter.delAllTask() -------------------------------------------------------------------------------- /DLNest/Operations/ClearDLNestOutput.py: -------------------------------------------------------------------------------- 1 | from DLNest.Output.DLNestBuffer import DLNestBuffer 2 | 3 | def clearDLNestOutput(): 4 | buffer = DLNestBuffer() 5 | buffer.clear() -------------------------------------------------------------------------------- /DLNest/Information/test/FakeSubprocess.py: -------------------------------------------------------------------------------- 1 | class FakeSubprocess: 2 | def __init__(self,alive): 3 | self.alive = alive 4 | 5 | def is_alive(self): 6 | return self.alive -------------------------------------------------------------------------------- /DLNest/Operations/DelATask.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.InfoCenter import InfoCenter 2 | 3 | def delATask( 4 | taskID : str 5 | ): 6 | infoCenter = InfoCenter() 7 | infoCenter.delATask(taskID) -------------------------------------------------------------------------------- /DLNest/Common/ModelBaseTorch.py: -------------------------------------------------------------------------------- 1 | from DLNest.Common.RunnerBaseTorch import RunnerBaseTorch 2 | 3 | print("ModelBaseTorch is deprecated, please use RunnerBaseTorch instead") 4 | ModelBaseTorch = RunnerBaseTorch -------------------------------------------------------------------------------- /DLNest/Operations/ChangeDelay.py: -------------------------------------------------------------------------------- 1 | from DLNest.Scheduler.Scheduler import Scheduler 2 | 3 | def changeDelay( 4 | newDelay : int 5 | ): 6 | scheduler = Scheduler() 7 | scheduler.changeTimeDelay(newDelay) -------------------------------------------------------------------------------- /DLNest/Operations/GetTasksInformation.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.InfoCenter import InfoCenter 2 | 3 | def getTasksInformation(): 4 | infoCenter = InfoCenter() 5 | return infoCenter.getTasksInformation() -------------------------------------------------------------------------------- /DLNest/Operations/GetDevicesInformation.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.InfoCenter import InfoCenter 2 | 3 | def getDevicesInformation(): 4 | infoCenter = InfoCenter() 5 | return infoCenter.getAvailableDevicesInformation() -------------------------------------------------------------------------------- /DLNest/Operations/ChangeMaxTaskPerDevice.py: -------------------------------------------------------------------------------- 1 | from DLNest.Scheduler.Scheduler import Scheduler 2 | 3 | def changeMaxTaskPerDevice( 4 | newMax : int 5 | ): 6 | scheduler = Scheduler() 7 | scheduler.changeMaxTaskPerDevice(newMax) -------------------------------------------------------------------------------- /DLNest/Operations/ChangeDevices.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.InfoCenter import InfoCenter 2 | 3 | def changeDevices( 4 | newDevicesIDList : [int] 5 | ): 6 | infoCenter = InfoCenter() 7 | infoCenter.changeDevices(newDevicesIDList) -------------------------------------------------------------------------------- /DLNest/Server.py: -------------------------------------------------------------------------------- 1 | from DLNest.TornadoServer.Server import DLNestServer 2 | 3 | 4 | if __name__ == "__main__": 5 | import sys 6 | if sys.path[0] != '': 7 | sys.path[0] = '' 8 | server = DLNestServer() 9 | server.start() -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryClean/Dataset/Dataset.py: -------------------------------------------------------------------------------- 1 | from DLNest.Common.DatasetBase import DatasetBase 2 | 3 | 4 | class Dataset(DatasetBase): 5 | def init(self,args : dict): 6 | pass 7 | #return {dict to model},self.trainLoader,self.testLoader 8 | -------------------------------------------------------------------------------- /DLNest/Operations/GetDLNestOutput.py: -------------------------------------------------------------------------------- 1 | from DLNest.Output.DLNestBuffer import DLNestBuffer 2 | 3 | def getDLNestOutput( 4 | style : bool = True 5 | ): 6 | buffer = DLNestBuffer() 7 | 8 | if style: 9 | return buffer.getStyledText() 10 | else: 11 | return buffer.getPlainText() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='DLNest', 5 | version='0.3.2', 6 | packages=find_packages(), 7 | include_package_data=True, 8 | zip_safe=False, 9 | install_requires=[ 10 | 'nvidia-ml-py3', 11 | 'APScheduler>=3.6.3', 12 | 'torch', 13 | 'prompt-toolkit>=3.0.7', 14 | 'tornado>=6.0' 15 | ], 16 | package_data={ 17 | "DLNest":[ 18 | "FactoryFiles/*", 19 | "FactoryFiles/*/*", 20 | "FactoryFiles/*/*/*" 21 | ] 22 | }, 23 | python_requires=">=3.6" 24 | ) 25 | -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/AnalyzeScripts/getAcc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def experience(self): 4 | valLoader = self.dataset.testLoader 5 | acc = 0 6 | total = 0 7 | for _iter,data in enumerate(valLoader): 8 | x,y = data 9 | if self.runner._status.env != "CPU": 10 | x,y = x.cuda(),y.cuda() 11 | with torch.no_grad(): 12 | output = self.runner.model(x) 13 | _,pred = torch.max(output,1) 14 | acc += sum(pred == y).item() 15 | total += x.shape[0] 16 | print("correct count:",acc,"total count:",total,"accuracy:",acc / total) 17 | -------------------------------------------------------------------------------- /DLNest/Information/CPUInformation.py: -------------------------------------------------------------------------------- 1 | try: 2 | from DeviceInformation import DeviceInformation 3 | except ImportError: 4 | from .DeviceInformation import DeviceInformation 5 | import psutil 6 | 7 | class CPUInformation(DeviceInformation): 8 | def __init__(self): 9 | super(CPUInformation,self).__init__("CPU") 10 | nowMem = psutil.virtual_memory() 11 | self.totalMemory = nowMem.total / 1024 ** 2 12 | 13 | def getFreeMemory(self): 14 | nowMem = psutil.virtual_memory() 15 | return nowMem.available / 1024 ** 2 16 | 17 | def getDeviceStr(self): 18 | return "cpu" -------------------------------------------------------------------------------- /DLNest/Scheduler/SchedulerStrategyBase.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.InfoCenter import InfoCenter 2 | from DLNest.Information.TaskInformation import TaskInformation 3 | from DLNest.Information.DeviceInformation import DeviceInformation 4 | 5 | from abc import ABCMeta, abstractmethod 6 | 7 | class SchedulerStrategyBase: 8 | def __init__(self): 9 | """ 10 | WARNING: before a run requirement send to scheduler, infoCenter.deviceLock should be released. 11 | """ 12 | self.infoCenter = InfoCenter() 13 | self.scheduler = None 14 | 15 | @abstractmethod 16 | def decide(self, maxTaskPerDevice : int,maxTimeDelay : int): 17 | pass -------------------------------------------------------------------------------- /DLNest/Operations/AnalyzeIndependent.py: -------------------------------------------------------------------------------- 1 | from DLNest.Executor.AnalyzeProcess import AnalyzeProcess 2 | from DLNest.Information.AnalyzeTask import AnalyzeTask 3 | 4 | def AnalyzeIndependent( 5 | recordPath : str, 6 | expFunc = None, 7 | scriptPath : str = "", 8 | checkpointID : int = -1, 9 | devices : list = [] 10 | ): 11 | task = AnalyzeTask( 12 | recordPath = recordPath, 13 | scriptPath = scriptPath, 14 | checkpointID = checkpointID, 15 | devices = devices 16 | ) 17 | try: 18 | AProcess = AnalyzeProcess(task,expFunc=expFunc) 19 | AProcess.run() 20 | except KeyboardInterrupt: 21 | exit(0) 22 | 23 | -------------------------------------------------------------------------------- /DLNest/Operations/Analyze.py: -------------------------------------------------------------------------------- 1 | from DLNest.Scheduler.Scheduler import Scheduler 2 | from DLNest.Information.AnalyzeTask import AnalyzeTask 3 | 4 | def analyze( 5 | recordPath : str, 6 | scriptPath : str = "", 7 | checkpointID : int = -1, 8 | CPU : bool = False, 9 | memoryConsumption : int = -1, 10 | otherArgs : dict = {} 11 | ): 12 | devices = [] 13 | if CPU: 14 | devices = [-1] 15 | task = AnalyzeTask( 16 | recordPath = recordPath, 17 | scriptPath = scriptPath, 18 | checkpointID = checkpointID, 19 | memoryConsumption = memoryConsumption, 20 | devices = devices 21 | ) 22 | 23 | scheduler = Scheduler() 24 | scheduler.giveATask(task,otherArgs = otherArgs) -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryClean/root_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_root":"./Saves", 3 | "runner_name":"Runner", 4 | "dataset_name":"Dataset", 5 | "life_cycle_name":"LifeCycle", 6 | "checkpoint_args":{ 7 | "max_ckpt_in_slow_track":0, 8 | "dilation_in_slow_track":100, 9 | "max_ckpt_in_fast_track":1, 10 | "max_ckpt_in_consistent_track":1 11 | }, 12 | "root_file_path":"", 13 | "runner_file_path":"./Model/Runner.py", 14 | "dataset_file_path":"./Dataset/Dataset.py", 15 | "life_cycle_file_path":"./LifeCycle.py", 16 | "other_file_paths":[], 17 | "child_jsons":[ 18 | "./model_config.json", 19 | "./dataset_config.json", 20 | "./plugins_config.json" 21 | ] 22 | } -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryClean/Model/Runner.py: -------------------------------------------------------------------------------- 1 | from DLNest.Common.RunnerBase import RunnerBase 2 | 3 | class Runner(RunnerBase): 4 | def init(self,args : dict,datasetInfo : dict = None): 5 | #Init models 6 | pass 7 | 8 | def initLog(self): 9 | return {} 10 | 11 | def initOptimizer(self): 12 | # init optimizers 13 | pass 14 | 15 | def runOneStep(self,data, log : dict, iter : int, epoch : int): 16 | pass 17 | 18 | def visualize(self,epoch : int, iter : int, log : dict): 19 | pass 20 | 21 | def validationInit(self): 22 | pass 23 | 24 | def validateABatch(self,data, iter : int): 25 | pass 26 | 27 | def validationAnalyze(self, log : dict): 28 | pass -------------------------------------------------------------------------------- /DLNest/Output/TrainStdout.py: -------------------------------------------------------------------------------- 1 | class TrainStdout: 2 | def __init__(self,fp,showOnScreen = False,originalStdout = None): 3 | super(TrainStdout,self).__init__() 4 | self.stdout = fp 5 | self.screenout = originalStdout 6 | self.showOnScreen = showOnScreen 7 | 8 | def write(self,s): 9 | ret = self.stdout.write(s) 10 | self.stdout.flush() 11 | if self.showOnScreen: 12 | self.screenout.write(s) 13 | return ret 14 | 15 | def flush(self): 16 | if self.showOnScreen: 17 | self.screenout.flush() 18 | return self.stdout.flush() 19 | 20 | def isatty(self): 21 | return self.stdout.isatty() 22 | 23 | def fileno(self): 24 | return self.stdout.fileno() -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/root_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_root":"./Saves", 3 | "runner_name":"Runner", 4 | "dataset_name":"MNISTDataset", 5 | "life_cycle_name":"LifeCycle", 6 | "checkpoint_args":{ 7 | "max_ckpt_in_slow_track":0, 8 | "dilation_in_slow_track":100, 9 | "max_ckpt_in_fast_track":1, 10 | "max_ckpt_in_consistent_track":1 11 | }, 12 | "root_file_path":"", 13 | "runner_file_path":"./Model/Runner.py", 14 | "dataset_file_path":"./Dataset/Dataset.py", 15 | "life_cycle_file_path":"./LifeCycle.py", 16 | "other_file_paths":[], 17 | "child_jsons":[ 18 | "./common_config.json", 19 | "./model_config.json", 20 | "./dataset_config.json", 21 | "./plugins_config.json" 22 | ] 23 | } -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/plugins_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "plugins" : [ 3 | "LogInit", 4 | "AutoTensorboardScalar", 5 | "SimpleCMDVisualize", 6 | "MailsNote" 7 | ], 8 | "plugins_config" : { 9 | "SimpleCMDVisualize" : { 10 | "stride" : 10, 11 | "keys" : ["loss","acc"], 12 | "format" : { 13 | "loss" : "\t| {} : {:.6f}", 14 | "acc" : "\t| {} : {:.6f}" 15 | } 16 | }, 17 | "LogInit" : { 18 | "level" : "INFO" 19 | }, 20 | "MailsNote" : { 21 | "enable_list" : ["Aborting","TrainFinish"], 22 | "username" : "", 23 | "password" : "", 24 | "host" : "mail.fudan.edu.cn", 25 | "port" : 25 26 | } 27 | } 28 | } -------------------------------------------------------------------------------- /DLNest/Operations/RunExp.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.InfoCenter import InfoCenter 2 | from DLNest.Information.AnalyzeTask import AnalyzeTask 3 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer 4 | 5 | def runExp( 6 | taskID : str, 7 | command : str 8 | ): 9 | infoCenter = InfoCenter() 10 | analyzeTask : AnalyzeTask = infoCenter.getTaskByID(taskID) 11 | 12 | assert analyzeTask != None # Wrong task 13 | assert isinstance(analyzeTask,AnalyzeTask) # Not an analyze task 14 | assert isinstance(analyzeTask.outputBuffer,AnalyzerBuffer) # don't have analyze buffer 15 | assert analyzeTask.process != None # not running 16 | assert analyzeTask.process.is_alive() # not alive 17 | assert analyzeTask.commandQueue != None # don't have command queue 18 | 19 | analyzeTask.commandQueue.put(command,block = False) -------------------------------------------------------------------------------- /DLNest/Operations/RunIndependent.py: -------------------------------------------------------------------------------- 1 | from DLNest.Executor.TrainProcess import TrainProcess 2 | from DLNest.Information.TrainTask import TrainTask 3 | 4 | def runIndependent( 5 | configPath : str, 6 | freqPath : str = "", 7 | devices : [int] = [-1], 8 | description : str = "", 9 | DDP : bool = False, 10 | noSave : bool = False, 11 | useDescriptionToSave : bool = False, 12 | showOnScreen : bool = True 13 | ): 14 | task = TrainTask.fromConfigFile( 15 | configPath = configPath, 16 | freqPath = freqPath, 17 | devices = devices, 18 | description = description, 19 | DDP = DDP, 20 | noSave = noSave, 21 | useDescriptionToSave = useDescriptionToSave 22 | ) 23 | try: 24 | TProcess = TrainProcess(task,showOnScreen) 25 | TProcess.run() 26 | except KeyboardInterrupt: 27 | exit(0) -------------------------------------------------------------------------------- /DLNest/Operations/GetAnalyzeOutput.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.InfoCenter import InfoCenter 2 | from DLNest.Information.AnalyzeTask import AnalyzeTask 3 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer 4 | 5 | def getAnalyzeOutput( 6 | taskID : str, 7 | style : bool = True 8 | ): 9 | infoCenter = InfoCenter() 10 | analyzeTask : AnalyzeTask = infoCenter.getTaskByID(taskID) 11 | 12 | assert analyzeTask != None # Wrong task 13 | assert isinstance(analyzeTask,AnalyzeTask) # Not an analyze task 14 | assert isinstance(analyzeTask.outputBuffer,AnalyzerBuffer) # don't have analyze buffer 15 | assert analyzeTask.process != None # not running 16 | assert analyzeTask.process.is_alive() # not alive 17 | 18 | if style: 19 | return analyzeTask.outputBuffer.getStyledText() 20 | else: 21 | return analyzeTask.outputBuffer.getPlainText() -------------------------------------------------------------------------------- /DLNest/Operations/ContinueTrain.py: -------------------------------------------------------------------------------- 1 | from DLNest.Scheduler.Scheduler import Scheduler 2 | from DLNest.Information.TrainTask import TrainTask 3 | 4 | def continueTrain( 5 | recordPath : str, 6 | checkpointID : int = -1, 7 | memoryConsumption : int = -1, 8 | CPU : bool = False, 9 | DDP : bool = False, 10 | multiGPU : bool = False, 11 | description : str = "", 12 | otherArgs : dict = {} 13 | ): 14 | devices = [] 15 | if CPU: 16 | devices = [-1] 17 | 18 | task = TrainTask.fromRecord( 19 | recordPath = recordPath, 20 | checkpointID = checkpointID, 21 | devices = devices, 22 | memoryConsumption = memoryConsumption, 23 | multiGPU = multiGPU, 24 | DDP = DDP, 25 | description = description 26 | ) 27 | 28 | scheduler = Scheduler() 29 | scheduler.giveATask(task,otherArgs = otherArgs) -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryClean/LifeCycle.py: -------------------------------------------------------------------------------- 1 | from DLNest.Common.DatasetBase import DatasetBase 2 | from DLNest.Common.RunnerBase import RunnerBase 3 | from DLNest.Common.LifeCycleBase import LifeCycleBase 4 | 5 | 6 | class LifeCycle(LifeCycleBase): 7 | def needVisualize(self, epoch : int, iter : int, logdict : dict, args : dict): 8 | return False 9 | 10 | def needValidation(self, epoch : int, logdict : dict, args : dict): 11 | return False 12 | 13 | def commandLineOutput(self,epoch : int, logdict : dict, args : dict): 14 | print("Epoch #" + str(epoch) + " finished!") 15 | 16 | def needSaveModel(self, epoch : int, logdict : dict, args : dict): 17 | return False 18 | 19 | def holdThisCheckpoint(self, epoch : int, logdict : dict, args : dict): 20 | return False 21 | 22 | def needContinueTrain(self, epoch : int, logdict : dict, args : dict): 23 | return False -------------------------------------------------------------------------------- /DLNest/Plugins/LogInit.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase as DPB 3 | 4 | class DLNestPlugin(DPB): 5 | _NAME = "LogInit" 6 | _config = { 7 | "level" : "INFO", 8 | "format" : "[%(asctime)s][%(levelname)s] %(message)s", 9 | "datefmt" : "%Y-%m-%d %H:%M:%S" 10 | } 11 | _defaultKeys = ["level"] 12 | def BAll(self): 13 | import sys 14 | log = logging.getLogger() 15 | while len(log.handlers) > 0: 16 | log.removeHandler(log.handlers[0]) 17 | 18 | args = self.taskProcess.task.args 19 | pluginName = "LogInit" 20 | level = DPB.getArgs(self, pluginName, "level", "INFO") 21 | format = DPB.getArgs(self, pluginName, "format", "[%(asctime)s][%(levelname)s] %(message)s") 22 | datefmt= DPB.getArgs(self, pluginName, "datefmt", "%Y-%m-%d %H:%M:%S") 23 | logging.basicConfig(level = level, format = format,datefmt = datefmt) -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/Model/MNISTCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MNISTModel(nn.Module): 5 | def __init__(self,args : dict): 6 | super(MNISTModel, self).__init__() 7 | feats = args["model_config"]["feats"] 8 | self.conv1 = nn.Sequential(nn.Conv2d(feats[0],feats[1],kernel_size=3,stride=1,padding=1), 9 | nn.ReLU(), 10 | nn.Conv2d(feats[1],feats[2],kernel_size=3,stride=1,padding=1), 11 | nn.ReLU(), 12 | nn.MaxPool2d(stride=2,kernel_size=2)) 13 | self.dense = nn.Sequential(nn.Linear(14*14*feats[2],feats[3]), 14 | nn.ReLU(), 15 | nn.Dropout(p=0.5), 16 | nn.Linear(1024, 10)) 17 | def forward(self, x): 18 | x = self.conv1(x) 19 | x = x.view(-1, 14*14*128) 20 | x = self.dense(x) 21 | return x -------------------------------------------------------------------------------- /DLNest/Operations/Run.py: -------------------------------------------------------------------------------- 1 | from DLNest.Scheduler.Scheduler import Scheduler 2 | from DLNest.Information.TrainTask import TrainTask 3 | 4 | def run( 5 | configPath : str, 6 | freqPath : str = "", 7 | description : str = "", 8 | memoryConsumption : int = -1, 9 | CPU : bool = False, 10 | DDP : bool = False, 11 | multiGPU : bool = False, 12 | noSave : bool = False, 13 | useDescriptionToSave : bool = False, 14 | otherArgs : dict = {} 15 | ): 16 | scheduler = Scheduler() 17 | 18 | devices = [] 19 | if CPU: 20 | devices = [-1] 21 | 22 | task = TrainTask.fromConfigFile( 23 | configPath = configPath, 24 | freqPath = freqPath, 25 | description = description, 26 | devices = devices, 27 | memoryConsumption = memoryConsumption, 28 | multiGPU = multiGPU, 29 | DDP = DDP, 30 | noSave = noSave, 31 | useDescriptionToSave = useDescriptionToSave 32 | ) 33 | 34 | scheduler.giveATask(task,otherArgs = otherArgs) -------------------------------------------------------------------------------- /DLNest/Plugins/DLNestPluginBase.py: -------------------------------------------------------------------------------- 1 | 2 | class DLNestPluginBase: 3 | _NANE = "DLNestPluginBase" 4 | _config = {} 5 | _defaultKeys = [] 6 | 7 | @classmethod 8 | def getName(cls): 9 | return cls._NAME 10 | 11 | @classmethod 12 | def getDefaultConfig(cls): 13 | ret = {} 14 | for key in cls._defaultKeys: 15 | ret[key] = cls._config[key] 16 | return ret 17 | 18 | @classmethod 19 | def getFullConfig(cls): 20 | return cls._config 21 | 22 | @classmethod 23 | def getArgs(cls, self, pluginName : str, name : str, default : any): 24 | args = self.getArgs() 25 | if "plugins_config" in args: 26 | pArgs = args["plugins_config"] 27 | if pluginName in pArgs: 28 | if name in pArgs[pluginName]: 29 | return pArgs[pluginName][name] 30 | else: 31 | return default 32 | else: 33 | return default 34 | else: 35 | return default -------------------------------------------------------------------------------- /DLNest/Plugins/Utils/CheckPlugins.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import logging 3 | 4 | def checkPlugins(func): 5 | @wraps(func) 6 | def checkAndRun(*args, **kwargs): 7 | name = func.__name__ 8 | for plugin in args[0]._plugins: 9 | if name[1:] in dir(plugin): 10 | try: 11 | getattr(plugin,name[1:])(*args, **kwargs) 12 | except Exception as e: 13 | logging.debug(str(e)) 14 | return func(*args, **kwargs) 15 | 16 | return checkAndRun 17 | 18 | def checkDictOutputPlugins(func): 19 | @wraps(func) 20 | def checkAndRun(*args, **kwargs): 21 | name = func.__name__ 22 | ret = {} 23 | for plugin in args[0]._plugins: 24 | if name[1:] in dir(plugin): 25 | try: 26 | ret.update(getattr(plugin,name[1:])(*args, **kwargs)) 27 | except Exception as e: 28 | logging.debug(str(e)) 29 | ret.update(func(*args, **kwargs)) 30 | return ret 31 | 32 | return checkAndRun -------------------------------------------------------------------------------- /DLNest/Simple.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import argparse 3 | from prompt_toolkit import PromptSession,HTML 4 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 5 | import traceback 6 | import json 7 | from DLNest.ShellClient.Communicator import Communicator 8 | 9 | 10 | class DLNestSimple: 11 | def __init__(self, url : str = "127.0.0.1", port : int = 9999): 12 | self.communicator = Communicator(url = url, port = port) 13 | def run(self): 14 | self.session = PromptSession(auto_suggest = AutoSuggestFromHistory()) 15 | while True: 16 | try: 17 | command = self.session.prompt(HTML("DLNest>>")) 18 | commandWordList = command.strip().split(' ') 19 | output = self.communicator.giveACommand(commandWordList) 20 | print(output) 21 | if "exit" in output: 22 | exit(0) 23 | except KeyboardInterrupt: 24 | exit(0) 25 | except Exception as e: 26 | s = traceback.format_exc() 27 | listS = s.split("\n")[:-1] 28 | s = "\n".join(listS[-3:]) 29 | print(s) 30 | 31 | if __name__ == "__main__": 32 | main = DLNestSimple() 33 | main.run() -------------------------------------------------------------------------------- /DLNest/Plugins/AutoTensorboardScalar.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import logging 3 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase as DPB 4 | 5 | class DLNestPlugin(DPB): 6 | _NAME = "AutoTensorboardScalar" 7 | _config = {} 8 | _defaultKeys = [] 9 | def runnerInit(self,args : dict, datasetInfo : dict = None): 10 | if self._status.rank == -1 or self._status.rank == 0: 11 | self.writer = SummaryWriter(".") 12 | logging.debug("[AutoTensorboardScalar] Finish modelInit") 13 | 14 | def visualize(self, log : dict, epoch : int, iter : int): 15 | if self._status.rank == -1 or self._status.rank == 0: 16 | if not "_lastLen" in dir(self): 17 | self._lastLen = {key : 0 for key in log} 18 | 19 | for key in log: 20 | if isinstance(log[key], list) and len(log[key]) != self._lastLen[key]: 21 | try: 22 | self.writer.add_scalar(key,log[key][-1],len(log[key]) - 1) 23 | self._lastLen[key] = len(log[key]) 24 | except Exception as e: 25 | logging.debug("[AutoTensorboardScalar]" + str(e)) 26 | logging.debug("[AutoTensorboardScalar] Finish visualize") -------------------------------------------------------------------------------- /DLNest/Information/test/test_AnalyzeTask.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from DLNest.Information.AnalyzeTask import AnalyzeTask 3 | import time 4 | from pathlib import Path 5 | 6 | @pytest.mark.Information 7 | class TestAnalyzeTask: 8 | def test_init(self): 9 | AT = AnalyzeTask( 10 | recordPath = "/root/code/DLNestTest/Saves/Some Description", 11 | scriptPath = "scriptPath", 12 | checkpointID = 233 13 | ) 14 | assert isinstance(AT.recordPath,Path) 15 | assert isinstance(AT.scriptPath,Path) 16 | assert str(AT.recordPath) == "/root/code/DLNestTest/Saves/Some Description" 17 | assert str(AT.scriptPath) == "scriptPath" 18 | assert AT.checkpointID == 233 19 | assert AT.type == "Analyze" 20 | assert AT.commandPipe == None 21 | assert AT.outputQueue == None 22 | 23 | def test_getDict(self): 24 | AT = AnalyzeTask( 25 | recordPath = "/root/code/DLNestTest/Saves/Some Description", 26 | scriptPath = "scriptPath", 27 | checkpointID = 233 28 | ) 29 | retDict = AT.getDict() 30 | assert retDict["record_path"] == "/root/code/DLNestTest/Saves/Some Description" 31 | assert retDict["script_path"] == "scriptPath" 32 | assert retDict["checkpoint_ID"] == 233 -------------------------------------------------------------------------------- /DLNest/Operations/GetPluginConfig.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import importlib 4 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase 5 | 6 | def _loadAModule(filePath : Path,name : str): 7 | spec = importlib.util.spec_from_file_location( 8 | name, 9 | filePath 10 | ) 11 | module = importlib.util.module_from_spec(spec) 12 | dirName = str(filePath.parent) 13 | if not dirName in sys.path: 14 | sys.path.append(dirName) 15 | spec.loader.exec_module(module) 16 | return module 17 | 18 | def _loadAPlugin(pluginName : str): 19 | pluginPath = Path(__file__).parent.parent / "Plugins" / (pluginName + '.py') 20 | tmpPath = Path(pluginName) 21 | if tmpPath.is_absolute(): 22 | pluginPath = tmpPath 23 | pluginName = tmpPath.name 24 | pluginModule = _loadAModule(filePath = pluginPath,name = pluginName) 25 | pluginClass = pluginModule.__getattribute__("DLNestPlugin") 26 | return pluginClass 27 | 28 | def getPluginConfig(pluginName, full = False): 29 | try: 30 | pluginClass : DLNestPluginBase = _loadAPlugin(pluginName) 31 | config = {} 32 | name = pluginClass.getName() 33 | if full: 34 | config[name] = pluginClass.getFullConfig() 35 | else: 36 | config[name] = pluginClass.getDefaultConfig() 37 | 38 | return name,config 39 | except Exception: 40 | return None, None -------------------------------------------------------------------------------- /DLNest/ShellClient/Windows/SCTextArea.py: -------------------------------------------------------------------------------- 1 | from apscheduler.schedulers.background import BackgroundScheduler 2 | from typing import Callable, Iterable, List, Optional 3 | 4 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box 5 | from prompt_toolkit.buffer import Buffer 6 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 7 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl 8 | from prompt_toolkit.key_binding import KeyBindings 9 | from prompt_toolkit.document import Document 10 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit 11 | from prompt_toolkit.lexers import Lexer 12 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples 13 | 14 | class SCTextArea(TextArea): 15 | def __init__(self,lexer,wrap_lines = True): 16 | super(SCTextArea,self).__init__( 17 | lexer = lexer, 18 | read_only=True, 19 | focusable=True, 20 | scrollbar=True, 21 | wrap_lines=wrap_lines 22 | ) 23 | 24 | @property 25 | def text(self) -> str: 26 | """ 27 | The `Buffer` text. 28 | """ 29 | return self.buffer.text 30 | 31 | @text.setter 32 | def text(self, value: str) -> None: 33 | oldPos = self.document.cursor_position 34 | if len(self.document.text) == oldPos: 35 | self.document = Document(value)#, 0) 36 | else: 37 | self.document = Document(value,oldPos) -------------------------------------------------------------------------------- /DLNest/Information/test/test_TaskInformation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from DLNest.Information.TaskInformation import TaskInformation 3 | import time 4 | 5 | @pytest.mark.Information 6 | class TestTaskInformation: 7 | def test_init(self): 8 | Task = TaskInformation( 9 | devices=[1,2,3], 10 | memoryConsumption = 1234, 11 | multiGPU = True, 12 | DDP = True 13 | ) 14 | assert Task.DDP == True 15 | assert Task.devices == [1,2,3] 16 | assert Task.multiGPU == True 17 | assert Task.memoryConsumption == 1234 18 | assert abs(Task.startTime - time.time()) < 1.0 19 | assert Task.type == "None" 20 | assert Task.status == "Pending" 21 | assert Task.process == None 22 | 23 | def test_getDict(self): 24 | Task = TaskInformation( 25 | devices=[1,2,3], 26 | memoryConsumption = 1234, 27 | multiGPU = True, 28 | DDP = True 29 | ) 30 | targetDict = { 31 | "devices" : [1,2,3], 32 | "memory_consumption" : 1234, 33 | "multi_GPU" : True, 34 | "DDP" : True, 35 | "type" : "None", 36 | "status" : "Pending" 37 | } 38 | retDict = Task.getDict() 39 | for key in retDict: 40 | assert key in targetDict or key == "ID" 41 | if key in targetDict: 42 | assert targetDict[key] == retDict[key] 43 | 44 | -------------------------------------------------------------------------------- /DLNest/Common/DatasetBase.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | except ImportError: 4 | pass 5 | from functools import wraps 6 | import logging 7 | from DLNest.Plugins.Utils.CheckPlugins import checkPlugins,checkDictOutputPlugins 8 | from DLNest.Common.RunningStatus import RunningStatus 9 | 10 | class DatasetBase: 11 | def __init__(self, args : dict = {}, plugins : list = [], status = RunningStatus()): 12 | self._args = args 13 | self._plugins = plugins 14 | self._status = status 15 | 16 | def getArgs(self): 17 | return self._args 18 | 19 | @checkPlugins 20 | def _datasetInit(self, args : dict): 21 | return self.init(args = args) 22 | 23 | def init(self,args : dict): 24 | """ 25 | input: 26 | args : dict 27 | output: 28 | dict to runner 29 | train loader 30 | val loader 31 | """ 32 | return {},None,None 33 | 34 | def getSampler(self, dataset): 35 | if self._status.env == "DDP": 36 | return torch.utils.data.distributed.DistributedSampler(dataset) 37 | else: 38 | return None 39 | 40 | @checkDictOutputPlugins 41 | def _getSaveDict(self): 42 | return self.getSaveDict() 43 | 44 | def getSaveDict(self): 45 | return {} 46 | 47 | @checkPlugins 48 | def _loadSaveDict(self, saveDict): 49 | return self.loadSaveDict(saveDict = saveDict) 50 | 51 | def loadSaveDict(self,saveDict): 52 | pass -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/LifeCycle.py: -------------------------------------------------------------------------------- 1 | from DLNest.Common.DatasetBase import DatasetBase 2 | from DLNest.Common.RunnerBase import RunnerBase 3 | from DLNest.Common.LifeCycleBase import LifeCycleBase 4 | from DLNest.Plugins.MailsNote import DLNestPlugin as MailsNotePlugin 5 | 6 | 7 | class LifeCycle(LifeCycleBase): 8 | def BAll(self): 9 | self.maxAcc = 0.0 10 | 11 | def needVisualize(self, epoch : int, iter : int, logdict : dict, args : dict): 12 | return True 13 | 14 | def needValidation(self, epoch : int, logdict : dict, args : dict): 15 | return True 16 | 17 | def commandLineOutput(self,epoch : int, logdict : dict, args : dict): 18 | print("Epoch #" + str(epoch) + " finished!") 19 | 20 | def needSaveModel(self, epoch : int, logdict : dict, args : dict): 21 | return True 22 | 23 | def holdThisCheckpoint(self, epoch : int, logdict : dict, args : dict): 24 | if logdict["acc"][-1] > self.maxAcc: 25 | self.maxAcc = logdict["acc"][-1] 26 | MailsNotePlugin.giveResultValues({'acc' : self.maxAcc}) 27 | return True 28 | return False 29 | 30 | def getSaveDict(self): 31 | return { 32 | "max_acc" : self.maxAcc 33 | } 34 | 35 | def AOneEpoch(self): 36 | # step the optimizer after every epoch 37 | self.runner.optimizer.step() 38 | 39 | def loadSaveDict(self,saveDict): 40 | self.maxAcc = saveDict["max_acc"] 41 | 42 | def needContinueTrain(self, epoch : int, logdict : dict, args : dict): 43 | if epoch >= args["epochs"]: 44 | return False 45 | return True -------------------------------------------------------------------------------- /DLNest/Common/RunningStatus.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | DLNestStatus = Enum("DLNestStatus", ("Training","Validating","Analyzing","Waiting")) 4 | 5 | class RunningStatus: 6 | def __init__(self): 7 | self.__epoch = 0 8 | self.__iter = 0 9 | self.__status = DLNestStatus.Waiting 10 | self.__env = "CPU" 11 | self.__rank = -1 12 | self.__worldSize = 0 13 | 14 | @property 15 | def epoch(self): 16 | return self.__epoch 17 | 18 | @property 19 | def iter(self): 20 | return self.__iter 21 | 22 | @property 23 | def status(self): 24 | return self.__status 25 | 26 | @property 27 | def env(self): 28 | return self.__env 29 | 30 | @property 31 | def rank(self): 32 | return self.__rank 33 | 34 | @property 35 | def worldSize(self): 36 | return self.__worldSize 37 | 38 | def isTraining(self): 39 | return self.status == DLNestStatus.Training 40 | 41 | def isValidating(self): 42 | return self.status == DLNestStatus.Validating 43 | 44 | def isAnalyzing(self): 45 | return self.status == DLNestStatus.Analyzing 46 | 47 | def isWaiting(self): 48 | return self.status == DLNestStatus.Waiting 49 | 50 | def startTraining(self): 51 | self.__status = DLNestStatus.Training 52 | 53 | def startValidating(self): 54 | self.__status = DLNestStatus.Validating 55 | 56 | def startAnalyzing(self): 57 | self.__status = DLNestStatus.Analyzing 58 | 59 | def startWaiting(self): 60 | self.__status = DLNestStatus.Waiting -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/Dataset/Dataset.py: -------------------------------------------------------------------------------- 1 | from DLNest.Common.DatasetBase import DatasetBase 2 | 3 | import torchvision 4 | import torch 5 | from torch.utils import data 6 | from torchvision import datasets, transforms 7 | 8 | class MNISTDataset(DatasetBase): 9 | def init(self,args : dict): 10 | transform = transforms.Compose([transforms.ToTensor(), 11 | transforms.Normalize(mean=[0.5],std=[0.5])]) 12 | self.dataTrain = datasets.MNIST(root = args["dataset_config"]["data_root"], 13 | transform=transform, 14 | train = True, 15 | download = True) 16 | 17 | self.dataTest = datasets.MNIST(root = args["dataset_config"]["data_root"], 18 | transform = transform, 19 | train = False) 20 | 21 | trainSampler = self.getSampler(self.dataTrain) 22 | testSampler = self.getSampler(self.dataTest) 23 | if trainSampler is None: # DDP 24 | self.trainLoader = data.DataLoader(self.dataTrain,batch_size = args["dataset_config"]["batch_size"],shuffle = True) 25 | self.testLoader = data.DataLoader(self.dataTest,batch_size = args["dataset_config"]["batch_size"],shuffle = False) 26 | else: 27 | self.trainLoader = data.DataLoader(self.dataTrain,batch_size = args["dataset_config"]["batch_size"],sampler=trainSampler) 28 | self.testLoader = data.DataLoader(self.dataTest,batch_size = args["dataset_config"]["batch_size"],sampler=testSampler) 29 | 30 | return {},self.trainLoader,self.testLoader 31 | -------------------------------------------------------------------------------- /DLNest/Plugins/Utils/SendMailTools.py: -------------------------------------------------------------------------------- 1 | import smtplib 2 | from email.mime.text import MIMEText 3 | from email.header import Header 4 | import logging 5 | 6 | _fromTitle = "DLNest Mail Plugin" 7 | _toTitle = "DLNest User" 8 | def sendSelfMail(username : str, password : str, message : str, subject : str, host : str, port : int = 25, fromTitle : str = _fromTitle, toTitle : str = _toTitle): 9 | message = MIMEText(message,'plain', 'utf-8') 10 | message["From"] = Header(fromTitle, 'utf-8') 11 | message["To"] = Header(toTitle, 'utf-8') 12 | message["subject"] = Header(subject,'utf-8') 13 | try: 14 | smtpObj = smtplib.SMTP(host = host, port = port) 15 | smtpObj.login(username, password) 16 | smtpObj.sendmail(username, [username], message.as_string()) 17 | logging.info("[SendSelfMail] Send mail successfully to {}".format(username)) 18 | except Exception as e: 19 | logging.debug("[SendSelfMail]" + str(e)) 20 | 21 | def sendMail(username : str, password : str, toName : str, message : str, subject : str, host : str, port : int = 25, fromTitle : str = _fromTitle, toTitle : str = _toTitle): 22 | message = MIMEText(message,'plain', 'utf-8') 23 | message["From"] = Header(fromTitle, 'utf-8') 24 | message["To"] = Header(toTitle, 'utf-8') 25 | message["subject"] = Header(subject,'utf-8') 26 | 27 | receiver = toName if isinstance(toName, list) else [toName] 28 | 29 | try: 30 | smtpObj = smtplib.SMTP(host = host, port = port) 31 | smtpObj.login(username, password) 32 | smtpObj.sendmail(username, receiver, message.as_string()) 33 | logging.info("[SendMail] Send mail successfully to {}".format(username)) 34 | except Exception as e: 35 | logging.debug("[SendMail]" + str(e)) -------------------------------------------------------------------------------- /DLNest/Output/DLNestBuffer.py: -------------------------------------------------------------------------------- 1 | from DLNest.Output.OutputLayerBase import OutputLayerBase 2 | 3 | class Singleton(object): 4 | def __new__(cls, *args, **kwargs): 5 | if not hasattr(cls, '_instance'): 6 | orig = super(Singleton, cls) 7 | cls._instance = orig.__new__(cls, *args, **kwargs) 8 | return cls._instance 9 | 10 | class DLNestBuffer(OutputLayerBase,Singleton): 11 | def __init__(self): 12 | if hasattr(self,"styledText"): 13 | return 14 | OutputLayerBase.__init__(self) 15 | self.styleDict = { 16 | '' : '#efefef', 17 | 'app' : '#ff0f9f bold', 18 | 'ignored':'#ff7070', 19 | 'error' : '#ff0000', 20 | 'message' : '#efefef', 21 | } 22 | self.appName = "DLNest" 23 | 24 | def logMessage(self,message : str): 25 | self.putStyledText('app',"["+ self.appName +"] ") 26 | self.putStyledText('message',message) 27 | self.putStyledText('message','\n') 28 | 29 | def logIgnError(self,message : str): 30 | self.putStyledText('app',"["+ self.appName +"] ") 31 | self.putStyledText('ignored' , "[Ignored] ") 32 | self.putStyledText('message',message) 33 | self.putStyledText('message','\n') 34 | 35 | def logError(self,message : str): 36 | self.putStyledText('app',"["+ self.appName +"] ") 37 | self.putStyledText('error' , "[ERROR] ") 38 | self.putStyledText('message',message) 39 | self.putStyledText('message','\n') 40 | 41 | def logDebug(self,message : str): 42 | self.putStyledText('app',"["+ self.appName +"] ") 43 | self.putStyledText('ignored' , "[DEBUG] ") 44 | self.putStyledText('message',message) 45 | self.putStyledText('message','\n') 46 | -------------------------------------------------------------------------------- /DLNest/Plugins/SimpleCMDVisualize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase as DPB 3 | 4 | class DLNestPlugin(DPB): 5 | _NAME = "SimpleCMDVisualize" 6 | _config = { 7 | "stride" : 1, 8 | "keys" : [], 9 | "format" : {} 10 | } 11 | _defaultKeys = ["stride", "keys", "format"] 12 | 13 | def runnerInit(self,args : dict, datasetInfo : dict = None): 14 | pluginName = "SimpleCMDVisualize" 15 | self._visualStride = DPB.getArgs(self, pluginName, "stride", 1) 16 | self._visualKeys = DPB.getArgs(self, pluginName, "keys" , []) 17 | self._visualFormat = DPB.getArgs(self, pluginName, "format", {}) 18 | fmtDict = {key : "\t| {}: {}" for key in self._visualKeys} 19 | fmtDict.update(self._visualFormat) 20 | self._visualFormat = fmtDict 21 | 22 | def visualize(self, epoch : int, iter : int, log : dict): 23 | try: 24 | if iter % self._visualStride == 0: 25 | infoStr = "" 26 | for key in self._visualKeys: 27 | if len(log[key]) > 0: 28 | try: 29 | infoStr = infoStr + self._visualFormat[key].format(key, log[key][-1]) 30 | except Exception as e: 31 | logging.debug("[SimpleCMDVisualize]" + str(e)) 32 | else: 33 | try: 34 | infoStr = infoStr + self._visualFormat[key].format(key, None) 35 | except Exception as e: 36 | logging.debug("[SimpleCMDVisualize]" + str(e)) 37 | logging.info("Iter : " + str(iter) + infoStr) 38 | except Exception as e: 39 | logging.debug("[SimpleCMDVisualize]" + str(e)) -------------------------------------------------------------------------------- /DLNest/Information/TaskInformation.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | try: 4 | from ..SavePackage.SavePackage import SavePackage 5 | except ImportError: 6 | from DLNest.SavePackage.SavePackage import SavePackage 7 | 8 | class TaskInformation: 9 | def __init__( 10 | self, 11 | savePackage : SavePackage, 12 | devices : list = [], 13 | memoryConsumption : int = -1, 14 | multiGPU : bool = False, 15 | DDP : bool = False, 16 | loadCkpt : bool = False, 17 | checkpointID : int = -1 18 | ): 19 | self.ID = ("%.4f" % time.time())[6:] + "_" + str(random.randint(0,9)) 20 | self.args = savePackage.args 21 | self.devices = devices 22 | self.memoryConsumption = memoryConsumption 23 | self.multiGPU = multiGPU 24 | self.DDP = DDP 25 | if self.DDP: 26 | self.multiGPU = True 27 | self.startTime = time.time() 28 | 29 | self.type = "None" 30 | self.status = "Pending" 31 | self.process = None 32 | 33 | self.port = str(45535 + random.randint(0,20000)) # default port 34 | self.address = 'localhost' # default address 35 | self.savePackage = savePackage 36 | self.loadCkpt = loadCkpt 37 | self.checkpointID = checkpointID 38 | 39 | self.extraInfo = {} 40 | 41 | def getDict(self): 42 | return { 43 | "ID" : self.ID, 44 | "args" : self.args, 45 | "devices" : self.devices, 46 | "memory_consumption" : self.memoryConsumption, 47 | "multi_GPU" : self.multiGPU, 48 | "DDP" : self.DDP, 49 | "type" : self.type, 50 | "status" : self.status, 51 | "load_ckpt" : self.loadCkpt, 52 | "checkpoint_ID" : self.checkpointID 53 | } -------------------------------------------------------------------------------- /DLNest/Information/GPUInformation.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pynvml 3 | try: 4 | from DeviceInformation import DeviceInformation 5 | except ImportError: 6 | from .DeviceInformation import DeviceInformation 7 | 8 | class GPUInformation(DeviceInformation): 9 | def __init__(self,ID): 10 | super(GPUInformation,self).__init__("GPU") 11 | self.ID = ID 12 | try: 13 | # 若卡信息获取失败,当作卡失效,设置isBreak = True 14 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(self.ID) 15 | self.meminfo = pynvml.nvmlDeviceGetMemoryInfo(self.handle) 16 | self.totalMemory = self.meminfo.total / 1024 / 1024 17 | self.isBreak = False 18 | except Exception as e: 19 | # 卡失效,totalMemory设为0 20 | self.totalMemory = 0 21 | self.isBreak = True 22 | 23 | def restartGPU(self): 24 | """ 25 | 重新尝试得到显卡信息,若成功,则返回True,同时修改isBreak,若失败则返回False 26 | """ 27 | try: 28 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(self.ID) 29 | self.meminfo = pynvml.nvmlDeviceGetMemoryInfo(self.handle) 30 | self.totalMemory = self.meminfo.total / 1024 / 1024 31 | self.isBreak = False 32 | return True 33 | except Exception: 34 | self.isBreak = True 35 | return False 36 | 37 | def getFreeMemory(self): 38 | """ 39 | return in MB 40 | """ 41 | if self.isBreak: 42 | self.restartGPU() 43 | if self.isBreak: 44 | return 0 45 | try: 46 | self.meminfo = pynvml.nvmlDeviceGetMemoryInfo(self.handle) 47 | return self.meminfo.free / 1024 / 1024 48 | except Exception as e: 49 | self.isBreak = True 50 | return 0 51 | 52 | def getDeviceStr(self): 53 | return "cuda:" + str(self.ID) -------------------------------------------------------------------------------- /DLNest/FactoryFiles/FactoryMNIST/Model/Runner.py: -------------------------------------------------------------------------------- 1 | from DLNest.Common.RunnerBaseTorch import RunnerBaseTorch 2 | 3 | import torch 4 | import torch.nn as nn 5 | from MNISTCNN import MNISTModel 6 | 7 | class Runner(RunnerBaseTorch): 8 | def init(self,args : dict,datasetInfo : dict = None): 9 | self.model = MNISTModel(args) # if BN layers need to be sync, use the following code 10 | # self.model = self.register(model, syncBN=True) 11 | self.cost = nn.CrossEntropyLoss() 12 | 13 | def initLog(self): 14 | return { 15 | "loss" : [], 16 | "acc" : [], 17 | } 18 | 19 | def initOptimizer(self): 20 | self.optimizer = torch.optim.Adam(self.model.parameters()) 21 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 3) 22 | 23 | def runOneStep(self,data, log : dict, iter : int, epoch : int): 24 | self.model.zero_grad() 25 | x,y = data 26 | pred = self.model(x) 27 | loss = self.cost(pred,y) 28 | loss.backward() 29 | self.optimizer.step() 30 | 31 | loss = self._reduceMean(loss) 32 | log["loss"].append(loss.detach().item()) 33 | 34 | def visualize(self,epoch : int, iter : int, log : dict): 35 | pass 36 | 37 | def validationInit(self): 38 | self.totalCorrect = 0 39 | self.total = 0 40 | 41 | def validateABatch(self,data, iter : int): 42 | x,y = data 43 | with torch.no_grad(): 44 | output = self.model(x) 45 | _,pred = torch.max(output, 1) 46 | correct = (pred == y).sum() / y.shape[0] 47 | correct = self._reduceMean(correct) 48 | self.totalCorrect += correct 49 | self.total += 1 50 | 51 | def validationAnalyze(self, log : dict): 52 | acc = self.totalCorrect / self.total 53 | log["acc"].append(acc.item()) -------------------------------------------------------------------------------- /DLNest/Operations/New.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | import json 4 | from DLNest.Operations.UsePlugin import usePlugin 5 | 6 | def new(targetDir : str, MNIST : bool = False, pluginsName = []): 7 | projectPath = Path(targetDir).absolute() 8 | # If target path already exists, return 9 | if projectPath.exists(): 10 | print("Path already exists") 11 | return 12 | 13 | # Get the factory file 14 | factoryPath = Path(__file__).parent.parent / "FactoryFiles" 15 | if MNIST: 16 | factoryPath = factoryPath / "FactoryMNIST" 17 | else: 18 | factoryPath = factoryPath / "FactoryClean" 19 | 20 | # Copy to target 21 | shutil.copytree(factoryPath,projectPath) 22 | 23 | # Modify save_root and root_file_path in root_config.json 24 | rootConfigPath = projectPath / "root_config.json" 25 | rootConfig = {} 26 | with rootConfigPath.open("r") as fp: 27 | rootConfig = json.load(fp) 28 | rootConfig["save_root"] = str(projectPath / "Saves") 29 | rootConfig["root_file_path"] = str(projectPath) 30 | with rootConfigPath.open("w") as fp: 31 | json.dump(rootConfig,fp,indent = 4,separators = (',',':')) 32 | 33 | # If MNIST, modify the data_root in dataset_config.json 34 | if MNIST: 35 | datasetConfigPath = projectPath / "dataset_config.json" 36 | datasetConfig = {} 37 | with datasetConfigPath.open("r") as fp: 38 | datasetConfig = json.load(fp) 39 | datasetConfig["dataset_config"]["data_root"] = str(projectPath / "MNIST") 40 | with datasetConfigPath.open("w") as fp: 41 | json.dump(datasetConfig,fp,indent = 4, separators = (',',':')) 42 | 43 | 44 | # Add plugins 45 | for pluginName in pluginsName: 46 | usePlugin(targetDir, pluginName = pluginName, full = False) 47 | 48 | print("Create a project in " + targetDir + ".") 49 | -------------------------------------------------------------------------------- /DLNest/Information/AnalyzeTask.py: -------------------------------------------------------------------------------- 1 | try: 2 | from TaskInformation import TaskInformation 3 | from ..SavePackage.SavePackage import SavePackage 4 | from ..Output.AnalyzerBuffer import AnalyzerBuffer 5 | except ImportError: 6 | from DLNest.Information.TaskInformation import TaskInformation 7 | from DLNest.SavePackage.SavePackage import SavePackage 8 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer 9 | from pathlib import Path 10 | from multiprocessing import Queue 11 | 12 | class AnalyzeTask(TaskInformation): 13 | def __init__( 14 | self, 15 | recordPath : str, 16 | scriptPath : str = "", 17 | checkpointID : int = -1, 18 | devices : list = [], 19 | memoryConsumption : int = -1, 20 | DDP : bool = False 21 | ): 22 | savePackage = SavePackage() 23 | savePackage.initFromAnExistSavePackage(recordPath) 24 | super(AnalyzeTask,self).__init__( 25 | savePackage = savePackage, 26 | devices = devices, 27 | memoryConsumption = memoryConsumption, 28 | multiGPU = False, 29 | DDP = DDP, 30 | loadCkpt = True if checkpointID != -2 else False, 31 | checkpointID = checkpointID 32 | ) 33 | self.ID = "A_" + self.ID 34 | self.recordPath = Path(recordPath) 35 | if scriptPath == "": 36 | self.scriptPath = self.recordPath.parent.parent / "AnalyzeScripts" 37 | else: 38 | self.scriptPath = Path(scriptPath) 39 | 40 | self.type = "Analyze" 41 | self.commandQueue = Queue() 42 | self.outputBuffer : AnalyzerBuffer = None 43 | 44 | def getDict(self): 45 | ret = super().getDict() 46 | ret["record_path"] = str(self.recordPath) 47 | ret["script_path"] = str(self.scriptPath) 48 | ret["checkpoint_ID"] = self.checkpointID 49 | return ret 50 | -------------------------------------------------------------------------------- /DLNest/Operations/UsePlugin.py: -------------------------------------------------------------------------------- 1 | from DLNest.Operations.GetPluginConfig import getPluginConfig 2 | from pathlib import Path 3 | import json 4 | 5 | def usePlugin(targetDir : str, pluginName : str, full : bool = False): 6 | """ 7 | Only return False when targetDir is wrong 8 | """ 9 | name, config = getPluginConfig(pluginName, full = full) 10 | if not config: 11 | print("Wrong plugin name: {}".format(pluginName)) 12 | return True 13 | 14 | root_dir = Path(targetDir).absolute() 15 | 16 | if not root_dir.exists(): 17 | print("Wrong target dir: {}".format(targetDir)) 18 | return False 19 | 20 | plugins_config_path = root_dir / "plugins_config.json" 21 | if plugins_config_path.exists(): 22 | with plugins_config_path.open("r") as f: 23 | pluginsConfig = json.load(f) 24 | if name in pluginsConfig["plugins"]: 25 | return True 26 | 27 | pluginsConfig["plugins"].append(name) 28 | pluginsConfig["plugins_config"].update(config) 29 | with plugins_config_path.open("w") as f: 30 | json.dump(pluginsConfig, f, indent = 4,separators = (',',':')) 31 | else: 32 | root_config_path = root_dir / "root_config.json" 33 | 34 | if not root_config_path.exists(): 35 | print("Wrong target dir: {}".format(targetDir)) 36 | return False 37 | 38 | with root_config_path.open("r") as f: 39 | allConfig = json.load(f) 40 | if name in allConfig["plugins"]: 41 | return True 42 | allConfig["plugins"].append(name) 43 | allConfig["plugins_config"].update(config) 44 | with root_config_path.open("w") as f: 45 | json.dump(allConfig, f, indent = 4,separators = (',',':')) 46 | 47 | return True 48 | 49 | def usePlugins(targetDir, pluginsName : list, full : bool = False): 50 | for name in pluginsName: 51 | if not usePlugin(targetDir, name, full): 52 | return -------------------------------------------------------------------------------- /DLNest/ShellClient/Windows/CommandInput.py: -------------------------------------------------------------------------------- 1 | from apscheduler.schedulers.background import BackgroundScheduler 2 | from typing import Callable, Iterable, List, Optional 3 | 4 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box 5 | from prompt_toolkit.buffer import Buffer 6 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 7 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl 8 | from prompt_toolkit.key_binding import KeyBindings 9 | from prompt_toolkit.document import Document 10 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit 11 | from prompt_toolkit.lexers import Lexer 12 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples 13 | 14 | from DLNest.ShellClient.Windows.Utils.Completers import getCommandCompleter 15 | 16 | class CommandInput: 17 | def __init__(self,title: str = "DLNest Command Line",onAccept : Callable[[str],None] = None): 18 | self.kb = KeyBindings() 19 | self.onAccept = onAccept 20 | 21 | def accept_text(buf : Buffer): 22 | if not self.onAccept is None: 23 | self.onAccept(buf.text) 24 | return False 25 | 26 | self.completer = getCommandCompleter() 27 | self.title = title 28 | 29 | self.text = TextArea( 30 | height=3, 31 | auto_suggest=AutoSuggestFromHistory(), 32 | completer=self.completer, 33 | prompt=[("#fff5ee","DLNest>>")], 34 | accept_handler=accept_text, 35 | scrollbar=True, 36 | multiline=False 37 | ) 38 | self.text.height=3 39 | self.infoBar = Label( 40 | [("bg:#006060","Press Enter to enter a command. Press ctrl + c to exit.")] 41 | ) 42 | 43 | self.content = HSplit([ 44 | self.text, 45 | self.infoBar 46 | ]) 47 | 48 | def getWindow(self): 49 | return Frame( 50 | self.content, 51 | title = self.title, 52 | style="class:command_frame" 53 | ) -------------------------------------------------------------------------------- /DLNest/Information/DeviceInformation.py: -------------------------------------------------------------------------------- 1 | import time 2 | try: 3 | from TaskInformation import TaskInformation 4 | except ImportError: 5 | from DLNest.Information.TaskInformation import TaskInformation 6 | 7 | class DeviceInformation: 8 | def __init__(self,type : str = ""): 9 | self.ID = -1 10 | self.type = type 11 | self.nowTask = 0 12 | self.totalMemory = 0 13 | self.runningTask = [] 14 | self.isBreak = False 15 | 16 | def getFreeMemory(self): 17 | return float('inf') 18 | 19 | def checkTasks(self): 20 | """ 21 | Check the tasks in this device, if one has no subprocess or the subprocess is dead, delete the task from the device 22 | """ 23 | newList = [] 24 | for item in self.runningTask: 25 | if item.process is None or not item.process.is_alive(): 26 | self.nowTask -= 1 27 | else: 28 | newList.append(item) 29 | self.runningTask = newList 30 | 31 | def addATask(self,newTask : TaskInformation): 32 | """ 33 | Add a task information to this GPU. 34 | """ 35 | self.nowTask += 1 36 | self.runningTask.append(newTask) 37 | 38 | def lastUseTime(self): 39 | """ 40 | Get the start time of the last task, if no running task, return 0. 41 | """ 42 | self.checkTasks() 43 | if len(self.runningTask) > 0: 44 | return self.runningTask[-1].startTime 45 | else: 46 | return 0 47 | 48 | def getTaskNum(self): 49 | self.checkTasks() 50 | return self.nowTask 51 | 52 | def getDict(self): 53 | self.checkTasks() 54 | return { 55 | "ID" : self.ID, 56 | "type" : self.type, 57 | "is_break" : self.isBreak, 58 | "num_tasks" : self.nowTask, 59 | "running_tasks" : [item.getDict() for item in self.runningTask], 60 | "total_memory" : self.totalMemory, 61 | "free_memory" : self.getFreeMemory(), 62 | "last_use_time" : self.lastUseTime() 63 | } 64 | 65 | def getDeviceStr(self): 66 | return "" -------------------------------------------------------------------------------- /DLNest/Run.py: -------------------------------------------------------------------------------- 1 | from DLNest.Operations.RunIndependent import runIndependent 2 | 3 | import os 4 | import pynvml 5 | import argparse 6 | 7 | class Arguments: 8 | def __init__(self,desc : str = ""): 9 | self._parser = argparse.ArgumentParser(description=desc) 10 | 11 | def parser(self): 12 | return self._parser 13 | 14 | class TaskArguments(Arguments): 15 | def __init__(self): 16 | super(TaskArguments, self).__init__(desc="Arguments for DLNest task.") 17 | 18 | self._parser.add_argument("-c",type=str, help="root configuration json file for this task.",required=True) 19 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)") 20 | self._parser.add_argument("-f",type=str, default = "", help="frequently changing configuration json file for this task.(default:None)") 21 | self._parser.add_argument("-devices",default = [],nargs='+', type=int) 22 | self._parser.add_argument("-ns",action='store_true', help="Set to save to the NOSAVE dir.") 23 | self._parser.add_argument("-ss",action='store_false', help="Set to also show on screen.") 24 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.") 25 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.") 26 | self._parser.add_argument("-sd",action='store_true',help="Set to use description as the save dir name.(coverd by ns)") 27 | 28 | def runTrain(args): 29 | assert not(args.DDP and args.CPU) 30 | if args.CPU: 31 | devices = [-1] 32 | else: 33 | devices = args.devices 34 | if devices == []: 35 | pynvml.nvmlInit() 36 | count = pynvml.nvmlDeviceGetCount() 37 | devices = [i for i in range(count)] 38 | runIndependent( 39 | configPath = args.c, 40 | freqPath = args.f, 41 | devices = devices, 42 | description = args.d, 43 | DDP = args.DDP, 44 | noSave = args.ns, 45 | useDescriptionToSave = args.sd, 46 | showOnScreen = args.ss 47 | ) 48 | 49 | if __name__ == "__main__": 50 | import sys 51 | if sys.path[0] != '': 52 | sys.path[0] = '' 53 | argparser = TaskArguments() 54 | parser = argparser.parser() 55 | args = parser.parse_args() 56 | runTrain(args) 57 | -------------------------------------------------------------------------------- /DLNest/Analyze.py: -------------------------------------------------------------------------------- 1 | from DLNest.Operations.AnalyzeIndependent import AnalyzeIndependent 2 | 3 | import os 4 | import pynvml 5 | import argparse 6 | from pathlib import Path 7 | import importlib 8 | 9 | class Arguments: 10 | def __init__(self,desc : str = ""): 11 | self._parser = argparse.ArgumentParser(description=desc) 12 | 13 | def parser(self): 14 | return self._parser 15 | 16 | class AnalyzeArguments(Arguments): 17 | def __init__(self): 18 | super(AnalyzeArguments, self).__init__(desc = "Arguments for DLNest analyze.") 19 | 20 | self._parser.add_argument("-r",type=str, help = "record path", required=True) 21 | self._parser.add_argument("-s",type=str, help = "script path", required=True) 22 | self._parser.add_argument("-c",type=int, default=-1,help = "checkpoint ID") 23 | self._parser.add_argument("-devices",default = [],nargs='+', type=int) 24 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.") 25 | 26 | def _loadAScript(filePath : Path,name : str): 27 | # load a script by name 28 | spec = importlib.util.spec_from_file_location( 29 | name, 30 | filePath 31 | ) 32 | module = importlib.util.module_from_spec(spec) 33 | spec.loader.exec_module(module) 34 | return module 35 | 36 | def _analyze(args): 37 | if args.CPU: 38 | devices = [-1] 39 | else: 40 | devices = args.devices 41 | if devices == []: 42 | devices = [0] 43 | 44 | scriptModule = _loadAScript(Path(args.s),"AScript") 45 | expFunc = scriptModule.__getattribute__("experience") 46 | 47 | AnalyzeIndependent( 48 | recordPath = args.r, 49 | expFunc = expFunc, 50 | scriptPath = "", 51 | checkpointID = args.c, 52 | devices = devices 53 | ) 54 | 55 | def analyze( 56 | recordPath : str, 57 | checkpointID : int = -1, 58 | scriptPath : str = "", 59 | devices = [0], 60 | expFunc = None 61 | ): 62 | AnalyzeIndependent( 63 | recordPath = recordPath, 64 | expFunc = expFunc, 65 | scriptPath = scriptPath, 66 | devices = devices, 67 | checkpointID = checkpointID 68 | ) 69 | 70 | if __name__ == "__main__": 71 | import sys 72 | if sys.path[0] != '': 73 | sys.path[0] = '' 74 | argparser = AnalyzeArguments() 75 | parser = argparser.parser() 76 | args = parser.parse_args() 77 | _analyze(args) -------------------------------------------------------------------------------- /DLNest/Information/test/test_DeviceInformation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from DLNest.Information.DeviceInformation import DeviceInformation 3 | from DLNest.Information.TaskInformation import TaskInformation 4 | import time 5 | class FakeSubprocess: 6 | def __init__(self,alive): 7 | self.alive = alive 8 | 9 | def is_alive(self): 10 | return self.alive 11 | 12 | @pytest.mark.Information 13 | class TestDeviceInformation: 14 | def test_init(self): 15 | self.DI = DeviceInformation("test") 16 | assert self.DI.type == "test" 17 | assert self.DI.nowTask == 0 18 | assert self.DI.totalMemory == 0 19 | assert self.DI.runningTask == [] 20 | assert self.DI.isBreak == False 21 | 22 | def test_getFreeMemory(self): 23 | self.DI = DeviceInformation("test") 24 | assert self.DI.getFreeMemory() > 10000000000000000000000000 25 | 26 | def initTasks(self): 27 | DI = DeviceInformation("test") 28 | fakeAliveSubprocess = FakeSubprocess(True) 29 | fakeDeadSubprocess = FakeSubprocess(False) 30 | aliveTask = TaskInformation() 31 | aliveTask.process = fakeAliveSubprocess 32 | deadTask = TaskInformation() 33 | deadTask.process = fakeDeadSubprocess 34 | DI.addATask(aliveTask) 35 | DI.addATask(deadTask) 36 | assert DI.nowTask == 2 37 | assert len(DI.runningTask) == 2 38 | assert DI.runningTask[0] == aliveTask 39 | assert DI.runningTask[1] == deadTask 40 | return DI,aliveTask 41 | 42 | def test_AddATask(self): 43 | DI,_ = self.initTasks() 44 | 45 | def test_lastUseTime(self): 46 | t = time.time() 47 | DI,_ = self.initTasks() 48 | lt = DI.lastUseTime() 49 | assert lt != 0 50 | assert abs(lt - t) < 0.1 51 | 52 | def test_checkTasks(self): 53 | DI,aliveTask = self.initTasks() 54 | DI.checkTasks() 55 | assert DI.nowTask == 1 56 | assert len(DI.runningTask) == 1 57 | assert DI.runningTask[-1] == aliveTask 58 | 59 | def test_getTaskNum(self): 60 | DI,_ = self.initTasks() 61 | assert DI.getTaskNum() == 1 62 | 63 | def test_getDict(self): 64 | DI,aliveTask = self.initTasks() 65 | retDict = DI.getDict() 66 | assert retDict["type"] == "test" 67 | assert retDict["is_break"] == False 68 | assert retDict["num_tasks"] == 1 69 | for item in retDict["running_tasks"]: 70 | for key in item: 71 | assert item[key] == aliveTask.getDict()[key] 72 | assert isinstance(retDict["total_memory"],float) or isinstance(retDict["total_memory"],int) 73 | assert isinstance(retDict["free_memory"],float) or isinstance(retDict["total_memory"],int) 74 | t = time.time() 75 | assert abs(t - retDict["last_use_time"]) < 0.1 -------------------------------------------------------------------------------- /DLNest/Output/AnalyzerBuffer.py: -------------------------------------------------------------------------------- 1 | from DLNest.Output.OutputLayerBase import OutputLayerBase 2 | 3 | from multiprocessing import Queue 4 | 5 | class AnalyzerBuffer(OutputLayerBase): 6 | def __init__(self): 7 | super(AnalyzerBuffer,self).__init__() 8 | self.styleDict = { 9 | '' : '#efefef', 10 | 'app' : '#2afa38 bold', 11 | 'ignored':'#ff7070', 12 | 'error' : '#ff0000', 13 | 'message' : '#efefef', 14 | } 15 | self.appName = "DLNest Analyzer" 16 | self.outputQueue = Queue() 17 | self.isSend = False 18 | 19 | def logMessage(self,message : str): 20 | self.putStyledText('app',"[" + self.appName + "] ") 21 | self.putStyledText('message',message) 22 | self.putStyledText('message','\n') 23 | self.sendData() 24 | 25 | def logIgnError(self,message : str): 26 | self.putStyledText('app',"[" + self.appName + "] ") 27 | self.putStyledText('ignored' , "[Ignored] ") 28 | self.putStyledText('message',message) 29 | self.putStyledText('message','\n') 30 | self.sendData() 31 | 32 | def logError(self,message : str): 33 | self.putStyledText('app',"[" + self.appName + "] ") 34 | self.putStyledText('error' , "[ERROR] ") 35 | self.putStyledText('message',message) 36 | self.putStyledText('message','\n') 37 | self.sendData() 38 | 39 | def logDebug(self,message : str): 40 | self.putStyledText('app',"[" + self.appName + "] ") 41 | self.putStyledText('ignored' , "[DEBUG] ") 42 | self.putStyledText('message',message) 43 | self.putStyledText('message','\n') 44 | self.sendData() 45 | 46 | def write(self,message : str): 47 | super().write(message) 48 | self.sendData() 49 | 50 | def sendData(self): 51 | # Temproral implementation. Trans all styled text to others. 52 | if not self.isSend: 53 | return 54 | if self.lock.acquire(): 55 | try: 56 | while not self.outputQueue.empty(): 57 | try: 58 | self.outputQueue.get(block=False) 59 | except Exception as e: 60 | pass 61 | 62 | self.outputQueue.put(self.styledText,block=False) 63 | finally: 64 | self.lock.release() 65 | 66 | def receiveData(self): 67 | if self.isSend: 68 | return 69 | if self.lock.acquire(): 70 | try: 71 | styledText = self.outputQueue.get(block = False) 72 | self.styledText = styledText 73 | except Exception as e: 74 | pass 75 | finally: 76 | self.lock.release() 77 | 78 | def getPlainText(self,from_token : int = -1,length : int = -1): 79 | self.receiveData() 80 | return super().getPlainText(from_token = from_token, length = length) 81 | 82 | def getStyledText(self,from_token : int = -1,length : int = -1): 83 | self.receiveData() 84 | return super().getStyledText(from_token = from_token, length = length) -------------------------------------------------------------------------------- /DLNest/Scheduler/Scheduler.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.TaskInformation import TaskInformation 2 | from DLNest.Executor.AnalyzeProcess import AnalyzeProcess 3 | from DLNest.Executor.TrainProcess import TrainProcess 4 | from DLNest.Information.InfoCenter import InfoCenter 5 | from DLNest.Scheduler.SchedulerStrategyBase import SchedulerStrategyBase 6 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer 7 | from DLNest.Scheduler.DefaultStrategy import DefaultStrategy 8 | 9 | from apscheduler.schedulers.background import BackgroundScheduler 10 | 11 | class Singleton(object): 12 | def __new__(cls, *args, **kwargs): 13 | if not hasattr(cls, '_instance'): 14 | orig = super(Singleton, cls) 15 | cls._instance = orig.__new__(cls, *args, **kwargs) 16 | return cls._instance 17 | 18 | 19 | class Scheduler(Singleton): 20 | def __init__( 21 | self, 22 | strategy : SchedulerStrategyBase = None, 23 | timeDelay : int = 60, 24 | maxTaskPerDevice : int = 10000, 25 | ): 26 | self.infoCenter = InfoCenter() 27 | self.timeDelay = timeDelay 28 | self.maxTaskPerDevice = maxTaskPerDevice 29 | 30 | if hasattr(self,"strategy"): 31 | if strategy != None: 32 | # change strategy 33 | self.strategy = strategy 34 | self.strategy.scheduler = self 35 | else: 36 | if strategy is None: 37 | # set default strategy 38 | strategy = DefaultStrategy() 39 | # set strategy 40 | self.strategy = strategy 41 | self.strategy.scheduler = self 42 | 43 | if not hasattr(self,"routineScheduler"): 44 | self.startRoutineTask() 45 | 46 | def giveATask(self,task : TaskInformation, otherArgs : dict): 47 | print("Received task " + task.ID + ".") 48 | 49 | task.extraInfo = otherArgs 50 | self.infoCenter.addATask(task) 51 | self.strategy.decide(self.maxTaskPerDevice,self.timeDelay) 52 | 53 | def runTask(self,task : TaskInformation,devices : [int]): 54 | task.devices = devices 55 | assert task.type == "Train" or task.type == "Analyze" 56 | if task.type == "Train": 57 | TProcess = TrainProcess(task) 58 | task.process = TProcess 59 | TProcess.start() 60 | elif task.type == "Analyze": 61 | ABuffer = AnalyzerBuffer() 62 | task.outputBuffer = ABuffer 63 | AProcess = AnalyzeProcess(task,outputBuffer = ABuffer) 64 | task.process = AProcess 65 | AProcess.start() 66 | 67 | print(task.type + " task " + task.ID + " is runing now.") 68 | 69 | self.infoCenter.runATask(task) 70 | 71 | def __routineRun(self): 72 | self.strategy.decide(self.maxTaskPerDevice,self.timeDelay) 73 | 74 | def changeTimeDelay(self,delay): 75 | self.timeDelay = delay 76 | self.routineScheduler.remove_job(self.routineJob.id) 77 | self.routineJob = self.routineScheduler.add_job(self.__routineRun,"interval",seconds = self.timeDelay) 78 | 79 | def changeMaxTaskPerDevice(self,newValue : int): 80 | self.maxTaskPerDevice = newValue 81 | 82 | def startRoutineTask(self): 83 | self.routineScheduler = BackgroundScheduler() 84 | self.routineJob = self.routineScheduler.add_job(self.__routineRun,"interval",seconds = self.timeDelay) 85 | self.routineScheduler.start() -------------------------------------------------------------------------------- /DLNest/Executor/AnalyzeProcess.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.AnalyzeTask import AnalyzeTask 2 | from DLNest.Executor.TaskProcess import TaskProcess 3 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer 4 | from pathlib import Path 5 | import sys 6 | import importlib 7 | import os 8 | 9 | import time 10 | from multiprocessing import Pipe 11 | 12 | class AnalyzeWrapper: 13 | def __init__(self,args : dict, runner , dataset, log : dict): 14 | self.args = args 15 | self.runner = runner 16 | self.model = runner # 后向兼容 17 | self.dataset = dataset 18 | self.log = log 19 | 20 | class AnalyzeProcess(TaskProcess): 21 | def __init__(self,task : AnalyzeTask, outputBuffer : AnalyzerBuffer = None, expFunc = None): 22 | super(AnalyzeProcess,self).__init__(task) 23 | self.commandQueue = task.commandQueue 24 | self.expFunc = expFunc 25 | self.output = outputBuffer 26 | 27 | def initOutput(self,rank = -1): 28 | assert rank == -1 29 | os.chdir(self.task.args["root_file_path"]) # Change CWD to the save package 30 | sys.path.append(str(self.task.scriptPath)) 31 | if self.output != None: 32 | self._debugf = sys.stdout 33 | sys.stdout = self.output 34 | sys.stderr = self.output 35 | self.output.appName = "DLNest Analyze Process" 36 | self.output.isSend = True 37 | 38 | def runExp(self): 39 | self.expFunc(self.analyzeWrapper) 40 | 41 | def mainLoop(self): 42 | self.analyzeWrapper = AnalyzeWrapper(self.task.args,self.runner,self.dataset,self.logDict) 43 | self.status.startAnalyzing() 44 | if self.expFunc != None: 45 | # Have a setted exp to run 46 | self.runExp() 47 | return 48 | 49 | print("Waiting for command...") 50 | while True: 51 | try: 52 | command = self.commandQueue.get(block=True) 53 | self.startTest(command) 54 | except Exception as e: 55 | if self.output != None: 56 | self.output.logIgnError(str(e)) 57 | else: 58 | print(e) 59 | 60 | def __loadAScript(self,filePath : Path,name : str): 61 | # load a script by name 62 | spec = importlib.util.spec_from_file_location( 63 | name, 64 | filePath 65 | ) 66 | module = importlib.util.module_from_spec(spec) 67 | spec.loader.exec_module(module) 68 | return module 69 | 70 | def startTest(self,command : str): 71 | scriptPath = self.task.scriptPath / (command + ".py") 72 | try: 73 | scriptModule = self.__loadAScript(scriptPath,"AScript") 74 | # 找到其中的experience函数 75 | self.expFunc = scriptModule.__getattribute__("experience") 76 | if self.output != None: 77 | self.output.appName = command 78 | self.runExp() 79 | except Exception as e: 80 | if self.output != None: 81 | self.output.logIgnError(str(e)) 82 | else: 83 | print(e) 84 | return 85 | finally: 86 | if self.output != None: 87 | self.output.appName = "DLNest Analyze Process" 88 | 89 | def run(self): 90 | try: 91 | super().run() 92 | except Exception as e: 93 | import traceback 94 | with open("./.analyzeException.tmp.txt","w") as f: 95 | f.write(traceback.format_exc()) -------------------------------------------------------------------------------- /DLNest/Common/RunnerBase.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import logging 3 | from DLNest.Plugins.Utils.CheckPlugins import checkPlugins,checkDictOutputPlugins 4 | from DLNest.Common.RunningStatus import RunningStatus 5 | 6 | 7 | class RunnerBase: 8 | def __init__(self,args : dict, plugins : list = [], status = RunningStatus()): 9 | self._plugins = plugins 10 | self._args = args 11 | self._status = status 12 | 13 | # for backward compatibility 14 | @property 15 | def _rank(self): 16 | if not "_warned_rank" in dir(self): 17 | self._warned_rank = True 18 | print("Runner._rank is deprecated, please use Runner._status.rank") 19 | return self._status.rank 20 | 21 | # for backward compatibility 22 | @property 23 | def _envType(self): 24 | if not "_warned_env" in dir(self): 25 | self._warned_env = True 26 | print("Runner._envType is deprecated, please use Runner._status.env") 27 | return self._status.env 28 | 29 | def getArgs(self): 30 | return self._args 31 | 32 | @checkPlugins 33 | def _runnerInit(self,args : dict, datasetInfo : dict = None): 34 | return self.init(args, datasetInfo) 35 | 36 | def init(self,args : dict, datasetInfo : dict = None): 37 | self.args = args 38 | pass 39 | 40 | @checkPlugins 41 | def _initOptimizer(self): 42 | return self.initOptimizer() 43 | 44 | def initOptimizer(self): 45 | pass 46 | 47 | def DDPOperation(self,rank : int): 48 | pass 49 | 50 | def afterDDP(self, rank : int): 51 | """ 52 | For SyncBN or something else. 53 | """ 54 | pass 55 | 56 | @checkDictOutputPlugins 57 | def _initLog(self): 58 | return self.initLog() 59 | 60 | def initLog(self): 61 | return {} 62 | 63 | @checkDictOutputPlugins 64 | def _getSaveDict(self): 65 | return self.getSaveDict() 66 | 67 | def getSaveDict(self): 68 | return {} 69 | 70 | @checkPlugins 71 | def _loadSaveDict(self,saveDict): 72 | return self.loadSaveDict(saveDict = saveDict) 73 | 74 | def loadSaveDict(self,saveDict): 75 | pass 76 | 77 | @checkPlugins 78 | def _runOneStep(self,data,log : dict, iter : int, epoch : int): 79 | return self.runOneStep(data = data, log = log, iter = iter, epoch = epoch) 80 | 81 | def runOneStep(self,data,log : dict, iter : int, epoch : int): 82 | pass 83 | 84 | @checkPlugins 85 | def _visualize(self,log : dict, iter : int, epoch : int): 86 | return self.visualize(log = log, iter = iter, epoch = epoch) 87 | 88 | def visualize(self,log : dict, iter : int, epoch : int): 89 | pass 90 | 91 | @checkPlugins 92 | def _validate(self, loader, log): 93 | return self.validate(loader, log) 94 | 95 | @checkPlugins 96 | def _validationInit(self): 97 | return self.validationInit() 98 | 99 | def validationInit(self): 100 | pass 101 | 102 | @checkPlugins 103 | def _validateABatch(self,data, iter : int): 104 | return self.validateABatch(data = data, iter = iter) 105 | 106 | def validateABatch(self,data, iter : int): 107 | pass 108 | 109 | @checkPlugins 110 | def _validationAnalyze(self, log : dict): 111 | return self.validationAnalyze(log = log) 112 | 113 | def validationAnalyze(self, log : dict): 114 | pass 115 | -------------------------------------------------------------------------------- /DLNest/ShellClient/Windows/DevicesInfoShower.py: -------------------------------------------------------------------------------- 1 | from apscheduler.schedulers.background import BackgroundScheduler 2 | from typing import Callable, Iterable, List, Optional 3 | 4 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box 5 | from prompt_toolkit.buffer import Buffer 6 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 7 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl 8 | from prompt_toolkit.key_binding import KeyBindings 9 | from prompt_toolkit.document import Document 10 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit 11 | from prompt_toolkit.lexers import Lexer 12 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples 13 | 14 | from DLNest.ShellClient.Windows.SCTextArea import SCTextArea 15 | 16 | class DevicesLexer(Lexer): 17 | def __init__(self): 18 | self.devicesInfo = [] 19 | self.IDLength = 5 20 | self.freeMemoryLength = 10 21 | self.runningTasksLength = 3 22 | 23 | def get_text(self): 24 | return "".join(["\n" for _ in self.devicesInfo]) 25 | 26 | def lex_document(self,document : Document) -> Callable[[int], StyleAndTextTuples]: 27 | def get_line(lineno : int) -> StyleAndTextTuples: 28 | try: 29 | device = self.devicesInfo[lineno] 30 | ID = "GPU " + str(device["ID"]) if device["ID"] != -1 else "CPU " 31 | isBreak = "Break" if device["is_break"] else "Valid" 32 | break_class = "break" if device["is_break"] else "valid" 33 | freeMemory = str(int(device["free_memory"])) + " MB" 34 | runningTasks = str(len(device["running_tasks"])) 35 | 36 | if len(ID) < self.IDLength: 37 | ID = " " * (self.IDLength - len(ID)) + ID 38 | 39 | if len(freeMemory) < self.freeMemoryLength: 40 | freeMemory = " " * (self.freeMemoryLength - len(freeMemory)) + freeMemory 41 | 42 | if len(runningTasks) < self.runningTasksLength: 43 | runningTasks = " " * (self.runningTasksLength - len(runningTasks)) + runningTasks 44 | 45 | return [ 46 | ("class:devices_id" ,"Devices: " + ID + " "), 47 | ("class:devices_status_" + break_class," " + isBreak + " "), 48 | ("class:devices_free_memory" ," F-Memory: " + freeMemory + " "), 49 | ("class:devices_tasks" ," #Tasks: " + runningTasks + " ") 50 | ] 51 | except Exception as e: 52 | return [] 53 | 54 | return get_line 55 | 56 | class DevicesInfoShower: 57 | def __init__( 58 | self, 59 | title : str = "Tasks", 60 | routineTask = None, 61 | freq : int = 1, 62 | style : str = "class:devices_info_shower" 63 | ): 64 | self.title = title 65 | self.routineTask = routineTask 66 | self.scheduler = BackgroundScheduler() 67 | if not routineTask is None: 68 | self.scheduler.add_job(self.routineTask,'interval',seconds=freq,args=[self]) 69 | self.scheduler.start() 70 | self.lexer = DevicesLexer() 71 | 72 | self.shower = SCTextArea( 73 | lexer = self.lexer, 74 | wrap_lines=False 75 | ) 76 | 77 | self.style = style 78 | self.window = Frame( 79 | self.shower, 80 | self.title, 81 | style = self.style, 82 | height=10, 83 | width=60 84 | ) 85 | 86 | def getWindow(self): 87 | return self.window -------------------------------------------------------------------------------- /DLNest/ShellClient/Windows/TaskInfoShower.py: -------------------------------------------------------------------------------- 1 | from apscheduler.schedulers.background import BackgroundScheduler 2 | from typing import Callable, Iterable, List, Optional 3 | from pathlib import Path 4 | 5 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box 6 | from prompt_toolkit.buffer import Buffer 7 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 8 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl 9 | from prompt_toolkit.key_binding import KeyBindings 10 | from prompt_toolkit.document import Document 11 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit 12 | from prompt_toolkit.lexers import Lexer 13 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples 14 | 15 | from DLNest.ShellClient.Windows.SCTextArea import SCTextArea 16 | 17 | class taskLexer(Lexer): 18 | def __init__(self): 19 | self.taskInfo = [] 20 | self.saveLength = 24 21 | self.devicesLength = 10 22 | 23 | def get_text(self): 24 | return "".join(["\n" for _ in self.taskInfo]) 25 | 26 | def lex_document(self,document : Document) -> Callable[[int], StyleAndTextTuples]: 27 | def get_line(lineno : int) -> StyleAndTextTuples: 28 | try: 29 | task = self.taskInfo[lineno] 30 | style_base = "class:pending_task" 31 | if task["status"] == "Running": 32 | style_base = "class:running_task" 33 | # elif task["status"] == "Suspend": 34 | # style_base = "class:suspend_task" 35 | ID = task["ID"] 36 | taskType = " Train " if task["type"] == "Train" else "Analyze" 37 | devices = " ".join([str(item) for item in task["devices"]]) 38 | description = task["description"] if "description" in task else "" 39 | save = Path(task["args"]["root_file_path"]).stem 40 | 41 | if len(save) < self.saveLength: 42 | save += " " * (self.saveLength - len(save)) 43 | elif len(save) > self.saveLength: 44 | save = save[:self.saveLength - 2] + ".." 45 | 46 | if len(devices) < self.devicesLength: 47 | devices += " " * (self.devicesLength - len(devices)) 48 | 49 | return [ 50 | (style_base + "_status" , "Status : " + task["status"] + " "), 51 | (style_base + "_type", " Type : " + taskType + " "), 52 | (style_base +"_id" , " ID : " + ID + " "), 53 | (style_base +"_device" , " Devices : " + devices + " "), 54 | (style_base +"_time" , " Folder : " + save + " "), 55 | (style_base +"_des" , " Note : " + description + " ") 56 | ] 57 | except Exception as e: 58 | return [] 59 | 60 | return get_line 61 | 62 | class TaskInfoShower: 63 | def __init__( 64 | self, 65 | title : str = "Tasks", 66 | routineTask = None, 67 | freq : int = 1, 68 | style : str = "class:task_info_shower" 69 | ): 70 | 71 | self.title = title 72 | self.routineTask = routineTask 73 | self.scheduler = BackgroundScheduler() 74 | if not routineTask is None: 75 | self.scheduler.add_job(self.routineTask,'interval',seconds=freq,args=[self]) 76 | self.scheduler.start() 77 | self.lexer =taskLexer() 78 | 79 | self.shower = SCTextArea( 80 | lexer = self.lexer, 81 | wrap_lines=False 82 | ) 83 | 84 | self.style = style 85 | self.window = Frame( 86 | self.shower, 87 | self.title, 88 | style = self.style, 89 | height=10 90 | ) 91 | 92 | def getWindow(self): 93 | return self.window -------------------------------------------------------------------------------- /DLNest/Information/TrainTask.py: -------------------------------------------------------------------------------- 1 | try: 2 | from TaskInformation import TaskInformation 3 | from ..SavePackage.SavePackage import SavePackage 4 | except ImportError: 5 | from DLNest.Information.TaskInformation import TaskInformation 6 | from DLNest.SavePackage.SavePackage import SavePackage 7 | 8 | from multiprocessing import Queue 9 | 10 | class TrainTask(TaskInformation): 11 | @classmethod 12 | def fromRecord( 13 | cls, 14 | recordPath : str, 15 | checkpointID : int = -1, 16 | devices : list = [], 17 | memoryConsumption : int = -1, 18 | multiGPU : bool = False, 19 | DDP : bool = False, 20 | description : str = "" 21 | ): 22 | # Get save package from existing package 23 | savePackage = SavePackage() 24 | savePackage.initFromAnExistSavePackage(recordPath) 25 | 26 | retTask = cls( 27 | savePackage = savePackage, 28 | devices = devices, 29 | memoryConsumption = memoryConsumption, 30 | multiGPU = multiGPU, 31 | DDP = DDP, 32 | description = description, 33 | loadCkpt = True, 34 | checkpointID = checkpointID 35 | ) 36 | return retTask 37 | 38 | @classmethod 39 | def fromConfigFile( 40 | cls, 41 | configPath : str, 42 | freqPath : str = "", 43 | devices : list = [], 44 | memoryConsumption : int = -1, 45 | multiGPU : bool = False, 46 | DDP : bool = False, 47 | description : str = "", 48 | noSave : bool = False, 49 | useDescriptionToSave : bool = False 50 | ): 51 | # Get save package by config path 52 | savePackage = SavePackage(configPath = configPath,freqPath = freqPath) 53 | 54 | # Making special save dir name by noSave or useDescriptionToSave 55 | saveName = "" 56 | if noSave: 57 | saveName = "NOSAVE" 58 | elif useDescriptionToSave: 59 | saveName = description 60 | savePackage.saveToNewDir(saveName) 61 | 62 | retTask = cls( 63 | savePackage = savePackage, 64 | devices = devices, 65 | memoryConsumption = memoryConsumption, 66 | multiGPU = multiGPU, 67 | DDP = DDP, 68 | description = description, 69 | noSave = noSave, 70 | useDescriptionToSave = useDescriptionToSave, 71 | loadCkpt = False 72 | ) 73 | return retTask 74 | 75 | def __init__( 76 | self, 77 | savePackage : SavePackage, 78 | devices : list = [], 79 | memoryConsumption : int = -1, 80 | multiGPU : bool = False, 81 | DDP : bool = False, 82 | description : str = "", 83 | noSave : bool = False, 84 | useDescriptionToSave : bool = False, 85 | loadCkpt : bool = False, 86 | checkpointID : int = -1 87 | ): 88 | super(TrainTask,self).__init__( 89 | savePackage = savePackage, 90 | devices = devices, 91 | memoryConsumption = memoryConsumption, 92 | multiGPU = multiGPU, 93 | DDP = DDP, 94 | loadCkpt = loadCkpt, 95 | checkpointID = checkpointID 96 | ) 97 | self.ID = "T_" + self.ID 98 | self.description = description 99 | self.noSave = noSave 100 | self.useDescriptionToSave = useDescriptionToSave 101 | 102 | self.type = "Train" 103 | 104 | self.commandQueue = Queue() 105 | 106 | if description != "": 107 | savePackage.saveVisualString("desc: " + description) 108 | 109 | def getDict(self): 110 | ret = super().getDict() 111 | ret["description"] = self.description 112 | return ret -------------------------------------------------------------------------------- /DLNest/Plugins/MailsNote.py: -------------------------------------------------------------------------------- 1 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase as DPB 2 | from DLNest.Plugins.Utils.SendMailTools import sendSelfMail 3 | import DLNest 4 | import traceback 5 | import logging 6 | 7 | class DLNestPlugin(DPB): 8 | _NAME = "MailsNote" 9 | _config = { 10 | "enable_list" : ["Aborting","TrainFinish"], 11 | "username" : "", 12 | "password" : "", 13 | "host" : "mail.fudan.edu.cn", 14 | "port" : 25 15 | } 16 | _defaultKeys = ["enable_list","username","password","host"] 17 | 18 | def trainAborting(self, exception : Exception): 19 | if self._status.rank != 0 and self._status.rank != -1: 20 | return 21 | excStr = traceback.format_exc() 22 | 23 | pluginName = DLNestPlugin._NAME 24 | enable = "Aborting" in DPB.getArgs(self, pluginName, "enable_list", []) 25 | if not enable: 26 | logging.debug("MailsNote is not enabled for aborting") 27 | return 28 | 29 | username = DPB.getArgs(self, pluginName, "username", "") 30 | password = DPB.getArgs(self, pluginName, "password", "") 31 | host = DPB.getArgs(self, pluginName, "host", "mail.fudan.edu.cn") 32 | port = DPB.getArgs(self, pluginName, "port", 25) 33 | message = "Train task in {} aborted with this exception\n".format(self.getArgs()["root_file_path"]) + excStr 34 | subject = "Train Aborted!" 35 | sendSelfMail( 36 | username = username, 37 | password = password, 38 | message = message, 39 | subject = subject, 40 | host = host, 41 | port = port 42 | ) 43 | logging.info("[MailsNote] " + message) 44 | 45 | def ATrain(self): 46 | if self._status.rank != 0 and self._status.rank != -1: 47 | return 48 | pluginName = DLNestPlugin._NAME 49 | enable = "TrainFinish" in DPB.getArgs(self, pluginName, "enable_list", []) 50 | if not enable: 51 | logging.debug("MailsNote is not enabled for train finish") 52 | return 53 | 54 | username = DPB.getArgs(self, pluginName, "username", "") 55 | password = DPB.getArgs(self, pluginName, "password", "") 56 | host = DPB.getArgs(self, pluginName, "host", "mail.fudan.edu.cn") 57 | port = DPB.getArgs(self, pluginName, "port", 25) 58 | message = "Train task in {} finished\n".format(self.getArgs()["root_file_path"]) 59 | 60 | resultStr = DLNest.Plugins.MailsNote.DLNestPlugin._getResultsStr() # To get correct class function and var. 61 | if len(resultStr) != 0: 62 | message += "\nThe final results are: \n" + resultStr 63 | 64 | subject = "Train Finished!" 65 | sendSelfMail( 66 | username = username, 67 | password = password, 68 | message = message, 69 | subject = subject, 70 | host = host, 71 | port = port 72 | ) 73 | logging.info("[MailsNote] " + message) 74 | 75 | _result_values = {} 76 | 77 | @classmethod 78 | def giveResultValues(cls, resultsDict : dict): 79 | cls._result_values = resultsDict 80 | 81 | @classmethod 82 | def _getResultsStr(cls): 83 | str = "" 84 | for key in cls._result_values: 85 | str += "| {} : {}".format(key, cls._result_values[key]) 86 | return str 87 | 88 | def custom(self, message, subject): 89 | pluginName = DLNestPlugin._NAME 90 | username = DPB.getArgs(self, pluginName, "username", "") 91 | password = DPB.getArgs(self, pluginName, "password", "") 92 | host = DPB.getArgs(self, pluginName, "host", "mail.fudan.edu.cn") 93 | port = DPB.getArgs(self, pluginName, "port", 25) 94 | message = "Train task in {} gives a message\n".format(self.getArgs()["root_file_path"]) + message 95 | 96 | sendSelfMail( 97 | username = username, 98 | password = password, 99 | message = message, 100 | subject = subject, 101 | host = host, 102 | port = port 103 | ) 104 | logging.info("[MailsNote] " + message) 105 | 106 | def SOTA(self, key, value): 107 | value = str(value) 108 | message = "Congratulations! SOTA performance in {} has been made by your train task {} with value {}!".format(key, self.getArgs()["root_file_path"], value) 109 | subject = "SOTA for {}!".format(key) 110 | DLNestPlugin.custom(self, message, subject) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### PyCharm ### 2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 4 | 5 | # User-specific stuff 6 | .idea/**/workspace.xml 7 | .idea/**/tasks.xml 8 | .idea/**/usage.statistics.xml 9 | .idea/**/dictionaries 10 | .idea/**/shelf 11 | 12 | # Generated files 13 | .idea/**/contentModel.xml 14 | 15 | # Sensitive or high-churn files 16 | .idea/**/dataSources/ 17 | .idea/**/dataSources.ids 18 | .idea/**/dataSources.local.xml 19 | .idea/**/sqlDataSources.xml 20 | .idea/**/dynamic.xml 21 | .idea/**/uiDesigner.xml 22 | .idea/**/dbnavigator.xml 23 | 24 | # Gradle 25 | .idea/**/gradle.xml 26 | .idea/**/libraries 27 | 28 | # Gradle and Maven with auto-import 29 | # When using Gradle or Maven with auto-import, you should exclude module files, 30 | # since they will be recreated, and may cause churn. Uncomment if using 31 | # auto-import. 32 | # .idea/modules.xml 33 | # .idea/*.iml 34 | # .idea/modules 35 | # *.iml 36 | # *.ipr 37 | 38 | # CMake 39 | cmake-build-*/ 40 | 41 | # Mongo Explorer plugin 42 | .idea/**/mongoSettings.xml 43 | 44 | # File-based project format 45 | *.iws 46 | 47 | # IntelliJ 48 | out/ 49 | 50 | # mpeltonen/sbt-idea plugin 51 | .idea_modules/ 52 | 53 | # JIRA plugin 54 | atlassian-ide-plugin.xml 55 | 56 | # Cursive Clojure plugin 57 | .idea/replstate.xml 58 | 59 | # Crashlytics plugin (for Android Studio and IntelliJ) 60 | com_crashlytics_export_strings.xml 61 | crashlytics.properties 62 | crashlytics-build.properties 63 | fabric.properties 64 | 65 | # Editor-based Rest Client 66 | .idea/httpRequests 67 | 68 | # Android studio 3.1+ serialized cache file 69 | .idea/caches/build_file_checksums.ser 70 | 71 | ### PyCharm Patch ### 72 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 73 | 74 | # *.iml 75 | # modules.xml 76 | # .idea/misc.xml 77 | # *.ipr 78 | 79 | # Sonarlint plugin 80 | .idea/**/sonarlint/ 81 | 82 | # SonarQube Plugin 83 | .idea/**/sonarIssues.xml 84 | 85 | # Markdown Navigator plugin 86 | .idea/**/markdown-navigator.xml 87 | .idea/**/markdown-navigator/ 88 | 89 | ### Python ### 90 | # Byte-compiled / optimized / DLL files 91 | __pycache__/ 92 | *.py[cod] 93 | *$py.class 94 | 95 | # C extensions 96 | *.so 97 | 98 | # Distribution / packaging 99 | .Python 100 | build/ 101 | develop-eggs/ 102 | dist/ 103 | downloads/ 104 | eggs/ 105 | .eggs/ 106 | lib/ 107 | lib64/ 108 | parts/ 109 | sdist/ 110 | var/ 111 | wheels/ 112 | pip-wheel-metadata/ 113 | share/python-wheels/ 114 | *.egg-info/ 115 | .installed.cfg 116 | *.egg 117 | MANIFEST 118 | 119 | # PyInstaller 120 | # Usually these files are written by a python script from a template 121 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 122 | *.manifest 123 | *.spec 124 | 125 | # Installer logs 126 | pip-log.txt 127 | pip-delete-this-directory.txt 128 | 129 | # Unit test / coverage reports 130 | htmlcov/ 131 | .tox/ 132 | .nox/ 133 | .coverage 134 | .coverage.* 135 | .cache 136 | nosetests.xml 137 | coverage.xml 138 | *.cover 139 | .hypothesis/ 140 | .pytest_cache/ 141 | 142 | # Translations 143 | *.mo 144 | *.pot 145 | 146 | # Scrapy stuff: 147 | .scrapy 148 | 149 | # Sphinx documentation 150 | docs/_build/ 151 | 152 | # PyBuilder 153 | target/ 154 | 155 | # pyenv 156 | .python-version 157 | 158 | # pipenv 159 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 160 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 161 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 162 | # install all needed dependencies. 163 | #Pipfile.lock 164 | 165 | # celery beat schedule file 166 | celerybeat-schedule 167 | 168 | # SageMath parsed files 169 | *.sage.py 170 | 171 | # Spyder project settings 172 | .spyderproject 173 | .spyproject 174 | 175 | # Rope project settings 176 | .ropeproject 177 | 178 | # Mr Developer 179 | .mr.developer.cfg 180 | .project 181 | .pydevproject 182 | 183 | # mkdocs documentation 184 | /site 185 | 186 | # mypy 187 | .mypy_cache/ 188 | .dmypy.json 189 | dmypy.json 190 | 191 | # Pyre type checker 192 | .pyre/ 193 | 194 | ### VisualStudioCode ### 195 | .vscode/ 196 | .vscode/* 197 | !.vscode/settings.json 198 | !.vscode/tasks.json 199 | !.vscode/launch.json 200 | !.vscode/extensions.json 201 | 202 | ### VisualStudioCode Patch ### 203 | # Ignore all local history of files 204 | .history 205 | 206 | # End of https://www.gitignore.io/api/python,pycharm,visualstudiocode -------------------------------------------------------------------------------- /DLNest/ShellClient/Windows/ResultsOutput.py: -------------------------------------------------------------------------------- 1 | from apscheduler.schedulers.background import BackgroundScheduler 2 | from typing import Callable, Iterable, List, Optional 3 | 4 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box 5 | from prompt_toolkit.buffer import Buffer 6 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 7 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl 8 | from prompt_toolkit.key_binding import KeyBindings 9 | from prompt_toolkit.document import Document 10 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit 11 | from prompt_toolkit.lexers import Lexer 12 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples 13 | 14 | from DLNest.ShellClient.Windows.SCTextArea import SCTextArea 15 | 16 | class styledTextLexer(Lexer): 17 | def __init__(self,styled_text : StyleAndTextTuples = []): 18 | self.styled_text = styled_text 19 | self.DEBUG = False 20 | 21 | def __get_styled_lines(self): 22 | self.styled_text_lines = [] 23 | now = [] 24 | for item in self.styled_text: 25 | finish = False 26 | if "\n" in item[1]: 27 | finish = True 28 | item = (item[0],item[1].replace("\n", "")) 29 | now.append(item) 30 | if finish: 31 | self.styled_text_lines.append(now) 32 | now = [] 33 | self.styled_text_lines.append(now) 34 | 35 | def lex_document(self,document : Document) -> Callable[[int], StyleAndTextTuples]: 36 | self.__get_styled_lines() 37 | lines = document.lines 38 | 39 | if self.DEBUG: 40 | with open("/root/STLexerDEBUG.txt","w") as f: 41 | print(self.styled_text,file = f) 42 | print(self.styled_text_lines,file=f) 43 | print(lines,file = f) 44 | print(len(lines),len(self.styled_text_lines),file=f) 45 | 46 | def get_line(lineno : int) -> StyleAndTextTuples: 47 | try: 48 | return self.styled_text_lines[lineno] 49 | except Exception: 50 | return [("","")] 51 | 52 | return get_line 53 | 54 | class ResultsOutput: 55 | def __init__( 56 | self, 57 | title : str = "DLNest Output", 58 | routineTask = None, 59 | freq : int = 1, 60 | style : str = "class:results_output" 61 | ): 62 | self.title = title 63 | self.routineTask = routineTask 64 | self.scheduler = BackgroundScheduler() 65 | if not routineTask is None: 66 | self.scheduler.add_job(self.routineTask,'interval',seconds=freq,args=[self]) 67 | self.scheduler.start() 68 | self.lexer = styledTextLexer() 69 | 70 | self.shower = SCTextArea( 71 | self.lexer 72 | ) 73 | 74 | self.style = style 75 | self.setKeyBinding() 76 | self.window = Frame( 77 | self.shower, 78 | self.title, 79 | style=self.style, 80 | modal=True, 81 | key_bindings=self.kb 82 | ) 83 | 84 | def setKeyBinding(self): 85 | self.kb = KeyBindings() 86 | 87 | @self.kb.add("escape") 88 | def toEnd(event): 89 | self.shower.buffer.cursor_position = len(self.shower.document.text) 90 | 91 | def getWindow(self): 92 | return self.window 93 | 94 | class AnalyzeOutput(ResultsOutput): 95 | def __init__( 96 | self, 97 | title : str = "DLNest Output", 98 | routineTask = None, 99 | freq : int = 1, 100 | style : str = "class_analyzer_output" 101 | ): 102 | super(AnalyzeOutput,self).__init__( 103 | title, 104 | routineTask, 105 | freq, 106 | style 107 | ) 108 | 109 | self.infoText = FormattedTextControl( 110 | [("", " No analyze task is running ")], 111 | focusable=False, 112 | show_cursor=False 113 | ) 114 | 115 | self.infoWindow = Window( 116 | content=self.infoText 117 | ) 118 | 119 | self.infoLabel = Box( 120 | body=self.infoWindow, 121 | height=3, 122 | padding_top=1, 123 | padding_bottom=1, 124 | # padding_left=3, 125 | # padding_right=3, 126 | style="class:analyzer_info_label" 127 | ) 128 | 129 | self.window = Frame( 130 | HSplit([ 131 | self.infoLabel, 132 | self.shower 133 | ]), 134 | title=self.title, 135 | style = self.style, 136 | modal = True, 137 | key_bindings=self.kb 138 | ) -------------------------------------------------------------------------------- /DLNest/Scheduler/DefaultStrategy.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.InfoCenter import InfoCenter 2 | from DLNest.Scheduler.SchedulerStrategyBase import SchedulerStrategyBase 3 | from DLNest.Information.DeviceInformation import DeviceInformation 4 | from DLNest.Information.TaskInformation import TaskInformation 5 | 6 | import time 7 | 8 | class DefaultStrategy(SchedulerStrategyBase): 9 | def __checkTaskNumAndLastUseTime(self, device : DeviceInformation): 10 | """ 11 | Check isBreak, task num, last use time 12 | """ 13 | if device.isBreak: 14 | return False 15 | 16 | # Check task num 17 | taskNum = device.getTaskNum() 18 | if taskNum >= self.maxTaskPerDevice: 19 | return False 20 | 21 | # Check last use time 22 | delta = time.time() - device.lastUseTime() 23 | if delta < self.maxTimeDelay: 24 | return False 25 | 26 | return True 27 | 28 | def __canTaskRunOnDevice(self,task : TaskInformation, device : DeviceInformation): 29 | """ 30 | Check free memory and last use time 31 | """ 32 | if not self.__checkTaskNumAndLastUseTime(device): 33 | return False 34 | 35 | # Check free memory 36 | memoryConsumption = device.totalMemory * 0.9 if task.memoryConsumption < 0 else task.memoryConsumption 37 | freeMemory = device.getFreeMemory() 38 | if freeMemory <= memoryConsumption: 39 | return False 40 | 41 | return True 42 | 43 | def __canTaskRunOnDevices(self, task : TaskInformation, devices): 44 | """ 45 | return [] if cannot 46 | return [ids] if can run 47 | won't care the CPU 48 | """ 49 | memoryConsumption = devices[0].totalMemory * 0.9 if task.memoryConsumption < 0 else task.memoryConsumption 50 | ret = [] 51 | memInfo = [] 52 | for device in devices: 53 | # don't consider CPU and broken devices 54 | if device.type == "CPU" or device.isBreak: 55 | continue 56 | if not self.__checkTaskNumAndLastUseTime(device): 57 | continue 58 | 59 | freeMemory = device.getFreeMemory() 60 | memInfo.append((freeMemory,device)) 61 | 62 | # use the devices with big free memory first 63 | memInfo.sort(key = lambda x:x[0], reverse = True) 64 | memCount = 0.0 65 | for item in memInfo: 66 | freeMemory,device = item 67 | ret.append(device.ID) 68 | memCount += freeMemory 69 | if memCount > memoryConsumption: 70 | break 71 | 72 | # return [] if no enough free memory 73 | if memCount <= memoryConsumption: 74 | return [] 75 | else: 76 | return ret 77 | 78 | def __findARuningDevice(self, task : TaskInformation, devices): 79 | """ 80 | return [] if cannot 81 | return [id] if can run 82 | won't care the CPU 83 | """ 84 | for device in devices: 85 | if device.type == "CPU": 86 | continue 87 | 88 | if self.__canTaskRunOnDevice(task, device): 89 | return [device.ID] 90 | 91 | return [] 92 | 93 | def decide(self, maxTaskPerDevice : int, maxTimeDelay : int): 94 | self.maxTaskPerDevice = maxTaskPerDevice 95 | self.maxTimeDelay = maxTimeDelay 96 | 97 | continueCheck = True 98 | while continueCheck: 99 | continueCheck = False 100 | 101 | tasks = self.infoCenter.usingTasks() 102 | try: 103 | numTasks = len(tasks) 104 | taskWait2Run = None 105 | for i in range(numTasks): 106 | # only consider pending tasks 107 | if tasks[i].status == "Pending": 108 | taskWait2Run = tasks[i] 109 | # only consider the oldest task 110 | break 111 | else: 112 | continue 113 | 114 | if taskWait2Run != None: 115 | # using device informations 116 | devices = self.infoCenter.usingDevicesInformation() 117 | usingDevices = [] 118 | if taskWait2Run.devices != []: 119 | # if the task has some decided devices(such as CPU), just run it. 120 | usingDevices = taskWait2Run.devices 121 | elif taskWait2Run.multiGPU: 122 | # multi GPU logic 123 | usingDevices = self.__canTaskRunOnDevices(taskWait2Run,devices) 124 | else: 125 | # single GPU logic 126 | usingDevices = self.__findARuningDevice(taskWait2Run,devices) 127 | # must release device before run task 128 | self.infoCenter.releaseDeviceInformation() 129 | else: 130 | break 131 | 132 | # if have proper device 133 | if usingDevices != []: 134 | self.scheduler.runTask(taskWait2Run,usingDevices) 135 | continueCheck = True 136 | finally: 137 | self.infoCenter.releaseTasks() 138 | -------------------------------------------------------------------------------- /DLNest/Output/OutputLayerBase.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import prompt_toolkit 3 | from prompt_toolkit.styles import Style 4 | import threading 5 | 6 | class OutputLayerBase(): 7 | def __init__(self): 8 | self.styleDict = { 9 | '' : '#efefef', 10 | 'app' : '#ff0f9f', 11 | 'status' : '#ff0000', 12 | 'message' : '#ffffff', 13 | } 14 | if not hasattr(self,'styledText'): 15 | self.styledText = [] 16 | self.maxBufferLen = 10000 17 | self.offset = 0 18 | self.lock = threading.Lock() 19 | self.isLineHead = True 20 | 21 | def getPlainText(self,from_token : int = -1,length : int = -1): 22 | if self.lock.acquire(): 23 | try: 24 | if len(self.styledText) == 0: 25 | return self.offset,"" 26 | 27 | if from_token == -1: 28 | from_token = self.offset 29 | 30 | if self.offset > from_token: 31 | # 开始位置已被删除 32 | return self.offset,None 33 | else: 34 | ret = "" 35 | start = from_token - self.offset 36 | end = 0 37 | # 开始位置还没数据 38 | if start >= len(self.styledText): 39 | return self.offset,None 40 | 41 | if length == -1: 42 | end = len(self.styledText) 43 | else: 44 | end = min(len(self.styledText),start + length) 45 | 46 | for i in range(start,end): 47 | ret = ret + self.styledText[i][1] 48 | return self.offset,ret 49 | finally: 50 | self.lock.release() 51 | 52 | def getStyledText(self,from_token : int = -1,length : int = -1): 53 | if self.lock.acquire(): 54 | try: 55 | if len(self.styledText) == 0: 56 | return self.offset,[("","")] 57 | if from_token == -1: 58 | from_token = self.offset 59 | 60 | if self.offset > from_token: 61 | # 开始位置已被删除 62 | return self.offset,None 63 | else: 64 | start = from_token - self.offset 65 | if start >= len(self.styledText): 66 | # 开始位置还没数据 67 | return self.offset,None 68 | if length == -1: 69 | end = len(self.styledText) 70 | else: 71 | end = min(len(self.styledText),start + length) 72 | 73 | return self.offset,self.styledText[start:end] 74 | finally: 75 | self.lock.release() 76 | 77 | def getStyledDict(self): 78 | return self.styleDict 79 | 80 | def maintainText(self): 81 | if len(self.styledText) <= self.maxBufferLen: 82 | return 83 | while len(self.styledText) > self.maxBufferLen: 84 | while self.styledText[0][1] != "\n": 85 | self.styledText.pop(0) 86 | self.offset += 1 87 | self.styledText.pop(0) 88 | self.offset += 1 89 | 90 | def putPlainText(self,text : str): 91 | if self.lock.acquire(): 92 | try: 93 | self.styledText.append(( 94 | (self.styleDict[''] , text) 95 | )) 96 | self.maintainText() 97 | return True 98 | finally: 99 | self.lock.release() 100 | 101 | def putStyledText(self,style_name : str, text : str): 102 | if self.lock.acquire(): 103 | try: 104 | if style_name in self.styleDict: 105 | self.styledText.append(( 106 | self.styleDict[style_name], text 107 | )) 108 | self.maintainText() 109 | return True 110 | else: 111 | return False 112 | finally: 113 | self.lock.release() 114 | 115 | def clear(self): 116 | if self.lock.acquire(): 117 | try: 118 | self.offset = 0 119 | self.styledText = [] 120 | finally: 121 | self.lock.release() 122 | 123 | def write(self,message : str): 124 | if self.isLineHead: 125 | self.putStyledText('app',"[" + self.appName + "] ") 126 | self.isLineHead = False 127 | 128 | if "\n" in message: 129 | if message == '\n': 130 | self.putStyledText('message',message) 131 | self.isLineHead = True 132 | else: 133 | lines = message.split("\n") 134 | endBy_n = message[-1] == "\n" 135 | for i in range(len(lines)): 136 | if self.isLineHead: 137 | self.putStyledText('app',"[" + self.appName + "] ") 138 | self.isLineHead = False 139 | self.putStyledText('message',lines[i]) 140 | if i < len(lines) - 1: 141 | self.putStyledText('message',"\n") 142 | self.isLineHead = True 143 | elif endBy_n: 144 | self.putStyledText('message',"\n") 145 | self.isLineHead = True 146 | else: 147 | self.putStyledText('message',message) 148 | 149 | def flush(self): 150 | ... 151 | 152 | def isatty(self): 153 | return True 154 | 155 | def fileno(self): 156 | return 3 # Maybe hazard -------------------------------------------------------------------------------- /DLNest/Common/RunnerBaseTorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from .RunnerBase import RunnerBase 5 | from abc import ABCMeta, abstractmethod 6 | import typing 7 | from torch.optim.lr_scheduler import _LRScheduler 8 | from torch.optim.optimizer import Optimizer 9 | from DLNest.Plugins.Utils.CheckPlugins import checkPlugins,checkDictOutputPlugins 10 | 11 | class ModuleWrapper: 12 | """ 13 | Contain at most two nn.modules, one is noraml torch module and the other one is the DDP shell contain this module 14 | Any attribute accessed by . are searched in the normal module and the __call__ function will call the DDP shell. 15 | Using singleCard to call the module's forward function. 16 | """ 17 | keywords = ["_ModuleWrapper__M","_ModuleWrapper__DDP", "__dict__","singleCard"] 18 | 19 | def __init__(self, module : nn.Module, DDP_wrapper): 20 | self.__M = module 21 | self.__DDP = DDP_wrapper 22 | 23 | def __getattr__(self, name): 24 | if name in ModuleWrapper.keywords: 25 | return super().__getattr__(name) 26 | return self.__M.__getattr__(name) 27 | 28 | def __setattr__(self, name, value): 29 | if name in ModuleWrapper.keywords: 30 | self.__dict__[name] = value 31 | return 32 | return self.__M.__setattr__(name, value) 33 | 34 | def __delattr__(self, name): 35 | return self.__M.__delattr__(name) 36 | 37 | def __getattribute__(self, name): 38 | if name in ModuleWrapper.keywords: 39 | return super().__getattribute__(name) 40 | return self.__M.__getattribute__(name) 41 | 42 | def __call__(self, *args, **kwargs): 43 | if self.__DDP: 44 | return self.__DDP(*args, **kwargs) 45 | else: 46 | return self.__M(*args, **kwargs) 47 | 48 | def __str__(self): 49 | if self.__DDP: 50 | return self.__DDP.__str__() + " in DLNest ModuleWrapper" 51 | else: 52 | return self.__M.__str__() + " in DLNest ModuleWrapper" 53 | 54 | def singleCard(self, *args, **kwargs): 55 | self.__M(*args, **kwargs) 56 | 57 | class RunnerBaseTorch(RunnerBase): 58 | def __init__(self,args : dict, plugins = [], status = None): 59 | super(RunnerBaseTorch,self).__init__(args = args, plugins = plugins, status = status) 60 | self.__modelList = [] 61 | self.__optimizerDict = {} 62 | self.__schedulerDict = {} 63 | 64 | def DDPOperation(self,rank : int): 65 | pass 66 | 67 | def __setattr__(self, name, value): 68 | if isinstance(value, nn.Module): 69 | super().__setattr__(name, self.register(value)) 70 | elif isinstance(value, Optimizer): 71 | super().__setattr__(name, self.__registerOptimizer(value, name)) 72 | elif isinstance(value, _LRScheduler): 73 | super().__setattr__(name, self.__registerLRScheduler(value, name)) 74 | else: 75 | super().__setattr__(name,value) 76 | 77 | def register(self,module : nn.Module, syncBN : bool = False): 78 | assert isinstance(module, nn.Module), "Only nn.Module can be registered. You tried to register a " + module.__class__.__name__ 79 | self.__modelList.append(module) 80 | if self._status.env == "DDP": 81 | if isinstance(module, nn.Module): 82 | model = module.cuda() 83 | if syncBN: 84 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 85 | try: 86 | DDPmodel = nn.parallel.DistributedDataParallel( 87 | model, 88 | device_ids=[self._status.rank], 89 | output_device=self._status.rank, 90 | find_unused_parameters=True 91 | ) 92 | except AssertionError as e: 93 | DDPmodel = None 94 | return ModuleWrapper(model, DDPmodel) 95 | else: 96 | if isinstance(module, nn.Module): 97 | if self._status.env != "CPU": 98 | model = module.cuda() 99 | return ModuleWrapper(model, None) 100 | 101 | def __registerOptimizer(self, optimizer : Optimizer, name): 102 | self.__optimizerDict[name] = optimizer 103 | return optimizer 104 | 105 | def __registerLRScheduler(self, scheduler : _LRScheduler, name): 106 | self.__schedulerDict[name] = scheduler 107 | return scheduler 108 | 109 | def getSaveDict(self): 110 | stateDict = {} 111 | for i in range(len(self.__modelList)): 112 | stateDict[i] = self.__modelList[i].state_dict() 113 | 114 | stateDict["optimizer"] = {} 115 | for name in self.__optimizerDict: 116 | optimizer = self.__optimizerDict[name] 117 | try: 118 | stateDict["optimizer"][name] = optimizer.state_dict() 119 | except Exception as e: 120 | pass # may add some warning 121 | 122 | stateDict["scheduler"] = {} 123 | for name in self.__schedulerDict: 124 | scheduler = self.__schedulerDict[name] 125 | try: 126 | stateDict["scheduler"][name] = scheduler.state_dict() 127 | except Exception as e: 128 | pass # may add some warning 129 | return stateDict 130 | 131 | def loadSaveDict(self,saveDict): 132 | for i in range(len(self.__modelList)): 133 | self.__modelList[i].load_state_dict(saveDict[i]) 134 | 135 | if "optimizer" in saveDict: 136 | for name in saveDict["optimizer"]: 137 | self.__optimizerDict[name].load_state_dict(saveDict["optimizer"][name]) 138 | 139 | if "scheduler" in saveDict: 140 | for name in saveDict["scheduler"]: 141 | self.__schedulerDict[name].load_state_dict(saveDict["scheduler"][name]) 142 | 143 | # to reset the learing rate 144 | self.__schedulerDict[name].step( 145 | self.__schedulerDict[name].last_epoch 146 | ) 147 | 148 | def _reduceSum(self,tensor): 149 | """ 150 | ALPHA VERSION FUNCTION 151 | 152 | reduce tensor if using DDP, 153 | if not using DDP, do nothing. 154 | """ 155 | if self._status.env == "DDP": 156 | dist.reduce(tensor,0) 157 | return tensor 158 | 159 | def _reduceMean(self,tensor): 160 | """ 161 | ALPHA VERSION FUNCTION 162 | 163 | reduce tensor if using DDP, 164 | if not using DDP, do nothing. 165 | """ 166 | if self._status.env == "DDP": 167 | dist.reduce(tensor,0) 168 | if self._status.rank == 0: 169 | tensor = tensor / self._status.worldSize 170 | return tensor 171 | -------------------------------------------------------------------------------- /DLNest/Common/LifeCycleBase.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | try: 3 | from .DatasetBase import DatasetBase 4 | from .RunnerBase import RunnerBase 5 | except ImportError: 6 | from DLNest.Common.DatasetBase import DatasetBase 7 | from DLNest.Common.RunnerBase import RunnerBase 8 | from DLNest.Plugins.Utils.CheckPlugins import checkPlugins 9 | import traceback 10 | from functools import wraps 11 | import logging 12 | 13 | class LifeCycleBase: 14 | def __init__(self,runner : RunnerBase = None,dataset : DatasetBase = None, taskProcess = None, plugins : list = [], status = None): 15 | self.runner = runner 16 | self.dataset = dataset 17 | self.taskProcess = taskProcess 18 | self._plugins = plugins 19 | self._status = status 20 | 21 | # for backward compatibility 22 | @property 23 | def rank(self): 24 | if not "_warned_rank" in dir(self): 25 | self._warned_rank = True 26 | print("LifeCycle.rank is deprecated, please use LifeCycle.status.rank") 27 | return self._status.rank 28 | 29 | def getArgs(self): 30 | return self.taskProcess.task.args 31 | 32 | def BAll(self): 33 | pass 34 | 35 | def BDatasetInit(self): 36 | pass 37 | 38 | def ADatasetInit(self): 39 | pass 40 | 41 | def BModelInit(self): 42 | pass 43 | 44 | def AModelInit(self): 45 | pass 46 | 47 | def BTrain(self): 48 | pass 49 | 50 | def ATrain(self): 51 | pass 52 | 53 | def BOneEpoch(self): 54 | pass 55 | 56 | def AOneEpoch(self): 57 | pass 58 | 59 | def BGetCommand(self): 60 | pass 61 | 62 | def AGetCommand(self,command): 63 | pass 64 | 65 | def BSuspend(self): 66 | pass 67 | 68 | def ASuspend(self): 69 | pass 70 | 71 | def BLoadFromSuspend(self): 72 | pass 73 | 74 | def ALoadFromSuspend(self): 75 | pass 76 | 77 | def BModelOneStep(self): 78 | pass 79 | 80 | def AModelOneStep(self): 81 | pass 82 | 83 | @abstractmethod 84 | def needVisualize(self, epoch : int, iter : int, logdict : dict, args : dict): 85 | return False 86 | 87 | def BVisualize(self): 88 | pass 89 | 90 | def AVisualize(self): 91 | pass 92 | 93 | @abstractmethod 94 | def needValidation(self, epoch : int, logdict : dict, args : dict): 95 | return False 96 | 97 | def BValidation(self): 98 | pass 99 | 100 | def BValidateABatch(self): 101 | pass 102 | 103 | def AValidateABatch(self): 104 | pass 105 | 106 | def BValidationAnalyze(self): 107 | pass 108 | 109 | def AValidationAnalyze(self): 110 | pass 111 | 112 | def AValidation(self): 113 | pass 114 | 115 | def commandLineOutput(self,epoch : int, logdict : dict, args : dict): 116 | print("Epoch #" + str(epoch) + " finished!") 117 | 118 | @abstractmethod 119 | def needSaveModel(self, epoch : int, logdict : dict, args : dict): 120 | return True 121 | 122 | def BSaveModel(self): 123 | pass 124 | 125 | def ASaveModel(self): 126 | pass 127 | 128 | @abstractmethod 129 | def holdThisCheckpoint(self, epoch : int, logdict : dict, args : dict): 130 | return False 131 | 132 | @abstractmethod 133 | def needContinueTrain(self, epoch : int, logdict : dict, args : dict): 134 | return False 135 | 136 | def getSaveDict(self): 137 | return {} 138 | 139 | def loadSaveDict(self,saveDict): 140 | pass 141 | 142 | def AAll(self): 143 | pass 144 | 145 | def trainAborting(self,exception : Exception): 146 | traceback.print_exc() 147 | 148 | @checkPlugins 149 | def _BAll(self): 150 | return self.BAll() 151 | 152 | @checkPlugins 153 | def _BDatasetInit(self): 154 | return self.BDatasetInit() 155 | 156 | @checkPlugins 157 | def _ADatasetInit(self): 158 | return self.ADatasetInit() 159 | 160 | @checkPlugins 161 | def _BModelInit(self): 162 | return self.BModelInit() 163 | 164 | @checkPlugins 165 | def _AModelInit(self): 166 | return self.AModelInit() 167 | 168 | @checkPlugins 169 | def _BTrain(self): 170 | return self.BTrain() 171 | 172 | @checkPlugins 173 | def _ATrain(self): 174 | return self.ATrain() 175 | 176 | @checkPlugins 177 | def _BOneEpoch(self): 178 | return self.BOneEpoch() 179 | 180 | @checkPlugins 181 | def _AOneEpoch(self): 182 | return self.AOneEpoch() 183 | 184 | @checkPlugins 185 | def _BGetCommand(self): 186 | return self.BGetCommand() 187 | 188 | @checkPlugins 189 | def _AGetCommand(self,command): 190 | return self.AGetCommand(command = command) 191 | 192 | @checkPlugins 193 | def _BSuspend(self): 194 | return self.BSuspend() 195 | 196 | @checkPlugins 197 | def _ASuspend(self): 198 | return self.ASuspend() 199 | 200 | @checkPlugins 201 | def _BLoadFromSuspend(self): 202 | return self.BLoadFromSuspend() 203 | 204 | @checkPlugins 205 | def _ALoadFromSuspend(self): 206 | return self.ALoadFromSuspend() 207 | 208 | @checkPlugins 209 | def _BModelOneStep(self): 210 | return self.BModelOneStep() 211 | 212 | @checkPlugins 213 | def _AModelOneStep(self): 214 | return self.AModelOneStep() 215 | 216 | @checkPlugins 217 | def _BVisualize(self): 218 | return self.BVisualize() 219 | 220 | @checkPlugins 221 | def _AVisualize(self): 222 | return self.AVisualize() 223 | 224 | @checkPlugins 225 | def _BValidation(self): 226 | return self.BValidation() 227 | 228 | @checkPlugins 229 | def _BValidateABatch(self): 230 | return self.BValidateABatch() 231 | 232 | @checkPlugins 233 | def _AValidateABatch(self): 234 | return self.AValidateABatch() 235 | 236 | @checkPlugins 237 | def _BValidationAnalyze(self): 238 | return self.BValidationAnalyze() 239 | 240 | @checkPlugins 241 | def _AValidationAnalyze(self): 242 | return self.AValidationAnalyze() 243 | 244 | @checkPlugins 245 | def _AValidation(self): 246 | return self.AValidation() 247 | 248 | @checkPlugins 249 | def _commandLineOutput(self,epoch : int, logdict : dict, args : dict): 250 | return self.commandLineOutput(epoch = epoch, logdict = logdict, args = args) 251 | 252 | @checkPlugins 253 | def _BSaveModel(self): 254 | return self.BSaveModel() 255 | 256 | @checkPlugins 257 | def _ASaveModel(self): 258 | return self.ASaveModel() 259 | 260 | @checkPlugins 261 | def _getSaveDict(self): 262 | return self.getSaveDict() 263 | 264 | @checkPlugins 265 | def _loadSaveDict(self,saveDict): 266 | return self.loadSaveDict(saveDict) 267 | 268 | @checkPlugins 269 | def _AAll(self): 270 | return self.AAll() 271 | 272 | @checkPlugins 273 | def _trainAborting(self,exception : Exception): 274 | return self.trainAborting(exception = exception) -------------------------------------------------------------------------------- /DLNest/ShellClient/Windows/Utils/Completers.py: -------------------------------------------------------------------------------- 1 | from prompt_toolkit.completion import CompleteEvent, Completer, Completion,merge_completers,WordCompleter 2 | from prompt_toolkit.completion.nested import NestedCompleter 3 | from prompt_toolkit.buffer import Buffer 4 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 5 | from prompt_toolkit import PromptSession 6 | from prompt_toolkit.key_binding import KeyBindings 7 | from typing import Callable, Iterable, List, Optional 8 | from prompt_toolkit.document import Document 9 | import os 10 | 11 | class PartPathCompleter(Completer): 12 | """ 13 | Complete for Path variables. 14 | 15 | :param get_paths: Callable which returns a list of directories to look into 16 | when the user enters a relative path. 17 | :param file_filter: Callable which takes a filename and returns whether 18 | this file should show up in the completion. ``None`` 19 | when no filtering has to be done. 20 | :param min_input_len: Don't do autocompletion when the input string is shorter. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | only_directories: bool = False, 26 | get_paths: Optional[Callable[[], List[str]]] = None, 27 | file_filter: Optional[Callable[[str], bool]] = None, 28 | min_input_len: int = 0, 29 | expanduser: bool = False, 30 | ) -> None: 31 | 32 | self.only_directories = only_directories 33 | self.get_paths = get_paths or (lambda: ["."]) 34 | self.file_filter = file_filter or (lambda _: True) 35 | self.min_input_len = min_input_len 36 | self.expanduser = expanduser 37 | 38 | def get_completions( 39 | self, document: Document, complete_event: CompleteEvent 40 | ) -> Iterable[Completion]: 41 | text = document.text_before_cursor 42 | text = text.split(" ")[-1] 43 | 44 | # Complete only when we have at least the minimal input length, 45 | # otherwise, we can too many results and autocompletion will become too 46 | # heavy. 47 | if len(text) < self.min_input_len: 48 | return 49 | 50 | try: 51 | # Do tilde expansion. 52 | if self.expanduser: 53 | text = os.path.expanduser(text) 54 | 55 | # Directories where to look. 56 | dirname = os.path.dirname(text) 57 | if dirname: 58 | directories = [ 59 | os.path.dirname(os.path.join(p, text)) for p in self.get_paths() 60 | ] 61 | else: 62 | directories = self.get_paths() 63 | 64 | # Start of current file. 65 | prefix = os.path.basename(text) 66 | 67 | # Get all filenames. 68 | filenames = [] 69 | for directory in directories: 70 | # Look for matches in this directory. 71 | if os.path.isdir(directory): 72 | for filename in os.listdir(directory): 73 | if filename.startswith(prefix): 74 | filenames.append((directory, filename)) 75 | 76 | # Sort 77 | filenames = sorted(filenames, key=lambda k: k[1]) 78 | 79 | # Yield them. 80 | for directory, filename in filenames: 81 | completion = filename[len(prefix) :] 82 | full_name = os.path.join(directory, filename) 83 | 84 | if os.path.isdir(full_name): 85 | # For directories, add a slash to the filename. 86 | # (We don't add them to the `completion`. Users can type it 87 | # to trigger the autocompletion themselves.) 88 | filename += "/" 89 | elif self.only_directories: 90 | continue 91 | 92 | if not self.file_filter(full_name): 93 | continue 94 | 95 | yield Completion(completion, 0, display=filename) 96 | except OSError: 97 | pass 98 | 99 | class CommandCompleter(Completer): 100 | def __init__(self, rules : dict): 101 | self.rules = rules 102 | 103 | def get_completions( 104 | self, document: Document, complete_event: CompleteEvent 105 | ): 106 | text = document.text_before_cursor.lstrip() 107 | tokens = text.replace("\n"," ").replace("\t"," ").split(" ") 108 | nowRules = self.rules 109 | if len(tokens) <= 1: 110 | completer = WordCompleter( 111 | list(self.rules.keys()) 112 | ) 113 | for c in completer.get_completions(document, complete_event): 114 | yield c 115 | else: 116 | title = tokens[0] 117 | if not title in self.rules: 118 | return 119 | else: 120 | try: 121 | rule_dict = self.rules[title] 122 | if len(tokens) > 2 and tokens[-2] in rule_dict: 123 | if not isinstance(rule_dict[tokens[-2]],Completer): 124 | if not isinstance(rule_dict[tokens[-2]],int): 125 | return 126 | else: 127 | for c in rule_dict[tokens[-2]].get_completions(document, complete_event): 128 | yield c 129 | return 130 | search_list = [] 131 | if rule_dict is None: 132 | return 133 | for key in rule_dict: 134 | if not key in tokens: 135 | search_list.append(key) 136 | completer = WordCompleter( 137 | search_list 138 | ) 139 | for c in completer.get_completions(document, complete_event): 140 | yield c 141 | except Exception as e: 142 | ... 143 | #print(e) 144 | 145 | def getCommandCompleter(): 146 | commands = { 147 | "run" : { 148 | "-c" : PartPathCompleter(), 149 | "-d" : None, 150 | "-f" : PartPathCompleter(), 151 | "-m" : None, 152 | "-ns" : -1, 153 | "-mc" : -1, 154 | "-sd" : -1, 155 | "-DDP" : -1, 156 | "-CPU" : -1 157 | }, 158 | "continue" : { 159 | "-r" : PartPathCompleter(), 160 | "-c" : None, 161 | "-d" : None, 162 | "-m" : None, 163 | "-DDP" : -1, 164 | "-CPU" : -1, 165 | "-mc" : -1 166 | }, 167 | "analyze" : { 168 | "-r" : PartPathCompleter(), 169 | "-s" : PartPathCompleter(), 170 | "-c" : None, 171 | "-m" : None, 172 | "-CPU" : -1 173 | }, 174 | "new" : { 175 | "-d" : PartPathCompleter(), 176 | "-MNIST" : None, 177 | "-p" : None, 178 | }, 179 | "changeDevices" : { 180 | "-d" : None 181 | }, 182 | "addP" : { 183 | "-d" : PartPathCompleter(), 184 | "-p" : None, 185 | "-F" : None, 186 | }, 187 | "changeDelay" : None, 188 | "runExp" : None, 189 | "del" : None, 190 | "exit" : None 191 | } 192 | return CommandCompleter(commands) -------------------------------------------------------------------------------- /DLNest/Information/InfoCenter.py: -------------------------------------------------------------------------------- 1 | try: 2 | from TaskInformation import TaskInformation 3 | from TrainTask import TrainTask 4 | from AnalyzeTask import AnalyzeTask 5 | from DeviceInformation import DeviceInformation 6 | from CPUInformation import CPUInformation 7 | except ImportError: 8 | from DLNest.Information.TaskInformation import TaskInformation 9 | from DLNest.Information.TrainTask import TrainTask 10 | from DLNest.Information.AnalyzeTask import AnalyzeTask 11 | from DLNest.Information.DeviceInformation import DeviceInformation 12 | from DLNest.Information.CPUInformation import CPUInformation 13 | 14 | HAVE_GPU = True 15 | 16 | try: 17 | from GPUInformation import GPUInformation 18 | except Exception: 19 | try: 20 | from .GPUInformation import GPUInformation 21 | except Exception: 22 | HAVE_GPU = False 23 | 24 | if HAVE_GPU: 25 | import pynvml 26 | 27 | import threading 28 | 29 | class Singleton(object): 30 | def __new__(cls, *args, **kwargs): 31 | if not hasattr(cls, '_instance'): 32 | orig = super(Singleton, cls) 33 | cls._instance = orig.__new__(cls, *args, **kwargs) 34 | return cls._instance 35 | 36 | 37 | class InfoCenter(Singleton): 38 | def __init__(self): 39 | if hasattr(self,'tasks'): 40 | return 41 | if HAVE_GPU: 42 | try: 43 | # Init GPUs information 44 | pynvml.nvmlInit() 45 | self.totalGPUsInSystem = pynvml.nvmlDeviceGetCount() 46 | self.devices = [GPUInformation(i) for i in range(self.totalGPUsInSystem)] + [CPUInformation()] 47 | self.availableDevices = [i for i in range(self.totalGPUsInSystem)] + [-1] 48 | except Exception as e: 49 | # using CPU only 50 | self.devices = [CPUInformation()] 51 | self.availableDevices = [0] 52 | else: 53 | # using CPU only 54 | self.devices = [CPUInformation()] 55 | self.availableDevices = [0] 56 | 57 | self.taskLock = threading.Lock() 58 | self.deviceLock = threading.Lock() 59 | self.tasks = [] 60 | 61 | def __getAvailableDevices(self): 62 | """ 63 | return the device information classes of available devices. 64 | """ 65 | return [self.devices[item] for item in self.availableDevices] 66 | 67 | def usingDevicesInformation(self): 68 | """ 69 | Occupy the devices information 70 | """ 71 | if self.deviceLock.acquire(): 72 | return self.__getAvailableDevices() 73 | 74 | def releaseDeviceInformation(self): 75 | self.deviceLock.release() 76 | 77 | def getDevicesInformation(self): 78 | """ 79 | Get the informaton of all devices. 80 | """ 81 | if self.deviceLock.acquire(): 82 | try: 83 | return [item.getDict() for item in self.devices] 84 | finally: 85 | self.deviceLock.release() 86 | 87 | def getAvailableDevicesInformation(self): 88 | """ 89 | Get the information of available devices in dict. 90 | """ 91 | if self.deviceLock.acquire(): 92 | try: 93 | availableDevices = self.__getAvailableDevices() 94 | return [item.getDict() for item in availableDevices] 95 | finally: 96 | self.deviceLock.release() 97 | 98 | def changeDevices(self,newDevicesIDList : list): 99 | """ 100 | Modify available devices 101 | """ 102 | newList = [] 103 | for item in newDevicesIDList: 104 | if item < -1 or item >= len(self.devices): 105 | print("Wrong device " + str(item)) 106 | else: 107 | newList.append(item) 108 | self.availableDevices = newList 109 | return True 110 | 111 | def __checkTasks(self): 112 | """ 113 | Delete all the tasks who are supposed to be running but have no alive process. 114 | """ 115 | newList = [] 116 | for item in self.tasks: 117 | if item.status == "Running": 118 | if item.process == None or not item.process.is_alive(): 119 | continue 120 | else: 121 | newList.append(item) 122 | else: 123 | newList.append(item) 124 | self.tasks = newList 125 | 126 | def getTasksInformation(self): 127 | """ 128 | Get the information of all tasks in dict. 129 | """ 130 | if self.taskLock.acquire(): 131 | try: 132 | self.__checkTasks() 133 | return [item.getDict() for item in self.tasks] 134 | finally: 135 | self.taskLock.release() 136 | 137 | def usingTasks(self): 138 | """ 139 | Occupy the tasks information. 140 | """ 141 | if self.taskLock.acquire(): 142 | return self.tasks 143 | 144 | def releaseTasks(self): 145 | """ 146 | Release the tasks. 147 | """ 148 | self.taskLock.release() 149 | 150 | def addATask(self,task : TaskInformation): 151 | """ 152 | Add a task to the tasks list. 153 | """ 154 | if self.taskLock.acquire(): 155 | try: 156 | self.__checkTasks() 157 | self.tasks.append(task) 158 | finally: 159 | self.taskLock.release() 160 | 161 | def runATask(self,task : TaskInformation): 162 | """ 163 | Run a task to the devices. Tasks are supposed to be in self.task already 164 | """ 165 | if self.deviceLock.acquire(): 166 | try: 167 | for device in task.devices: 168 | task.status = "Running" 169 | self.devices[device].addATask(task) 170 | finally: 171 | self.deviceLock.release() 172 | 173 | def getTaskByID(self,taskID : str): 174 | """ 175 | Get the task information class by a taskID. 176 | """ 177 | if self.taskLock.acquire(): 178 | try: 179 | for item in self.tasks: 180 | if item.ID == taskID: 181 | return item 182 | return None 183 | finally: 184 | self.taskLock.release() 185 | 186 | def delATask(self,taskID : str): 187 | """ 188 | Delete a task by taskID. 189 | """ 190 | if self.taskLock.acquire(): 191 | try: 192 | self.__checkTasks() 193 | newList = [] 194 | for item in self.tasks: 195 | if item.ID == taskID: 196 | if item.process != None: 197 | item.process.terminate() 198 | else: 199 | newList.append(item) 200 | self.tasks = newList 201 | finally: 202 | self.taskLock.release() 203 | 204 | def delAllTask(self): 205 | """ 206 | Delete all task 207 | """ 208 | if self.taskLock.acquire(): 209 | try: 210 | self.__checkTasks() 211 | for task in self.tasks: 212 | if task.process != None: 213 | task.process.terminate() 214 | self.tasks = [] 215 | finally: 216 | self.taskLock.release() 217 | 218 | if __name__ == "__main__": 219 | IC = InfoCenter() 220 | print(IC.getDevicesInformation()) -------------------------------------------------------------------------------- /DLNest/Executor/TrainProcess.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.TaskInformation import TaskInformation 2 | from DLNest.Information.TrainTask import TrainTask 3 | from DLNest.Executor.TaskProcess import TaskProcess 4 | from DLNest.Output.TrainStdout import TrainStdout 5 | 6 | from pathlib import Path 7 | import sys 8 | 9 | import time 10 | import os 11 | 12 | try: 13 | import torch 14 | import torch.distributed as dist 15 | except ImportError: 16 | pass 17 | 18 | class TrainProcess(TaskProcess): 19 | def __init__(self,task : TrainTask,showOnScreen = False): 20 | """ 21 | if showOnScreen, also output to the stdout 22 | """ 23 | super(TrainProcess,self).__init__(task) 24 | self.showOnScreen = showOnScreen 25 | self.commandQueue = task.commandQueue 26 | 27 | def initOutput(self,rank = -1): 28 | """ 29 | redirect the output 30 | """ 31 | os.chdir(self.task.args["root_file_path"]) # Change CWD to the save package 32 | outputFP = self.task.savePackage.getOutputFile(rank) 33 | self.outputDelegate = TrainStdout(outputFP,showOnScreen = self.showOnScreen,originalStdout = sys.stdout) 34 | sys.stdout = self.outputDelegate 35 | sys.stderr = self.outputDelegate 36 | 37 | def __saveModel(self): 38 | holdThisCheckpoint = self.lifeCycle.holdThisCheckpoint(self.finishedEpoch,self.logDict,self.task.args) 39 | self.saveCkpt(holdThisCheckpoint = holdThisCheckpoint) 40 | 41 | def __moveAData(self,data): 42 | try: 43 | if "cuda" in dir(data): 44 | ret_data = data.cuda() 45 | return ret_data 46 | elif "to" in dir(data): 47 | tmp = torch.tensor(1).cuda() 48 | ret_data = data.to(tmp.device) 49 | return ret_data 50 | else: 51 | return data 52 | except Exception as e: 53 | return data 54 | 55 | def __moveData(self,data): 56 | # move data to the proper location 57 | if self.status.env != "CPU": 58 | try: 59 | if isinstance(data,list): 60 | for index in range(len(data)): 61 | data[index] = self.__moveAData(data[index]) 62 | elif isinstance(data, tuple): 63 | ret = [] 64 | for index in range(len(data)): 65 | ret.append(self.__moveAData(data[index])) 66 | data = tuple(ret) 67 | elif isinstance(data,dict): 68 | for key in data: 69 | data[key] = self.__moveAData(data[key]) 70 | else: 71 | data = self.__moveAData(self, data) 72 | except Exception as e: 73 | pass 74 | return data 75 | 76 | def __train_an_epoch(self,start_epoch : int): 77 | nowEpoch = start_epoch 78 | for _iter,data in enumerate(self.trainLoader): 79 | # run one step 80 | self.status._RunningStatus__iter = _iter 81 | if self.lifeCycle._BModelOneStep() != "Skip": 82 | data = self.__moveData(data) 83 | self.runner._runOneStep(data,self.logDict,_iter,nowEpoch) 84 | self.lifeCycle._AModelOneStep() 85 | 86 | # visualize 87 | if self.lifeCycle.needVisualize(nowEpoch,_iter,self.logDict,self.task.args): 88 | if self.status.env == "DDP": 89 | if self.status.rank == 0 and self.lifeCycle._BVisualize() != "Skip": 90 | self.runner._visualize(epoch = nowEpoch, iter = _iter, log = self.logDict) 91 | else: 92 | if self.lifeCycle._BVisualize() != "Skip": 93 | self.runner._visualize(epoch = nowEpoch, iter = _iter, log = self.logDict) 94 | self.lifeCycle._AVisualize() 95 | 96 | 97 | def __validate(self): 98 | self.runner._validationInit() 99 | for _iter,data in enumerate(self.valLoader): 100 | self.status._RunningStatus__iter = _iter 101 | if self.lifeCycle._BValidateABatch() != "Skip": 102 | data = self.__moveData(data) 103 | self.runner._validateABatch(data,_iter) 104 | self.lifeCycle._AValidateABatch() 105 | 106 | if self.lifeCycle._BValidationAnalyze() != "Skip": 107 | self.runner._validationAnalyze(self.logDict) 108 | self.lifeCycle._AValidationAnalyze() 109 | 110 | def mainLoop(self): 111 | try: 112 | nowEpoch = self.finishedEpoch + 1 113 | self.status._RunningStatus__epoch = nowEpoch 114 | if self.lifeCycle._BTrain() == "Skip": 115 | self.lifeCycle._ATrain() 116 | return 117 | while True: 118 | if self.lifeCycle._BOneEpoch() != "Skip": 119 | # Training! 120 | self.status.startTraining() 121 | 122 | if self.status.env == "DDP": 123 | self.trainLoader.sampler.set_epoch(nowEpoch) 124 | 125 | self.__train_an_epoch(nowEpoch) 126 | 127 | if self.status.env == "DDP": 128 | dist.barrier() # Sync before validation 129 | 130 | self.finishedEpoch = nowEpoch 131 | # output in command Line 132 | self.lifeCycle._commandLineOutput(self.finishedEpoch,self.logDict,self.task.args) 133 | 134 | # validation 135 | if self.lifeCycle.needValidation(self.finishedEpoch,self.logDict,self.task.args): 136 | if self.lifeCycle._BValidation() != "Skip": 137 | self.status.startValidating() 138 | if "validate" in dir(self.runner): 139 | self.runner._validate(self.valLoader,self.logDict) 140 | else: 141 | self.__validate() 142 | self.lifeCycle._AValidation() 143 | 144 | # start other operations 145 | self.status.startWaiting() 146 | 147 | if self.status.env == "DDP": 148 | dist.barrier() # Sync before saving 149 | 150 | #save checkpoint 151 | if self.status.env == "DDP": 152 | if self.status.rank == 0 and self.lifeCycle._BSaveModel() != "Skip": 153 | if self.lifeCycle.needSaveModel(self.finishedEpoch,self.logDict,self.task.args): 154 | self.__saveModel() 155 | else: 156 | if self.lifeCycle._BSaveModel() != "Skip": 157 | if self.lifeCycle.needSaveModel(self.finishedEpoch,self.logDict,self.task.args): 158 | self.__saveModel() 159 | self.lifeCycle._ASaveModel() 160 | 161 | self.lifeCycle._AOneEpoch() 162 | # break decision 163 | if self.lifeCycle.needContinueTrain(self.finishedEpoch,self.logDict,self.task.args): 164 | nowEpoch = self.finishedEpoch + 1 165 | self.status._RunningStatus__epoch = nowEpoch 166 | else: 167 | break 168 | except (Exception,SystemExit) as e: 169 | self.lifeCycle._trainAborting(e) 170 | else: 171 | # After Train 172 | self.lifeCycle._ATrain() 173 | 174 | def loadCkpt(self): 175 | super().loadCkpt() 176 | if self.task.loadCkpt: 177 | if self.task.checkpointID != -1: 178 | # -1 means the last one,which is default option. != -1 needs to set the ckptID rather than default 179 | self.task.savePackage.setCkptID(self.task.checkpointID + 1) 180 | 181 | if __name__ == "__main__": 182 | TT = TrainTask.fromConfigFile("/root/code/DLNestTest/root_config.json",devices = [0,1,2,3],noSave = True,DDP = True) 183 | # TT = TrainTask.fromRecord("/root/code/DLNestTest/Saves/2021-03-29_16-04-46_146",checkpointID = 5,devices = [0,1,2,3],DDP = False) 184 | TP = TrainProcess(TT,True) 185 | TP.start() 186 | try: 187 | while True: 188 | pass 189 | except KeyboardInterrupt: 190 | TP.terminate() 191 | TP.join() 192 | print("terminate") 193 | -------------------------------------------------------------------------------- /DLNest/ShellClient/Client.py: -------------------------------------------------------------------------------- 1 | from prompt_toolkit import Application 2 | from prompt_toolkit.layout.containers import VSplit,HSplit 3 | from prompt_toolkit.layout.layout import Layout 4 | from prompt_toolkit.key_binding import KeyBindings 5 | from prompt_toolkit.styles import Style 6 | 7 | import json 8 | from pathlib import Path 9 | import argparse 10 | import os 11 | 12 | from DLNest.ShellClient.Windows.DevicesInfoShower import DevicesInfoShower 13 | from DLNest.ShellClient.Windows.CommandInput import CommandInput 14 | from DLNest.ShellClient.Windows.ResultsOutput import ResultsOutput,AnalyzeOutput 15 | from DLNest.ShellClient.Windows.TaskInfoShower import TaskInfoShower 16 | from DLNest.ShellClient.Communicator import Communicator 17 | 18 | class Client: 19 | def __init__(self, url : str = "127.0.0.1", port : int = "9999"): 20 | self.communicator = Communicator(url,port) 21 | 22 | self.CMDIN = CommandInput(title="DLNest Command Line(F1)",onAccept=self.onCommandAccept) 23 | self.w1 = self.CMDIN.getWindow() 24 | 25 | self.DLOutput = ResultsOutput(routineTask=self.routineTaskDLOutput,title = "DLNest Output (F2)",style="class:dlnest_output") 26 | self.w2 = self.DLOutput.getWindow() 27 | 28 | self.ANOutput = AnalyzeOutput(routineTask=self.routineTaskANOutput,title = "Analyzer Output (F3)",style="class:analyzer_output") 29 | self.w3 = self.ANOutput.getWindow() 30 | self.analyzeTaskID = "" 31 | 32 | self.TaskInfo = TaskInfoShower(routineTask = self.routineTaskInfo,title = "Tasks (F4)") 33 | self.w4 = self.TaskInfo.getWindow() 34 | 35 | self.DevicesInfo = DevicesInfoShower(routineTask = self.routineTaskDevices, title = "Devices (F5)") 36 | self.w5 = self.DevicesInfo.getWindow() 37 | 38 | self.container_fat = HSplit([ 39 | self.w1, 40 | VSplit([self.w2,self.w3]), 41 | VSplit([self.w4,self.w5]) 42 | ]) 43 | self.container_tall = HSplit([ 44 | self.w1, 45 | self.w2, 46 | self.w3, 47 | self.w4, 48 | self.w5 49 | ]) 50 | 51 | self.kb = KeyBindings() 52 | @self.kb.add('c-c') 53 | def exit_(event): 54 | event.app.exit() 55 | 56 | @self.kb.add('f1') 57 | def focus1(event): 58 | event.app.layout.focus(self.w1) 59 | 60 | @self.kb.add('f2') 61 | def focus2(event): 62 | event.app.layout.focus(self.w2) 63 | 64 | @self.kb.add('f3') 65 | def focus3(event): 66 | event.app.layout.focus(self.w3) 67 | 68 | @self.kb.add('f4') 69 | def focus4(event): 70 | event.app.layout.focus(self.w4) 71 | 72 | @self.kb.add('f5') 73 | def focus5(event): 74 | event.app.layout.focus(self.w5) 75 | 76 | self.style = Style.from_dict({ 77 | "frame.border" : "fg:#ffb6c1", 78 | "frame.title" : "fg:#1ef0ff", 79 | "command_frame" : "bg:#008b8b", 80 | "dlnest_output" : "bg:#451a4a", 81 | "analyzer_output" : "bg:#451a4a", 82 | "analyzer_info_label" : "bg:#da70d6", 83 | "analyzer_info_text1" : "bg:#3f3f00", 84 | "analyzer_info_text2" : "bg:#ff00ff", 85 | "running_task_status" : "bg:#a01010 bold", 86 | "running_task_id" : "bg:#303030", 87 | "running_task_gpu" : "bg:#556b2f", 88 | "running_task_des" : "bg:#c71585", 89 | "running_task_time" : "bg:#2e3b37", 90 | "pending_task_status" : "bg:#1010a0 bold", 91 | "pending_task_id" : "bg:#303030", 92 | "pending_task_gpu" : "bg:#556b2f", 93 | "pending_task_des" : "bg:#c71585", 94 | "pending_task_time" : "bg:#2e3b37", 95 | "suspend_task_status" : "bg:#10a010 bold", 96 | "suspend_task_id" : "bg:#303030", 97 | "suspend_task_gpu" : "bg:#556b2f", 98 | "suspend_task_des" : "bg:#c71585", 99 | "suspend_task_time" : "bg:#2e3b37", 100 | "task_info_shower" : "bg:#008bc0", 101 | "devices_info_shower" : "bg:#008bc0", 102 | "devices_id" : "bg:#303030", 103 | "devices_status_valid" : "bg:#3cb371 bold", 104 | "devices_status_break" : "bg:#a01010 bold", 105 | "devices_free_memory" : "bg:#556b2f", 106 | "devices_tasks" : "bg:#c71585" 107 | }) 108 | 109 | self.layout = Layout(self.container_fat,focused_element=self.w1) 110 | self.app = Application(key_bindings=self.kb, layout=self.layout, full_screen=True,style=self.style) 111 | self.app._on_resize = self.on_resize 112 | 113 | def on_resize(self): 114 | cols, rows = os.get_terminal_size(0) 115 | focused_element = self.layout.current_window 116 | if cols >= 2 * rows: # fat 117 | self.app.layout = Layout(self.container_fat,focused_element=focused_element) 118 | else: # tall 119 | self.app.layout = Layout(self.container_tall,focused_element=focused_element) 120 | 121 | self.app.renderer.erase(leave_alternate_screen=False) 122 | self.app._request_absolute_cursor_position() 123 | self.app._redraw() 124 | 125 | def getApp(self): 126 | return self.app 127 | 128 | def onCommandAccept(self,s : str): 129 | commandWordList = s.split(" ") 130 | while "" in commandWordList: 131 | commandWordList.remove("") 132 | 133 | # no command 134 | if len(commandWordList) <= 0: 135 | return 136 | 137 | if commandWordList[0] == "watch": 138 | self.analyzeTaskID = commandWordList[1] 139 | elif commandWordList[0] == "withdraw": 140 | self.analyzeTaskID = "" 141 | 142 | if commandWordList[0] == "runExp": 143 | if len(commandWordList) != 3: 144 | if self.analyzeTaskID != "": 145 | commandWordList = [commandWordList[0],self.analyzeTaskID,commandWordList[1]] 146 | else: 147 | return 148 | 149 | ret = self.communicator.giveACommand(commandWordList) 150 | 151 | if commandWordList[0] == "del": 152 | if ret["status"] == "success" and commandWordList[1] == self.analyzeTaskID: 153 | self.analyzeTaskID = "" 154 | 155 | if "exit" in ret: 156 | self.app.exit() 157 | 158 | def routineTaskDLOutput(self, obj): 159 | #for buffer fresh 160 | if not hasattr(obj,"_count_"): 161 | obj._count_ = 0 162 | 163 | outStyledDict = self.communicator.giveACommand(["showDL","-s"]) 164 | outPlainDict = self.communicator.giveACommand(["showDL"]) 165 | if "text" in outStyledDict and "text" in outPlainDict: 166 | try: 167 | obj.lexer.styled_text = outStyledDict["text"] 168 | obj.shower.text = outPlainDict["text"] 169 | except Exception as e: 170 | pass 171 | 172 | def routineTaskANOutput(self, obj): 173 | #for buffer fresh 174 | if not hasattr(obj,"_count_"): 175 | obj._count_ = 0 176 | 177 | if self.analyzeTaskID == "": 178 | obj.lexer.styled_text = [] 179 | obj.shower.text = "" 180 | obj.infoText.text = [("","No valid analyzer task is running")] 181 | obj.infoWindow.width = 33 182 | return 183 | 184 | outStyledDict = self.communicator.giveACommand(["showAN","-t",self.analyzeTaskID,"-s"]) 185 | outPlainDict = self.communicator.giveACommand(["showAN","-t",self.analyzeTaskID]) 186 | if "text" in outStyledDict and "text" in outPlainDict: 187 | try: 188 | obj.lexer.styled_text = outStyledDict["text"] 189 | obj.shower.text = outPlainDict["text"] 190 | obj.infoText.text = [("class:analyzer_info_text1",self.analyzeTaskID)] 191 | obj.infoWindow.width = len(self.analyzeTaskID) 192 | except Exception as e: 193 | pass 194 | else: 195 | self.analyzeTaskID = "" 196 | 197 | def routineTaskInfo(self,obj): 198 | # for buffer fresh 199 | if not hasattr(obj,"_count_"): 200 | obj._count_ = 0 201 | 202 | r = self.communicator.giveACommand(["showTask"]) 203 | if r["status"] != "success": 204 | obj.lexer.taskInfo = [] 205 | obj.shower.text = obj.lexer.get_text() 206 | return 207 | taskInfo = r["info"] 208 | try: 209 | obj.lexer.taskInfo = taskInfo 210 | obj.shower.text = obj.lexer.get_text() 211 | except Exception as e: 212 | pass 213 | 214 | def routineTaskDevices(self, obj): 215 | # for buffer fresh 216 | if not hasattr(obj,"_count_"): 217 | obj._count_ = 0 218 | 219 | r = self.communicator.giveACommand(["showDevice"]) 220 | if r["status"] != "success": 221 | obj.lexer.devicesInfo = [] 222 | obj.shower.text = obj.lexer.get_text() 223 | return 224 | obj.lexer.devicesInfo = r["info"] 225 | try: 226 | obj.shower.text = obj.lexer.get_text() 227 | except Exception as e: 228 | pass 229 | 230 | def startClient(): 231 | parser = argparse.ArgumentParser() 232 | parser.add_argument("-u",type=str, default="127.0.0.1",help="DLNest server address") 233 | parser.add_argument("-p",type=int, default=9999, help = "DLNest server port") 234 | args=parser.parse_args() 235 | client = Client(url = args.u, port = args.p) 236 | client.getApp().run() -------------------------------------------------------------------------------- /DLNest/Local.py: -------------------------------------------------------------------------------- 1 | from DLNest.Operations.Analyze import analyze 2 | from DLNest.Operations.ChangeDelay import changeDelay 3 | from DLNest.Operations.ChangeDevices import changeDevices 4 | from DLNest.Operations.ChangeMaxTaskPerDevice import changeMaxTaskPerDevice 5 | from DLNest.Operations.ContinueTrain import continueTrain 6 | from DLNest.Operations.DelATask import delATask 7 | from DLNest.Operations.GetAnalyzeOutput import getAnalyzeOutput 8 | from DLNest.Operations.GetDevicesInformation import getDevicesInformation 9 | from DLNest.Operations.GetTasksInformation import getTasksInformation 10 | from DLNest.Operations.New import new 11 | from DLNest.Operations.Run import run 12 | from DLNest.Operations.RunExp import runExp 13 | from DLNest.Operations.SafeExit import safeExit 14 | from DLNest.Operations.UsePlugin import usePlugins 15 | 16 | import argparse 17 | from prompt_toolkit import PromptSession,HTML 18 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 19 | import traceback 20 | 21 | class Arguments: 22 | def __init__(self,desc : str = ""): 23 | self._parser = argparse.ArgumentParser(description=desc) 24 | 25 | def parser(self): 26 | return self._parser 27 | 28 | class TrainArguments(Arguments): 29 | def __init__(self): 30 | super(TrainArguments, self).__init__(desc="Arguments for DLNest task.") 31 | 32 | self._parser.add_argument("-c",type=str, help="root configuration json file for this task.",required = True) 33 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)") 34 | self._parser.add_argument("-f",type=str, default = "", help="frequently changing configuration json file for this task.(default:None)") 35 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)") 36 | self._parser.add_argument("-ns",action='store_true', help="Set to save to the NOSAVE dir.") 37 | self._parser.add_argument("-mc",action='store_true',help="Set to use multi card.") 38 | self._parser.add_argument("-sd",action='store_true',help="Set to use description as the save dir name.(coverd by ns)") 39 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.") 40 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.") 41 | 42 | class ProjectArguments(Arguments): 43 | def __init__(self): 44 | super(ProjectArguments, self).__init__(desc="Arguments for create a DLNest project.") 45 | 46 | self._parser.add_argument("-d",type=str, help="Path to the directory you want to create the project.", required = True) 47 | self._parser.add_argument("-MNIST",action='store_true', help="Set to new a project with MNIST task.") 48 | self._parser.add_argument("-p", type = str, nargs='+',default=[], help = "Set plugins need to be used.") 49 | 50 | class AnalyzeArguments(Arguments): 51 | def __init__(self): 52 | super(AnalyzeArguments, self).__init__(desc="Arguments for an Analyzer") 53 | 54 | self._parser.add_argument("-r",type=str, help = "path to the model record directory.", required = True) 55 | self._parser.add_argument("-s",type=str, default = "", help = "path to the analyze scripts.") 56 | self._parser.add_argument("-c",type=int, default = -1, help = "which epoch you want the model to load.(int)") 57 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)") 58 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.") 59 | 60 | class ContinueArguments(Arguments): 61 | def __init__(self): 62 | super(ContinueArguments, self).__init__(desc="Arguments for an Analyzer") 63 | 64 | self._parser.add_argument("-r",type=str, help = "path to the model record directory.", required = True) 65 | self._parser.add_argument("-c",type=int, default = -1, help = "which epoch you want the model to load.(int)") 66 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)") 67 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)") 68 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.") 69 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.") 70 | self._parser.add_argument("-mc",action='store_true',help="Set to use multi card.") 71 | 72 | class DeviceChangeArguments(Arguments): 73 | def __init__(self): 74 | super(DeviceChangeArguments, self).__init__(desc="Arguments for change valid cards.") 75 | self._parser.add_argument("-d",type=int, nargs='+', help='valid devices', required = True) 76 | 77 | class AddPluginsArguents(Arguments): 78 | def __init__(self): 79 | super(AddPluginsArguents, self).__init__(desc="Arguments for add plugins.") 80 | 81 | self._parser.add_argument("-d", type=str, help="Path to the directory you want to create the project.", required = True) 82 | self._parser.add_argument("-p", type = str, nargs='+', help = "Set plugins need to be used.",required = True) 83 | self._parser.add_argument("-F", action='store_true', help="Set to use full config") 84 | 85 | class DLNestLocal: 86 | def __init__(self): 87 | self.trainArgParser = TrainArguments() 88 | self.continueArgParser = ContinueArguments() 89 | self.projectArgParser = ProjectArguments() 90 | self.analyzeArgParser = AnalyzeArguments() 91 | self.deviceChangeArgParser = DeviceChangeArguments() 92 | self.addPluginsArgParser = AddPluginsArguents() 93 | 94 | def runTrain(self,commandWordList : list): 95 | args,otherArgs = self.trainArgParser.parser().parse_known_args(commandWordList[1:]) 96 | run( 97 | configPath = args.c, 98 | freqPath = args.f, 99 | description = args.d, 100 | memoryConsumption = args.m, 101 | CPU = args.CPU, 102 | DDP = args.DDP, 103 | multiGPU = args.mc, 104 | noSave = args.ns, 105 | useDescriptionToSave = args.sd, 106 | otherArgs = {} 107 | ) 108 | 109 | def newProject(self,commandWordList : list): 110 | args,otherArgs = self.projectArgParser.parser().parse_known_args(commandWordList[1:]) 111 | new( 112 | targetDir = args.d, 113 | MNIST = args.MNIST, 114 | pluginsName = args.p 115 | ) 116 | 117 | def runAnalyze(self,commandWordList : list): 118 | args,otherArgs = self.analyzeArgParser.parser().parse_known_args(commandWordList[1:]) 119 | analyze( 120 | recordPath = args.r, 121 | scriptPath = args.s, 122 | checkpointID = args.c, 123 | CPU = args.CPU, 124 | memoryConsumption = args.m, 125 | otherArgs = {} 126 | ) 127 | 128 | def continueTrain(self,commandWordList : list): 129 | args,otherArgs = self.continueArgParser.parser().parse_known_args(commandWordList[1:]) 130 | continueTrain( 131 | recordPath = args.r, 132 | checkpointID = args.c, 133 | memoryConsumption = args.m, 134 | CPU = args.CPU, 135 | DDP = args.DDP, 136 | multiGPU = args.mc, 137 | description = args.d, 138 | otherArgs = {} 139 | ) 140 | 141 | def changeDevices(self,commandWordList : list): 142 | args,otherArgs = self.deviceChangeArgParser.parser().parse_known_args(commandWordList[1:]) 143 | changeDevices(args.d) 144 | 145 | def runExp(self,commandWordList : list): 146 | runExp(commandWordList[1],commandWordList[2]) 147 | 148 | def addPlugin(self, commandWordList : list): 149 | args, otherArgs = self.addPluginsArgParser.parser().parse_known_args(commandWordList[1:]) 150 | usePlugins( 151 | targetDir = args.d, 152 | pluginsName = args.p, 153 | full = args.F 154 | ) 155 | 156 | def run(self): 157 | self.session = PromptSession(auto_suggest = AutoSuggestFromHistory()) 158 | while True: 159 | try: 160 | command = self.session.prompt(HTML("DLNest>>")) 161 | commandWordList = command.strip().split(' ') 162 | if commandWordList[0] == "run": 163 | self.runTrain(commandWordList) 164 | elif commandWordList[0] == "continue": 165 | self.continueTrain(commandWordList) 166 | elif commandWordList[0] == "new": 167 | self.newProject(commandWordList) 168 | elif commandWordList[0] == "analyze": 169 | self.runAnalyze(commandWordList) 170 | elif commandWordList[0] == "runExp": 171 | self.runExp(commandWordList) 172 | elif commandWordList[0] == "del": 173 | delATask(commandWordList[1]) 174 | elif commandWordList[0] == "showAN": 175 | print(getAnalyzeOutput(commandWordList[1],False)[1]) 176 | elif commandWordList[0] == "showTask": 177 | print(getTasksInformation()) 178 | elif commandWordList[0] == "showDevice": 179 | print(getDevicesInformation()) 180 | elif commandWordList[0] == 'changeDevices': 181 | self.changeDevices(commandWordList) 182 | elif commandWordList[0] == "addP": 183 | self.addPlugin(commandWordList) 184 | elif commandWordList[0] == "exit": 185 | safeExit() 186 | exit(0) 187 | else: 188 | print("Wrong command") 189 | except KeyboardInterrupt: 190 | safeExit() 191 | exit(0) 192 | except Exception as e: 193 | s = traceback.format_exc() 194 | listS = s.split("\n")[:-1] 195 | s = "\n".join(listS[-3:]) 196 | print(s) 197 | 198 | if __name__ == "__main__": 199 | import sys 200 | if sys.path[0] != '': 201 | sys.path[0] = '' 202 | main = DLNestLocal() 203 | main.run() -------------------------------------------------------------------------------- /DLNest/Executor/TaskProcess.py: -------------------------------------------------------------------------------- 1 | from DLNest.Information.TaskInformation import TaskInformation 2 | 3 | import multiprocessing 4 | import os 5 | import sys 6 | from multiprocessing import Process 7 | import importlib 8 | import shutil 9 | import json 10 | import random 11 | from pathlib import Path 12 | import numpy as np 13 | 14 | from abc import ABCMeta, abstractmethod 15 | try: 16 | import torch 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | import torch.distributed as dist 19 | import torch.multiprocessing as mp 20 | USINGTORCH = True 21 | except ImportError: 22 | USINGTORCH = False 23 | 24 | from DLNest.Common.RunningStatus import RunningStatus 25 | 26 | class TaskProcess(Process): 27 | def __init__(self,task : TaskInformation): 28 | super(TaskProcess,self).__init__() 29 | self.task = task 30 | self.status = RunningStatus() 31 | self.finishedEpoch = -1 32 | self.plugins = [] 33 | 34 | def __loadAModule(self,filePath : Path,name : str): 35 | if not filePath.is_absolute(): 36 | filePath = Path(self.task.args["root_file_path"]) / filePath 37 | spec = importlib.util.spec_from_file_location( 38 | name, 39 | filePath 40 | ) 41 | module = importlib.util.module_from_spec(spec) 42 | dirName = str(filePath.parent) 43 | if not dirName in sys.path: 44 | sys.path.append(dirName) 45 | spec.loader.exec_module(module) 46 | return module 47 | 48 | def __loadLifeCycle(self): 49 | self.lifeCycleModule = self.__loadAModule(Path(self.task.args["life_cycle_file_path"]),"LifeCycle") 50 | 51 | def __initLifeCycle(self,rank : int = -1): 52 | lifeCycleName = self.task.args['life_cycle_name'] 53 | if lifeCycleName in dir(self.lifeCycleModule): 54 | self.lifeCycle = self.lifeCycleModule.__getattribute__(lifeCycleName)( 55 | taskProcess = self, 56 | plugins = self.plugins, 57 | status = self.status 58 | ) 59 | if self.task.loadCkpt: 60 | self.lifeCycle._loadSaveDict(self.stateDict["life_cycle"]) 61 | else: 62 | raise Exception("Cannot find lifeCycle class") 63 | 64 | def __loadAPlugin(self,pluginName : str): 65 | pluginPath = Path(__file__).parent.parent / "Plugins" / (pluginName + '.py') 66 | tmpPath = Path(pluginName) 67 | if tmpPath.is_absolute(): 68 | pluginPath = tmpPath 69 | pluginName = tmpPath.name 70 | pluginModule = self.__loadAModule(filePath = pluginPath,name = pluginName) 71 | pluginClass = pluginModule.__getattribute__("DLNestPlugin") 72 | pluginClass._status = self.status 73 | return pluginClass 74 | 75 | def __loadPlugins(self): 76 | if not "plugins" in self.task.args: 77 | return [] 78 | pluginNames = self.task.args["plugins"] 79 | pluginList = [] 80 | for name in pluginNames: 81 | pluginList.append(self.__loadAPlugin(name)) 82 | self.plugins = pluginList 83 | 84 | def __loadRunnerAndDataset(self): 85 | runnerPath = Path(self.task.args["runner_file_path"] if "runner_file_path" in self.task.args else self.task.args["model_file_path"]) # need to be deprecated 86 | datasetPath = Path(self.task.args["dataset_file_path"]) 87 | self.runnerModule = self.__loadAModule(runnerPath,runnerPath.stem) 88 | self.datasetModule = self.__loadAModule(datasetPath,datasetPath.stem) 89 | sys.modules[datasetPath.stem] = self.datasetModule # Why? 90 | 91 | def __initDataset(self): 92 | datasetName = self.task.args['dataset_name'] 93 | if datasetName in dir(self.datasetModule): 94 | datasetClass = self.datasetModule.__getattribute__(datasetName) 95 | self.dataset = datasetClass( 96 | args = self.task.args, 97 | plugins = self.plugins, 98 | status = self.status 99 | ) 100 | self.datasetInfo,self.trainLoader,self.valLoader = self.dataset._datasetInit(self.task.args) 101 | # load from ckpt is needed. 102 | if self.task.loadCkpt: 103 | self.dataset._loadSaveDict(self.stateDict["dataset"]) 104 | else: 105 | raise Exception("Cannot find dataset class") 106 | 107 | def __initRunner(self): 108 | runnerName = self.task.args['runner_name'] if "runner_name" in self.task.args else self.task.args["model_name"] # need to be deprecated 109 | if runnerName in dir(self.runnerModule): 110 | runnerClass = self.runnerModule.__getattribute__(runnerName) 111 | self.runner = runnerClass( 112 | args = self.task.args, 113 | plugins = self.plugins, 114 | status = self.status 115 | ) 116 | self.runner._runnerInit(self.task.args,self.datasetInfo) 117 | if self.status.env != "DDP": 118 | self.runner._initOptimizer() 119 | else: 120 | self.runner.DDPOperation(rank = self.status.rank) 121 | self.runner._initOptimizer() 122 | self.runner.afterDDP(rank = self.status.rank) 123 | # load from ckpt is needed. 124 | if self.task.loadCkpt: 125 | self.runner._loadSaveDict(self.stateDict["runner"] if "runner" in self.stateDict else self.stateDict["model"]) 126 | else: 127 | # if load from ckpt, logDict has been loaded in self.loadCkpt 128 | self.logDict = self.runner._initLog() 129 | 130 | def checkDeviceEnviroment(self): 131 | """ 132 | make correct environ params and 133 | return "CPU","GPU","GPUs","DDP" 134 | """ 135 | assert len(self.task.devices) > 0 136 | if self.task.devices[0] == -1: 137 | return "CPU" 138 | 139 | ids = [str(item) for item in self.task.devices] 140 | assert not ("-1" in ids) #no CPU in the devices list. Preventing world size error 141 | 142 | if self.task.DDP: 143 | assert USINGTORCH 144 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(ids) 145 | self.deviceNum = len(ids) 146 | return "DDP" 147 | elif self.task.multiGPU: 148 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(ids) 149 | return "GPUs" 150 | else: 151 | torch.cuda.set_device(self.task.devices[0]) 152 | os.environ["CUDA_VISIBLE_DEVICES"] = ids[0] 153 | return "GPU" 154 | 155 | def setupSeed(self): 156 | seed = self.seed 157 | np.random.seed(seed) 158 | random.seed(seed) 159 | if USINGTORCH: 160 | torch.manual_seed(seed) 161 | torch.cuda.manual_seed_all(seed) 162 | torch.backends.cudnn.deterministic = True 163 | 164 | def initBeforeDDP(self): 165 | self.seed = random.randint(0,2147483647) 166 | 167 | def initAfterDDP(self,rank,worldSize): 168 | os.environ["MASTER_ADDR"] = self.task.address 169 | os.environ["MASTER_PORT"] = self.task.port 170 | torch.cuda.set_device(rank) 171 | dist.init_process_group("nccl", rank = rank, world_size = worldSize) 172 | self.status._RunningStatus__rank = rank 173 | self.status._RunningStatus__worldSize = worldSize 174 | self.setupSeed() 175 | 176 | def runDDP(self,rank): 177 | self.initAfterDDP(rank,self.deviceNum) 178 | self.initOutput(rank = rank) 179 | self.loadCkpt() 180 | if self.initModules(rank = rank) != "Skip": 181 | dist.barrier() # start together 182 | self.mainLoop() 183 | 184 | self.lifeCycle._AAll() 185 | dist.destroy_process_group() 186 | exit(0) 187 | 188 | @abstractmethod 189 | def mainLoop(self): 190 | """ 191 | needs to be override by other subprocess. 192 | """ 193 | pass 194 | 195 | @abstractmethod 196 | def initOutput(self,rank = -1): 197 | """ 198 | needs to be override by other subprocess. 199 | """ 200 | pass 201 | 202 | def loadCkpt(self): 203 | """ 204 | If task have ckpt to load, load it to the pointed device. 205 | """ 206 | deviceStr = "" 207 | if self.task.devices[0] == -1: 208 | deviceStr = "cpu" 209 | else: 210 | deviceStr = "cuda" 211 | if self.task.loadCkpt: 212 | self.stateDict = self.task.savePackage.getStateDict(self.task.checkpointID,deviceStr) 213 | self.logDict = self.stateDict["log_dict"] 214 | self.finishedEpoch = self.stateDict["finished_epoch"] 215 | 216 | def saveCkpt(self,otherDict : dict = {}, holdThisCheckpoint = False): 217 | """ 218 | Save the states of life cycle, dataset and runner. 219 | """ 220 | stateDict = {} 221 | stateDict["life_cycle"] = self.lifeCycle._getSaveDict() 222 | stateDict["dataset"] = self.dataset._getSaveDict() 223 | stateDict["runner"] = self.runner._getSaveDict() 224 | stateDict["log_dict"] = self.logDict 225 | stateDict["finished_epoch"] = self.finishedEpoch 226 | for key in otherDict: 227 | stateDict[key] = otherDict[key] 228 | self.task.savePackage.saveACheckpoint(stateDict,holdThisCheckpoint=holdThisCheckpoint) 229 | 230 | def initModules(self,rank : int = -1): 231 | """ 232 | Init all modules 233 | """ 234 | self.__loadLifeCycle() 235 | self.__loadPlugins() 236 | self.__initLifeCycle(rank = rank) 237 | 238 | self.__loadRunnerAndDataset() 239 | 240 | if self.lifeCycle._BAll() != "Skip": 241 | if self.lifeCycle._BDatasetInit() != "Skip": 242 | self.__initDataset() 243 | self.lifeCycle.dataset = self.dataset 244 | self.lifeCycle._ADatasetInit() 245 | 246 | if self.lifeCycle._BModelInit() != "Skip": 247 | self.__initRunner() 248 | self.lifeCycle.runner = self.runner 249 | self.lifeCycle._AModelInit() 250 | return 251 | else: 252 | return "Skip" 253 | 254 | def run(self): 255 | self.status._RunningStatus__env = self.checkDeviceEnviroment() 256 | if self.status.env == "DDP": 257 | # run DDP 258 | self.ppid = os.getpid() 259 | self.initBeforeDDP() 260 | context = mp.spawn( 261 | self.runDDP, 262 | nprocs=self.deviceNum, 263 | join=False, 264 | daemon = True 265 | ) 266 | self.initOutput() 267 | while not context.join(): 268 | pass 269 | else: 270 | self.initOutput() 271 | self.loadCkpt() 272 | if self.initModules() != "Skip": 273 | self.mainLoop() 274 | self.lifeCycle._AAll() 275 | -------------------------------------------------------------------------------- /DLNest/SavePackage/SavePackage.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import json 3 | from pathlib import Path 4 | import time 5 | import random 6 | try: 7 | import torch 8 | except ImportError: 9 | # need to costomize the doSaveOperation and doLoadOperation functions 10 | pass 11 | 12 | class SavePackage: 13 | def __init__( 14 | self, 15 | configPath : str = "", 16 | freqPath : str = "" 17 | ): 18 | self.args = {} 19 | haveArgs = False 20 | if configPath != "": 21 | configPath = Path(configPath) 22 | self.__loadArgs(configPath,self.args) 23 | haveArgs = True 24 | if freqPath != "": 25 | freqPath = Path(freqPath) 26 | self.__loadArgs(freqPath,self.args) 27 | haveArgs = True 28 | 29 | self.ckpts = {} 30 | self.nowCkptID = 0 31 | self.ckptSlow = [] 32 | self.ckptFast = [] 33 | self.ckptConsistent = [] 34 | self.checkpointsDir = None 35 | if haveArgs: 36 | self.maxCkptSlow = self.args["checkpoint_args"]["max_ckpt_in_slow_track"] 37 | self.maxCkptFast = self.args["checkpoint_args"]["max_ckpt_in_fast_track"] 38 | self.maxCkptConsisitent = self.args["checkpoint_args"]["max_ckpt_in_consistent_track"] 39 | self.slowDilation = self.args["checkpoint_args"]["dilation_in_slow_track"] 40 | 41 | self.root = None 42 | self.checkpointsDir = None 43 | self.prefix = "state_" 44 | 45 | def giveArgs(self,args : dict): # Not used 46 | self.args = args 47 | self.maxCkptSlow = self.args["checkpoint_args"]["max_ckpt_in_slow_track"] 48 | self.maxCkptFast = self.args["checkpoint_args"]["max_ckpt_in_fast_track"] 49 | self.maxCkptConsisitent = self.args["checkpoint_args"]["max_ckpt_in_consistent_track"] 50 | self.slowDilation = self.args["checkpoint_args"]["dilation_in_slow_track"] 51 | 52 | def __replaceArgs(self,newArgName,newArgValue,args): 53 | """ 54 | newArgName : 应该是args下的名称 55 | 若新参数不是一个dict,或新参数不存在覆盖问题,则直接新建该参数或覆盖 56 | 若新参数是一个dict且存在覆盖问题,则递归dict调用 57 | """ 58 | if not newArgName in args: 59 | args[newArgName] = newArgValue 60 | return 61 | 62 | if not isinstance(newArgValue,dict): 63 | args[newArgName] = newArgValue 64 | return 65 | 66 | for key in newArgValue: 67 | self.__replaceArgs(key,newArgValue[key],args[newArgName]) 68 | return 69 | 70 | def __loadArgs(self,filePath : Path,args : dict): 71 | """ 72 | load configs to the args dict 73 | """ 74 | # 若文件不存在或不是一个文件,直接报错 75 | if not filePath.is_file(): 76 | raise BaseException("Config file doesn't exists. " + str(filePath)) 77 | fp = filePath.open('r') 78 | newArgs = json.load(fp) 79 | 80 | # 对除了child_jsons的每一个key尝试覆盖或新建args 81 | for key in newArgs: 82 | if key == "child_jsons": 83 | continue 84 | 85 | # 为避免dict类型的参数被完全覆盖,使用__replaceArgs方法新建或覆盖args 86 | self.__replaceArgs(key,newArgs[key],args) 87 | 88 | # 递归查找子json,子json覆盖父json的参数(按照DFS序) 89 | if "child_jsons" in newArgs: 90 | for item in newArgs["child_jsons"]: 91 | path = Path(item) 92 | 93 | # 若子json路径不是绝对路径,则按照当前json路径寻找相对路径 94 | if not path.is_absolute(): 95 | path = filePath.parent / item 96 | 97 | # 载入子json 98 | self.__loadArgs(path,args) 99 | 100 | fp.close() 101 | return 102 | 103 | def __copyAFile(self,filePath : Path, saveDir : Path): 104 | ''' 105 | 若filePath为相对路径,则复制到其对应文件夹 106 | 若filePath为绝对路径,则复制到储存包的根 107 | ''' 108 | if filePath.is_absolute(): 109 | if filePath.is_dir(): 110 | shutil.copytree(filePath,saveDir / filePath.stem) 111 | else: 112 | shutil.copy(filePath,saveDir / (filePath.stem + filePath.suffix)) 113 | else: 114 | abFilePath = Path(self.args["root_file_path"]) / filePath 115 | if abFilePath.is_dir(): 116 | shutil.copytree(abFilePath,saveDir / filePath) 117 | else: 118 | target = saveDir / filePath 119 | target_dir = target.parent 120 | if not target_dir.exists(): 121 | target_dir.mkdir(parents = True, exist_ok = True) 122 | shutil.copy(abFilePath,target) 123 | 124 | def __savePackageInformation(self): 125 | packageInfoPath = self.root / "_package.json" 126 | packageInfo = { 127 | "ckpt_slow" : self.ckptSlow, 128 | "ckpt_fast" : self.ckptFast, 129 | "ckpt_consistent" : self.ckptConsistent, 130 | "prefix" : self.prefix 131 | } 132 | with packageInfoPath.open("w") as fp: 133 | json.dump(packageInfo,fp,sort_keys=True, indent=4, separators=(',', ':')) 134 | 135 | def __loadPackageInformation(self): 136 | packageInfoPath = self.root / "_package.json" 137 | packageInfo = {} 138 | with packageInfoPath.open("r") as fp: 139 | packageInfo = json.load(fp) 140 | self.ckptSlow = packageInfo["ckpt_slow"] 141 | self.ckptFast = packageInfo["ckpt_fast"] 142 | self.ckptConsistent = packageInfo["ckpt_consistent"] 143 | self.prefix = packageInfo["prefix"] 144 | 145 | def saveToNewDir(self,overrideSaveName : str = "",copyFiles = True): 146 | """ 147 | Make a new save package. If having the override save name, use it. If not, use timestamp to save. 148 | """ 149 | saveRoot = Path(self.args["save_root"]) 150 | if not saveRoot.is_absolute(): 151 | saveRoot = Path(self.args["root_file_path"]) / saveRoot 152 | self.args["save_root"] = str(saveRoot) 153 | nowTime = time.time() 154 | saveName = "" 155 | if overrideSaveName != "": 156 | saveName = overrideSaveName 157 | else: 158 | saveName = time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(nowTime)) + "_" + str(random.randint(100,999)) # avoid conflict. May be changed later 159 | saveDir = saveRoot / saveName 160 | self.root = saveDir 161 | 162 | if saveDir.exists(): 163 | shutil.rmtree(saveDir) 164 | saveDir.mkdir(parents=True,exist_ok=True) 165 | 166 | #checkpoints save dir 167 | self.checkpointsDir = saveDir / "Checkpoints" 168 | self.checkpointsDir.mkdir(parents=True,exist_ok=True) 169 | self.ckpts = {} 170 | 171 | # copy python files into dir if copy Files is TRUE 172 | if copyFiles: 173 | self.__copyAFile(Path(self.args["runner_file_path"] if "runner_file_path" in self.args else self.args["model_file_path"]),saveDir) # need to be deprecated 174 | self.__copyAFile(Path(self.args["dataset_file_path"]),saveDir) 175 | self.__copyAFile(Path(self.args["life_cycle_file_path"]),saveDir) 176 | for item in self.args["other_file_paths"]: 177 | self.__copyAFile(Path(item),saveDir) 178 | 179 | # save args 180 | argsPath = saveDir / "args.json" 181 | argsFP = argsPath.open('w') 182 | self.args["root_file_path"] = str(saveDir) 183 | json.dump(self.args,argsFP,sort_keys=True, indent=4, separators=(',', ':')) 184 | argsFP.close() 185 | 186 | # save package info 187 | self.__savePackageInformation() 188 | 189 | def saveACheckpoint(self,stateDict : dict, holdThisCheckpoint : bool = False): 190 | """ 191 | Save a new checkpoint and manage the storage. 192 | """ 193 | idsNeed2Delete = [] 194 | 195 | # save this state dict 196 | saveFile = self.checkpointsDir / (self.prefix + str(self.nowCkptID) + ".ckpt") 197 | saveName = str(saveFile) 198 | self.doSaveOperation(stateDict,saveName) 199 | 200 | # add to ckpts dict 201 | self.ckpts[self.nowCkptID] = saveFile 202 | 203 | # append in fast track 204 | self.ckptFast.append(self.nowCkptID) 205 | if len(self.ckptFast) > self.maxCkptFast: 206 | w2did = self.ckptFast.pop(0) 207 | idsNeed2Delete.append(w2did) 208 | 209 | # append in slow track 210 | if self.nowCkptID % self.slowDilation == 0: 211 | self.ckptSlow.append(self.nowCkptID) 212 | if len(self.ckptSlow) > self.maxCkptSlow: 213 | w2did = self.ckptSlow.pop(0) 214 | if not w2did in idsNeed2Delete: 215 | idsNeed2Delete.append(w2did) 216 | 217 | # append in consistent track 218 | if holdThisCheckpoint: 219 | self.ckptConsistent.append(self.nowCkptID) 220 | if len(self.ckptConsistent) > self.maxCkptConsisitent: 221 | w2did = self.ckptConsistent.pop(0) 222 | if not w2did in idsNeed2Delete: 223 | idsNeed2Delete.append(w2did) 224 | 225 | #delete useless checkpoints on disk 226 | for id in idsNeed2Delete: 227 | if not (id in self.ckptFast or 228 | id in self.ckptSlow or 229 | id in self.ckptConsistent): 230 | path = self.ckpts[id] 231 | path.unlink() 232 | self.ckpts.pop(id) 233 | 234 | self.nowCkptID += 1 235 | # save package info 236 | self.__savePackageInformation() 237 | 238 | def doSaveOperation(self,stateDict,fileName): 239 | """ 240 | pytorch version of save ckpt. If no torch in the environment, please override this function. 241 | """ 242 | torch.save(stateDict,fileName) 243 | 244 | def doLoadOperation(self,fileName : str,device): 245 | return torch.load(fileName,map_location = device) 246 | 247 | def saveVisualString(self, visualString : str): 248 | """ 249 | Make a new file with this visual string which makes the information noticeable. 250 | """ 251 | visualFile = self.root / ("_" + visualString) 252 | visualFile.touch() 253 | with open(visualFile,"w") as f: 254 | f.write(visualString) 255 | 256 | def initCkptsFromExist(self): 257 | self.checkpointsDir = self.root / "Checkpoints" 258 | filelist = self.checkpointsDir.iterdir() 259 | for item in filelist: 260 | if item.suffix == ".ckpt": 261 | id = int(item.stem.split("_")[-1]) 262 | self.ckpts[id] = item 263 | self.nowCkptID = max(self.nowCkptID,id) 264 | self.nowCkptID += 1 265 | 266 | def initFromAnExistSavePackage(self,packagePath : str): 267 | self.root = Path(packagePath) 268 | configPath = self.root / "args.json" 269 | self.args = {} 270 | self.__loadArgs(configPath,self.args) 271 | 272 | self.maxCkptSlow = self.args["checkpoint_args"]["max_ckpt_in_slow_track"] 273 | self.maxCkptFast = self.args["checkpoint_args"]["max_ckpt_in_fast_track"] 274 | self.maxCkptConsisitent = self.args["checkpoint_args"]["max_ckpt_in_consistent_track"] 275 | self.slowDilation = self.args["checkpoint_args"]["dilation_in_slow_track"] 276 | 277 | # load package info 278 | self.__loadPackageInformation() 279 | 280 | # load ckpt 281 | self.initCkptsFromExist() 282 | 283 | def getStateDict(self,id = -1,device = "cuda:0"): 284 | if id == -1: 285 | id = self.nowCkptID - 1 286 | if not id in self.ckpts: 287 | raise BaseException("The checkpoint doesn't exist.") 288 | return self.doLoadOperation(str(self.ckpts[id]),device) 289 | 290 | def getOutputFile(self,rank = -1): 291 | outputFileName = "_output.txt" 292 | if rank != -1: 293 | outputFileName = "_output_" + str(rank) + ".txt" 294 | outputFilePath = self.root / outputFileName 295 | if outputFilePath.exists(): 296 | return outputFilePath.open("a") 297 | else: 298 | return outputFilePath.open("w") 299 | 300 | def setCkptID(self,ckptID : int): 301 | self.nowCkptID = ckptID 302 | 303 | newslow = [] 304 | for item in self.ckptSlow: 305 | if item >= self.nowCkptID: 306 | pass 307 | else: 308 | newslow.append(item) 309 | self.ckptSlow = newslow 310 | 311 | newfast = [] 312 | for item in self.ckptFast: 313 | if item >= self.nowCkptID: 314 | pass 315 | else: 316 | newfast.append(item) 317 | self.ckptFast = newfast 318 | 319 | newconsistent = [] 320 | for item in self.ckptConsistent: 321 | if item >= self.nowCkptID: 322 | pass 323 | else: 324 | newconsistent.append(item) 325 | self.ckptConsistent = newconsistent 326 | 327 | popList = [] 328 | for key in self.ckpts: 329 | if key >= self.nowCkptID: 330 | popList.append(key) 331 | for item in popList: 332 | self.ckpts.pop(item) -------------------------------------------------------------------------------- /DLNest/ShellClient/Communicator.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import argparse 3 | import traceback 4 | import json 5 | from functools import wraps 6 | 7 | def raiseError(self): 8 | raise Exception("command error") 9 | 10 | class Arguments: 11 | def __init__(self,desc : str = ""): 12 | self._parser = argparse.ArgumentParser(description=desc) 13 | self._parser.error = raiseError 14 | 15 | def parser(self): 16 | return self._parser 17 | 18 | class TrainArguments(Arguments): 19 | def __init__(self): 20 | super(TrainArguments, self).__init__(desc="Arguments for DLNest task.") 21 | 22 | self._parser.add_argument("-c",type=str, help="root configuration json file for this task.",required = True) 23 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)") 24 | self._parser.add_argument("-f",type=str, default = "", help="frequently changing configuration json file for this task.(default:None)") 25 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)") 26 | self._parser.add_argument("-ns",action='store_true', help="Set to save to the NOSAVE dir.") 27 | self._parser.add_argument("-mc",action='store_true',help="Set to use multi card.") 28 | self._parser.add_argument("-sd",action='store_true',help="Set to use description as the save dir name.(coverd by ns)") 29 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.") 30 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.") 31 | 32 | class ProjectArguments(Arguments): 33 | def __init__(self): 34 | super(ProjectArguments, self).__init__(desc="Arguments for create a DLNest project.") 35 | 36 | self._parser.add_argument("-d",type=str, help="Path to the directory you want to create the project.", required = True) 37 | self._parser.add_argument("-MNIST",action='store_true', help="Set to new a project with MNIST task.") 38 | self._parser.add_argument("-p", type = str, nargs='+', help = "Set plugins need to be used.") 39 | 40 | class AnalyzeArguments(Arguments): 41 | def __init__(self): 42 | super(AnalyzeArguments, self).__init__(desc="Arguments for an Analyzer") 43 | 44 | self._parser.add_argument("-r",type=str, help = "path to the model record directory.", required = True) 45 | self._parser.add_argument("-s",type=str, default = "", help = "path to the analyze scripts.") 46 | self._parser.add_argument("-c",type=int, default = -1, help = "which epoch you want the model to load.(int)") 47 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)") 48 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.") 49 | 50 | class ContinueArguments(Arguments): 51 | def __init__(self): 52 | super(ContinueArguments, self).__init__(desc="Arguments for an Analyzer") 53 | 54 | self._parser.add_argument("-r",type=str, help = "path to the model record directory.", required = True) 55 | self._parser.add_argument("-c",type=int, default = -1, help = "which epoch you want the model to load.(int)") 56 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)") 57 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)") 58 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.") 59 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.") 60 | self._parser.add_argument("-mc",action='store_true',help="Set to use multi card.") 61 | 62 | class DeviceChangeArguments(Arguments): 63 | def __init__(self): 64 | super(DeviceChangeArguments, self).__init__(desc="Arguments for change valid cards.") 65 | self._parser.add_argument("-d",type=int, nargs='+', help='valid devices', required = True) 66 | 67 | class OutputArguments(Arguments): 68 | def __init__(self): 69 | super(OutputArguments, self).__init__(desc="Argumetns for get styled or not styled outputs.") 70 | self._parser.add_argument("-t",type=str, default = "", help="task ID") 71 | self._parser.add_argument("-s",action="store_true",help = "set to get styled") 72 | 73 | class AddPluginsArguents(Arguments): 74 | def __init__(self): 75 | super(AddPluginsArguents, self).__init__(desc="Arguments for add plugins.") 76 | 77 | self._parser.add_argument("-d", type=str, help="Path to the directory you want to create the project.", required = True) 78 | self._parser.add_argument("-p", type = str, nargs='+', help = "Set plugins need to be used.",required = True) 79 | self._parser.add_argument("-F", action='store_true', help="Set to use full config") 80 | 81 | def stableRun(f): 82 | wraps(f) 83 | def doStableRun(*args, **kwargs): 84 | try: 85 | return f(*args, **kwargs) 86 | except Exception as e: 87 | return {"status" : "error", "error" : str(e)} 88 | return doStableRun 89 | 90 | 91 | class Communicator: 92 | def __init__(self, url : str = "127.0.0.1", port : int = 9999): 93 | self.trainArgParser = TrainArguments() 94 | self.continueArgParser = ContinueArguments() 95 | self.projectArgParser = ProjectArguments() 96 | self.analyzeArgParser = AnalyzeArguments() 97 | self.deviceChangeArgParser = DeviceChangeArguments() 98 | self.outputArgParser = OutputArguments() 99 | self.addPluginsArgParser = AddPluginsArguents() 100 | self.url = "http://" + url + ":" + str(port) 101 | 102 | def shortenList(self,commandWordList : list): 103 | newList = [] 104 | now = "" 105 | for item in commandWordList: 106 | if item[0] == "-": 107 | if now != "": 108 | newList.append(now) 109 | now = "" 110 | newList.append(item) 111 | continue 112 | elif item[0] == "\"": 113 | if now != "": 114 | newList.append(now) 115 | now = item 116 | else: 117 | if now == "": 118 | now = now + item 119 | else: 120 | now = now + " " + item 121 | if now != "": 122 | newList.append(now) 123 | for i in range(len(newList)): 124 | item = newList[i] 125 | if item[0] == "\"" and item[-1] == "\"": 126 | newList[i] = newList[i][1:-1] 127 | return newList 128 | 129 | @stableRun 130 | def runTrain(self,commandWordList : list): 131 | commandWordList = self.shortenList(commandWordList) 132 | args,otherArgs = self.trainArgParser.parser().parse_known_args(commandWordList[1:]) 133 | r = requests.post(self.url + "/run_train",{ 134 | "config_path" : args.c, 135 | "freq_path" : args.f, 136 | "description" : args.d, 137 | "memory_consumption" : args.m, 138 | "CPU" : args.CPU, 139 | "DDP" : args.DDP, 140 | "multi_GPU" : args.mc, 141 | "no_save" : args.ns, 142 | "use_description" : args.sd, 143 | }) 144 | return json.loads(r.text) 145 | 146 | @stableRun 147 | def newProject(self,commandWordList : list): 148 | commandWordList = self.shortenList(commandWordList) 149 | args,otherArgs = self.projectArgParser.parser().parse_known_args(commandWordList[1:]) 150 | r = requests.post(self.url + "/new_proj",{ 151 | "target_dir" : args.d, 152 | "MNIST" : args.MNIST, 153 | "plugins" : args.p 154 | }) 155 | return json.loads(r.text) 156 | 157 | @stableRun 158 | def runAnalyze(self,commandWordList : list): 159 | commandWordList = self.shortenList(commandWordList) 160 | args,otherArgs = self.analyzeArgParser.parser().parse_known_args(commandWordList[1:]) 161 | r = requests.post(self.url + "/analyze",{ 162 | "record_path" : args.r, 163 | "script_path" : args.s, 164 | "checkpoint_ID" : args.c, 165 | "CPU" : args.CPU, 166 | "memory_consumption" : args.m 167 | }) 168 | return json.loads(r.text) 169 | 170 | @stableRun 171 | def continueTrain(self,commandWordList : list): 172 | commandWordList = self.shortenList(commandWordList) 173 | args,otherArgs = self.continueArgParser.parser().parse_known_args(commandWordList[1:]) 174 | r = requests.post(self.url + "/continue_train",{ 175 | "record_path" : args.r, 176 | "checkpoint_ID" : args.c, 177 | "memory_consumption" : args.m, 178 | "CPU" : args.CPU, 179 | "DDP" : args.DDP, 180 | "multi_GPU" : args.mc, 181 | "description" : args.d 182 | }) 183 | return json.loads(r.text) 184 | 185 | @stableRun 186 | def changeDevices(self,commandWordList : list): 187 | args,otherArgs = self.deviceChangeArgParser.parser().parse_known_args(commandWordList[1:]) 188 | r = requests.post(self.url + "/change_devices",{ 189 | "new_devices_IDs" : args.d 190 | }) 191 | return json.loads(r.text) 192 | 193 | @stableRun 194 | def runExp(self,commandWordList : list): 195 | """ 196 | commandWordList : taskID , command 197 | """ 198 | r = requests.post(self.url + "/run_exp",{ 199 | "task_ID" : commandWordList[1], 200 | "command" : commandWordList[2] 201 | }) 202 | return json.loads(r.text) 203 | 204 | @stableRun 205 | def getAnalyzeOutput(self,commandWordList : list): 206 | """ 207 | commandWordList : -t task_ID -s 208 | """ 209 | args,otherArgs = self.outputArgParser.parser().parse_known_args(commandWordList[1:]) 210 | r = requests.get(self.url + "/get_analyze_output",{ 211 | "task_ID" : args.t, 212 | "styled" : args.s 213 | }) 214 | return json.loads(r.text) 215 | 216 | @stableRun 217 | def getDLNestOutput(self,commandWordList : list): 218 | """ 219 | commandWordList : -s 220 | """ 221 | args,otherArgs = self.outputArgParser.parser().parse_known_args(commandWordList[1:]) 222 | r = requests.get(self.url + "/get_DLNest_output",{ 223 | "styled" : args.s 224 | }) 225 | return json.loads(r.text) 226 | 227 | @stableRun 228 | def getTasksInformation(self): 229 | r = requests.get(self.url + "/get_task_info",{}) 230 | return json.loads(r.text) 231 | 232 | @stableRun 233 | def getDevicesInformation(self): 234 | r = requests.get(self.url + "/get_devices_info",{}) 235 | return json.loads(r.text) 236 | 237 | @stableRun 238 | def delATask(self,commandWordList : list): 239 | """ 240 | commandWordList: taskID 241 | """ 242 | r = requests.post(self.url + "/del_task",{ 243 | "task_ID" : commandWordList 244 | }) 245 | return json.loads(r.text) 246 | 247 | @stableRun 248 | def clear(self): 249 | r = requests.post(self.url + "/clear",{}) 250 | return json.loads(r.text) 251 | 252 | @stableRun 253 | def addPlugins(self, commandWordList : list): 254 | commandWordList = self.shortenList(commandWordList) 255 | args, otherArgs = self.addPluginsArgParser.parser().parse_known_args(commandWordList[1:]) 256 | r = requests.post(self.url + "/add_plugins", { 257 | "target_dir" : args.d, 258 | "plugins" : args.p, 259 | "full" : args.F 260 | }) 261 | return json.loads(r.text) 262 | 263 | def giveACommand(self, commandWordList : list): 264 | if commandWordList[0] == "run": 265 | return self.runTrain(commandWordList) 266 | elif commandWordList[0] == "continue": 267 | return self.continueTrain(commandWordList) 268 | elif commandWordList[0] == "new": 269 | return self.newProject(commandWordList) 270 | elif commandWordList[0] == "analyze": 271 | return self.runAnalyze(commandWordList) 272 | elif commandWordList[0] == "runExp": 273 | return self.runExp(commandWordList) 274 | elif commandWordList[0] == "del": 275 | return self.delATask(commandWordList[1]) 276 | elif commandWordList[0] == "showAN": 277 | return self.getAnalyzeOutput(commandWordList) 278 | elif commandWordList[0] == "showDL": 279 | return self.getDLNestOutput(commandWordList) 280 | elif commandWordList[0] == "showTask": 281 | return self.getTasksInformation() 282 | elif commandWordList[0] == "showDevice": 283 | return self.getDevicesInformation() 284 | elif commandWordList[0] == "changeDevices": 285 | return self.changeDevices(commandWordList) 286 | elif commandWordList[0] == "clear": 287 | return self.clear() 288 | elif commandWordList[0] == "addP": 289 | return self.addPlugins(commandWordList) 290 | elif commandWordList[0] == "exit": 291 | return {"exit" : True} 292 | else: 293 | return {"error" : "Wrong command"} --------------------------------------------------------------------------------