├── DLNest
├── __init__.py
├── Common
│ ├── __init__.py
│ ├── ModelBase.py
│ ├── ModelBaseTorch.py
│ ├── DatasetBase.py
│ ├── RunningStatus.py
│ ├── RunnerBase.py
│ ├── RunnerBaseTorch.py
│ └── LifeCycleBase.py
├── Executor
│ ├── __init__.py
│ ├── AnalyzeProcess.py
│ ├── TrainProcess.py
│ └── TaskProcess.py
├── Output
│ ├── __init__.py
│ ├── TrainStdout.py
│ ├── DLNestBuffer.py
│ ├── AnalyzerBuffer.py
│ └── OutputLayerBase.py
├── Plugins
│ ├── __init__.py
│ ├── Utils
│ │ ├── __init__.py
│ │ ├── CheckPlugins.py
│ │ └── SendMailTools.py
│ ├── LogInit.py
│ ├── DLNestPluginBase.py
│ ├── AutoTensorboardScalar.py
│ ├── SimpleCMDVisualize.py
│ └── MailsNote.py
├── Scheduler
│ ├── __init__.py
│ ├── SchedulerStrategyBase.py
│ ├── Scheduler.py
│ └── DefaultStrategy.py
├── Information
│ ├── __init__.py
│ ├── test
│ │ ├── __init__.py
│ │ ├── FakeSubprocess.py
│ │ ├── test_AnalyzeTask.py
│ │ ├── test_TaskInformation.py
│ │ └── test_DeviceInformation.py
│ ├── CPUInformation.py
│ ├── TaskInformation.py
│ ├── GPUInformation.py
│ ├── AnalyzeTask.py
│ ├── DeviceInformation.py
│ ├── TrainTask.py
│ └── InfoCenter.py
├── Operations
│ ├── __init__.py
│ ├── SafeExit.py
│ ├── ClearDLNestOutput.py
│ ├── DelATask.py
│ ├── ChangeDelay.py
│ ├── GetTasksInformation.py
│ ├── GetDevicesInformation.py
│ ├── ChangeMaxTaskPerDevice.py
│ ├── ChangeDevices.py
│ ├── GetDLNestOutput.py
│ ├── AnalyzeIndependent.py
│ ├── Analyze.py
│ ├── RunExp.py
│ ├── RunIndependent.py
│ ├── GetAnalyzeOutput.py
│ ├── ContinueTrain.py
│ ├── Run.py
│ ├── GetPluginConfig.py
│ ├── New.py
│ └── UsePlugin.py
├── SavePackage
│ ├── __init__.py
│ └── SavePackage.py
├── ShellClient
│ ├── __init__.py
│ ├── Windows
│ │ ├── __init__.py
│ │ ├── Utils
│ │ │ ├── __init__.py
│ │ │ └── Completers.py
│ │ ├── SCTextArea.py
│ │ ├── CommandInput.py
│ │ ├── DevicesInfoShower.py
│ │ ├── TaskInfoShower.py
│ │ └── ResultsOutput.py
│ ├── Client.py
│ └── Communicator.py
├── TornadoServer
│ └── __init__.py
├── FactoryFiles
│ ├── FactoryClean
│ │ ├── freq_config.json
│ │ ├── dataset_config.json
│ │ ├── model_config.json
│ │ ├── plugins_config.json
│ │ ├── AnalyzeScripts
│ │ │ └── test.py
│ │ ├── Dataset
│ │ │ └── Dataset.py
│ │ ├── root_config.json
│ │ ├── Model
│ │ │ └── Runner.py
│ │ └── LifeCycle.py
│ └── FactoryMNIST
│ │ ├── freq_config.json
│ │ ├── AnalyzeScripts
│ │ ├── showModel.py
│ │ └── getAcc.py
│ │ ├── model_config.json
│ │ ├── common_config.json
│ │ ├── dataset_config.json
│ │ ├── root_config.json
│ │ ├── plugins_config.json
│ │ ├── Model
│ │ ├── MNISTCNN.py
│ │ └── Runner.py
│ │ ├── LifeCycle.py
│ │ └── Dataset
│ │ └── Dataset.py
├── Client.py
├── pytest.ini
├── Server.py
├── Simple.py
├── Run.py
├── Analyze.py
└── Local.py
├── manifest.in
├── setup.py
└── .gitignore
/DLNest/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/Common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/Executor/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/Output/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/Plugins/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/Scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/Information/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/Operations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/Plugins/Utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/SavePackage/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/ShellClient/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/TornadoServer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/manifest.in:
--------------------------------------------------------------------------------
1 | global-exclude *.pyc
--------------------------------------------------------------------------------
/DLNest/Information/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/ShellClient/Windows/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/ShellClient/Windows/Utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryClean/freq_config.json:
--------------------------------------------------------------------------------
1 | {
2 | }
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/freq_config.json:
--------------------------------------------------------------------------------
1 | {
2 | }
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryClean/dataset_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_config" : {
3 | }
4 | }
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryClean/model_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_config" : {
3 | }
4 | }
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryClean/plugins_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "plugins" : [],
3 | "plugins_config" : {}
4 | }
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/AnalyzeScripts/showModel.py:
--------------------------------------------------------------------------------
1 | def experience(self):
2 | print(self.runner.model)
--------------------------------------------------------------------------------
/DLNest/Client.py:
--------------------------------------------------------------------------------
1 | from DLNest.ShellClient.Client import startClient
2 |
3 | if __name__ == "__main__":
4 | startClient()
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryClean/AnalyzeScripts/test.py:
--------------------------------------------------------------------------------
1 | def experience(self):
2 | print(self.runner,self.dataset,self.args)
--------------------------------------------------------------------------------
/DLNest/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | markers =
3 | Information : 'Tests for Information module'
4 |
5 | addopts = --strict-markers
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/model_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_config" : {
3 | "feats" : [1,64,128,1024]
4 | }
5 | }
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/common_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "other_file_paths":[
3 | "./Model/MNISTCNN.py"
4 | ],
5 | "epochs" : 5
6 | }
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/dataset_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_config" : {
3 | "data_root" : "",
4 | "batch_size" : 64
5 | }
6 | }
--------------------------------------------------------------------------------
/DLNest/Common/ModelBase.py:
--------------------------------------------------------------------------------
1 | from DLNest.Common.RunnerBase import RunnerBase
2 |
3 | print("ModelBase is deprecated, please use RunnerBase instead")
4 | ModelBase = RunnerBase
--------------------------------------------------------------------------------
/DLNest/Operations/SafeExit.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.InfoCenter import InfoCenter
2 |
3 | def safeExit():
4 | infoCenter = InfoCenter()
5 | infoCenter.delAllTask()
--------------------------------------------------------------------------------
/DLNest/Operations/ClearDLNestOutput.py:
--------------------------------------------------------------------------------
1 | from DLNest.Output.DLNestBuffer import DLNestBuffer
2 |
3 | def clearDLNestOutput():
4 | buffer = DLNestBuffer()
5 | buffer.clear()
--------------------------------------------------------------------------------
/DLNest/Information/test/FakeSubprocess.py:
--------------------------------------------------------------------------------
1 | class FakeSubprocess:
2 | def __init__(self,alive):
3 | self.alive = alive
4 |
5 | def is_alive(self):
6 | return self.alive
--------------------------------------------------------------------------------
/DLNest/Operations/DelATask.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.InfoCenter import InfoCenter
2 |
3 | def delATask(
4 | taskID : str
5 | ):
6 | infoCenter = InfoCenter()
7 | infoCenter.delATask(taskID)
--------------------------------------------------------------------------------
/DLNest/Common/ModelBaseTorch.py:
--------------------------------------------------------------------------------
1 | from DLNest.Common.RunnerBaseTorch import RunnerBaseTorch
2 |
3 | print("ModelBaseTorch is deprecated, please use RunnerBaseTorch instead")
4 | ModelBaseTorch = RunnerBaseTorch
--------------------------------------------------------------------------------
/DLNest/Operations/ChangeDelay.py:
--------------------------------------------------------------------------------
1 | from DLNest.Scheduler.Scheduler import Scheduler
2 |
3 | def changeDelay(
4 | newDelay : int
5 | ):
6 | scheduler = Scheduler()
7 | scheduler.changeTimeDelay(newDelay)
--------------------------------------------------------------------------------
/DLNest/Operations/GetTasksInformation.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.InfoCenter import InfoCenter
2 |
3 | def getTasksInformation():
4 | infoCenter = InfoCenter()
5 | return infoCenter.getTasksInformation()
--------------------------------------------------------------------------------
/DLNest/Operations/GetDevicesInformation.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.InfoCenter import InfoCenter
2 |
3 | def getDevicesInformation():
4 | infoCenter = InfoCenter()
5 | return infoCenter.getAvailableDevicesInformation()
--------------------------------------------------------------------------------
/DLNest/Operations/ChangeMaxTaskPerDevice.py:
--------------------------------------------------------------------------------
1 | from DLNest.Scheduler.Scheduler import Scheduler
2 |
3 | def changeMaxTaskPerDevice(
4 | newMax : int
5 | ):
6 | scheduler = Scheduler()
7 | scheduler.changeMaxTaskPerDevice(newMax)
--------------------------------------------------------------------------------
/DLNest/Operations/ChangeDevices.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.InfoCenter import InfoCenter
2 |
3 | def changeDevices(
4 | newDevicesIDList : [int]
5 | ):
6 | infoCenter = InfoCenter()
7 | infoCenter.changeDevices(newDevicesIDList)
--------------------------------------------------------------------------------
/DLNest/Server.py:
--------------------------------------------------------------------------------
1 | from DLNest.TornadoServer.Server import DLNestServer
2 |
3 |
4 | if __name__ == "__main__":
5 | import sys
6 | if sys.path[0] != '':
7 | sys.path[0] = ''
8 | server = DLNestServer()
9 | server.start()
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryClean/Dataset/Dataset.py:
--------------------------------------------------------------------------------
1 | from DLNest.Common.DatasetBase import DatasetBase
2 |
3 |
4 | class Dataset(DatasetBase):
5 | def init(self,args : dict):
6 | pass
7 | #return {dict to model},self.trainLoader,self.testLoader
8 |
--------------------------------------------------------------------------------
/DLNest/Operations/GetDLNestOutput.py:
--------------------------------------------------------------------------------
1 | from DLNest.Output.DLNestBuffer import DLNestBuffer
2 |
3 | def getDLNestOutput(
4 | style : bool = True
5 | ):
6 | buffer = DLNestBuffer()
7 |
8 | if style:
9 | return buffer.getStyledText()
10 | else:
11 | return buffer.getPlainText()
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name='DLNest',
5 | version='0.3.2',
6 | packages=find_packages(),
7 | include_package_data=True,
8 | zip_safe=False,
9 | install_requires=[
10 | 'nvidia-ml-py3',
11 | 'APScheduler>=3.6.3',
12 | 'torch',
13 | 'prompt-toolkit>=3.0.7',
14 | 'tornado>=6.0'
15 | ],
16 | package_data={
17 | "DLNest":[
18 | "FactoryFiles/*",
19 | "FactoryFiles/*/*",
20 | "FactoryFiles/*/*/*"
21 | ]
22 | },
23 | python_requires=">=3.6"
24 | )
25 |
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/AnalyzeScripts/getAcc.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def experience(self):
4 | valLoader = self.dataset.testLoader
5 | acc = 0
6 | total = 0
7 | for _iter,data in enumerate(valLoader):
8 | x,y = data
9 | if self.runner._status.env != "CPU":
10 | x,y = x.cuda(),y.cuda()
11 | with torch.no_grad():
12 | output = self.runner.model(x)
13 | _,pred = torch.max(output,1)
14 | acc += sum(pred == y).item()
15 | total += x.shape[0]
16 | print("correct count:",acc,"total count:",total,"accuracy:",acc / total)
17 |
--------------------------------------------------------------------------------
/DLNest/Information/CPUInformation.py:
--------------------------------------------------------------------------------
1 | try:
2 | from DeviceInformation import DeviceInformation
3 | except ImportError:
4 | from .DeviceInformation import DeviceInformation
5 | import psutil
6 |
7 | class CPUInformation(DeviceInformation):
8 | def __init__(self):
9 | super(CPUInformation,self).__init__("CPU")
10 | nowMem = psutil.virtual_memory()
11 | self.totalMemory = nowMem.total / 1024 ** 2
12 |
13 | def getFreeMemory(self):
14 | nowMem = psutil.virtual_memory()
15 | return nowMem.available / 1024 ** 2
16 |
17 | def getDeviceStr(self):
18 | return "cpu"
--------------------------------------------------------------------------------
/DLNest/Scheduler/SchedulerStrategyBase.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.InfoCenter import InfoCenter
2 | from DLNest.Information.TaskInformation import TaskInformation
3 | from DLNest.Information.DeviceInformation import DeviceInformation
4 |
5 | from abc import ABCMeta, abstractmethod
6 |
7 | class SchedulerStrategyBase:
8 | def __init__(self):
9 | """
10 | WARNING: before a run requirement send to scheduler, infoCenter.deviceLock should be released.
11 | """
12 | self.infoCenter = InfoCenter()
13 | self.scheduler = None
14 |
15 | @abstractmethod
16 | def decide(self, maxTaskPerDevice : int,maxTimeDelay : int):
17 | pass
--------------------------------------------------------------------------------
/DLNest/Operations/AnalyzeIndependent.py:
--------------------------------------------------------------------------------
1 | from DLNest.Executor.AnalyzeProcess import AnalyzeProcess
2 | from DLNest.Information.AnalyzeTask import AnalyzeTask
3 |
4 | def AnalyzeIndependent(
5 | recordPath : str,
6 | expFunc = None,
7 | scriptPath : str = "",
8 | checkpointID : int = -1,
9 | devices : list = []
10 | ):
11 | task = AnalyzeTask(
12 | recordPath = recordPath,
13 | scriptPath = scriptPath,
14 | checkpointID = checkpointID,
15 | devices = devices
16 | )
17 | try:
18 | AProcess = AnalyzeProcess(task,expFunc=expFunc)
19 | AProcess.run()
20 | except KeyboardInterrupt:
21 | exit(0)
22 |
23 |
--------------------------------------------------------------------------------
/DLNest/Operations/Analyze.py:
--------------------------------------------------------------------------------
1 | from DLNest.Scheduler.Scheduler import Scheduler
2 | from DLNest.Information.AnalyzeTask import AnalyzeTask
3 |
4 | def analyze(
5 | recordPath : str,
6 | scriptPath : str = "",
7 | checkpointID : int = -1,
8 | CPU : bool = False,
9 | memoryConsumption : int = -1,
10 | otherArgs : dict = {}
11 | ):
12 | devices = []
13 | if CPU:
14 | devices = [-1]
15 | task = AnalyzeTask(
16 | recordPath = recordPath,
17 | scriptPath = scriptPath,
18 | checkpointID = checkpointID,
19 | memoryConsumption = memoryConsumption,
20 | devices = devices
21 | )
22 |
23 | scheduler = Scheduler()
24 | scheduler.giveATask(task,otherArgs = otherArgs)
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryClean/root_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "save_root":"./Saves",
3 | "runner_name":"Runner",
4 | "dataset_name":"Dataset",
5 | "life_cycle_name":"LifeCycle",
6 | "checkpoint_args":{
7 | "max_ckpt_in_slow_track":0,
8 | "dilation_in_slow_track":100,
9 | "max_ckpt_in_fast_track":1,
10 | "max_ckpt_in_consistent_track":1
11 | },
12 | "root_file_path":"",
13 | "runner_file_path":"./Model/Runner.py",
14 | "dataset_file_path":"./Dataset/Dataset.py",
15 | "life_cycle_file_path":"./LifeCycle.py",
16 | "other_file_paths":[],
17 | "child_jsons":[
18 | "./model_config.json",
19 | "./dataset_config.json",
20 | "./plugins_config.json"
21 | ]
22 | }
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryClean/Model/Runner.py:
--------------------------------------------------------------------------------
1 | from DLNest.Common.RunnerBase import RunnerBase
2 |
3 | class Runner(RunnerBase):
4 | def init(self,args : dict,datasetInfo : dict = None):
5 | #Init models
6 | pass
7 |
8 | def initLog(self):
9 | return {}
10 |
11 | def initOptimizer(self):
12 | # init optimizers
13 | pass
14 |
15 | def runOneStep(self,data, log : dict, iter : int, epoch : int):
16 | pass
17 |
18 | def visualize(self,epoch : int, iter : int, log : dict):
19 | pass
20 |
21 | def validationInit(self):
22 | pass
23 |
24 | def validateABatch(self,data, iter : int):
25 | pass
26 |
27 | def validationAnalyze(self, log : dict):
28 | pass
--------------------------------------------------------------------------------
/DLNest/Output/TrainStdout.py:
--------------------------------------------------------------------------------
1 | class TrainStdout:
2 | def __init__(self,fp,showOnScreen = False,originalStdout = None):
3 | super(TrainStdout,self).__init__()
4 | self.stdout = fp
5 | self.screenout = originalStdout
6 | self.showOnScreen = showOnScreen
7 |
8 | def write(self,s):
9 | ret = self.stdout.write(s)
10 | self.stdout.flush()
11 | if self.showOnScreen:
12 | self.screenout.write(s)
13 | return ret
14 |
15 | def flush(self):
16 | if self.showOnScreen:
17 | self.screenout.flush()
18 | return self.stdout.flush()
19 |
20 | def isatty(self):
21 | return self.stdout.isatty()
22 |
23 | def fileno(self):
24 | return self.stdout.fileno()
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/root_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "save_root":"./Saves",
3 | "runner_name":"Runner",
4 | "dataset_name":"MNISTDataset",
5 | "life_cycle_name":"LifeCycle",
6 | "checkpoint_args":{
7 | "max_ckpt_in_slow_track":0,
8 | "dilation_in_slow_track":100,
9 | "max_ckpt_in_fast_track":1,
10 | "max_ckpt_in_consistent_track":1
11 | },
12 | "root_file_path":"",
13 | "runner_file_path":"./Model/Runner.py",
14 | "dataset_file_path":"./Dataset/Dataset.py",
15 | "life_cycle_file_path":"./LifeCycle.py",
16 | "other_file_paths":[],
17 | "child_jsons":[
18 | "./common_config.json",
19 | "./model_config.json",
20 | "./dataset_config.json",
21 | "./plugins_config.json"
22 | ]
23 | }
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/plugins_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "plugins" : [
3 | "LogInit",
4 | "AutoTensorboardScalar",
5 | "SimpleCMDVisualize",
6 | "MailsNote"
7 | ],
8 | "plugins_config" : {
9 | "SimpleCMDVisualize" : {
10 | "stride" : 10,
11 | "keys" : ["loss","acc"],
12 | "format" : {
13 | "loss" : "\t| {} : {:.6f}",
14 | "acc" : "\t| {} : {:.6f}"
15 | }
16 | },
17 | "LogInit" : {
18 | "level" : "INFO"
19 | },
20 | "MailsNote" : {
21 | "enable_list" : ["Aborting","TrainFinish"],
22 | "username" : "",
23 | "password" : "",
24 | "host" : "mail.fudan.edu.cn",
25 | "port" : 25
26 | }
27 | }
28 | }
--------------------------------------------------------------------------------
/DLNest/Operations/RunExp.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.InfoCenter import InfoCenter
2 | from DLNest.Information.AnalyzeTask import AnalyzeTask
3 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer
4 |
5 | def runExp(
6 | taskID : str,
7 | command : str
8 | ):
9 | infoCenter = InfoCenter()
10 | analyzeTask : AnalyzeTask = infoCenter.getTaskByID(taskID)
11 |
12 | assert analyzeTask != None # Wrong task
13 | assert isinstance(analyzeTask,AnalyzeTask) # Not an analyze task
14 | assert isinstance(analyzeTask.outputBuffer,AnalyzerBuffer) # don't have analyze buffer
15 | assert analyzeTask.process != None # not running
16 | assert analyzeTask.process.is_alive() # not alive
17 | assert analyzeTask.commandQueue != None # don't have command queue
18 |
19 | analyzeTask.commandQueue.put(command,block = False)
--------------------------------------------------------------------------------
/DLNest/Operations/RunIndependent.py:
--------------------------------------------------------------------------------
1 | from DLNest.Executor.TrainProcess import TrainProcess
2 | from DLNest.Information.TrainTask import TrainTask
3 |
4 | def runIndependent(
5 | configPath : str,
6 | freqPath : str = "",
7 | devices : [int] = [-1],
8 | description : str = "",
9 | DDP : bool = False,
10 | noSave : bool = False,
11 | useDescriptionToSave : bool = False,
12 | showOnScreen : bool = True
13 | ):
14 | task = TrainTask.fromConfigFile(
15 | configPath = configPath,
16 | freqPath = freqPath,
17 | devices = devices,
18 | description = description,
19 | DDP = DDP,
20 | noSave = noSave,
21 | useDescriptionToSave = useDescriptionToSave
22 | )
23 | try:
24 | TProcess = TrainProcess(task,showOnScreen)
25 | TProcess.run()
26 | except KeyboardInterrupt:
27 | exit(0)
--------------------------------------------------------------------------------
/DLNest/Operations/GetAnalyzeOutput.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.InfoCenter import InfoCenter
2 | from DLNest.Information.AnalyzeTask import AnalyzeTask
3 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer
4 |
5 | def getAnalyzeOutput(
6 | taskID : str,
7 | style : bool = True
8 | ):
9 | infoCenter = InfoCenter()
10 | analyzeTask : AnalyzeTask = infoCenter.getTaskByID(taskID)
11 |
12 | assert analyzeTask != None # Wrong task
13 | assert isinstance(analyzeTask,AnalyzeTask) # Not an analyze task
14 | assert isinstance(analyzeTask.outputBuffer,AnalyzerBuffer) # don't have analyze buffer
15 | assert analyzeTask.process != None # not running
16 | assert analyzeTask.process.is_alive() # not alive
17 |
18 | if style:
19 | return analyzeTask.outputBuffer.getStyledText()
20 | else:
21 | return analyzeTask.outputBuffer.getPlainText()
--------------------------------------------------------------------------------
/DLNest/Operations/ContinueTrain.py:
--------------------------------------------------------------------------------
1 | from DLNest.Scheduler.Scheduler import Scheduler
2 | from DLNest.Information.TrainTask import TrainTask
3 |
4 | def continueTrain(
5 | recordPath : str,
6 | checkpointID : int = -1,
7 | memoryConsumption : int = -1,
8 | CPU : bool = False,
9 | DDP : bool = False,
10 | multiGPU : bool = False,
11 | description : str = "",
12 | otherArgs : dict = {}
13 | ):
14 | devices = []
15 | if CPU:
16 | devices = [-1]
17 |
18 | task = TrainTask.fromRecord(
19 | recordPath = recordPath,
20 | checkpointID = checkpointID,
21 | devices = devices,
22 | memoryConsumption = memoryConsumption,
23 | multiGPU = multiGPU,
24 | DDP = DDP,
25 | description = description
26 | )
27 |
28 | scheduler = Scheduler()
29 | scheduler.giveATask(task,otherArgs = otherArgs)
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryClean/LifeCycle.py:
--------------------------------------------------------------------------------
1 | from DLNest.Common.DatasetBase import DatasetBase
2 | from DLNest.Common.RunnerBase import RunnerBase
3 | from DLNest.Common.LifeCycleBase import LifeCycleBase
4 |
5 |
6 | class LifeCycle(LifeCycleBase):
7 | def needVisualize(self, epoch : int, iter : int, logdict : dict, args : dict):
8 | return False
9 |
10 | def needValidation(self, epoch : int, logdict : dict, args : dict):
11 | return False
12 |
13 | def commandLineOutput(self,epoch : int, logdict : dict, args : dict):
14 | print("Epoch #" + str(epoch) + " finished!")
15 |
16 | def needSaveModel(self, epoch : int, logdict : dict, args : dict):
17 | return False
18 |
19 | def holdThisCheckpoint(self, epoch : int, logdict : dict, args : dict):
20 | return False
21 |
22 | def needContinueTrain(self, epoch : int, logdict : dict, args : dict):
23 | return False
--------------------------------------------------------------------------------
/DLNest/Plugins/LogInit.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase as DPB
3 |
4 | class DLNestPlugin(DPB):
5 | _NAME = "LogInit"
6 | _config = {
7 | "level" : "INFO",
8 | "format" : "[%(asctime)s][%(levelname)s] %(message)s",
9 | "datefmt" : "%Y-%m-%d %H:%M:%S"
10 | }
11 | _defaultKeys = ["level"]
12 | def BAll(self):
13 | import sys
14 | log = logging.getLogger()
15 | while len(log.handlers) > 0:
16 | log.removeHandler(log.handlers[0])
17 |
18 | args = self.taskProcess.task.args
19 | pluginName = "LogInit"
20 | level = DPB.getArgs(self, pluginName, "level", "INFO")
21 | format = DPB.getArgs(self, pluginName, "format", "[%(asctime)s][%(levelname)s] %(message)s")
22 | datefmt= DPB.getArgs(self, pluginName, "datefmt", "%Y-%m-%d %H:%M:%S")
23 | logging.basicConfig(level = level, format = format,datefmt = datefmt)
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/Model/MNISTCNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class MNISTModel(nn.Module):
5 | def __init__(self,args : dict):
6 | super(MNISTModel, self).__init__()
7 | feats = args["model_config"]["feats"]
8 | self.conv1 = nn.Sequential(nn.Conv2d(feats[0],feats[1],kernel_size=3,stride=1,padding=1),
9 | nn.ReLU(),
10 | nn.Conv2d(feats[1],feats[2],kernel_size=3,stride=1,padding=1),
11 | nn.ReLU(),
12 | nn.MaxPool2d(stride=2,kernel_size=2))
13 | self.dense = nn.Sequential(nn.Linear(14*14*feats[2],feats[3]),
14 | nn.ReLU(),
15 | nn.Dropout(p=0.5),
16 | nn.Linear(1024, 10))
17 | def forward(self, x):
18 | x = self.conv1(x)
19 | x = x.view(-1, 14*14*128)
20 | x = self.dense(x)
21 | return x
--------------------------------------------------------------------------------
/DLNest/Operations/Run.py:
--------------------------------------------------------------------------------
1 | from DLNest.Scheduler.Scheduler import Scheduler
2 | from DLNest.Information.TrainTask import TrainTask
3 |
4 | def run(
5 | configPath : str,
6 | freqPath : str = "",
7 | description : str = "",
8 | memoryConsumption : int = -1,
9 | CPU : bool = False,
10 | DDP : bool = False,
11 | multiGPU : bool = False,
12 | noSave : bool = False,
13 | useDescriptionToSave : bool = False,
14 | otherArgs : dict = {}
15 | ):
16 | scheduler = Scheduler()
17 |
18 | devices = []
19 | if CPU:
20 | devices = [-1]
21 |
22 | task = TrainTask.fromConfigFile(
23 | configPath = configPath,
24 | freqPath = freqPath,
25 | description = description,
26 | devices = devices,
27 | memoryConsumption = memoryConsumption,
28 | multiGPU = multiGPU,
29 | DDP = DDP,
30 | noSave = noSave,
31 | useDescriptionToSave = useDescriptionToSave
32 | )
33 |
34 | scheduler.giveATask(task,otherArgs = otherArgs)
--------------------------------------------------------------------------------
/DLNest/Plugins/DLNestPluginBase.py:
--------------------------------------------------------------------------------
1 |
2 | class DLNestPluginBase:
3 | _NANE = "DLNestPluginBase"
4 | _config = {}
5 | _defaultKeys = []
6 |
7 | @classmethod
8 | def getName(cls):
9 | return cls._NAME
10 |
11 | @classmethod
12 | def getDefaultConfig(cls):
13 | ret = {}
14 | for key in cls._defaultKeys:
15 | ret[key] = cls._config[key]
16 | return ret
17 |
18 | @classmethod
19 | def getFullConfig(cls):
20 | return cls._config
21 |
22 | @classmethod
23 | def getArgs(cls, self, pluginName : str, name : str, default : any):
24 | args = self.getArgs()
25 | if "plugins_config" in args:
26 | pArgs = args["plugins_config"]
27 | if pluginName in pArgs:
28 | if name in pArgs[pluginName]:
29 | return pArgs[pluginName][name]
30 | else:
31 | return default
32 | else:
33 | return default
34 | else:
35 | return default
--------------------------------------------------------------------------------
/DLNest/Plugins/Utils/CheckPlugins.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | import logging
3 |
4 | def checkPlugins(func):
5 | @wraps(func)
6 | def checkAndRun(*args, **kwargs):
7 | name = func.__name__
8 | for plugin in args[0]._plugins:
9 | if name[1:] in dir(plugin):
10 | try:
11 | getattr(plugin,name[1:])(*args, **kwargs)
12 | except Exception as e:
13 | logging.debug(str(e))
14 | return func(*args, **kwargs)
15 |
16 | return checkAndRun
17 |
18 | def checkDictOutputPlugins(func):
19 | @wraps(func)
20 | def checkAndRun(*args, **kwargs):
21 | name = func.__name__
22 | ret = {}
23 | for plugin in args[0]._plugins:
24 | if name[1:] in dir(plugin):
25 | try:
26 | ret.update(getattr(plugin,name[1:])(*args, **kwargs))
27 | except Exception as e:
28 | logging.debug(str(e))
29 | ret.update(func(*args, **kwargs))
30 | return ret
31 |
32 | return checkAndRun
--------------------------------------------------------------------------------
/DLNest/Simple.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import argparse
3 | from prompt_toolkit import PromptSession,HTML
4 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
5 | import traceback
6 | import json
7 | from DLNest.ShellClient.Communicator import Communicator
8 |
9 |
10 | class DLNestSimple:
11 | def __init__(self, url : str = "127.0.0.1", port : int = 9999):
12 | self.communicator = Communicator(url = url, port = port)
13 | def run(self):
14 | self.session = PromptSession(auto_suggest = AutoSuggestFromHistory())
15 | while True:
16 | try:
17 | command = self.session.prompt(HTML("DLNest>>"))
18 | commandWordList = command.strip().split(' ')
19 | output = self.communicator.giveACommand(commandWordList)
20 | print(output)
21 | if "exit" in output:
22 | exit(0)
23 | except KeyboardInterrupt:
24 | exit(0)
25 | except Exception as e:
26 | s = traceback.format_exc()
27 | listS = s.split("\n")[:-1]
28 | s = "\n".join(listS[-3:])
29 | print(s)
30 |
31 | if __name__ == "__main__":
32 | main = DLNestSimple()
33 | main.run()
--------------------------------------------------------------------------------
/DLNest/Plugins/AutoTensorboardScalar.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | import logging
3 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase as DPB
4 |
5 | class DLNestPlugin(DPB):
6 | _NAME = "AutoTensorboardScalar"
7 | _config = {}
8 | _defaultKeys = []
9 | def runnerInit(self,args : dict, datasetInfo : dict = None):
10 | if self._status.rank == -1 or self._status.rank == 0:
11 | self.writer = SummaryWriter(".")
12 | logging.debug("[AutoTensorboardScalar] Finish modelInit")
13 |
14 | def visualize(self, log : dict, epoch : int, iter : int):
15 | if self._status.rank == -1 or self._status.rank == 0:
16 | if not "_lastLen" in dir(self):
17 | self._lastLen = {key : 0 for key in log}
18 |
19 | for key in log:
20 | if isinstance(log[key], list) and len(log[key]) != self._lastLen[key]:
21 | try:
22 | self.writer.add_scalar(key,log[key][-1],len(log[key]) - 1)
23 | self._lastLen[key] = len(log[key])
24 | except Exception as e:
25 | logging.debug("[AutoTensorboardScalar]" + str(e))
26 | logging.debug("[AutoTensorboardScalar] Finish visualize")
--------------------------------------------------------------------------------
/DLNest/Information/test/test_AnalyzeTask.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from DLNest.Information.AnalyzeTask import AnalyzeTask
3 | import time
4 | from pathlib import Path
5 |
6 | @pytest.mark.Information
7 | class TestAnalyzeTask:
8 | def test_init(self):
9 | AT = AnalyzeTask(
10 | recordPath = "/root/code/DLNestTest/Saves/Some Description",
11 | scriptPath = "scriptPath",
12 | checkpointID = 233
13 | )
14 | assert isinstance(AT.recordPath,Path)
15 | assert isinstance(AT.scriptPath,Path)
16 | assert str(AT.recordPath) == "/root/code/DLNestTest/Saves/Some Description"
17 | assert str(AT.scriptPath) == "scriptPath"
18 | assert AT.checkpointID == 233
19 | assert AT.type == "Analyze"
20 | assert AT.commandPipe == None
21 | assert AT.outputQueue == None
22 |
23 | def test_getDict(self):
24 | AT = AnalyzeTask(
25 | recordPath = "/root/code/DLNestTest/Saves/Some Description",
26 | scriptPath = "scriptPath",
27 | checkpointID = 233
28 | )
29 | retDict = AT.getDict()
30 | assert retDict["record_path"] == "/root/code/DLNestTest/Saves/Some Description"
31 | assert retDict["script_path"] == "scriptPath"
32 | assert retDict["checkpoint_ID"] == 233
--------------------------------------------------------------------------------
/DLNest/Operations/GetPluginConfig.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import sys
3 | import importlib
4 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase
5 |
6 | def _loadAModule(filePath : Path,name : str):
7 | spec = importlib.util.spec_from_file_location(
8 | name,
9 | filePath
10 | )
11 | module = importlib.util.module_from_spec(spec)
12 | dirName = str(filePath.parent)
13 | if not dirName in sys.path:
14 | sys.path.append(dirName)
15 | spec.loader.exec_module(module)
16 | return module
17 |
18 | def _loadAPlugin(pluginName : str):
19 | pluginPath = Path(__file__).parent.parent / "Plugins" / (pluginName + '.py')
20 | tmpPath = Path(pluginName)
21 | if tmpPath.is_absolute():
22 | pluginPath = tmpPath
23 | pluginName = tmpPath.name
24 | pluginModule = _loadAModule(filePath = pluginPath,name = pluginName)
25 | pluginClass = pluginModule.__getattribute__("DLNestPlugin")
26 | return pluginClass
27 |
28 | def getPluginConfig(pluginName, full = False):
29 | try:
30 | pluginClass : DLNestPluginBase = _loadAPlugin(pluginName)
31 | config = {}
32 | name = pluginClass.getName()
33 | if full:
34 | config[name] = pluginClass.getFullConfig()
35 | else:
36 | config[name] = pluginClass.getDefaultConfig()
37 |
38 | return name,config
39 | except Exception:
40 | return None, None
--------------------------------------------------------------------------------
/DLNest/ShellClient/Windows/SCTextArea.py:
--------------------------------------------------------------------------------
1 | from apscheduler.schedulers.background import BackgroundScheduler
2 | from typing import Callable, Iterable, List, Optional
3 |
4 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box
5 | from prompt_toolkit.buffer import Buffer
6 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
7 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl
8 | from prompt_toolkit.key_binding import KeyBindings
9 | from prompt_toolkit.document import Document
10 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit
11 | from prompt_toolkit.lexers import Lexer
12 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples
13 |
14 | class SCTextArea(TextArea):
15 | def __init__(self,lexer,wrap_lines = True):
16 | super(SCTextArea,self).__init__(
17 | lexer = lexer,
18 | read_only=True,
19 | focusable=True,
20 | scrollbar=True,
21 | wrap_lines=wrap_lines
22 | )
23 |
24 | @property
25 | def text(self) -> str:
26 | """
27 | The `Buffer` text.
28 | """
29 | return self.buffer.text
30 |
31 | @text.setter
32 | def text(self, value: str) -> None:
33 | oldPos = self.document.cursor_position
34 | if len(self.document.text) == oldPos:
35 | self.document = Document(value)#, 0)
36 | else:
37 | self.document = Document(value,oldPos)
--------------------------------------------------------------------------------
/DLNest/Information/test/test_TaskInformation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from DLNest.Information.TaskInformation import TaskInformation
3 | import time
4 |
5 | @pytest.mark.Information
6 | class TestTaskInformation:
7 | def test_init(self):
8 | Task = TaskInformation(
9 | devices=[1,2,3],
10 | memoryConsumption = 1234,
11 | multiGPU = True,
12 | DDP = True
13 | )
14 | assert Task.DDP == True
15 | assert Task.devices == [1,2,3]
16 | assert Task.multiGPU == True
17 | assert Task.memoryConsumption == 1234
18 | assert abs(Task.startTime - time.time()) < 1.0
19 | assert Task.type == "None"
20 | assert Task.status == "Pending"
21 | assert Task.process == None
22 |
23 | def test_getDict(self):
24 | Task = TaskInformation(
25 | devices=[1,2,3],
26 | memoryConsumption = 1234,
27 | multiGPU = True,
28 | DDP = True
29 | )
30 | targetDict = {
31 | "devices" : [1,2,3],
32 | "memory_consumption" : 1234,
33 | "multi_GPU" : True,
34 | "DDP" : True,
35 | "type" : "None",
36 | "status" : "Pending"
37 | }
38 | retDict = Task.getDict()
39 | for key in retDict:
40 | assert key in targetDict or key == "ID"
41 | if key in targetDict:
42 | assert targetDict[key] == retDict[key]
43 |
44 |
--------------------------------------------------------------------------------
/DLNest/Common/DatasetBase.py:
--------------------------------------------------------------------------------
1 | try:
2 | import torch
3 | except ImportError:
4 | pass
5 | from functools import wraps
6 | import logging
7 | from DLNest.Plugins.Utils.CheckPlugins import checkPlugins,checkDictOutputPlugins
8 | from DLNest.Common.RunningStatus import RunningStatus
9 |
10 | class DatasetBase:
11 | def __init__(self, args : dict = {}, plugins : list = [], status = RunningStatus()):
12 | self._args = args
13 | self._plugins = plugins
14 | self._status = status
15 |
16 | def getArgs(self):
17 | return self._args
18 |
19 | @checkPlugins
20 | def _datasetInit(self, args : dict):
21 | return self.init(args = args)
22 |
23 | def init(self,args : dict):
24 | """
25 | input:
26 | args : dict
27 | output:
28 | dict to runner
29 | train loader
30 | val loader
31 | """
32 | return {},None,None
33 |
34 | def getSampler(self, dataset):
35 | if self._status.env == "DDP":
36 | return torch.utils.data.distributed.DistributedSampler(dataset)
37 | else:
38 | return None
39 |
40 | @checkDictOutputPlugins
41 | def _getSaveDict(self):
42 | return self.getSaveDict()
43 |
44 | def getSaveDict(self):
45 | return {}
46 |
47 | @checkPlugins
48 | def _loadSaveDict(self, saveDict):
49 | return self.loadSaveDict(saveDict = saveDict)
50 |
51 | def loadSaveDict(self,saveDict):
52 | pass
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/LifeCycle.py:
--------------------------------------------------------------------------------
1 | from DLNest.Common.DatasetBase import DatasetBase
2 | from DLNest.Common.RunnerBase import RunnerBase
3 | from DLNest.Common.LifeCycleBase import LifeCycleBase
4 | from DLNest.Plugins.MailsNote import DLNestPlugin as MailsNotePlugin
5 |
6 |
7 | class LifeCycle(LifeCycleBase):
8 | def BAll(self):
9 | self.maxAcc = 0.0
10 |
11 | def needVisualize(self, epoch : int, iter : int, logdict : dict, args : dict):
12 | return True
13 |
14 | def needValidation(self, epoch : int, logdict : dict, args : dict):
15 | return True
16 |
17 | def commandLineOutput(self,epoch : int, logdict : dict, args : dict):
18 | print("Epoch #" + str(epoch) + " finished!")
19 |
20 | def needSaveModel(self, epoch : int, logdict : dict, args : dict):
21 | return True
22 |
23 | def holdThisCheckpoint(self, epoch : int, logdict : dict, args : dict):
24 | if logdict["acc"][-1] > self.maxAcc:
25 | self.maxAcc = logdict["acc"][-1]
26 | MailsNotePlugin.giveResultValues({'acc' : self.maxAcc})
27 | return True
28 | return False
29 |
30 | def getSaveDict(self):
31 | return {
32 | "max_acc" : self.maxAcc
33 | }
34 |
35 | def AOneEpoch(self):
36 | # step the optimizer after every epoch
37 | self.runner.optimizer.step()
38 |
39 | def loadSaveDict(self,saveDict):
40 | self.maxAcc = saveDict["max_acc"]
41 |
42 | def needContinueTrain(self, epoch : int, logdict : dict, args : dict):
43 | if epoch >= args["epochs"]:
44 | return False
45 | return True
--------------------------------------------------------------------------------
/DLNest/Common/RunningStatus.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | DLNestStatus = Enum("DLNestStatus", ("Training","Validating","Analyzing","Waiting"))
4 |
5 | class RunningStatus:
6 | def __init__(self):
7 | self.__epoch = 0
8 | self.__iter = 0
9 | self.__status = DLNestStatus.Waiting
10 | self.__env = "CPU"
11 | self.__rank = -1
12 | self.__worldSize = 0
13 |
14 | @property
15 | def epoch(self):
16 | return self.__epoch
17 |
18 | @property
19 | def iter(self):
20 | return self.__iter
21 |
22 | @property
23 | def status(self):
24 | return self.__status
25 |
26 | @property
27 | def env(self):
28 | return self.__env
29 |
30 | @property
31 | def rank(self):
32 | return self.__rank
33 |
34 | @property
35 | def worldSize(self):
36 | return self.__worldSize
37 |
38 | def isTraining(self):
39 | return self.status == DLNestStatus.Training
40 |
41 | def isValidating(self):
42 | return self.status == DLNestStatus.Validating
43 |
44 | def isAnalyzing(self):
45 | return self.status == DLNestStatus.Analyzing
46 |
47 | def isWaiting(self):
48 | return self.status == DLNestStatus.Waiting
49 |
50 | def startTraining(self):
51 | self.__status = DLNestStatus.Training
52 |
53 | def startValidating(self):
54 | self.__status = DLNestStatus.Validating
55 |
56 | def startAnalyzing(self):
57 | self.__status = DLNestStatus.Analyzing
58 |
59 | def startWaiting(self):
60 | self.__status = DLNestStatus.Waiting
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/Dataset/Dataset.py:
--------------------------------------------------------------------------------
1 | from DLNest.Common.DatasetBase import DatasetBase
2 |
3 | import torchvision
4 | import torch
5 | from torch.utils import data
6 | from torchvision import datasets, transforms
7 |
8 | class MNISTDataset(DatasetBase):
9 | def init(self,args : dict):
10 | transform = transforms.Compose([transforms.ToTensor(),
11 | transforms.Normalize(mean=[0.5],std=[0.5])])
12 | self.dataTrain = datasets.MNIST(root = args["dataset_config"]["data_root"],
13 | transform=transform,
14 | train = True,
15 | download = True)
16 |
17 | self.dataTest = datasets.MNIST(root = args["dataset_config"]["data_root"],
18 | transform = transform,
19 | train = False)
20 |
21 | trainSampler = self.getSampler(self.dataTrain)
22 | testSampler = self.getSampler(self.dataTest)
23 | if trainSampler is None: # DDP
24 | self.trainLoader = data.DataLoader(self.dataTrain,batch_size = args["dataset_config"]["batch_size"],shuffle = True)
25 | self.testLoader = data.DataLoader(self.dataTest,batch_size = args["dataset_config"]["batch_size"],shuffle = False)
26 | else:
27 | self.trainLoader = data.DataLoader(self.dataTrain,batch_size = args["dataset_config"]["batch_size"],sampler=trainSampler)
28 | self.testLoader = data.DataLoader(self.dataTest,batch_size = args["dataset_config"]["batch_size"],sampler=testSampler)
29 |
30 | return {},self.trainLoader,self.testLoader
31 |
--------------------------------------------------------------------------------
/DLNest/Plugins/Utils/SendMailTools.py:
--------------------------------------------------------------------------------
1 | import smtplib
2 | from email.mime.text import MIMEText
3 | from email.header import Header
4 | import logging
5 |
6 | _fromTitle = "DLNest Mail Plugin"
7 | _toTitle = "DLNest User"
8 | def sendSelfMail(username : str, password : str, message : str, subject : str, host : str, port : int = 25, fromTitle : str = _fromTitle, toTitle : str = _toTitle):
9 | message = MIMEText(message,'plain', 'utf-8')
10 | message["From"] = Header(fromTitle, 'utf-8')
11 | message["To"] = Header(toTitle, 'utf-8')
12 | message["subject"] = Header(subject,'utf-8')
13 | try:
14 | smtpObj = smtplib.SMTP(host = host, port = port)
15 | smtpObj.login(username, password)
16 | smtpObj.sendmail(username, [username], message.as_string())
17 | logging.info("[SendSelfMail] Send mail successfully to {}".format(username))
18 | except Exception as e:
19 | logging.debug("[SendSelfMail]" + str(e))
20 |
21 | def sendMail(username : str, password : str, toName : str, message : str, subject : str, host : str, port : int = 25, fromTitle : str = _fromTitle, toTitle : str = _toTitle):
22 | message = MIMEText(message,'plain', 'utf-8')
23 | message["From"] = Header(fromTitle, 'utf-8')
24 | message["To"] = Header(toTitle, 'utf-8')
25 | message["subject"] = Header(subject,'utf-8')
26 |
27 | receiver = toName if isinstance(toName, list) else [toName]
28 |
29 | try:
30 | smtpObj = smtplib.SMTP(host = host, port = port)
31 | smtpObj.login(username, password)
32 | smtpObj.sendmail(username, receiver, message.as_string())
33 | logging.info("[SendMail] Send mail successfully to {}".format(username))
34 | except Exception as e:
35 | logging.debug("[SendMail]" + str(e))
--------------------------------------------------------------------------------
/DLNest/Output/DLNestBuffer.py:
--------------------------------------------------------------------------------
1 | from DLNest.Output.OutputLayerBase import OutputLayerBase
2 |
3 | class Singleton(object):
4 | def __new__(cls, *args, **kwargs):
5 | if not hasattr(cls, '_instance'):
6 | orig = super(Singleton, cls)
7 | cls._instance = orig.__new__(cls, *args, **kwargs)
8 | return cls._instance
9 |
10 | class DLNestBuffer(OutputLayerBase,Singleton):
11 | def __init__(self):
12 | if hasattr(self,"styledText"):
13 | return
14 | OutputLayerBase.__init__(self)
15 | self.styleDict = {
16 | '' : '#efefef',
17 | 'app' : '#ff0f9f bold',
18 | 'ignored':'#ff7070',
19 | 'error' : '#ff0000',
20 | 'message' : '#efefef',
21 | }
22 | self.appName = "DLNest"
23 |
24 | def logMessage(self,message : str):
25 | self.putStyledText('app',"["+ self.appName +"] ")
26 | self.putStyledText('message',message)
27 | self.putStyledText('message','\n')
28 |
29 | def logIgnError(self,message : str):
30 | self.putStyledText('app',"["+ self.appName +"] ")
31 | self.putStyledText('ignored' , "[Ignored] ")
32 | self.putStyledText('message',message)
33 | self.putStyledText('message','\n')
34 |
35 | def logError(self,message : str):
36 | self.putStyledText('app',"["+ self.appName +"] ")
37 | self.putStyledText('error' , "[ERROR] ")
38 | self.putStyledText('message',message)
39 | self.putStyledText('message','\n')
40 |
41 | def logDebug(self,message : str):
42 | self.putStyledText('app',"["+ self.appName +"] ")
43 | self.putStyledText('ignored' , "[DEBUG] ")
44 | self.putStyledText('message',message)
45 | self.putStyledText('message','\n')
46 |
--------------------------------------------------------------------------------
/DLNest/Plugins/SimpleCMDVisualize.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase as DPB
3 |
4 | class DLNestPlugin(DPB):
5 | _NAME = "SimpleCMDVisualize"
6 | _config = {
7 | "stride" : 1,
8 | "keys" : [],
9 | "format" : {}
10 | }
11 | _defaultKeys = ["stride", "keys", "format"]
12 |
13 | def runnerInit(self,args : dict, datasetInfo : dict = None):
14 | pluginName = "SimpleCMDVisualize"
15 | self._visualStride = DPB.getArgs(self, pluginName, "stride", 1)
16 | self._visualKeys = DPB.getArgs(self, pluginName, "keys" , [])
17 | self._visualFormat = DPB.getArgs(self, pluginName, "format", {})
18 | fmtDict = {key : "\t| {}: {}" for key in self._visualKeys}
19 | fmtDict.update(self._visualFormat)
20 | self._visualFormat = fmtDict
21 |
22 | def visualize(self, epoch : int, iter : int, log : dict):
23 | try:
24 | if iter % self._visualStride == 0:
25 | infoStr = ""
26 | for key in self._visualKeys:
27 | if len(log[key]) > 0:
28 | try:
29 | infoStr = infoStr + self._visualFormat[key].format(key, log[key][-1])
30 | except Exception as e:
31 | logging.debug("[SimpleCMDVisualize]" + str(e))
32 | else:
33 | try:
34 | infoStr = infoStr + self._visualFormat[key].format(key, None)
35 | except Exception as e:
36 | logging.debug("[SimpleCMDVisualize]" + str(e))
37 | logging.info("Iter : " + str(iter) + infoStr)
38 | except Exception as e:
39 | logging.debug("[SimpleCMDVisualize]" + str(e))
--------------------------------------------------------------------------------
/DLNest/Information/TaskInformation.py:
--------------------------------------------------------------------------------
1 | import time
2 | import random
3 | try:
4 | from ..SavePackage.SavePackage import SavePackage
5 | except ImportError:
6 | from DLNest.SavePackage.SavePackage import SavePackage
7 |
8 | class TaskInformation:
9 | def __init__(
10 | self,
11 | savePackage : SavePackage,
12 | devices : list = [],
13 | memoryConsumption : int = -1,
14 | multiGPU : bool = False,
15 | DDP : bool = False,
16 | loadCkpt : bool = False,
17 | checkpointID : int = -1
18 | ):
19 | self.ID = ("%.4f" % time.time())[6:] + "_" + str(random.randint(0,9))
20 | self.args = savePackage.args
21 | self.devices = devices
22 | self.memoryConsumption = memoryConsumption
23 | self.multiGPU = multiGPU
24 | self.DDP = DDP
25 | if self.DDP:
26 | self.multiGPU = True
27 | self.startTime = time.time()
28 |
29 | self.type = "None"
30 | self.status = "Pending"
31 | self.process = None
32 |
33 | self.port = str(45535 + random.randint(0,20000)) # default port
34 | self.address = 'localhost' # default address
35 | self.savePackage = savePackage
36 | self.loadCkpt = loadCkpt
37 | self.checkpointID = checkpointID
38 |
39 | self.extraInfo = {}
40 |
41 | def getDict(self):
42 | return {
43 | "ID" : self.ID,
44 | "args" : self.args,
45 | "devices" : self.devices,
46 | "memory_consumption" : self.memoryConsumption,
47 | "multi_GPU" : self.multiGPU,
48 | "DDP" : self.DDP,
49 | "type" : self.type,
50 | "status" : self.status,
51 | "load_ckpt" : self.loadCkpt,
52 | "checkpoint_ID" : self.checkpointID
53 | }
--------------------------------------------------------------------------------
/DLNest/Information/GPUInformation.py:
--------------------------------------------------------------------------------
1 | import time
2 | import pynvml
3 | try:
4 | from DeviceInformation import DeviceInformation
5 | except ImportError:
6 | from .DeviceInformation import DeviceInformation
7 |
8 | class GPUInformation(DeviceInformation):
9 | def __init__(self,ID):
10 | super(GPUInformation,self).__init__("GPU")
11 | self.ID = ID
12 | try:
13 | # 若卡信息获取失败,当作卡失效,设置isBreak = True
14 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(self.ID)
15 | self.meminfo = pynvml.nvmlDeviceGetMemoryInfo(self.handle)
16 | self.totalMemory = self.meminfo.total / 1024 / 1024
17 | self.isBreak = False
18 | except Exception as e:
19 | # 卡失效,totalMemory设为0
20 | self.totalMemory = 0
21 | self.isBreak = True
22 |
23 | def restartGPU(self):
24 | """
25 | 重新尝试得到显卡信息,若成功,则返回True,同时修改isBreak,若失败则返回False
26 | """
27 | try:
28 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(self.ID)
29 | self.meminfo = pynvml.nvmlDeviceGetMemoryInfo(self.handle)
30 | self.totalMemory = self.meminfo.total / 1024 / 1024
31 | self.isBreak = False
32 | return True
33 | except Exception:
34 | self.isBreak = True
35 | return False
36 |
37 | def getFreeMemory(self):
38 | """
39 | return in MB
40 | """
41 | if self.isBreak:
42 | self.restartGPU()
43 | if self.isBreak:
44 | return 0
45 | try:
46 | self.meminfo = pynvml.nvmlDeviceGetMemoryInfo(self.handle)
47 | return self.meminfo.free / 1024 / 1024
48 | except Exception as e:
49 | self.isBreak = True
50 | return 0
51 |
52 | def getDeviceStr(self):
53 | return "cuda:" + str(self.ID)
--------------------------------------------------------------------------------
/DLNest/FactoryFiles/FactoryMNIST/Model/Runner.py:
--------------------------------------------------------------------------------
1 | from DLNest.Common.RunnerBaseTorch import RunnerBaseTorch
2 |
3 | import torch
4 | import torch.nn as nn
5 | from MNISTCNN import MNISTModel
6 |
7 | class Runner(RunnerBaseTorch):
8 | def init(self,args : dict,datasetInfo : dict = None):
9 | self.model = MNISTModel(args) # if BN layers need to be sync, use the following code
10 | # self.model = self.register(model, syncBN=True)
11 | self.cost = nn.CrossEntropyLoss()
12 |
13 | def initLog(self):
14 | return {
15 | "loss" : [],
16 | "acc" : [],
17 | }
18 |
19 | def initOptimizer(self):
20 | self.optimizer = torch.optim.Adam(self.model.parameters())
21 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 3)
22 |
23 | def runOneStep(self,data, log : dict, iter : int, epoch : int):
24 | self.model.zero_grad()
25 | x,y = data
26 | pred = self.model(x)
27 | loss = self.cost(pred,y)
28 | loss.backward()
29 | self.optimizer.step()
30 |
31 | loss = self._reduceMean(loss)
32 | log["loss"].append(loss.detach().item())
33 |
34 | def visualize(self,epoch : int, iter : int, log : dict):
35 | pass
36 |
37 | def validationInit(self):
38 | self.totalCorrect = 0
39 | self.total = 0
40 |
41 | def validateABatch(self,data, iter : int):
42 | x,y = data
43 | with torch.no_grad():
44 | output = self.model(x)
45 | _,pred = torch.max(output, 1)
46 | correct = (pred == y).sum() / y.shape[0]
47 | correct = self._reduceMean(correct)
48 | self.totalCorrect += correct
49 | self.total += 1
50 |
51 | def validationAnalyze(self, log : dict):
52 | acc = self.totalCorrect / self.total
53 | log["acc"].append(acc.item())
--------------------------------------------------------------------------------
/DLNest/Operations/New.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | from pathlib import Path
3 | import json
4 | from DLNest.Operations.UsePlugin import usePlugin
5 |
6 | def new(targetDir : str, MNIST : bool = False, pluginsName = []):
7 | projectPath = Path(targetDir).absolute()
8 | # If target path already exists, return
9 | if projectPath.exists():
10 | print("Path already exists")
11 | return
12 |
13 | # Get the factory file
14 | factoryPath = Path(__file__).parent.parent / "FactoryFiles"
15 | if MNIST:
16 | factoryPath = factoryPath / "FactoryMNIST"
17 | else:
18 | factoryPath = factoryPath / "FactoryClean"
19 |
20 | # Copy to target
21 | shutil.copytree(factoryPath,projectPath)
22 |
23 | # Modify save_root and root_file_path in root_config.json
24 | rootConfigPath = projectPath / "root_config.json"
25 | rootConfig = {}
26 | with rootConfigPath.open("r") as fp:
27 | rootConfig = json.load(fp)
28 | rootConfig["save_root"] = str(projectPath / "Saves")
29 | rootConfig["root_file_path"] = str(projectPath)
30 | with rootConfigPath.open("w") as fp:
31 | json.dump(rootConfig,fp,indent = 4,separators = (',',':'))
32 |
33 | # If MNIST, modify the data_root in dataset_config.json
34 | if MNIST:
35 | datasetConfigPath = projectPath / "dataset_config.json"
36 | datasetConfig = {}
37 | with datasetConfigPath.open("r") as fp:
38 | datasetConfig = json.load(fp)
39 | datasetConfig["dataset_config"]["data_root"] = str(projectPath / "MNIST")
40 | with datasetConfigPath.open("w") as fp:
41 | json.dump(datasetConfig,fp,indent = 4, separators = (',',':'))
42 |
43 |
44 | # Add plugins
45 | for pluginName in pluginsName:
46 | usePlugin(targetDir, pluginName = pluginName, full = False)
47 |
48 | print("Create a project in " + targetDir + ".")
49 |
--------------------------------------------------------------------------------
/DLNest/Information/AnalyzeTask.py:
--------------------------------------------------------------------------------
1 | try:
2 | from TaskInformation import TaskInformation
3 | from ..SavePackage.SavePackage import SavePackage
4 | from ..Output.AnalyzerBuffer import AnalyzerBuffer
5 | except ImportError:
6 | from DLNest.Information.TaskInformation import TaskInformation
7 | from DLNest.SavePackage.SavePackage import SavePackage
8 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer
9 | from pathlib import Path
10 | from multiprocessing import Queue
11 |
12 | class AnalyzeTask(TaskInformation):
13 | def __init__(
14 | self,
15 | recordPath : str,
16 | scriptPath : str = "",
17 | checkpointID : int = -1,
18 | devices : list = [],
19 | memoryConsumption : int = -1,
20 | DDP : bool = False
21 | ):
22 | savePackage = SavePackage()
23 | savePackage.initFromAnExistSavePackage(recordPath)
24 | super(AnalyzeTask,self).__init__(
25 | savePackage = savePackage,
26 | devices = devices,
27 | memoryConsumption = memoryConsumption,
28 | multiGPU = False,
29 | DDP = DDP,
30 | loadCkpt = True if checkpointID != -2 else False,
31 | checkpointID = checkpointID
32 | )
33 | self.ID = "A_" + self.ID
34 | self.recordPath = Path(recordPath)
35 | if scriptPath == "":
36 | self.scriptPath = self.recordPath.parent.parent / "AnalyzeScripts"
37 | else:
38 | self.scriptPath = Path(scriptPath)
39 |
40 | self.type = "Analyze"
41 | self.commandQueue = Queue()
42 | self.outputBuffer : AnalyzerBuffer = None
43 |
44 | def getDict(self):
45 | ret = super().getDict()
46 | ret["record_path"] = str(self.recordPath)
47 | ret["script_path"] = str(self.scriptPath)
48 | ret["checkpoint_ID"] = self.checkpointID
49 | return ret
50 |
--------------------------------------------------------------------------------
/DLNest/Operations/UsePlugin.py:
--------------------------------------------------------------------------------
1 | from DLNest.Operations.GetPluginConfig import getPluginConfig
2 | from pathlib import Path
3 | import json
4 |
5 | def usePlugin(targetDir : str, pluginName : str, full : bool = False):
6 | """
7 | Only return False when targetDir is wrong
8 | """
9 | name, config = getPluginConfig(pluginName, full = full)
10 | if not config:
11 | print("Wrong plugin name: {}".format(pluginName))
12 | return True
13 |
14 | root_dir = Path(targetDir).absolute()
15 |
16 | if not root_dir.exists():
17 | print("Wrong target dir: {}".format(targetDir))
18 | return False
19 |
20 | plugins_config_path = root_dir / "plugins_config.json"
21 | if plugins_config_path.exists():
22 | with plugins_config_path.open("r") as f:
23 | pluginsConfig = json.load(f)
24 | if name in pluginsConfig["plugins"]:
25 | return True
26 |
27 | pluginsConfig["plugins"].append(name)
28 | pluginsConfig["plugins_config"].update(config)
29 | with plugins_config_path.open("w") as f:
30 | json.dump(pluginsConfig, f, indent = 4,separators = (',',':'))
31 | else:
32 | root_config_path = root_dir / "root_config.json"
33 |
34 | if not root_config_path.exists():
35 | print("Wrong target dir: {}".format(targetDir))
36 | return False
37 |
38 | with root_config_path.open("r") as f:
39 | allConfig = json.load(f)
40 | if name in allConfig["plugins"]:
41 | return True
42 | allConfig["plugins"].append(name)
43 | allConfig["plugins_config"].update(config)
44 | with root_config_path.open("w") as f:
45 | json.dump(allConfig, f, indent = 4,separators = (',',':'))
46 |
47 | return True
48 |
49 | def usePlugins(targetDir, pluginsName : list, full : bool = False):
50 | for name in pluginsName:
51 | if not usePlugin(targetDir, name, full):
52 | return
--------------------------------------------------------------------------------
/DLNest/ShellClient/Windows/CommandInput.py:
--------------------------------------------------------------------------------
1 | from apscheduler.schedulers.background import BackgroundScheduler
2 | from typing import Callable, Iterable, List, Optional
3 |
4 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box
5 | from prompt_toolkit.buffer import Buffer
6 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
7 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl
8 | from prompt_toolkit.key_binding import KeyBindings
9 | from prompt_toolkit.document import Document
10 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit
11 | from prompt_toolkit.lexers import Lexer
12 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples
13 |
14 | from DLNest.ShellClient.Windows.Utils.Completers import getCommandCompleter
15 |
16 | class CommandInput:
17 | def __init__(self,title: str = "DLNest Command Line",onAccept : Callable[[str],None] = None):
18 | self.kb = KeyBindings()
19 | self.onAccept = onAccept
20 |
21 | def accept_text(buf : Buffer):
22 | if not self.onAccept is None:
23 | self.onAccept(buf.text)
24 | return False
25 |
26 | self.completer = getCommandCompleter()
27 | self.title = title
28 |
29 | self.text = TextArea(
30 | height=3,
31 | auto_suggest=AutoSuggestFromHistory(),
32 | completer=self.completer,
33 | prompt=[("#fff5ee","DLNest>>")],
34 | accept_handler=accept_text,
35 | scrollbar=True,
36 | multiline=False
37 | )
38 | self.text.height=3
39 | self.infoBar = Label(
40 | [("bg:#006060","Press Enter to enter a command. Press ctrl + c to exit.")]
41 | )
42 |
43 | self.content = HSplit([
44 | self.text,
45 | self.infoBar
46 | ])
47 |
48 | def getWindow(self):
49 | return Frame(
50 | self.content,
51 | title = self.title,
52 | style="class:command_frame"
53 | )
--------------------------------------------------------------------------------
/DLNest/Information/DeviceInformation.py:
--------------------------------------------------------------------------------
1 | import time
2 | try:
3 | from TaskInformation import TaskInformation
4 | except ImportError:
5 | from DLNest.Information.TaskInformation import TaskInformation
6 |
7 | class DeviceInformation:
8 | def __init__(self,type : str = ""):
9 | self.ID = -1
10 | self.type = type
11 | self.nowTask = 0
12 | self.totalMemory = 0
13 | self.runningTask = []
14 | self.isBreak = False
15 |
16 | def getFreeMemory(self):
17 | return float('inf')
18 |
19 | def checkTasks(self):
20 | """
21 | Check the tasks in this device, if one has no subprocess or the subprocess is dead, delete the task from the device
22 | """
23 | newList = []
24 | for item in self.runningTask:
25 | if item.process is None or not item.process.is_alive():
26 | self.nowTask -= 1
27 | else:
28 | newList.append(item)
29 | self.runningTask = newList
30 |
31 | def addATask(self,newTask : TaskInformation):
32 | """
33 | Add a task information to this GPU.
34 | """
35 | self.nowTask += 1
36 | self.runningTask.append(newTask)
37 |
38 | def lastUseTime(self):
39 | """
40 | Get the start time of the last task, if no running task, return 0.
41 | """
42 | self.checkTasks()
43 | if len(self.runningTask) > 0:
44 | return self.runningTask[-1].startTime
45 | else:
46 | return 0
47 |
48 | def getTaskNum(self):
49 | self.checkTasks()
50 | return self.nowTask
51 |
52 | def getDict(self):
53 | self.checkTasks()
54 | return {
55 | "ID" : self.ID,
56 | "type" : self.type,
57 | "is_break" : self.isBreak,
58 | "num_tasks" : self.nowTask,
59 | "running_tasks" : [item.getDict() for item in self.runningTask],
60 | "total_memory" : self.totalMemory,
61 | "free_memory" : self.getFreeMemory(),
62 | "last_use_time" : self.lastUseTime()
63 | }
64 |
65 | def getDeviceStr(self):
66 | return ""
--------------------------------------------------------------------------------
/DLNest/Run.py:
--------------------------------------------------------------------------------
1 | from DLNest.Operations.RunIndependent import runIndependent
2 |
3 | import os
4 | import pynvml
5 | import argparse
6 |
7 | class Arguments:
8 | def __init__(self,desc : str = ""):
9 | self._parser = argparse.ArgumentParser(description=desc)
10 |
11 | def parser(self):
12 | return self._parser
13 |
14 | class TaskArguments(Arguments):
15 | def __init__(self):
16 | super(TaskArguments, self).__init__(desc="Arguments for DLNest task.")
17 |
18 | self._parser.add_argument("-c",type=str, help="root configuration json file for this task.",required=True)
19 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)")
20 | self._parser.add_argument("-f",type=str, default = "", help="frequently changing configuration json file for this task.(default:None)")
21 | self._parser.add_argument("-devices",default = [],nargs='+', type=int)
22 | self._parser.add_argument("-ns",action='store_true', help="Set to save to the NOSAVE dir.")
23 | self._parser.add_argument("-ss",action='store_false', help="Set to also show on screen.")
24 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.")
25 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.")
26 | self._parser.add_argument("-sd",action='store_true',help="Set to use description as the save dir name.(coverd by ns)")
27 |
28 | def runTrain(args):
29 | assert not(args.DDP and args.CPU)
30 | if args.CPU:
31 | devices = [-1]
32 | else:
33 | devices = args.devices
34 | if devices == []:
35 | pynvml.nvmlInit()
36 | count = pynvml.nvmlDeviceGetCount()
37 | devices = [i for i in range(count)]
38 | runIndependent(
39 | configPath = args.c,
40 | freqPath = args.f,
41 | devices = devices,
42 | description = args.d,
43 | DDP = args.DDP,
44 | noSave = args.ns,
45 | useDescriptionToSave = args.sd,
46 | showOnScreen = args.ss
47 | )
48 |
49 | if __name__ == "__main__":
50 | import sys
51 | if sys.path[0] != '':
52 | sys.path[0] = ''
53 | argparser = TaskArguments()
54 | parser = argparser.parser()
55 | args = parser.parse_args()
56 | runTrain(args)
57 |
--------------------------------------------------------------------------------
/DLNest/Analyze.py:
--------------------------------------------------------------------------------
1 | from DLNest.Operations.AnalyzeIndependent import AnalyzeIndependent
2 |
3 | import os
4 | import pynvml
5 | import argparse
6 | from pathlib import Path
7 | import importlib
8 |
9 | class Arguments:
10 | def __init__(self,desc : str = ""):
11 | self._parser = argparse.ArgumentParser(description=desc)
12 |
13 | def parser(self):
14 | return self._parser
15 |
16 | class AnalyzeArguments(Arguments):
17 | def __init__(self):
18 | super(AnalyzeArguments, self).__init__(desc = "Arguments for DLNest analyze.")
19 |
20 | self._parser.add_argument("-r",type=str, help = "record path", required=True)
21 | self._parser.add_argument("-s",type=str, help = "script path", required=True)
22 | self._parser.add_argument("-c",type=int, default=-1,help = "checkpoint ID")
23 | self._parser.add_argument("-devices",default = [],nargs='+', type=int)
24 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.")
25 |
26 | def _loadAScript(filePath : Path,name : str):
27 | # load a script by name
28 | spec = importlib.util.spec_from_file_location(
29 | name,
30 | filePath
31 | )
32 | module = importlib.util.module_from_spec(spec)
33 | spec.loader.exec_module(module)
34 | return module
35 |
36 | def _analyze(args):
37 | if args.CPU:
38 | devices = [-1]
39 | else:
40 | devices = args.devices
41 | if devices == []:
42 | devices = [0]
43 |
44 | scriptModule = _loadAScript(Path(args.s),"AScript")
45 | expFunc = scriptModule.__getattribute__("experience")
46 |
47 | AnalyzeIndependent(
48 | recordPath = args.r,
49 | expFunc = expFunc,
50 | scriptPath = "",
51 | checkpointID = args.c,
52 | devices = devices
53 | )
54 |
55 | def analyze(
56 | recordPath : str,
57 | checkpointID : int = -1,
58 | scriptPath : str = "",
59 | devices = [0],
60 | expFunc = None
61 | ):
62 | AnalyzeIndependent(
63 | recordPath = recordPath,
64 | expFunc = expFunc,
65 | scriptPath = scriptPath,
66 | devices = devices,
67 | checkpointID = checkpointID
68 | )
69 |
70 | if __name__ == "__main__":
71 | import sys
72 | if sys.path[0] != '':
73 | sys.path[0] = ''
74 | argparser = AnalyzeArguments()
75 | parser = argparser.parser()
76 | args = parser.parse_args()
77 | _analyze(args)
--------------------------------------------------------------------------------
/DLNest/Information/test/test_DeviceInformation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from DLNest.Information.DeviceInformation import DeviceInformation
3 | from DLNest.Information.TaskInformation import TaskInformation
4 | import time
5 | class FakeSubprocess:
6 | def __init__(self,alive):
7 | self.alive = alive
8 |
9 | def is_alive(self):
10 | return self.alive
11 |
12 | @pytest.mark.Information
13 | class TestDeviceInformation:
14 | def test_init(self):
15 | self.DI = DeviceInformation("test")
16 | assert self.DI.type == "test"
17 | assert self.DI.nowTask == 0
18 | assert self.DI.totalMemory == 0
19 | assert self.DI.runningTask == []
20 | assert self.DI.isBreak == False
21 |
22 | def test_getFreeMemory(self):
23 | self.DI = DeviceInformation("test")
24 | assert self.DI.getFreeMemory() > 10000000000000000000000000
25 |
26 | def initTasks(self):
27 | DI = DeviceInformation("test")
28 | fakeAliveSubprocess = FakeSubprocess(True)
29 | fakeDeadSubprocess = FakeSubprocess(False)
30 | aliveTask = TaskInformation()
31 | aliveTask.process = fakeAliveSubprocess
32 | deadTask = TaskInformation()
33 | deadTask.process = fakeDeadSubprocess
34 | DI.addATask(aliveTask)
35 | DI.addATask(deadTask)
36 | assert DI.nowTask == 2
37 | assert len(DI.runningTask) == 2
38 | assert DI.runningTask[0] == aliveTask
39 | assert DI.runningTask[1] == deadTask
40 | return DI,aliveTask
41 |
42 | def test_AddATask(self):
43 | DI,_ = self.initTasks()
44 |
45 | def test_lastUseTime(self):
46 | t = time.time()
47 | DI,_ = self.initTasks()
48 | lt = DI.lastUseTime()
49 | assert lt != 0
50 | assert abs(lt - t) < 0.1
51 |
52 | def test_checkTasks(self):
53 | DI,aliveTask = self.initTasks()
54 | DI.checkTasks()
55 | assert DI.nowTask == 1
56 | assert len(DI.runningTask) == 1
57 | assert DI.runningTask[-1] == aliveTask
58 |
59 | def test_getTaskNum(self):
60 | DI,_ = self.initTasks()
61 | assert DI.getTaskNum() == 1
62 |
63 | def test_getDict(self):
64 | DI,aliveTask = self.initTasks()
65 | retDict = DI.getDict()
66 | assert retDict["type"] == "test"
67 | assert retDict["is_break"] == False
68 | assert retDict["num_tasks"] == 1
69 | for item in retDict["running_tasks"]:
70 | for key in item:
71 | assert item[key] == aliveTask.getDict()[key]
72 | assert isinstance(retDict["total_memory"],float) or isinstance(retDict["total_memory"],int)
73 | assert isinstance(retDict["free_memory"],float) or isinstance(retDict["total_memory"],int)
74 | t = time.time()
75 | assert abs(t - retDict["last_use_time"]) < 0.1
--------------------------------------------------------------------------------
/DLNest/Output/AnalyzerBuffer.py:
--------------------------------------------------------------------------------
1 | from DLNest.Output.OutputLayerBase import OutputLayerBase
2 |
3 | from multiprocessing import Queue
4 |
5 | class AnalyzerBuffer(OutputLayerBase):
6 | def __init__(self):
7 | super(AnalyzerBuffer,self).__init__()
8 | self.styleDict = {
9 | '' : '#efefef',
10 | 'app' : '#2afa38 bold',
11 | 'ignored':'#ff7070',
12 | 'error' : '#ff0000',
13 | 'message' : '#efefef',
14 | }
15 | self.appName = "DLNest Analyzer"
16 | self.outputQueue = Queue()
17 | self.isSend = False
18 |
19 | def logMessage(self,message : str):
20 | self.putStyledText('app',"[" + self.appName + "] ")
21 | self.putStyledText('message',message)
22 | self.putStyledText('message','\n')
23 | self.sendData()
24 |
25 | def logIgnError(self,message : str):
26 | self.putStyledText('app',"[" + self.appName + "] ")
27 | self.putStyledText('ignored' , "[Ignored] ")
28 | self.putStyledText('message',message)
29 | self.putStyledText('message','\n')
30 | self.sendData()
31 |
32 | def logError(self,message : str):
33 | self.putStyledText('app',"[" + self.appName + "] ")
34 | self.putStyledText('error' , "[ERROR] ")
35 | self.putStyledText('message',message)
36 | self.putStyledText('message','\n')
37 | self.sendData()
38 |
39 | def logDebug(self,message : str):
40 | self.putStyledText('app',"[" + self.appName + "] ")
41 | self.putStyledText('ignored' , "[DEBUG] ")
42 | self.putStyledText('message',message)
43 | self.putStyledText('message','\n')
44 | self.sendData()
45 |
46 | def write(self,message : str):
47 | super().write(message)
48 | self.sendData()
49 |
50 | def sendData(self):
51 | # Temproral implementation. Trans all styled text to others.
52 | if not self.isSend:
53 | return
54 | if self.lock.acquire():
55 | try:
56 | while not self.outputQueue.empty():
57 | try:
58 | self.outputQueue.get(block=False)
59 | except Exception as e:
60 | pass
61 |
62 | self.outputQueue.put(self.styledText,block=False)
63 | finally:
64 | self.lock.release()
65 |
66 | def receiveData(self):
67 | if self.isSend:
68 | return
69 | if self.lock.acquire():
70 | try:
71 | styledText = self.outputQueue.get(block = False)
72 | self.styledText = styledText
73 | except Exception as e:
74 | pass
75 | finally:
76 | self.lock.release()
77 |
78 | def getPlainText(self,from_token : int = -1,length : int = -1):
79 | self.receiveData()
80 | return super().getPlainText(from_token = from_token, length = length)
81 |
82 | def getStyledText(self,from_token : int = -1,length : int = -1):
83 | self.receiveData()
84 | return super().getStyledText(from_token = from_token, length = length)
--------------------------------------------------------------------------------
/DLNest/Scheduler/Scheduler.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.TaskInformation import TaskInformation
2 | from DLNest.Executor.AnalyzeProcess import AnalyzeProcess
3 | from DLNest.Executor.TrainProcess import TrainProcess
4 | from DLNest.Information.InfoCenter import InfoCenter
5 | from DLNest.Scheduler.SchedulerStrategyBase import SchedulerStrategyBase
6 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer
7 | from DLNest.Scheduler.DefaultStrategy import DefaultStrategy
8 |
9 | from apscheduler.schedulers.background import BackgroundScheduler
10 |
11 | class Singleton(object):
12 | def __new__(cls, *args, **kwargs):
13 | if not hasattr(cls, '_instance'):
14 | orig = super(Singleton, cls)
15 | cls._instance = orig.__new__(cls, *args, **kwargs)
16 | return cls._instance
17 |
18 |
19 | class Scheduler(Singleton):
20 | def __init__(
21 | self,
22 | strategy : SchedulerStrategyBase = None,
23 | timeDelay : int = 60,
24 | maxTaskPerDevice : int = 10000,
25 | ):
26 | self.infoCenter = InfoCenter()
27 | self.timeDelay = timeDelay
28 | self.maxTaskPerDevice = maxTaskPerDevice
29 |
30 | if hasattr(self,"strategy"):
31 | if strategy != None:
32 | # change strategy
33 | self.strategy = strategy
34 | self.strategy.scheduler = self
35 | else:
36 | if strategy is None:
37 | # set default strategy
38 | strategy = DefaultStrategy()
39 | # set strategy
40 | self.strategy = strategy
41 | self.strategy.scheduler = self
42 |
43 | if not hasattr(self,"routineScheduler"):
44 | self.startRoutineTask()
45 |
46 | def giveATask(self,task : TaskInformation, otherArgs : dict):
47 | print("Received task " + task.ID + ".")
48 |
49 | task.extraInfo = otherArgs
50 | self.infoCenter.addATask(task)
51 | self.strategy.decide(self.maxTaskPerDevice,self.timeDelay)
52 |
53 | def runTask(self,task : TaskInformation,devices : [int]):
54 | task.devices = devices
55 | assert task.type == "Train" or task.type == "Analyze"
56 | if task.type == "Train":
57 | TProcess = TrainProcess(task)
58 | task.process = TProcess
59 | TProcess.start()
60 | elif task.type == "Analyze":
61 | ABuffer = AnalyzerBuffer()
62 | task.outputBuffer = ABuffer
63 | AProcess = AnalyzeProcess(task,outputBuffer = ABuffer)
64 | task.process = AProcess
65 | AProcess.start()
66 |
67 | print(task.type + " task " + task.ID + " is runing now.")
68 |
69 | self.infoCenter.runATask(task)
70 |
71 | def __routineRun(self):
72 | self.strategy.decide(self.maxTaskPerDevice,self.timeDelay)
73 |
74 | def changeTimeDelay(self,delay):
75 | self.timeDelay = delay
76 | self.routineScheduler.remove_job(self.routineJob.id)
77 | self.routineJob = self.routineScheduler.add_job(self.__routineRun,"interval",seconds = self.timeDelay)
78 |
79 | def changeMaxTaskPerDevice(self,newValue : int):
80 | self.maxTaskPerDevice = newValue
81 |
82 | def startRoutineTask(self):
83 | self.routineScheduler = BackgroundScheduler()
84 | self.routineJob = self.routineScheduler.add_job(self.__routineRun,"interval",seconds = self.timeDelay)
85 | self.routineScheduler.start()
--------------------------------------------------------------------------------
/DLNest/Executor/AnalyzeProcess.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.AnalyzeTask import AnalyzeTask
2 | from DLNest.Executor.TaskProcess import TaskProcess
3 | from DLNest.Output.AnalyzerBuffer import AnalyzerBuffer
4 | from pathlib import Path
5 | import sys
6 | import importlib
7 | import os
8 |
9 | import time
10 | from multiprocessing import Pipe
11 |
12 | class AnalyzeWrapper:
13 | def __init__(self,args : dict, runner , dataset, log : dict):
14 | self.args = args
15 | self.runner = runner
16 | self.model = runner # 后向兼容
17 | self.dataset = dataset
18 | self.log = log
19 |
20 | class AnalyzeProcess(TaskProcess):
21 | def __init__(self,task : AnalyzeTask, outputBuffer : AnalyzerBuffer = None, expFunc = None):
22 | super(AnalyzeProcess,self).__init__(task)
23 | self.commandQueue = task.commandQueue
24 | self.expFunc = expFunc
25 | self.output = outputBuffer
26 |
27 | def initOutput(self,rank = -1):
28 | assert rank == -1
29 | os.chdir(self.task.args["root_file_path"]) # Change CWD to the save package
30 | sys.path.append(str(self.task.scriptPath))
31 | if self.output != None:
32 | self._debugf = sys.stdout
33 | sys.stdout = self.output
34 | sys.stderr = self.output
35 | self.output.appName = "DLNest Analyze Process"
36 | self.output.isSend = True
37 |
38 | def runExp(self):
39 | self.expFunc(self.analyzeWrapper)
40 |
41 | def mainLoop(self):
42 | self.analyzeWrapper = AnalyzeWrapper(self.task.args,self.runner,self.dataset,self.logDict)
43 | self.status.startAnalyzing()
44 | if self.expFunc != None:
45 | # Have a setted exp to run
46 | self.runExp()
47 | return
48 |
49 | print("Waiting for command...")
50 | while True:
51 | try:
52 | command = self.commandQueue.get(block=True)
53 | self.startTest(command)
54 | except Exception as e:
55 | if self.output != None:
56 | self.output.logIgnError(str(e))
57 | else:
58 | print(e)
59 |
60 | def __loadAScript(self,filePath : Path,name : str):
61 | # load a script by name
62 | spec = importlib.util.spec_from_file_location(
63 | name,
64 | filePath
65 | )
66 | module = importlib.util.module_from_spec(spec)
67 | spec.loader.exec_module(module)
68 | return module
69 |
70 | def startTest(self,command : str):
71 | scriptPath = self.task.scriptPath / (command + ".py")
72 | try:
73 | scriptModule = self.__loadAScript(scriptPath,"AScript")
74 | # 找到其中的experience函数
75 | self.expFunc = scriptModule.__getattribute__("experience")
76 | if self.output != None:
77 | self.output.appName = command
78 | self.runExp()
79 | except Exception as e:
80 | if self.output != None:
81 | self.output.logIgnError(str(e))
82 | else:
83 | print(e)
84 | return
85 | finally:
86 | if self.output != None:
87 | self.output.appName = "DLNest Analyze Process"
88 |
89 | def run(self):
90 | try:
91 | super().run()
92 | except Exception as e:
93 | import traceback
94 | with open("./.analyzeException.tmp.txt","w") as f:
95 | f.write(traceback.format_exc())
--------------------------------------------------------------------------------
/DLNest/Common/RunnerBase.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | import logging
3 | from DLNest.Plugins.Utils.CheckPlugins import checkPlugins,checkDictOutputPlugins
4 | from DLNest.Common.RunningStatus import RunningStatus
5 |
6 |
7 | class RunnerBase:
8 | def __init__(self,args : dict, plugins : list = [], status = RunningStatus()):
9 | self._plugins = plugins
10 | self._args = args
11 | self._status = status
12 |
13 | # for backward compatibility
14 | @property
15 | def _rank(self):
16 | if not "_warned_rank" in dir(self):
17 | self._warned_rank = True
18 | print("Runner._rank is deprecated, please use Runner._status.rank")
19 | return self._status.rank
20 |
21 | # for backward compatibility
22 | @property
23 | def _envType(self):
24 | if not "_warned_env" in dir(self):
25 | self._warned_env = True
26 | print("Runner._envType is deprecated, please use Runner._status.env")
27 | return self._status.env
28 |
29 | def getArgs(self):
30 | return self._args
31 |
32 | @checkPlugins
33 | def _runnerInit(self,args : dict, datasetInfo : dict = None):
34 | return self.init(args, datasetInfo)
35 |
36 | def init(self,args : dict, datasetInfo : dict = None):
37 | self.args = args
38 | pass
39 |
40 | @checkPlugins
41 | def _initOptimizer(self):
42 | return self.initOptimizer()
43 |
44 | def initOptimizer(self):
45 | pass
46 |
47 | def DDPOperation(self,rank : int):
48 | pass
49 |
50 | def afterDDP(self, rank : int):
51 | """
52 | For SyncBN or something else.
53 | """
54 | pass
55 |
56 | @checkDictOutputPlugins
57 | def _initLog(self):
58 | return self.initLog()
59 |
60 | def initLog(self):
61 | return {}
62 |
63 | @checkDictOutputPlugins
64 | def _getSaveDict(self):
65 | return self.getSaveDict()
66 |
67 | def getSaveDict(self):
68 | return {}
69 |
70 | @checkPlugins
71 | def _loadSaveDict(self,saveDict):
72 | return self.loadSaveDict(saveDict = saveDict)
73 |
74 | def loadSaveDict(self,saveDict):
75 | pass
76 |
77 | @checkPlugins
78 | def _runOneStep(self,data,log : dict, iter : int, epoch : int):
79 | return self.runOneStep(data = data, log = log, iter = iter, epoch = epoch)
80 |
81 | def runOneStep(self,data,log : dict, iter : int, epoch : int):
82 | pass
83 |
84 | @checkPlugins
85 | def _visualize(self,log : dict, iter : int, epoch : int):
86 | return self.visualize(log = log, iter = iter, epoch = epoch)
87 |
88 | def visualize(self,log : dict, iter : int, epoch : int):
89 | pass
90 |
91 | @checkPlugins
92 | def _validate(self, loader, log):
93 | return self.validate(loader, log)
94 |
95 | @checkPlugins
96 | def _validationInit(self):
97 | return self.validationInit()
98 |
99 | def validationInit(self):
100 | pass
101 |
102 | @checkPlugins
103 | def _validateABatch(self,data, iter : int):
104 | return self.validateABatch(data = data, iter = iter)
105 |
106 | def validateABatch(self,data, iter : int):
107 | pass
108 |
109 | @checkPlugins
110 | def _validationAnalyze(self, log : dict):
111 | return self.validationAnalyze(log = log)
112 |
113 | def validationAnalyze(self, log : dict):
114 | pass
115 |
--------------------------------------------------------------------------------
/DLNest/ShellClient/Windows/DevicesInfoShower.py:
--------------------------------------------------------------------------------
1 | from apscheduler.schedulers.background import BackgroundScheduler
2 | from typing import Callable, Iterable, List, Optional
3 |
4 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box
5 | from prompt_toolkit.buffer import Buffer
6 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
7 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl
8 | from prompt_toolkit.key_binding import KeyBindings
9 | from prompt_toolkit.document import Document
10 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit
11 | from prompt_toolkit.lexers import Lexer
12 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples
13 |
14 | from DLNest.ShellClient.Windows.SCTextArea import SCTextArea
15 |
16 | class DevicesLexer(Lexer):
17 | def __init__(self):
18 | self.devicesInfo = []
19 | self.IDLength = 5
20 | self.freeMemoryLength = 10
21 | self.runningTasksLength = 3
22 |
23 | def get_text(self):
24 | return "".join(["\n" for _ in self.devicesInfo])
25 |
26 | def lex_document(self,document : Document) -> Callable[[int], StyleAndTextTuples]:
27 | def get_line(lineno : int) -> StyleAndTextTuples:
28 | try:
29 | device = self.devicesInfo[lineno]
30 | ID = "GPU " + str(device["ID"]) if device["ID"] != -1 else "CPU "
31 | isBreak = "Break" if device["is_break"] else "Valid"
32 | break_class = "break" if device["is_break"] else "valid"
33 | freeMemory = str(int(device["free_memory"])) + " MB"
34 | runningTasks = str(len(device["running_tasks"]))
35 |
36 | if len(ID) < self.IDLength:
37 | ID = " " * (self.IDLength - len(ID)) + ID
38 |
39 | if len(freeMemory) < self.freeMemoryLength:
40 | freeMemory = " " * (self.freeMemoryLength - len(freeMemory)) + freeMemory
41 |
42 | if len(runningTasks) < self.runningTasksLength:
43 | runningTasks = " " * (self.runningTasksLength - len(runningTasks)) + runningTasks
44 |
45 | return [
46 | ("class:devices_id" ,"Devices: " + ID + " "),
47 | ("class:devices_status_" + break_class," " + isBreak + " "),
48 | ("class:devices_free_memory" ," F-Memory: " + freeMemory + " "),
49 | ("class:devices_tasks" ," #Tasks: " + runningTasks + " ")
50 | ]
51 | except Exception as e:
52 | return []
53 |
54 | return get_line
55 |
56 | class DevicesInfoShower:
57 | def __init__(
58 | self,
59 | title : str = "Tasks",
60 | routineTask = None,
61 | freq : int = 1,
62 | style : str = "class:devices_info_shower"
63 | ):
64 | self.title = title
65 | self.routineTask = routineTask
66 | self.scheduler = BackgroundScheduler()
67 | if not routineTask is None:
68 | self.scheduler.add_job(self.routineTask,'interval',seconds=freq,args=[self])
69 | self.scheduler.start()
70 | self.lexer = DevicesLexer()
71 |
72 | self.shower = SCTextArea(
73 | lexer = self.lexer,
74 | wrap_lines=False
75 | )
76 |
77 | self.style = style
78 | self.window = Frame(
79 | self.shower,
80 | self.title,
81 | style = self.style,
82 | height=10,
83 | width=60
84 | )
85 |
86 | def getWindow(self):
87 | return self.window
--------------------------------------------------------------------------------
/DLNest/ShellClient/Windows/TaskInfoShower.py:
--------------------------------------------------------------------------------
1 | from apscheduler.schedulers.background import BackgroundScheduler
2 | from typing import Callable, Iterable, List, Optional
3 | from pathlib import Path
4 |
5 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box
6 | from prompt_toolkit.buffer import Buffer
7 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
8 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl
9 | from prompt_toolkit.key_binding import KeyBindings
10 | from prompt_toolkit.document import Document
11 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit
12 | from prompt_toolkit.lexers import Lexer
13 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples
14 |
15 | from DLNest.ShellClient.Windows.SCTextArea import SCTextArea
16 |
17 | class taskLexer(Lexer):
18 | def __init__(self):
19 | self.taskInfo = []
20 | self.saveLength = 24
21 | self.devicesLength = 10
22 |
23 | def get_text(self):
24 | return "".join(["\n" for _ in self.taskInfo])
25 |
26 | def lex_document(self,document : Document) -> Callable[[int], StyleAndTextTuples]:
27 | def get_line(lineno : int) -> StyleAndTextTuples:
28 | try:
29 | task = self.taskInfo[lineno]
30 | style_base = "class:pending_task"
31 | if task["status"] == "Running":
32 | style_base = "class:running_task"
33 | # elif task["status"] == "Suspend":
34 | # style_base = "class:suspend_task"
35 | ID = task["ID"]
36 | taskType = " Train " if task["type"] == "Train" else "Analyze"
37 | devices = " ".join([str(item) for item in task["devices"]])
38 | description = task["description"] if "description" in task else ""
39 | save = Path(task["args"]["root_file_path"]).stem
40 |
41 | if len(save) < self.saveLength:
42 | save += " " * (self.saveLength - len(save))
43 | elif len(save) > self.saveLength:
44 | save = save[:self.saveLength - 2] + ".."
45 |
46 | if len(devices) < self.devicesLength:
47 | devices += " " * (self.devicesLength - len(devices))
48 |
49 | return [
50 | (style_base + "_status" , "Status : " + task["status"] + " "),
51 | (style_base + "_type", " Type : " + taskType + " "),
52 | (style_base +"_id" , " ID : " + ID + " "),
53 | (style_base +"_device" , " Devices : " + devices + " "),
54 | (style_base +"_time" , " Folder : " + save + " "),
55 | (style_base +"_des" , " Note : " + description + " ")
56 | ]
57 | except Exception as e:
58 | return []
59 |
60 | return get_line
61 |
62 | class TaskInfoShower:
63 | def __init__(
64 | self,
65 | title : str = "Tasks",
66 | routineTask = None,
67 | freq : int = 1,
68 | style : str = "class:task_info_shower"
69 | ):
70 |
71 | self.title = title
72 | self.routineTask = routineTask
73 | self.scheduler = BackgroundScheduler()
74 | if not routineTask is None:
75 | self.scheduler.add_job(self.routineTask,'interval',seconds=freq,args=[self])
76 | self.scheduler.start()
77 | self.lexer =taskLexer()
78 |
79 | self.shower = SCTextArea(
80 | lexer = self.lexer,
81 | wrap_lines=False
82 | )
83 |
84 | self.style = style
85 | self.window = Frame(
86 | self.shower,
87 | self.title,
88 | style = self.style,
89 | height=10
90 | )
91 |
92 | def getWindow(self):
93 | return self.window
--------------------------------------------------------------------------------
/DLNest/Information/TrainTask.py:
--------------------------------------------------------------------------------
1 | try:
2 | from TaskInformation import TaskInformation
3 | from ..SavePackage.SavePackage import SavePackage
4 | except ImportError:
5 | from DLNest.Information.TaskInformation import TaskInformation
6 | from DLNest.SavePackage.SavePackage import SavePackage
7 |
8 | from multiprocessing import Queue
9 |
10 | class TrainTask(TaskInformation):
11 | @classmethod
12 | def fromRecord(
13 | cls,
14 | recordPath : str,
15 | checkpointID : int = -1,
16 | devices : list = [],
17 | memoryConsumption : int = -1,
18 | multiGPU : bool = False,
19 | DDP : bool = False,
20 | description : str = ""
21 | ):
22 | # Get save package from existing package
23 | savePackage = SavePackage()
24 | savePackage.initFromAnExistSavePackage(recordPath)
25 |
26 | retTask = cls(
27 | savePackage = savePackage,
28 | devices = devices,
29 | memoryConsumption = memoryConsumption,
30 | multiGPU = multiGPU,
31 | DDP = DDP,
32 | description = description,
33 | loadCkpt = True,
34 | checkpointID = checkpointID
35 | )
36 | return retTask
37 |
38 | @classmethod
39 | def fromConfigFile(
40 | cls,
41 | configPath : str,
42 | freqPath : str = "",
43 | devices : list = [],
44 | memoryConsumption : int = -1,
45 | multiGPU : bool = False,
46 | DDP : bool = False,
47 | description : str = "",
48 | noSave : bool = False,
49 | useDescriptionToSave : bool = False
50 | ):
51 | # Get save package by config path
52 | savePackage = SavePackage(configPath = configPath,freqPath = freqPath)
53 |
54 | # Making special save dir name by noSave or useDescriptionToSave
55 | saveName = ""
56 | if noSave:
57 | saveName = "NOSAVE"
58 | elif useDescriptionToSave:
59 | saveName = description
60 | savePackage.saveToNewDir(saveName)
61 |
62 | retTask = cls(
63 | savePackage = savePackage,
64 | devices = devices,
65 | memoryConsumption = memoryConsumption,
66 | multiGPU = multiGPU,
67 | DDP = DDP,
68 | description = description,
69 | noSave = noSave,
70 | useDescriptionToSave = useDescriptionToSave,
71 | loadCkpt = False
72 | )
73 | return retTask
74 |
75 | def __init__(
76 | self,
77 | savePackage : SavePackage,
78 | devices : list = [],
79 | memoryConsumption : int = -1,
80 | multiGPU : bool = False,
81 | DDP : bool = False,
82 | description : str = "",
83 | noSave : bool = False,
84 | useDescriptionToSave : bool = False,
85 | loadCkpt : bool = False,
86 | checkpointID : int = -1
87 | ):
88 | super(TrainTask,self).__init__(
89 | savePackage = savePackage,
90 | devices = devices,
91 | memoryConsumption = memoryConsumption,
92 | multiGPU = multiGPU,
93 | DDP = DDP,
94 | loadCkpt = loadCkpt,
95 | checkpointID = checkpointID
96 | )
97 | self.ID = "T_" + self.ID
98 | self.description = description
99 | self.noSave = noSave
100 | self.useDescriptionToSave = useDescriptionToSave
101 |
102 | self.type = "Train"
103 |
104 | self.commandQueue = Queue()
105 |
106 | if description != "":
107 | savePackage.saveVisualString("desc: " + description)
108 |
109 | def getDict(self):
110 | ret = super().getDict()
111 | ret["description"] = self.description
112 | return ret
--------------------------------------------------------------------------------
/DLNest/Plugins/MailsNote.py:
--------------------------------------------------------------------------------
1 | from DLNest.Plugins.DLNestPluginBase import DLNestPluginBase as DPB
2 | from DLNest.Plugins.Utils.SendMailTools import sendSelfMail
3 | import DLNest
4 | import traceback
5 | import logging
6 |
7 | class DLNestPlugin(DPB):
8 | _NAME = "MailsNote"
9 | _config = {
10 | "enable_list" : ["Aborting","TrainFinish"],
11 | "username" : "",
12 | "password" : "",
13 | "host" : "mail.fudan.edu.cn",
14 | "port" : 25
15 | }
16 | _defaultKeys = ["enable_list","username","password","host"]
17 |
18 | def trainAborting(self, exception : Exception):
19 | if self._status.rank != 0 and self._status.rank != -1:
20 | return
21 | excStr = traceback.format_exc()
22 |
23 | pluginName = DLNestPlugin._NAME
24 | enable = "Aborting" in DPB.getArgs(self, pluginName, "enable_list", [])
25 | if not enable:
26 | logging.debug("MailsNote is not enabled for aborting")
27 | return
28 |
29 | username = DPB.getArgs(self, pluginName, "username", "")
30 | password = DPB.getArgs(self, pluginName, "password", "")
31 | host = DPB.getArgs(self, pluginName, "host", "mail.fudan.edu.cn")
32 | port = DPB.getArgs(self, pluginName, "port", 25)
33 | message = "Train task in {} aborted with this exception\n".format(self.getArgs()["root_file_path"]) + excStr
34 | subject = "Train Aborted!"
35 | sendSelfMail(
36 | username = username,
37 | password = password,
38 | message = message,
39 | subject = subject,
40 | host = host,
41 | port = port
42 | )
43 | logging.info("[MailsNote] " + message)
44 |
45 | def ATrain(self):
46 | if self._status.rank != 0 and self._status.rank != -1:
47 | return
48 | pluginName = DLNestPlugin._NAME
49 | enable = "TrainFinish" in DPB.getArgs(self, pluginName, "enable_list", [])
50 | if not enable:
51 | logging.debug("MailsNote is not enabled for train finish")
52 | return
53 |
54 | username = DPB.getArgs(self, pluginName, "username", "")
55 | password = DPB.getArgs(self, pluginName, "password", "")
56 | host = DPB.getArgs(self, pluginName, "host", "mail.fudan.edu.cn")
57 | port = DPB.getArgs(self, pluginName, "port", 25)
58 | message = "Train task in {} finished\n".format(self.getArgs()["root_file_path"])
59 |
60 | resultStr = DLNest.Plugins.MailsNote.DLNestPlugin._getResultsStr() # To get correct class function and var.
61 | if len(resultStr) != 0:
62 | message += "\nThe final results are: \n" + resultStr
63 |
64 | subject = "Train Finished!"
65 | sendSelfMail(
66 | username = username,
67 | password = password,
68 | message = message,
69 | subject = subject,
70 | host = host,
71 | port = port
72 | )
73 | logging.info("[MailsNote] " + message)
74 |
75 | _result_values = {}
76 |
77 | @classmethod
78 | def giveResultValues(cls, resultsDict : dict):
79 | cls._result_values = resultsDict
80 |
81 | @classmethod
82 | def _getResultsStr(cls):
83 | str = ""
84 | for key in cls._result_values:
85 | str += "| {} : {}".format(key, cls._result_values[key])
86 | return str
87 |
88 | def custom(self, message, subject):
89 | pluginName = DLNestPlugin._NAME
90 | username = DPB.getArgs(self, pluginName, "username", "")
91 | password = DPB.getArgs(self, pluginName, "password", "")
92 | host = DPB.getArgs(self, pluginName, "host", "mail.fudan.edu.cn")
93 | port = DPB.getArgs(self, pluginName, "port", 25)
94 | message = "Train task in {} gives a message\n".format(self.getArgs()["root_file_path"]) + message
95 |
96 | sendSelfMail(
97 | username = username,
98 | password = password,
99 | message = message,
100 | subject = subject,
101 | host = host,
102 | port = port
103 | )
104 | logging.info("[MailsNote] " + message)
105 |
106 | def SOTA(self, key, value):
107 | value = str(value)
108 | message = "Congratulations! SOTA performance in {} has been made by your train task {} with value {}!".format(key, self.getArgs()["root_file_path"], value)
109 | subject = "SOTA for {}!".format(key)
110 | DLNestPlugin.custom(self, message, subject)
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### PyCharm ###
2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
4 |
5 | # User-specific stuff
6 | .idea/**/workspace.xml
7 | .idea/**/tasks.xml
8 | .idea/**/usage.statistics.xml
9 | .idea/**/dictionaries
10 | .idea/**/shelf
11 |
12 | # Generated files
13 | .idea/**/contentModel.xml
14 |
15 | # Sensitive or high-churn files
16 | .idea/**/dataSources/
17 | .idea/**/dataSources.ids
18 | .idea/**/dataSources.local.xml
19 | .idea/**/sqlDataSources.xml
20 | .idea/**/dynamic.xml
21 | .idea/**/uiDesigner.xml
22 | .idea/**/dbnavigator.xml
23 |
24 | # Gradle
25 | .idea/**/gradle.xml
26 | .idea/**/libraries
27 |
28 | # Gradle and Maven with auto-import
29 | # When using Gradle or Maven with auto-import, you should exclude module files,
30 | # since they will be recreated, and may cause churn. Uncomment if using
31 | # auto-import.
32 | # .idea/modules.xml
33 | # .idea/*.iml
34 | # .idea/modules
35 | # *.iml
36 | # *.ipr
37 |
38 | # CMake
39 | cmake-build-*/
40 |
41 | # Mongo Explorer plugin
42 | .idea/**/mongoSettings.xml
43 |
44 | # File-based project format
45 | *.iws
46 |
47 | # IntelliJ
48 | out/
49 |
50 | # mpeltonen/sbt-idea plugin
51 | .idea_modules/
52 |
53 | # JIRA plugin
54 | atlassian-ide-plugin.xml
55 |
56 | # Cursive Clojure plugin
57 | .idea/replstate.xml
58 |
59 | # Crashlytics plugin (for Android Studio and IntelliJ)
60 | com_crashlytics_export_strings.xml
61 | crashlytics.properties
62 | crashlytics-build.properties
63 | fabric.properties
64 |
65 | # Editor-based Rest Client
66 | .idea/httpRequests
67 |
68 | # Android studio 3.1+ serialized cache file
69 | .idea/caches/build_file_checksums.ser
70 |
71 | ### PyCharm Patch ###
72 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
73 |
74 | # *.iml
75 | # modules.xml
76 | # .idea/misc.xml
77 | # *.ipr
78 |
79 | # Sonarlint plugin
80 | .idea/**/sonarlint/
81 |
82 | # SonarQube Plugin
83 | .idea/**/sonarIssues.xml
84 |
85 | # Markdown Navigator plugin
86 | .idea/**/markdown-navigator.xml
87 | .idea/**/markdown-navigator/
88 |
89 | ### Python ###
90 | # Byte-compiled / optimized / DLL files
91 | __pycache__/
92 | *.py[cod]
93 | *$py.class
94 |
95 | # C extensions
96 | *.so
97 |
98 | # Distribution / packaging
99 | .Python
100 | build/
101 | develop-eggs/
102 | dist/
103 | downloads/
104 | eggs/
105 | .eggs/
106 | lib/
107 | lib64/
108 | parts/
109 | sdist/
110 | var/
111 | wheels/
112 | pip-wheel-metadata/
113 | share/python-wheels/
114 | *.egg-info/
115 | .installed.cfg
116 | *.egg
117 | MANIFEST
118 |
119 | # PyInstaller
120 | # Usually these files are written by a python script from a template
121 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
122 | *.manifest
123 | *.spec
124 |
125 | # Installer logs
126 | pip-log.txt
127 | pip-delete-this-directory.txt
128 |
129 | # Unit test / coverage reports
130 | htmlcov/
131 | .tox/
132 | .nox/
133 | .coverage
134 | .coverage.*
135 | .cache
136 | nosetests.xml
137 | coverage.xml
138 | *.cover
139 | .hypothesis/
140 | .pytest_cache/
141 |
142 | # Translations
143 | *.mo
144 | *.pot
145 |
146 | # Scrapy stuff:
147 | .scrapy
148 |
149 | # Sphinx documentation
150 | docs/_build/
151 |
152 | # PyBuilder
153 | target/
154 |
155 | # pyenv
156 | .python-version
157 |
158 | # pipenv
159 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
160 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
161 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
162 | # install all needed dependencies.
163 | #Pipfile.lock
164 |
165 | # celery beat schedule file
166 | celerybeat-schedule
167 |
168 | # SageMath parsed files
169 | *.sage.py
170 |
171 | # Spyder project settings
172 | .spyderproject
173 | .spyproject
174 |
175 | # Rope project settings
176 | .ropeproject
177 |
178 | # Mr Developer
179 | .mr.developer.cfg
180 | .project
181 | .pydevproject
182 |
183 | # mkdocs documentation
184 | /site
185 |
186 | # mypy
187 | .mypy_cache/
188 | .dmypy.json
189 | dmypy.json
190 |
191 | # Pyre type checker
192 | .pyre/
193 |
194 | ### VisualStudioCode ###
195 | .vscode/
196 | .vscode/*
197 | !.vscode/settings.json
198 | !.vscode/tasks.json
199 | !.vscode/launch.json
200 | !.vscode/extensions.json
201 |
202 | ### VisualStudioCode Patch ###
203 | # Ignore all local history of files
204 | .history
205 |
206 | # End of https://www.gitignore.io/api/python,pycharm,visualstudiocode
--------------------------------------------------------------------------------
/DLNest/ShellClient/Windows/ResultsOutput.py:
--------------------------------------------------------------------------------
1 | from apscheduler.schedulers.background import BackgroundScheduler
2 | from typing import Callable, Iterable, List, Optional
3 |
4 | from prompt_toolkit.widgets import Frame,TextArea,Label,Box
5 | from prompt_toolkit.buffer import Buffer
6 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
7 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl
8 | from prompt_toolkit.key_binding import KeyBindings
9 | from prompt_toolkit.document import Document
10 | from prompt_toolkit.layout.containers import VSplit, Window,HSplit
11 | from prompt_toolkit.lexers import Lexer
12 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples
13 |
14 | from DLNest.ShellClient.Windows.SCTextArea import SCTextArea
15 |
16 | class styledTextLexer(Lexer):
17 | def __init__(self,styled_text : StyleAndTextTuples = []):
18 | self.styled_text = styled_text
19 | self.DEBUG = False
20 |
21 | def __get_styled_lines(self):
22 | self.styled_text_lines = []
23 | now = []
24 | for item in self.styled_text:
25 | finish = False
26 | if "\n" in item[1]:
27 | finish = True
28 | item = (item[0],item[1].replace("\n", ""))
29 | now.append(item)
30 | if finish:
31 | self.styled_text_lines.append(now)
32 | now = []
33 | self.styled_text_lines.append(now)
34 |
35 | def lex_document(self,document : Document) -> Callable[[int], StyleAndTextTuples]:
36 | self.__get_styled_lines()
37 | lines = document.lines
38 |
39 | if self.DEBUG:
40 | with open("/root/STLexerDEBUG.txt","w") as f:
41 | print(self.styled_text,file = f)
42 | print(self.styled_text_lines,file=f)
43 | print(lines,file = f)
44 | print(len(lines),len(self.styled_text_lines),file=f)
45 |
46 | def get_line(lineno : int) -> StyleAndTextTuples:
47 | try:
48 | return self.styled_text_lines[lineno]
49 | except Exception:
50 | return [("","")]
51 |
52 | return get_line
53 |
54 | class ResultsOutput:
55 | def __init__(
56 | self,
57 | title : str = "DLNest Output",
58 | routineTask = None,
59 | freq : int = 1,
60 | style : str = "class:results_output"
61 | ):
62 | self.title = title
63 | self.routineTask = routineTask
64 | self.scheduler = BackgroundScheduler()
65 | if not routineTask is None:
66 | self.scheduler.add_job(self.routineTask,'interval',seconds=freq,args=[self])
67 | self.scheduler.start()
68 | self.lexer = styledTextLexer()
69 |
70 | self.shower = SCTextArea(
71 | self.lexer
72 | )
73 |
74 | self.style = style
75 | self.setKeyBinding()
76 | self.window = Frame(
77 | self.shower,
78 | self.title,
79 | style=self.style,
80 | modal=True,
81 | key_bindings=self.kb
82 | )
83 |
84 | def setKeyBinding(self):
85 | self.kb = KeyBindings()
86 |
87 | @self.kb.add("escape")
88 | def toEnd(event):
89 | self.shower.buffer.cursor_position = len(self.shower.document.text)
90 |
91 | def getWindow(self):
92 | return self.window
93 |
94 | class AnalyzeOutput(ResultsOutput):
95 | def __init__(
96 | self,
97 | title : str = "DLNest Output",
98 | routineTask = None,
99 | freq : int = 1,
100 | style : str = "class_analyzer_output"
101 | ):
102 | super(AnalyzeOutput,self).__init__(
103 | title,
104 | routineTask,
105 | freq,
106 | style
107 | )
108 |
109 | self.infoText = FormattedTextControl(
110 | [("", " No analyze task is running ")],
111 | focusable=False,
112 | show_cursor=False
113 | )
114 |
115 | self.infoWindow = Window(
116 | content=self.infoText
117 | )
118 |
119 | self.infoLabel = Box(
120 | body=self.infoWindow,
121 | height=3,
122 | padding_top=1,
123 | padding_bottom=1,
124 | # padding_left=3,
125 | # padding_right=3,
126 | style="class:analyzer_info_label"
127 | )
128 |
129 | self.window = Frame(
130 | HSplit([
131 | self.infoLabel,
132 | self.shower
133 | ]),
134 | title=self.title,
135 | style = self.style,
136 | modal = True,
137 | key_bindings=self.kb
138 | )
--------------------------------------------------------------------------------
/DLNest/Scheduler/DefaultStrategy.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.InfoCenter import InfoCenter
2 | from DLNest.Scheduler.SchedulerStrategyBase import SchedulerStrategyBase
3 | from DLNest.Information.DeviceInformation import DeviceInformation
4 | from DLNest.Information.TaskInformation import TaskInformation
5 |
6 | import time
7 |
8 | class DefaultStrategy(SchedulerStrategyBase):
9 | def __checkTaskNumAndLastUseTime(self, device : DeviceInformation):
10 | """
11 | Check isBreak, task num, last use time
12 | """
13 | if device.isBreak:
14 | return False
15 |
16 | # Check task num
17 | taskNum = device.getTaskNum()
18 | if taskNum >= self.maxTaskPerDevice:
19 | return False
20 |
21 | # Check last use time
22 | delta = time.time() - device.lastUseTime()
23 | if delta < self.maxTimeDelay:
24 | return False
25 |
26 | return True
27 |
28 | def __canTaskRunOnDevice(self,task : TaskInformation, device : DeviceInformation):
29 | """
30 | Check free memory and last use time
31 | """
32 | if not self.__checkTaskNumAndLastUseTime(device):
33 | return False
34 |
35 | # Check free memory
36 | memoryConsumption = device.totalMemory * 0.9 if task.memoryConsumption < 0 else task.memoryConsumption
37 | freeMemory = device.getFreeMemory()
38 | if freeMemory <= memoryConsumption:
39 | return False
40 |
41 | return True
42 |
43 | def __canTaskRunOnDevices(self, task : TaskInformation, devices):
44 | """
45 | return [] if cannot
46 | return [ids] if can run
47 | won't care the CPU
48 | """
49 | memoryConsumption = devices[0].totalMemory * 0.9 if task.memoryConsumption < 0 else task.memoryConsumption
50 | ret = []
51 | memInfo = []
52 | for device in devices:
53 | # don't consider CPU and broken devices
54 | if device.type == "CPU" or device.isBreak:
55 | continue
56 | if not self.__checkTaskNumAndLastUseTime(device):
57 | continue
58 |
59 | freeMemory = device.getFreeMemory()
60 | memInfo.append((freeMemory,device))
61 |
62 | # use the devices with big free memory first
63 | memInfo.sort(key = lambda x:x[0], reverse = True)
64 | memCount = 0.0
65 | for item in memInfo:
66 | freeMemory,device = item
67 | ret.append(device.ID)
68 | memCount += freeMemory
69 | if memCount > memoryConsumption:
70 | break
71 |
72 | # return [] if no enough free memory
73 | if memCount <= memoryConsumption:
74 | return []
75 | else:
76 | return ret
77 |
78 | def __findARuningDevice(self, task : TaskInformation, devices):
79 | """
80 | return [] if cannot
81 | return [id] if can run
82 | won't care the CPU
83 | """
84 | for device in devices:
85 | if device.type == "CPU":
86 | continue
87 |
88 | if self.__canTaskRunOnDevice(task, device):
89 | return [device.ID]
90 |
91 | return []
92 |
93 | def decide(self, maxTaskPerDevice : int, maxTimeDelay : int):
94 | self.maxTaskPerDevice = maxTaskPerDevice
95 | self.maxTimeDelay = maxTimeDelay
96 |
97 | continueCheck = True
98 | while continueCheck:
99 | continueCheck = False
100 |
101 | tasks = self.infoCenter.usingTasks()
102 | try:
103 | numTasks = len(tasks)
104 | taskWait2Run = None
105 | for i in range(numTasks):
106 | # only consider pending tasks
107 | if tasks[i].status == "Pending":
108 | taskWait2Run = tasks[i]
109 | # only consider the oldest task
110 | break
111 | else:
112 | continue
113 |
114 | if taskWait2Run != None:
115 | # using device informations
116 | devices = self.infoCenter.usingDevicesInformation()
117 | usingDevices = []
118 | if taskWait2Run.devices != []:
119 | # if the task has some decided devices(such as CPU), just run it.
120 | usingDevices = taskWait2Run.devices
121 | elif taskWait2Run.multiGPU:
122 | # multi GPU logic
123 | usingDevices = self.__canTaskRunOnDevices(taskWait2Run,devices)
124 | else:
125 | # single GPU logic
126 | usingDevices = self.__findARuningDevice(taskWait2Run,devices)
127 | # must release device before run task
128 | self.infoCenter.releaseDeviceInformation()
129 | else:
130 | break
131 |
132 | # if have proper device
133 | if usingDevices != []:
134 | self.scheduler.runTask(taskWait2Run,usingDevices)
135 | continueCheck = True
136 | finally:
137 | self.infoCenter.releaseTasks()
138 |
--------------------------------------------------------------------------------
/DLNest/Output/OutputLayerBase.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | import prompt_toolkit
3 | from prompt_toolkit.styles import Style
4 | import threading
5 |
6 | class OutputLayerBase():
7 | def __init__(self):
8 | self.styleDict = {
9 | '' : '#efefef',
10 | 'app' : '#ff0f9f',
11 | 'status' : '#ff0000',
12 | 'message' : '#ffffff',
13 | }
14 | if not hasattr(self,'styledText'):
15 | self.styledText = []
16 | self.maxBufferLen = 10000
17 | self.offset = 0
18 | self.lock = threading.Lock()
19 | self.isLineHead = True
20 |
21 | def getPlainText(self,from_token : int = -1,length : int = -1):
22 | if self.lock.acquire():
23 | try:
24 | if len(self.styledText) == 0:
25 | return self.offset,""
26 |
27 | if from_token == -1:
28 | from_token = self.offset
29 |
30 | if self.offset > from_token:
31 | # 开始位置已被删除
32 | return self.offset,None
33 | else:
34 | ret = ""
35 | start = from_token - self.offset
36 | end = 0
37 | # 开始位置还没数据
38 | if start >= len(self.styledText):
39 | return self.offset,None
40 |
41 | if length == -1:
42 | end = len(self.styledText)
43 | else:
44 | end = min(len(self.styledText),start + length)
45 |
46 | for i in range(start,end):
47 | ret = ret + self.styledText[i][1]
48 | return self.offset,ret
49 | finally:
50 | self.lock.release()
51 |
52 | def getStyledText(self,from_token : int = -1,length : int = -1):
53 | if self.lock.acquire():
54 | try:
55 | if len(self.styledText) == 0:
56 | return self.offset,[("","")]
57 | if from_token == -1:
58 | from_token = self.offset
59 |
60 | if self.offset > from_token:
61 | # 开始位置已被删除
62 | return self.offset,None
63 | else:
64 | start = from_token - self.offset
65 | if start >= len(self.styledText):
66 | # 开始位置还没数据
67 | return self.offset,None
68 | if length == -1:
69 | end = len(self.styledText)
70 | else:
71 | end = min(len(self.styledText),start + length)
72 |
73 | return self.offset,self.styledText[start:end]
74 | finally:
75 | self.lock.release()
76 |
77 | def getStyledDict(self):
78 | return self.styleDict
79 |
80 | def maintainText(self):
81 | if len(self.styledText) <= self.maxBufferLen:
82 | return
83 | while len(self.styledText) > self.maxBufferLen:
84 | while self.styledText[0][1] != "\n":
85 | self.styledText.pop(0)
86 | self.offset += 1
87 | self.styledText.pop(0)
88 | self.offset += 1
89 |
90 | def putPlainText(self,text : str):
91 | if self.lock.acquire():
92 | try:
93 | self.styledText.append((
94 | (self.styleDict[''] , text)
95 | ))
96 | self.maintainText()
97 | return True
98 | finally:
99 | self.lock.release()
100 |
101 | def putStyledText(self,style_name : str, text : str):
102 | if self.lock.acquire():
103 | try:
104 | if style_name in self.styleDict:
105 | self.styledText.append((
106 | self.styleDict[style_name], text
107 | ))
108 | self.maintainText()
109 | return True
110 | else:
111 | return False
112 | finally:
113 | self.lock.release()
114 |
115 | def clear(self):
116 | if self.lock.acquire():
117 | try:
118 | self.offset = 0
119 | self.styledText = []
120 | finally:
121 | self.lock.release()
122 |
123 | def write(self,message : str):
124 | if self.isLineHead:
125 | self.putStyledText('app',"[" + self.appName + "] ")
126 | self.isLineHead = False
127 |
128 | if "\n" in message:
129 | if message == '\n':
130 | self.putStyledText('message',message)
131 | self.isLineHead = True
132 | else:
133 | lines = message.split("\n")
134 | endBy_n = message[-1] == "\n"
135 | for i in range(len(lines)):
136 | if self.isLineHead:
137 | self.putStyledText('app',"[" + self.appName + "] ")
138 | self.isLineHead = False
139 | self.putStyledText('message',lines[i])
140 | if i < len(lines) - 1:
141 | self.putStyledText('message',"\n")
142 | self.isLineHead = True
143 | elif endBy_n:
144 | self.putStyledText('message',"\n")
145 | self.isLineHead = True
146 | else:
147 | self.putStyledText('message',message)
148 |
149 | def flush(self):
150 | ...
151 |
152 | def isatty(self):
153 | return True
154 |
155 | def fileno(self):
156 | return 3 # Maybe hazard
--------------------------------------------------------------------------------
/DLNest/Common/RunnerBaseTorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.distributed as dist
4 | from .RunnerBase import RunnerBase
5 | from abc import ABCMeta, abstractmethod
6 | import typing
7 | from torch.optim.lr_scheduler import _LRScheduler
8 | from torch.optim.optimizer import Optimizer
9 | from DLNest.Plugins.Utils.CheckPlugins import checkPlugins,checkDictOutputPlugins
10 |
11 | class ModuleWrapper:
12 | """
13 | Contain at most two nn.modules, one is noraml torch module and the other one is the DDP shell contain this module
14 | Any attribute accessed by . are searched in the normal module and the __call__ function will call the DDP shell.
15 | Using singleCard to call the module's forward function.
16 | """
17 | keywords = ["_ModuleWrapper__M","_ModuleWrapper__DDP", "__dict__","singleCard"]
18 |
19 | def __init__(self, module : nn.Module, DDP_wrapper):
20 | self.__M = module
21 | self.__DDP = DDP_wrapper
22 |
23 | def __getattr__(self, name):
24 | if name in ModuleWrapper.keywords:
25 | return super().__getattr__(name)
26 | return self.__M.__getattr__(name)
27 |
28 | def __setattr__(self, name, value):
29 | if name in ModuleWrapper.keywords:
30 | self.__dict__[name] = value
31 | return
32 | return self.__M.__setattr__(name, value)
33 |
34 | def __delattr__(self, name):
35 | return self.__M.__delattr__(name)
36 |
37 | def __getattribute__(self, name):
38 | if name in ModuleWrapper.keywords:
39 | return super().__getattribute__(name)
40 | return self.__M.__getattribute__(name)
41 |
42 | def __call__(self, *args, **kwargs):
43 | if self.__DDP:
44 | return self.__DDP(*args, **kwargs)
45 | else:
46 | return self.__M(*args, **kwargs)
47 |
48 | def __str__(self):
49 | if self.__DDP:
50 | return self.__DDP.__str__() + " in DLNest ModuleWrapper"
51 | else:
52 | return self.__M.__str__() + " in DLNest ModuleWrapper"
53 |
54 | def singleCard(self, *args, **kwargs):
55 | self.__M(*args, **kwargs)
56 |
57 | class RunnerBaseTorch(RunnerBase):
58 | def __init__(self,args : dict, plugins = [], status = None):
59 | super(RunnerBaseTorch,self).__init__(args = args, plugins = plugins, status = status)
60 | self.__modelList = []
61 | self.__optimizerDict = {}
62 | self.__schedulerDict = {}
63 |
64 | def DDPOperation(self,rank : int):
65 | pass
66 |
67 | def __setattr__(self, name, value):
68 | if isinstance(value, nn.Module):
69 | super().__setattr__(name, self.register(value))
70 | elif isinstance(value, Optimizer):
71 | super().__setattr__(name, self.__registerOptimizer(value, name))
72 | elif isinstance(value, _LRScheduler):
73 | super().__setattr__(name, self.__registerLRScheduler(value, name))
74 | else:
75 | super().__setattr__(name,value)
76 |
77 | def register(self,module : nn.Module, syncBN : bool = False):
78 | assert isinstance(module, nn.Module), "Only nn.Module can be registered. You tried to register a " + module.__class__.__name__
79 | self.__modelList.append(module)
80 | if self._status.env == "DDP":
81 | if isinstance(module, nn.Module):
82 | model = module.cuda()
83 | if syncBN:
84 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
85 | try:
86 | DDPmodel = nn.parallel.DistributedDataParallel(
87 | model,
88 | device_ids=[self._status.rank],
89 | output_device=self._status.rank,
90 | find_unused_parameters=True
91 | )
92 | except AssertionError as e:
93 | DDPmodel = None
94 | return ModuleWrapper(model, DDPmodel)
95 | else:
96 | if isinstance(module, nn.Module):
97 | if self._status.env != "CPU":
98 | model = module.cuda()
99 | return ModuleWrapper(model, None)
100 |
101 | def __registerOptimizer(self, optimizer : Optimizer, name):
102 | self.__optimizerDict[name] = optimizer
103 | return optimizer
104 |
105 | def __registerLRScheduler(self, scheduler : _LRScheduler, name):
106 | self.__schedulerDict[name] = scheduler
107 | return scheduler
108 |
109 | def getSaveDict(self):
110 | stateDict = {}
111 | for i in range(len(self.__modelList)):
112 | stateDict[i] = self.__modelList[i].state_dict()
113 |
114 | stateDict["optimizer"] = {}
115 | for name in self.__optimizerDict:
116 | optimizer = self.__optimizerDict[name]
117 | try:
118 | stateDict["optimizer"][name] = optimizer.state_dict()
119 | except Exception as e:
120 | pass # may add some warning
121 |
122 | stateDict["scheduler"] = {}
123 | for name in self.__schedulerDict:
124 | scheduler = self.__schedulerDict[name]
125 | try:
126 | stateDict["scheduler"][name] = scheduler.state_dict()
127 | except Exception as e:
128 | pass # may add some warning
129 | return stateDict
130 |
131 | def loadSaveDict(self,saveDict):
132 | for i in range(len(self.__modelList)):
133 | self.__modelList[i].load_state_dict(saveDict[i])
134 |
135 | if "optimizer" in saveDict:
136 | for name in saveDict["optimizer"]:
137 | self.__optimizerDict[name].load_state_dict(saveDict["optimizer"][name])
138 |
139 | if "scheduler" in saveDict:
140 | for name in saveDict["scheduler"]:
141 | self.__schedulerDict[name].load_state_dict(saveDict["scheduler"][name])
142 |
143 | # to reset the learing rate
144 | self.__schedulerDict[name].step(
145 | self.__schedulerDict[name].last_epoch
146 | )
147 |
148 | def _reduceSum(self,tensor):
149 | """
150 | ALPHA VERSION FUNCTION
151 |
152 | reduce tensor if using DDP,
153 | if not using DDP, do nothing.
154 | """
155 | if self._status.env == "DDP":
156 | dist.reduce(tensor,0)
157 | return tensor
158 |
159 | def _reduceMean(self,tensor):
160 | """
161 | ALPHA VERSION FUNCTION
162 |
163 | reduce tensor if using DDP,
164 | if not using DDP, do nothing.
165 | """
166 | if self._status.env == "DDP":
167 | dist.reduce(tensor,0)
168 | if self._status.rank == 0:
169 | tensor = tensor / self._status.worldSize
170 | return tensor
171 |
--------------------------------------------------------------------------------
/DLNest/Common/LifeCycleBase.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | try:
3 | from .DatasetBase import DatasetBase
4 | from .RunnerBase import RunnerBase
5 | except ImportError:
6 | from DLNest.Common.DatasetBase import DatasetBase
7 | from DLNest.Common.RunnerBase import RunnerBase
8 | from DLNest.Plugins.Utils.CheckPlugins import checkPlugins
9 | import traceback
10 | from functools import wraps
11 | import logging
12 |
13 | class LifeCycleBase:
14 | def __init__(self,runner : RunnerBase = None,dataset : DatasetBase = None, taskProcess = None, plugins : list = [], status = None):
15 | self.runner = runner
16 | self.dataset = dataset
17 | self.taskProcess = taskProcess
18 | self._plugins = plugins
19 | self._status = status
20 |
21 | # for backward compatibility
22 | @property
23 | def rank(self):
24 | if not "_warned_rank" in dir(self):
25 | self._warned_rank = True
26 | print("LifeCycle.rank is deprecated, please use LifeCycle.status.rank")
27 | return self._status.rank
28 |
29 | def getArgs(self):
30 | return self.taskProcess.task.args
31 |
32 | def BAll(self):
33 | pass
34 |
35 | def BDatasetInit(self):
36 | pass
37 |
38 | def ADatasetInit(self):
39 | pass
40 |
41 | def BModelInit(self):
42 | pass
43 |
44 | def AModelInit(self):
45 | pass
46 |
47 | def BTrain(self):
48 | pass
49 |
50 | def ATrain(self):
51 | pass
52 |
53 | def BOneEpoch(self):
54 | pass
55 |
56 | def AOneEpoch(self):
57 | pass
58 |
59 | def BGetCommand(self):
60 | pass
61 |
62 | def AGetCommand(self,command):
63 | pass
64 |
65 | def BSuspend(self):
66 | pass
67 |
68 | def ASuspend(self):
69 | pass
70 |
71 | def BLoadFromSuspend(self):
72 | pass
73 |
74 | def ALoadFromSuspend(self):
75 | pass
76 |
77 | def BModelOneStep(self):
78 | pass
79 |
80 | def AModelOneStep(self):
81 | pass
82 |
83 | @abstractmethod
84 | def needVisualize(self, epoch : int, iter : int, logdict : dict, args : dict):
85 | return False
86 |
87 | def BVisualize(self):
88 | pass
89 |
90 | def AVisualize(self):
91 | pass
92 |
93 | @abstractmethod
94 | def needValidation(self, epoch : int, logdict : dict, args : dict):
95 | return False
96 |
97 | def BValidation(self):
98 | pass
99 |
100 | def BValidateABatch(self):
101 | pass
102 |
103 | def AValidateABatch(self):
104 | pass
105 |
106 | def BValidationAnalyze(self):
107 | pass
108 |
109 | def AValidationAnalyze(self):
110 | pass
111 |
112 | def AValidation(self):
113 | pass
114 |
115 | def commandLineOutput(self,epoch : int, logdict : dict, args : dict):
116 | print("Epoch #" + str(epoch) + " finished!")
117 |
118 | @abstractmethod
119 | def needSaveModel(self, epoch : int, logdict : dict, args : dict):
120 | return True
121 |
122 | def BSaveModel(self):
123 | pass
124 |
125 | def ASaveModel(self):
126 | pass
127 |
128 | @abstractmethod
129 | def holdThisCheckpoint(self, epoch : int, logdict : dict, args : dict):
130 | return False
131 |
132 | @abstractmethod
133 | def needContinueTrain(self, epoch : int, logdict : dict, args : dict):
134 | return False
135 |
136 | def getSaveDict(self):
137 | return {}
138 |
139 | def loadSaveDict(self,saveDict):
140 | pass
141 |
142 | def AAll(self):
143 | pass
144 |
145 | def trainAborting(self,exception : Exception):
146 | traceback.print_exc()
147 |
148 | @checkPlugins
149 | def _BAll(self):
150 | return self.BAll()
151 |
152 | @checkPlugins
153 | def _BDatasetInit(self):
154 | return self.BDatasetInit()
155 |
156 | @checkPlugins
157 | def _ADatasetInit(self):
158 | return self.ADatasetInit()
159 |
160 | @checkPlugins
161 | def _BModelInit(self):
162 | return self.BModelInit()
163 |
164 | @checkPlugins
165 | def _AModelInit(self):
166 | return self.AModelInit()
167 |
168 | @checkPlugins
169 | def _BTrain(self):
170 | return self.BTrain()
171 |
172 | @checkPlugins
173 | def _ATrain(self):
174 | return self.ATrain()
175 |
176 | @checkPlugins
177 | def _BOneEpoch(self):
178 | return self.BOneEpoch()
179 |
180 | @checkPlugins
181 | def _AOneEpoch(self):
182 | return self.AOneEpoch()
183 |
184 | @checkPlugins
185 | def _BGetCommand(self):
186 | return self.BGetCommand()
187 |
188 | @checkPlugins
189 | def _AGetCommand(self,command):
190 | return self.AGetCommand(command = command)
191 |
192 | @checkPlugins
193 | def _BSuspend(self):
194 | return self.BSuspend()
195 |
196 | @checkPlugins
197 | def _ASuspend(self):
198 | return self.ASuspend()
199 |
200 | @checkPlugins
201 | def _BLoadFromSuspend(self):
202 | return self.BLoadFromSuspend()
203 |
204 | @checkPlugins
205 | def _ALoadFromSuspend(self):
206 | return self.ALoadFromSuspend()
207 |
208 | @checkPlugins
209 | def _BModelOneStep(self):
210 | return self.BModelOneStep()
211 |
212 | @checkPlugins
213 | def _AModelOneStep(self):
214 | return self.AModelOneStep()
215 |
216 | @checkPlugins
217 | def _BVisualize(self):
218 | return self.BVisualize()
219 |
220 | @checkPlugins
221 | def _AVisualize(self):
222 | return self.AVisualize()
223 |
224 | @checkPlugins
225 | def _BValidation(self):
226 | return self.BValidation()
227 |
228 | @checkPlugins
229 | def _BValidateABatch(self):
230 | return self.BValidateABatch()
231 |
232 | @checkPlugins
233 | def _AValidateABatch(self):
234 | return self.AValidateABatch()
235 |
236 | @checkPlugins
237 | def _BValidationAnalyze(self):
238 | return self.BValidationAnalyze()
239 |
240 | @checkPlugins
241 | def _AValidationAnalyze(self):
242 | return self.AValidationAnalyze()
243 |
244 | @checkPlugins
245 | def _AValidation(self):
246 | return self.AValidation()
247 |
248 | @checkPlugins
249 | def _commandLineOutput(self,epoch : int, logdict : dict, args : dict):
250 | return self.commandLineOutput(epoch = epoch, logdict = logdict, args = args)
251 |
252 | @checkPlugins
253 | def _BSaveModel(self):
254 | return self.BSaveModel()
255 |
256 | @checkPlugins
257 | def _ASaveModel(self):
258 | return self.ASaveModel()
259 |
260 | @checkPlugins
261 | def _getSaveDict(self):
262 | return self.getSaveDict()
263 |
264 | @checkPlugins
265 | def _loadSaveDict(self,saveDict):
266 | return self.loadSaveDict(saveDict)
267 |
268 | @checkPlugins
269 | def _AAll(self):
270 | return self.AAll()
271 |
272 | @checkPlugins
273 | def _trainAborting(self,exception : Exception):
274 | return self.trainAborting(exception = exception)
--------------------------------------------------------------------------------
/DLNest/ShellClient/Windows/Utils/Completers.py:
--------------------------------------------------------------------------------
1 | from prompt_toolkit.completion import CompleteEvent, Completer, Completion,merge_completers,WordCompleter
2 | from prompt_toolkit.completion.nested import NestedCompleter
3 | from prompt_toolkit.buffer import Buffer
4 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
5 | from prompt_toolkit import PromptSession
6 | from prompt_toolkit.key_binding import KeyBindings
7 | from typing import Callable, Iterable, List, Optional
8 | from prompt_toolkit.document import Document
9 | import os
10 |
11 | class PartPathCompleter(Completer):
12 | """
13 | Complete for Path variables.
14 |
15 | :param get_paths: Callable which returns a list of directories to look into
16 | when the user enters a relative path.
17 | :param file_filter: Callable which takes a filename and returns whether
18 | this file should show up in the completion. ``None``
19 | when no filtering has to be done.
20 | :param min_input_len: Don't do autocompletion when the input string is shorter.
21 | """
22 |
23 | def __init__(
24 | self,
25 | only_directories: bool = False,
26 | get_paths: Optional[Callable[[], List[str]]] = None,
27 | file_filter: Optional[Callable[[str], bool]] = None,
28 | min_input_len: int = 0,
29 | expanduser: bool = False,
30 | ) -> None:
31 |
32 | self.only_directories = only_directories
33 | self.get_paths = get_paths or (lambda: ["."])
34 | self.file_filter = file_filter or (lambda _: True)
35 | self.min_input_len = min_input_len
36 | self.expanduser = expanduser
37 |
38 | def get_completions(
39 | self, document: Document, complete_event: CompleteEvent
40 | ) -> Iterable[Completion]:
41 | text = document.text_before_cursor
42 | text = text.split(" ")[-1]
43 |
44 | # Complete only when we have at least the minimal input length,
45 | # otherwise, we can too many results and autocompletion will become too
46 | # heavy.
47 | if len(text) < self.min_input_len:
48 | return
49 |
50 | try:
51 | # Do tilde expansion.
52 | if self.expanduser:
53 | text = os.path.expanduser(text)
54 |
55 | # Directories where to look.
56 | dirname = os.path.dirname(text)
57 | if dirname:
58 | directories = [
59 | os.path.dirname(os.path.join(p, text)) for p in self.get_paths()
60 | ]
61 | else:
62 | directories = self.get_paths()
63 |
64 | # Start of current file.
65 | prefix = os.path.basename(text)
66 |
67 | # Get all filenames.
68 | filenames = []
69 | for directory in directories:
70 | # Look for matches in this directory.
71 | if os.path.isdir(directory):
72 | for filename in os.listdir(directory):
73 | if filename.startswith(prefix):
74 | filenames.append((directory, filename))
75 |
76 | # Sort
77 | filenames = sorted(filenames, key=lambda k: k[1])
78 |
79 | # Yield them.
80 | for directory, filename in filenames:
81 | completion = filename[len(prefix) :]
82 | full_name = os.path.join(directory, filename)
83 |
84 | if os.path.isdir(full_name):
85 | # For directories, add a slash to the filename.
86 | # (We don't add them to the `completion`. Users can type it
87 | # to trigger the autocompletion themselves.)
88 | filename += "/"
89 | elif self.only_directories:
90 | continue
91 |
92 | if not self.file_filter(full_name):
93 | continue
94 |
95 | yield Completion(completion, 0, display=filename)
96 | except OSError:
97 | pass
98 |
99 | class CommandCompleter(Completer):
100 | def __init__(self, rules : dict):
101 | self.rules = rules
102 |
103 | def get_completions(
104 | self, document: Document, complete_event: CompleteEvent
105 | ):
106 | text = document.text_before_cursor.lstrip()
107 | tokens = text.replace("\n"," ").replace("\t"," ").split(" ")
108 | nowRules = self.rules
109 | if len(tokens) <= 1:
110 | completer = WordCompleter(
111 | list(self.rules.keys())
112 | )
113 | for c in completer.get_completions(document, complete_event):
114 | yield c
115 | else:
116 | title = tokens[0]
117 | if not title in self.rules:
118 | return
119 | else:
120 | try:
121 | rule_dict = self.rules[title]
122 | if len(tokens) > 2 and tokens[-2] in rule_dict:
123 | if not isinstance(rule_dict[tokens[-2]],Completer):
124 | if not isinstance(rule_dict[tokens[-2]],int):
125 | return
126 | else:
127 | for c in rule_dict[tokens[-2]].get_completions(document, complete_event):
128 | yield c
129 | return
130 | search_list = []
131 | if rule_dict is None:
132 | return
133 | for key in rule_dict:
134 | if not key in tokens:
135 | search_list.append(key)
136 | completer = WordCompleter(
137 | search_list
138 | )
139 | for c in completer.get_completions(document, complete_event):
140 | yield c
141 | except Exception as e:
142 | ...
143 | #print(e)
144 |
145 | def getCommandCompleter():
146 | commands = {
147 | "run" : {
148 | "-c" : PartPathCompleter(),
149 | "-d" : None,
150 | "-f" : PartPathCompleter(),
151 | "-m" : None,
152 | "-ns" : -1,
153 | "-mc" : -1,
154 | "-sd" : -1,
155 | "-DDP" : -1,
156 | "-CPU" : -1
157 | },
158 | "continue" : {
159 | "-r" : PartPathCompleter(),
160 | "-c" : None,
161 | "-d" : None,
162 | "-m" : None,
163 | "-DDP" : -1,
164 | "-CPU" : -1,
165 | "-mc" : -1
166 | },
167 | "analyze" : {
168 | "-r" : PartPathCompleter(),
169 | "-s" : PartPathCompleter(),
170 | "-c" : None,
171 | "-m" : None,
172 | "-CPU" : -1
173 | },
174 | "new" : {
175 | "-d" : PartPathCompleter(),
176 | "-MNIST" : None,
177 | "-p" : None,
178 | },
179 | "changeDevices" : {
180 | "-d" : None
181 | },
182 | "addP" : {
183 | "-d" : PartPathCompleter(),
184 | "-p" : None,
185 | "-F" : None,
186 | },
187 | "changeDelay" : None,
188 | "runExp" : None,
189 | "del" : None,
190 | "exit" : None
191 | }
192 | return CommandCompleter(commands)
--------------------------------------------------------------------------------
/DLNest/Information/InfoCenter.py:
--------------------------------------------------------------------------------
1 | try:
2 | from TaskInformation import TaskInformation
3 | from TrainTask import TrainTask
4 | from AnalyzeTask import AnalyzeTask
5 | from DeviceInformation import DeviceInformation
6 | from CPUInformation import CPUInformation
7 | except ImportError:
8 | from DLNest.Information.TaskInformation import TaskInformation
9 | from DLNest.Information.TrainTask import TrainTask
10 | from DLNest.Information.AnalyzeTask import AnalyzeTask
11 | from DLNest.Information.DeviceInformation import DeviceInformation
12 | from DLNest.Information.CPUInformation import CPUInformation
13 |
14 | HAVE_GPU = True
15 |
16 | try:
17 | from GPUInformation import GPUInformation
18 | except Exception:
19 | try:
20 | from .GPUInformation import GPUInformation
21 | except Exception:
22 | HAVE_GPU = False
23 |
24 | if HAVE_GPU:
25 | import pynvml
26 |
27 | import threading
28 |
29 | class Singleton(object):
30 | def __new__(cls, *args, **kwargs):
31 | if not hasattr(cls, '_instance'):
32 | orig = super(Singleton, cls)
33 | cls._instance = orig.__new__(cls, *args, **kwargs)
34 | return cls._instance
35 |
36 |
37 | class InfoCenter(Singleton):
38 | def __init__(self):
39 | if hasattr(self,'tasks'):
40 | return
41 | if HAVE_GPU:
42 | try:
43 | # Init GPUs information
44 | pynvml.nvmlInit()
45 | self.totalGPUsInSystem = pynvml.nvmlDeviceGetCount()
46 | self.devices = [GPUInformation(i) for i in range(self.totalGPUsInSystem)] + [CPUInformation()]
47 | self.availableDevices = [i for i in range(self.totalGPUsInSystem)] + [-1]
48 | except Exception as e:
49 | # using CPU only
50 | self.devices = [CPUInformation()]
51 | self.availableDevices = [0]
52 | else:
53 | # using CPU only
54 | self.devices = [CPUInformation()]
55 | self.availableDevices = [0]
56 |
57 | self.taskLock = threading.Lock()
58 | self.deviceLock = threading.Lock()
59 | self.tasks = []
60 |
61 | def __getAvailableDevices(self):
62 | """
63 | return the device information classes of available devices.
64 | """
65 | return [self.devices[item] for item in self.availableDevices]
66 |
67 | def usingDevicesInformation(self):
68 | """
69 | Occupy the devices information
70 | """
71 | if self.deviceLock.acquire():
72 | return self.__getAvailableDevices()
73 |
74 | def releaseDeviceInformation(self):
75 | self.deviceLock.release()
76 |
77 | def getDevicesInformation(self):
78 | """
79 | Get the informaton of all devices.
80 | """
81 | if self.deviceLock.acquire():
82 | try:
83 | return [item.getDict() for item in self.devices]
84 | finally:
85 | self.deviceLock.release()
86 |
87 | def getAvailableDevicesInformation(self):
88 | """
89 | Get the information of available devices in dict.
90 | """
91 | if self.deviceLock.acquire():
92 | try:
93 | availableDevices = self.__getAvailableDevices()
94 | return [item.getDict() for item in availableDevices]
95 | finally:
96 | self.deviceLock.release()
97 |
98 | def changeDevices(self,newDevicesIDList : list):
99 | """
100 | Modify available devices
101 | """
102 | newList = []
103 | for item in newDevicesIDList:
104 | if item < -1 or item >= len(self.devices):
105 | print("Wrong device " + str(item))
106 | else:
107 | newList.append(item)
108 | self.availableDevices = newList
109 | return True
110 |
111 | def __checkTasks(self):
112 | """
113 | Delete all the tasks who are supposed to be running but have no alive process.
114 | """
115 | newList = []
116 | for item in self.tasks:
117 | if item.status == "Running":
118 | if item.process == None or not item.process.is_alive():
119 | continue
120 | else:
121 | newList.append(item)
122 | else:
123 | newList.append(item)
124 | self.tasks = newList
125 |
126 | def getTasksInformation(self):
127 | """
128 | Get the information of all tasks in dict.
129 | """
130 | if self.taskLock.acquire():
131 | try:
132 | self.__checkTasks()
133 | return [item.getDict() for item in self.tasks]
134 | finally:
135 | self.taskLock.release()
136 |
137 | def usingTasks(self):
138 | """
139 | Occupy the tasks information.
140 | """
141 | if self.taskLock.acquire():
142 | return self.tasks
143 |
144 | def releaseTasks(self):
145 | """
146 | Release the tasks.
147 | """
148 | self.taskLock.release()
149 |
150 | def addATask(self,task : TaskInformation):
151 | """
152 | Add a task to the tasks list.
153 | """
154 | if self.taskLock.acquire():
155 | try:
156 | self.__checkTasks()
157 | self.tasks.append(task)
158 | finally:
159 | self.taskLock.release()
160 |
161 | def runATask(self,task : TaskInformation):
162 | """
163 | Run a task to the devices. Tasks are supposed to be in self.task already
164 | """
165 | if self.deviceLock.acquire():
166 | try:
167 | for device in task.devices:
168 | task.status = "Running"
169 | self.devices[device].addATask(task)
170 | finally:
171 | self.deviceLock.release()
172 |
173 | def getTaskByID(self,taskID : str):
174 | """
175 | Get the task information class by a taskID.
176 | """
177 | if self.taskLock.acquire():
178 | try:
179 | for item in self.tasks:
180 | if item.ID == taskID:
181 | return item
182 | return None
183 | finally:
184 | self.taskLock.release()
185 |
186 | def delATask(self,taskID : str):
187 | """
188 | Delete a task by taskID.
189 | """
190 | if self.taskLock.acquire():
191 | try:
192 | self.__checkTasks()
193 | newList = []
194 | for item in self.tasks:
195 | if item.ID == taskID:
196 | if item.process != None:
197 | item.process.terminate()
198 | else:
199 | newList.append(item)
200 | self.tasks = newList
201 | finally:
202 | self.taskLock.release()
203 |
204 | def delAllTask(self):
205 | """
206 | Delete all task
207 | """
208 | if self.taskLock.acquire():
209 | try:
210 | self.__checkTasks()
211 | for task in self.tasks:
212 | if task.process != None:
213 | task.process.terminate()
214 | self.tasks = []
215 | finally:
216 | self.taskLock.release()
217 |
218 | if __name__ == "__main__":
219 | IC = InfoCenter()
220 | print(IC.getDevicesInformation())
--------------------------------------------------------------------------------
/DLNest/Executor/TrainProcess.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.TaskInformation import TaskInformation
2 | from DLNest.Information.TrainTask import TrainTask
3 | from DLNest.Executor.TaskProcess import TaskProcess
4 | from DLNest.Output.TrainStdout import TrainStdout
5 |
6 | from pathlib import Path
7 | import sys
8 |
9 | import time
10 | import os
11 |
12 | try:
13 | import torch
14 | import torch.distributed as dist
15 | except ImportError:
16 | pass
17 |
18 | class TrainProcess(TaskProcess):
19 | def __init__(self,task : TrainTask,showOnScreen = False):
20 | """
21 | if showOnScreen, also output to the stdout
22 | """
23 | super(TrainProcess,self).__init__(task)
24 | self.showOnScreen = showOnScreen
25 | self.commandQueue = task.commandQueue
26 |
27 | def initOutput(self,rank = -1):
28 | """
29 | redirect the output
30 | """
31 | os.chdir(self.task.args["root_file_path"]) # Change CWD to the save package
32 | outputFP = self.task.savePackage.getOutputFile(rank)
33 | self.outputDelegate = TrainStdout(outputFP,showOnScreen = self.showOnScreen,originalStdout = sys.stdout)
34 | sys.stdout = self.outputDelegate
35 | sys.stderr = self.outputDelegate
36 |
37 | def __saveModel(self):
38 | holdThisCheckpoint = self.lifeCycle.holdThisCheckpoint(self.finishedEpoch,self.logDict,self.task.args)
39 | self.saveCkpt(holdThisCheckpoint = holdThisCheckpoint)
40 |
41 | def __moveAData(self,data):
42 | try:
43 | if "cuda" in dir(data):
44 | ret_data = data.cuda()
45 | return ret_data
46 | elif "to" in dir(data):
47 | tmp = torch.tensor(1).cuda()
48 | ret_data = data.to(tmp.device)
49 | return ret_data
50 | else:
51 | return data
52 | except Exception as e:
53 | return data
54 |
55 | def __moveData(self,data):
56 | # move data to the proper location
57 | if self.status.env != "CPU":
58 | try:
59 | if isinstance(data,list):
60 | for index in range(len(data)):
61 | data[index] = self.__moveAData(data[index])
62 | elif isinstance(data, tuple):
63 | ret = []
64 | for index in range(len(data)):
65 | ret.append(self.__moveAData(data[index]))
66 | data = tuple(ret)
67 | elif isinstance(data,dict):
68 | for key in data:
69 | data[key] = self.__moveAData(data[key])
70 | else:
71 | data = self.__moveAData(self, data)
72 | except Exception as e:
73 | pass
74 | return data
75 |
76 | def __train_an_epoch(self,start_epoch : int):
77 | nowEpoch = start_epoch
78 | for _iter,data in enumerate(self.trainLoader):
79 | # run one step
80 | self.status._RunningStatus__iter = _iter
81 | if self.lifeCycle._BModelOneStep() != "Skip":
82 | data = self.__moveData(data)
83 | self.runner._runOneStep(data,self.logDict,_iter,nowEpoch)
84 | self.lifeCycle._AModelOneStep()
85 |
86 | # visualize
87 | if self.lifeCycle.needVisualize(nowEpoch,_iter,self.logDict,self.task.args):
88 | if self.status.env == "DDP":
89 | if self.status.rank == 0 and self.lifeCycle._BVisualize() != "Skip":
90 | self.runner._visualize(epoch = nowEpoch, iter = _iter, log = self.logDict)
91 | else:
92 | if self.lifeCycle._BVisualize() != "Skip":
93 | self.runner._visualize(epoch = nowEpoch, iter = _iter, log = self.logDict)
94 | self.lifeCycle._AVisualize()
95 |
96 |
97 | def __validate(self):
98 | self.runner._validationInit()
99 | for _iter,data in enumerate(self.valLoader):
100 | self.status._RunningStatus__iter = _iter
101 | if self.lifeCycle._BValidateABatch() != "Skip":
102 | data = self.__moveData(data)
103 | self.runner._validateABatch(data,_iter)
104 | self.lifeCycle._AValidateABatch()
105 |
106 | if self.lifeCycle._BValidationAnalyze() != "Skip":
107 | self.runner._validationAnalyze(self.logDict)
108 | self.lifeCycle._AValidationAnalyze()
109 |
110 | def mainLoop(self):
111 | try:
112 | nowEpoch = self.finishedEpoch + 1
113 | self.status._RunningStatus__epoch = nowEpoch
114 | if self.lifeCycle._BTrain() == "Skip":
115 | self.lifeCycle._ATrain()
116 | return
117 | while True:
118 | if self.lifeCycle._BOneEpoch() != "Skip":
119 | # Training!
120 | self.status.startTraining()
121 |
122 | if self.status.env == "DDP":
123 | self.trainLoader.sampler.set_epoch(nowEpoch)
124 |
125 | self.__train_an_epoch(nowEpoch)
126 |
127 | if self.status.env == "DDP":
128 | dist.barrier() # Sync before validation
129 |
130 | self.finishedEpoch = nowEpoch
131 | # output in command Line
132 | self.lifeCycle._commandLineOutput(self.finishedEpoch,self.logDict,self.task.args)
133 |
134 | # validation
135 | if self.lifeCycle.needValidation(self.finishedEpoch,self.logDict,self.task.args):
136 | if self.lifeCycle._BValidation() != "Skip":
137 | self.status.startValidating()
138 | if "validate" in dir(self.runner):
139 | self.runner._validate(self.valLoader,self.logDict)
140 | else:
141 | self.__validate()
142 | self.lifeCycle._AValidation()
143 |
144 | # start other operations
145 | self.status.startWaiting()
146 |
147 | if self.status.env == "DDP":
148 | dist.barrier() # Sync before saving
149 |
150 | #save checkpoint
151 | if self.status.env == "DDP":
152 | if self.status.rank == 0 and self.lifeCycle._BSaveModel() != "Skip":
153 | if self.lifeCycle.needSaveModel(self.finishedEpoch,self.logDict,self.task.args):
154 | self.__saveModel()
155 | else:
156 | if self.lifeCycle._BSaveModel() != "Skip":
157 | if self.lifeCycle.needSaveModel(self.finishedEpoch,self.logDict,self.task.args):
158 | self.__saveModel()
159 | self.lifeCycle._ASaveModel()
160 |
161 | self.lifeCycle._AOneEpoch()
162 | # break decision
163 | if self.lifeCycle.needContinueTrain(self.finishedEpoch,self.logDict,self.task.args):
164 | nowEpoch = self.finishedEpoch + 1
165 | self.status._RunningStatus__epoch = nowEpoch
166 | else:
167 | break
168 | except (Exception,SystemExit) as e:
169 | self.lifeCycle._trainAborting(e)
170 | else:
171 | # After Train
172 | self.lifeCycle._ATrain()
173 |
174 | def loadCkpt(self):
175 | super().loadCkpt()
176 | if self.task.loadCkpt:
177 | if self.task.checkpointID != -1:
178 | # -1 means the last one,which is default option. != -1 needs to set the ckptID rather than default
179 | self.task.savePackage.setCkptID(self.task.checkpointID + 1)
180 |
181 | if __name__ == "__main__":
182 | TT = TrainTask.fromConfigFile("/root/code/DLNestTest/root_config.json",devices = [0,1,2,3],noSave = True,DDP = True)
183 | # TT = TrainTask.fromRecord("/root/code/DLNestTest/Saves/2021-03-29_16-04-46_146",checkpointID = 5,devices = [0,1,2,3],DDP = False)
184 | TP = TrainProcess(TT,True)
185 | TP.start()
186 | try:
187 | while True:
188 | pass
189 | except KeyboardInterrupt:
190 | TP.terminate()
191 | TP.join()
192 | print("terminate")
193 |
--------------------------------------------------------------------------------
/DLNest/ShellClient/Client.py:
--------------------------------------------------------------------------------
1 | from prompt_toolkit import Application
2 | from prompt_toolkit.layout.containers import VSplit,HSplit
3 | from prompt_toolkit.layout.layout import Layout
4 | from prompt_toolkit.key_binding import KeyBindings
5 | from prompt_toolkit.styles import Style
6 |
7 | import json
8 | from pathlib import Path
9 | import argparse
10 | import os
11 |
12 | from DLNest.ShellClient.Windows.DevicesInfoShower import DevicesInfoShower
13 | from DLNest.ShellClient.Windows.CommandInput import CommandInput
14 | from DLNest.ShellClient.Windows.ResultsOutput import ResultsOutput,AnalyzeOutput
15 | from DLNest.ShellClient.Windows.TaskInfoShower import TaskInfoShower
16 | from DLNest.ShellClient.Communicator import Communicator
17 |
18 | class Client:
19 | def __init__(self, url : str = "127.0.0.1", port : int = "9999"):
20 | self.communicator = Communicator(url,port)
21 |
22 | self.CMDIN = CommandInput(title="DLNest Command Line(F1)",onAccept=self.onCommandAccept)
23 | self.w1 = self.CMDIN.getWindow()
24 |
25 | self.DLOutput = ResultsOutput(routineTask=self.routineTaskDLOutput,title = "DLNest Output (F2)",style="class:dlnest_output")
26 | self.w2 = self.DLOutput.getWindow()
27 |
28 | self.ANOutput = AnalyzeOutput(routineTask=self.routineTaskANOutput,title = "Analyzer Output (F3)",style="class:analyzer_output")
29 | self.w3 = self.ANOutput.getWindow()
30 | self.analyzeTaskID = ""
31 |
32 | self.TaskInfo = TaskInfoShower(routineTask = self.routineTaskInfo,title = "Tasks (F4)")
33 | self.w4 = self.TaskInfo.getWindow()
34 |
35 | self.DevicesInfo = DevicesInfoShower(routineTask = self.routineTaskDevices, title = "Devices (F5)")
36 | self.w5 = self.DevicesInfo.getWindow()
37 |
38 | self.container_fat = HSplit([
39 | self.w1,
40 | VSplit([self.w2,self.w3]),
41 | VSplit([self.w4,self.w5])
42 | ])
43 | self.container_tall = HSplit([
44 | self.w1,
45 | self.w2,
46 | self.w3,
47 | self.w4,
48 | self.w5
49 | ])
50 |
51 | self.kb = KeyBindings()
52 | @self.kb.add('c-c')
53 | def exit_(event):
54 | event.app.exit()
55 |
56 | @self.kb.add('f1')
57 | def focus1(event):
58 | event.app.layout.focus(self.w1)
59 |
60 | @self.kb.add('f2')
61 | def focus2(event):
62 | event.app.layout.focus(self.w2)
63 |
64 | @self.kb.add('f3')
65 | def focus3(event):
66 | event.app.layout.focus(self.w3)
67 |
68 | @self.kb.add('f4')
69 | def focus4(event):
70 | event.app.layout.focus(self.w4)
71 |
72 | @self.kb.add('f5')
73 | def focus5(event):
74 | event.app.layout.focus(self.w5)
75 |
76 | self.style = Style.from_dict({
77 | "frame.border" : "fg:#ffb6c1",
78 | "frame.title" : "fg:#1ef0ff",
79 | "command_frame" : "bg:#008b8b",
80 | "dlnest_output" : "bg:#451a4a",
81 | "analyzer_output" : "bg:#451a4a",
82 | "analyzer_info_label" : "bg:#da70d6",
83 | "analyzer_info_text1" : "bg:#3f3f00",
84 | "analyzer_info_text2" : "bg:#ff00ff",
85 | "running_task_status" : "bg:#a01010 bold",
86 | "running_task_id" : "bg:#303030",
87 | "running_task_gpu" : "bg:#556b2f",
88 | "running_task_des" : "bg:#c71585",
89 | "running_task_time" : "bg:#2e3b37",
90 | "pending_task_status" : "bg:#1010a0 bold",
91 | "pending_task_id" : "bg:#303030",
92 | "pending_task_gpu" : "bg:#556b2f",
93 | "pending_task_des" : "bg:#c71585",
94 | "pending_task_time" : "bg:#2e3b37",
95 | "suspend_task_status" : "bg:#10a010 bold",
96 | "suspend_task_id" : "bg:#303030",
97 | "suspend_task_gpu" : "bg:#556b2f",
98 | "suspend_task_des" : "bg:#c71585",
99 | "suspend_task_time" : "bg:#2e3b37",
100 | "task_info_shower" : "bg:#008bc0",
101 | "devices_info_shower" : "bg:#008bc0",
102 | "devices_id" : "bg:#303030",
103 | "devices_status_valid" : "bg:#3cb371 bold",
104 | "devices_status_break" : "bg:#a01010 bold",
105 | "devices_free_memory" : "bg:#556b2f",
106 | "devices_tasks" : "bg:#c71585"
107 | })
108 |
109 | self.layout = Layout(self.container_fat,focused_element=self.w1)
110 | self.app = Application(key_bindings=self.kb, layout=self.layout, full_screen=True,style=self.style)
111 | self.app._on_resize = self.on_resize
112 |
113 | def on_resize(self):
114 | cols, rows = os.get_terminal_size(0)
115 | focused_element = self.layout.current_window
116 | if cols >= 2 * rows: # fat
117 | self.app.layout = Layout(self.container_fat,focused_element=focused_element)
118 | else: # tall
119 | self.app.layout = Layout(self.container_tall,focused_element=focused_element)
120 |
121 | self.app.renderer.erase(leave_alternate_screen=False)
122 | self.app._request_absolute_cursor_position()
123 | self.app._redraw()
124 |
125 | def getApp(self):
126 | return self.app
127 |
128 | def onCommandAccept(self,s : str):
129 | commandWordList = s.split(" ")
130 | while "" in commandWordList:
131 | commandWordList.remove("")
132 |
133 | # no command
134 | if len(commandWordList) <= 0:
135 | return
136 |
137 | if commandWordList[0] == "watch":
138 | self.analyzeTaskID = commandWordList[1]
139 | elif commandWordList[0] == "withdraw":
140 | self.analyzeTaskID = ""
141 |
142 | if commandWordList[0] == "runExp":
143 | if len(commandWordList) != 3:
144 | if self.analyzeTaskID != "":
145 | commandWordList = [commandWordList[0],self.analyzeTaskID,commandWordList[1]]
146 | else:
147 | return
148 |
149 | ret = self.communicator.giveACommand(commandWordList)
150 |
151 | if commandWordList[0] == "del":
152 | if ret["status"] == "success" and commandWordList[1] == self.analyzeTaskID:
153 | self.analyzeTaskID = ""
154 |
155 | if "exit" in ret:
156 | self.app.exit()
157 |
158 | def routineTaskDLOutput(self, obj):
159 | #for buffer fresh
160 | if not hasattr(obj,"_count_"):
161 | obj._count_ = 0
162 |
163 | outStyledDict = self.communicator.giveACommand(["showDL","-s"])
164 | outPlainDict = self.communicator.giveACommand(["showDL"])
165 | if "text" in outStyledDict and "text" in outPlainDict:
166 | try:
167 | obj.lexer.styled_text = outStyledDict["text"]
168 | obj.shower.text = outPlainDict["text"]
169 | except Exception as e:
170 | pass
171 |
172 | def routineTaskANOutput(self, obj):
173 | #for buffer fresh
174 | if not hasattr(obj,"_count_"):
175 | obj._count_ = 0
176 |
177 | if self.analyzeTaskID == "":
178 | obj.lexer.styled_text = []
179 | obj.shower.text = ""
180 | obj.infoText.text = [("","No valid analyzer task is running")]
181 | obj.infoWindow.width = 33
182 | return
183 |
184 | outStyledDict = self.communicator.giveACommand(["showAN","-t",self.analyzeTaskID,"-s"])
185 | outPlainDict = self.communicator.giveACommand(["showAN","-t",self.analyzeTaskID])
186 | if "text" in outStyledDict and "text" in outPlainDict:
187 | try:
188 | obj.lexer.styled_text = outStyledDict["text"]
189 | obj.shower.text = outPlainDict["text"]
190 | obj.infoText.text = [("class:analyzer_info_text1",self.analyzeTaskID)]
191 | obj.infoWindow.width = len(self.analyzeTaskID)
192 | except Exception as e:
193 | pass
194 | else:
195 | self.analyzeTaskID = ""
196 |
197 | def routineTaskInfo(self,obj):
198 | # for buffer fresh
199 | if not hasattr(obj,"_count_"):
200 | obj._count_ = 0
201 |
202 | r = self.communicator.giveACommand(["showTask"])
203 | if r["status"] != "success":
204 | obj.lexer.taskInfo = []
205 | obj.shower.text = obj.lexer.get_text()
206 | return
207 | taskInfo = r["info"]
208 | try:
209 | obj.lexer.taskInfo = taskInfo
210 | obj.shower.text = obj.lexer.get_text()
211 | except Exception as e:
212 | pass
213 |
214 | def routineTaskDevices(self, obj):
215 | # for buffer fresh
216 | if not hasattr(obj,"_count_"):
217 | obj._count_ = 0
218 |
219 | r = self.communicator.giveACommand(["showDevice"])
220 | if r["status"] != "success":
221 | obj.lexer.devicesInfo = []
222 | obj.shower.text = obj.lexer.get_text()
223 | return
224 | obj.lexer.devicesInfo = r["info"]
225 | try:
226 | obj.shower.text = obj.lexer.get_text()
227 | except Exception as e:
228 | pass
229 |
230 | def startClient():
231 | parser = argparse.ArgumentParser()
232 | parser.add_argument("-u",type=str, default="127.0.0.1",help="DLNest server address")
233 | parser.add_argument("-p",type=int, default=9999, help = "DLNest server port")
234 | args=parser.parse_args()
235 | client = Client(url = args.u, port = args.p)
236 | client.getApp().run()
--------------------------------------------------------------------------------
/DLNest/Local.py:
--------------------------------------------------------------------------------
1 | from DLNest.Operations.Analyze import analyze
2 | from DLNest.Operations.ChangeDelay import changeDelay
3 | from DLNest.Operations.ChangeDevices import changeDevices
4 | from DLNest.Operations.ChangeMaxTaskPerDevice import changeMaxTaskPerDevice
5 | from DLNest.Operations.ContinueTrain import continueTrain
6 | from DLNest.Operations.DelATask import delATask
7 | from DLNest.Operations.GetAnalyzeOutput import getAnalyzeOutput
8 | from DLNest.Operations.GetDevicesInformation import getDevicesInformation
9 | from DLNest.Operations.GetTasksInformation import getTasksInformation
10 | from DLNest.Operations.New import new
11 | from DLNest.Operations.Run import run
12 | from DLNest.Operations.RunExp import runExp
13 | from DLNest.Operations.SafeExit import safeExit
14 | from DLNest.Operations.UsePlugin import usePlugins
15 |
16 | import argparse
17 | from prompt_toolkit import PromptSession,HTML
18 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
19 | import traceback
20 |
21 | class Arguments:
22 | def __init__(self,desc : str = ""):
23 | self._parser = argparse.ArgumentParser(description=desc)
24 |
25 | def parser(self):
26 | return self._parser
27 |
28 | class TrainArguments(Arguments):
29 | def __init__(self):
30 | super(TrainArguments, self).__init__(desc="Arguments for DLNest task.")
31 |
32 | self._parser.add_argument("-c",type=str, help="root configuration json file for this task.",required = True)
33 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)")
34 | self._parser.add_argument("-f",type=str, default = "", help="frequently changing configuration json file for this task.(default:None)")
35 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)")
36 | self._parser.add_argument("-ns",action='store_true', help="Set to save to the NOSAVE dir.")
37 | self._parser.add_argument("-mc",action='store_true',help="Set to use multi card.")
38 | self._parser.add_argument("-sd",action='store_true',help="Set to use description as the save dir name.(coverd by ns)")
39 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.")
40 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.")
41 |
42 | class ProjectArguments(Arguments):
43 | def __init__(self):
44 | super(ProjectArguments, self).__init__(desc="Arguments for create a DLNest project.")
45 |
46 | self._parser.add_argument("-d",type=str, help="Path to the directory you want to create the project.", required = True)
47 | self._parser.add_argument("-MNIST",action='store_true', help="Set to new a project with MNIST task.")
48 | self._parser.add_argument("-p", type = str, nargs='+',default=[], help = "Set plugins need to be used.")
49 |
50 | class AnalyzeArguments(Arguments):
51 | def __init__(self):
52 | super(AnalyzeArguments, self).__init__(desc="Arguments for an Analyzer")
53 |
54 | self._parser.add_argument("-r",type=str, help = "path to the model record directory.", required = True)
55 | self._parser.add_argument("-s",type=str, default = "", help = "path to the analyze scripts.")
56 | self._parser.add_argument("-c",type=int, default = -1, help = "which epoch you want the model to load.(int)")
57 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)")
58 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.")
59 |
60 | class ContinueArguments(Arguments):
61 | def __init__(self):
62 | super(ContinueArguments, self).__init__(desc="Arguments for an Analyzer")
63 |
64 | self._parser.add_argument("-r",type=str, help = "path to the model record directory.", required = True)
65 | self._parser.add_argument("-c",type=int, default = -1, help = "which epoch you want the model to load.(int)")
66 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)")
67 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)")
68 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.")
69 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.")
70 | self._parser.add_argument("-mc",action='store_true',help="Set to use multi card.")
71 |
72 | class DeviceChangeArguments(Arguments):
73 | def __init__(self):
74 | super(DeviceChangeArguments, self).__init__(desc="Arguments for change valid cards.")
75 | self._parser.add_argument("-d",type=int, nargs='+', help='valid devices', required = True)
76 |
77 | class AddPluginsArguents(Arguments):
78 | def __init__(self):
79 | super(AddPluginsArguents, self).__init__(desc="Arguments for add plugins.")
80 |
81 | self._parser.add_argument("-d", type=str, help="Path to the directory you want to create the project.", required = True)
82 | self._parser.add_argument("-p", type = str, nargs='+', help = "Set plugins need to be used.",required = True)
83 | self._parser.add_argument("-F", action='store_true', help="Set to use full config")
84 |
85 | class DLNestLocal:
86 | def __init__(self):
87 | self.trainArgParser = TrainArguments()
88 | self.continueArgParser = ContinueArguments()
89 | self.projectArgParser = ProjectArguments()
90 | self.analyzeArgParser = AnalyzeArguments()
91 | self.deviceChangeArgParser = DeviceChangeArguments()
92 | self.addPluginsArgParser = AddPluginsArguents()
93 |
94 | def runTrain(self,commandWordList : list):
95 | args,otherArgs = self.trainArgParser.parser().parse_known_args(commandWordList[1:])
96 | run(
97 | configPath = args.c,
98 | freqPath = args.f,
99 | description = args.d,
100 | memoryConsumption = args.m,
101 | CPU = args.CPU,
102 | DDP = args.DDP,
103 | multiGPU = args.mc,
104 | noSave = args.ns,
105 | useDescriptionToSave = args.sd,
106 | otherArgs = {}
107 | )
108 |
109 | def newProject(self,commandWordList : list):
110 | args,otherArgs = self.projectArgParser.parser().parse_known_args(commandWordList[1:])
111 | new(
112 | targetDir = args.d,
113 | MNIST = args.MNIST,
114 | pluginsName = args.p
115 | )
116 |
117 | def runAnalyze(self,commandWordList : list):
118 | args,otherArgs = self.analyzeArgParser.parser().parse_known_args(commandWordList[1:])
119 | analyze(
120 | recordPath = args.r,
121 | scriptPath = args.s,
122 | checkpointID = args.c,
123 | CPU = args.CPU,
124 | memoryConsumption = args.m,
125 | otherArgs = {}
126 | )
127 |
128 | def continueTrain(self,commandWordList : list):
129 | args,otherArgs = self.continueArgParser.parser().parse_known_args(commandWordList[1:])
130 | continueTrain(
131 | recordPath = args.r,
132 | checkpointID = args.c,
133 | memoryConsumption = args.m,
134 | CPU = args.CPU,
135 | DDP = args.DDP,
136 | multiGPU = args.mc,
137 | description = args.d,
138 | otherArgs = {}
139 | )
140 |
141 | def changeDevices(self,commandWordList : list):
142 | args,otherArgs = self.deviceChangeArgParser.parser().parse_known_args(commandWordList[1:])
143 | changeDevices(args.d)
144 |
145 | def runExp(self,commandWordList : list):
146 | runExp(commandWordList[1],commandWordList[2])
147 |
148 | def addPlugin(self, commandWordList : list):
149 | args, otherArgs = self.addPluginsArgParser.parser().parse_known_args(commandWordList[1:])
150 | usePlugins(
151 | targetDir = args.d,
152 | pluginsName = args.p,
153 | full = args.F
154 | )
155 |
156 | def run(self):
157 | self.session = PromptSession(auto_suggest = AutoSuggestFromHistory())
158 | while True:
159 | try:
160 | command = self.session.prompt(HTML("DLNest>>"))
161 | commandWordList = command.strip().split(' ')
162 | if commandWordList[0] == "run":
163 | self.runTrain(commandWordList)
164 | elif commandWordList[0] == "continue":
165 | self.continueTrain(commandWordList)
166 | elif commandWordList[0] == "new":
167 | self.newProject(commandWordList)
168 | elif commandWordList[0] == "analyze":
169 | self.runAnalyze(commandWordList)
170 | elif commandWordList[0] == "runExp":
171 | self.runExp(commandWordList)
172 | elif commandWordList[0] == "del":
173 | delATask(commandWordList[1])
174 | elif commandWordList[0] == "showAN":
175 | print(getAnalyzeOutput(commandWordList[1],False)[1])
176 | elif commandWordList[0] == "showTask":
177 | print(getTasksInformation())
178 | elif commandWordList[0] == "showDevice":
179 | print(getDevicesInformation())
180 | elif commandWordList[0] == 'changeDevices':
181 | self.changeDevices(commandWordList)
182 | elif commandWordList[0] == "addP":
183 | self.addPlugin(commandWordList)
184 | elif commandWordList[0] == "exit":
185 | safeExit()
186 | exit(0)
187 | else:
188 | print("Wrong command")
189 | except KeyboardInterrupt:
190 | safeExit()
191 | exit(0)
192 | except Exception as e:
193 | s = traceback.format_exc()
194 | listS = s.split("\n")[:-1]
195 | s = "\n".join(listS[-3:])
196 | print(s)
197 |
198 | if __name__ == "__main__":
199 | import sys
200 | if sys.path[0] != '':
201 | sys.path[0] = ''
202 | main = DLNestLocal()
203 | main.run()
--------------------------------------------------------------------------------
/DLNest/Executor/TaskProcess.py:
--------------------------------------------------------------------------------
1 | from DLNest.Information.TaskInformation import TaskInformation
2 |
3 | import multiprocessing
4 | import os
5 | import sys
6 | from multiprocessing import Process
7 | import importlib
8 | import shutil
9 | import json
10 | import random
11 | from pathlib import Path
12 | import numpy as np
13 |
14 | from abc import ABCMeta, abstractmethod
15 | try:
16 | import torch
17 | from torch.nn.parallel import DistributedDataParallel as DDP
18 | import torch.distributed as dist
19 | import torch.multiprocessing as mp
20 | USINGTORCH = True
21 | except ImportError:
22 | USINGTORCH = False
23 |
24 | from DLNest.Common.RunningStatus import RunningStatus
25 |
26 | class TaskProcess(Process):
27 | def __init__(self,task : TaskInformation):
28 | super(TaskProcess,self).__init__()
29 | self.task = task
30 | self.status = RunningStatus()
31 | self.finishedEpoch = -1
32 | self.plugins = []
33 |
34 | def __loadAModule(self,filePath : Path,name : str):
35 | if not filePath.is_absolute():
36 | filePath = Path(self.task.args["root_file_path"]) / filePath
37 | spec = importlib.util.spec_from_file_location(
38 | name,
39 | filePath
40 | )
41 | module = importlib.util.module_from_spec(spec)
42 | dirName = str(filePath.parent)
43 | if not dirName in sys.path:
44 | sys.path.append(dirName)
45 | spec.loader.exec_module(module)
46 | return module
47 |
48 | def __loadLifeCycle(self):
49 | self.lifeCycleModule = self.__loadAModule(Path(self.task.args["life_cycle_file_path"]),"LifeCycle")
50 |
51 | def __initLifeCycle(self,rank : int = -1):
52 | lifeCycleName = self.task.args['life_cycle_name']
53 | if lifeCycleName in dir(self.lifeCycleModule):
54 | self.lifeCycle = self.lifeCycleModule.__getattribute__(lifeCycleName)(
55 | taskProcess = self,
56 | plugins = self.plugins,
57 | status = self.status
58 | )
59 | if self.task.loadCkpt:
60 | self.lifeCycle._loadSaveDict(self.stateDict["life_cycle"])
61 | else:
62 | raise Exception("Cannot find lifeCycle class")
63 |
64 | def __loadAPlugin(self,pluginName : str):
65 | pluginPath = Path(__file__).parent.parent / "Plugins" / (pluginName + '.py')
66 | tmpPath = Path(pluginName)
67 | if tmpPath.is_absolute():
68 | pluginPath = tmpPath
69 | pluginName = tmpPath.name
70 | pluginModule = self.__loadAModule(filePath = pluginPath,name = pluginName)
71 | pluginClass = pluginModule.__getattribute__("DLNestPlugin")
72 | pluginClass._status = self.status
73 | return pluginClass
74 |
75 | def __loadPlugins(self):
76 | if not "plugins" in self.task.args:
77 | return []
78 | pluginNames = self.task.args["plugins"]
79 | pluginList = []
80 | for name in pluginNames:
81 | pluginList.append(self.__loadAPlugin(name))
82 | self.plugins = pluginList
83 |
84 | def __loadRunnerAndDataset(self):
85 | runnerPath = Path(self.task.args["runner_file_path"] if "runner_file_path" in self.task.args else self.task.args["model_file_path"]) # need to be deprecated
86 | datasetPath = Path(self.task.args["dataset_file_path"])
87 | self.runnerModule = self.__loadAModule(runnerPath,runnerPath.stem)
88 | self.datasetModule = self.__loadAModule(datasetPath,datasetPath.stem)
89 | sys.modules[datasetPath.stem] = self.datasetModule # Why?
90 |
91 | def __initDataset(self):
92 | datasetName = self.task.args['dataset_name']
93 | if datasetName in dir(self.datasetModule):
94 | datasetClass = self.datasetModule.__getattribute__(datasetName)
95 | self.dataset = datasetClass(
96 | args = self.task.args,
97 | plugins = self.plugins,
98 | status = self.status
99 | )
100 | self.datasetInfo,self.trainLoader,self.valLoader = self.dataset._datasetInit(self.task.args)
101 | # load from ckpt is needed.
102 | if self.task.loadCkpt:
103 | self.dataset._loadSaveDict(self.stateDict["dataset"])
104 | else:
105 | raise Exception("Cannot find dataset class")
106 |
107 | def __initRunner(self):
108 | runnerName = self.task.args['runner_name'] if "runner_name" in self.task.args else self.task.args["model_name"] # need to be deprecated
109 | if runnerName in dir(self.runnerModule):
110 | runnerClass = self.runnerModule.__getattribute__(runnerName)
111 | self.runner = runnerClass(
112 | args = self.task.args,
113 | plugins = self.plugins,
114 | status = self.status
115 | )
116 | self.runner._runnerInit(self.task.args,self.datasetInfo)
117 | if self.status.env != "DDP":
118 | self.runner._initOptimizer()
119 | else:
120 | self.runner.DDPOperation(rank = self.status.rank)
121 | self.runner._initOptimizer()
122 | self.runner.afterDDP(rank = self.status.rank)
123 | # load from ckpt is needed.
124 | if self.task.loadCkpt:
125 | self.runner._loadSaveDict(self.stateDict["runner"] if "runner" in self.stateDict else self.stateDict["model"])
126 | else:
127 | # if load from ckpt, logDict has been loaded in self.loadCkpt
128 | self.logDict = self.runner._initLog()
129 |
130 | def checkDeviceEnviroment(self):
131 | """
132 | make correct environ params and
133 | return "CPU","GPU","GPUs","DDP"
134 | """
135 | assert len(self.task.devices) > 0
136 | if self.task.devices[0] == -1:
137 | return "CPU"
138 |
139 | ids = [str(item) for item in self.task.devices]
140 | assert not ("-1" in ids) #no CPU in the devices list. Preventing world size error
141 |
142 | if self.task.DDP:
143 | assert USINGTORCH
144 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(ids)
145 | self.deviceNum = len(ids)
146 | return "DDP"
147 | elif self.task.multiGPU:
148 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(ids)
149 | return "GPUs"
150 | else:
151 | torch.cuda.set_device(self.task.devices[0])
152 | os.environ["CUDA_VISIBLE_DEVICES"] = ids[0]
153 | return "GPU"
154 |
155 | def setupSeed(self):
156 | seed = self.seed
157 | np.random.seed(seed)
158 | random.seed(seed)
159 | if USINGTORCH:
160 | torch.manual_seed(seed)
161 | torch.cuda.manual_seed_all(seed)
162 | torch.backends.cudnn.deterministic = True
163 |
164 | def initBeforeDDP(self):
165 | self.seed = random.randint(0,2147483647)
166 |
167 | def initAfterDDP(self,rank,worldSize):
168 | os.environ["MASTER_ADDR"] = self.task.address
169 | os.environ["MASTER_PORT"] = self.task.port
170 | torch.cuda.set_device(rank)
171 | dist.init_process_group("nccl", rank = rank, world_size = worldSize)
172 | self.status._RunningStatus__rank = rank
173 | self.status._RunningStatus__worldSize = worldSize
174 | self.setupSeed()
175 |
176 | def runDDP(self,rank):
177 | self.initAfterDDP(rank,self.deviceNum)
178 | self.initOutput(rank = rank)
179 | self.loadCkpt()
180 | if self.initModules(rank = rank) != "Skip":
181 | dist.barrier() # start together
182 | self.mainLoop()
183 |
184 | self.lifeCycle._AAll()
185 | dist.destroy_process_group()
186 | exit(0)
187 |
188 | @abstractmethod
189 | def mainLoop(self):
190 | """
191 | needs to be override by other subprocess.
192 | """
193 | pass
194 |
195 | @abstractmethod
196 | def initOutput(self,rank = -1):
197 | """
198 | needs to be override by other subprocess.
199 | """
200 | pass
201 |
202 | def loadCkpt(self):
203 | """
204 | If task have ckpt to load, load it to the pointed device.
205 | """
206 | deviceStr = ""
207 | if self.task.devices[0] == -1:
208 | deviceStr = "cpu"
209 | else:
210 | deviceStr = "cuda"
211 | if self.task.loadCkpt:
212 | self.stateDict = self.task.savePackage.getStateDict(self.task.checkpointID,deviceStr)
213 | self.logDict = self.stateDict["log_dict"]
214 | self.finishedEpoch = self.stateDict["finished_epoch"]
215 |
216 | def saveCkpt(self,otherDict : dict = {}, holdThisCheckpoint = False):
217 | """
218 | Save the states of life cycle, dataset and runner.
219 | """
220 | stateDict = {}
221 | stateDict["life_cycle"] = self.lifeCycle._getSaveDict()
222 | stateDict["dataset"] = self.dataset._getSaveDict()
223 | stateDict["runner"] = self.runner._getSaveDict()
224 | stateDict["log_dict"] = self.logDict
225 | stateDict["finished_epoch"] = self.finishedEpoch
226 | for key in otherDict:
227 | stateDict[key] = otherDict[key]
228 | self.task.savePackage.saveACheckpoint(stateDict,holdThisCheckpoint=holdThisCheckpoint)
229 |
230 | def initModules(self,rank : int = -1):
231 | """
232 | Init all modules
233 | """
234 | self.__loadLifeCycle()
235 | self.__loadPlugins()
236 | self.__initLifeCycle(rank = rank)
237 |
238 | self.__loadRunnerAndDataset()
239 |
240 | if self.lifeCycle._BAll() != "Skip":
241 | if self.lifeCycle._BDatasetInit() != "Skip":
242 | self.__initDataset()
243 | self.lifeCycle.dataset = self.dataset
244 | self.lifeCycle._ADatasetInit()
245 |
246 | if self.lifeCycle._BModelInit() != "Skip":
247 | self.__initRunner()
248 | self.lifeCycle.runner = self.runner
249 | self.lifeCycle._AModelInit()
250 | return
251 | else:
252 | return "Skip"
253 |
254 | def run(self):
255 | self.status._RunningStatus__env = self.checkDeviceEnviroment()
256 | if self.status.env == "DDP":
257 | # run DDP
258 | self.ppid = os.getpid()
259 | self.initBeforeDDP()
260 | context = mp.spawn(
261 | self.runDDP,
262 | nprocs=self.deviceNum,
263 | join=False,
264 | daemon = True
265 | )
266 | self.initOutput()
267 | while not context.join():
268 | pass
269 | else:
270 | self.initOutput()
271 | self.loadCkpt()
272 | if self.initModules() != "Skip":
273 | self.mainLoop()
274 | self.lifeCycle._AAll()
275 |
--------------------------------------------------------------------------------
/DLNest/SavePackage/SavePackage.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import json
3 | from pathlib import Path
4 | import time
5 | import random
6 | try:
7 | import torch
8 | except ImportError:
9 | # need to costomize the doSaveOperation and doLoadOperation functions
10 | pass
11 |
12 | class SavePackage:
13 | def __init__(
14 | self,
15 | configPath : str = "",
16 | freqPath : str = ""
17 | ):
18 | self.args = {}
19 | haveArgs = False
20 | if configPath != "":
21 | configPath = Path(configPath)
22 | self.__loadArgs(configPath,self.args)
23 | haveArgs = True
24 | if freqPath != "":
25 | freqPath = Path(freqPath)
26 | self.__loadArgs(freqPath,self.args)
27 | haveArgs = True
28 |
29 | self.ckpts = {}
30 | self.nowCkptID = 0
31 | self.ckptSlow = []
32 | self.ckptFast = []
33 | self.ckptConsistent = []
34 | self.checkpointsDir = None
35 | if haveArgs:
36 | self.maxCkptSlow = self.args["checkpoint_args"]["max_ckpt_in_slow_track"]
37 | self.maxCkptFast = self.args["checkpoint_args"]["max_ckpt_in_fast_track"]
38 | self.maxCkptConsisitent = self.args["checkpoint_args"]["max_ckpt_in_consistent_track"]
39 | self.slowDilation = self.args["checkpoint_args"]["dilation_in_slow_track"]
40 |
41 | self.root = None
42 | self.checkpointsDir = None
43 | self.prefix = "state_"
44 |
45 | def giveArgs(self,args : dict): # Not used
46 | self.args = args
47 | self.maxCkptSlow = self.args["checkpoint_args"]["max_ckpt_in_slow_track"]
48 | self.maxCkptFast = self.args["checkpoint_args"]["max_ckpt_in_fast_track"]
49 | self.maxCkptConsisitent = self.args["checkpoint_args"]["max_ckpt_in_consistent_track"]
50 | self.slowDilation = self.args["checkpoint_args"]["dilation_in_slow_track"]
51 |
52 | def __replaceArgs(self,newArgName,newArgValue,args):
53 | """
54 | newArgName : 应该是args下的名称
55 | 若新参数不是一个dict,或新参数不存在覆盖问题,则直接新建该参数或覆盖
56 | 若新参数是一个dict且存在覆盖问题,则递归dict调用
57 | """
58 | if not newArgName in args:
59 | args[newArgName] = newArgValue
60 | return
61 |
62 | if not isinstance(newArgValue,dict):
63 | args[newArgName] = newArgValue
64 | return
65 |
66 | for key in newArgValue:
67 | self.__replaceArgs(key,newArgValue[key],args[newArgName])
68 | return
69 |
70 | def __loadArgs(self,filePath : Path,args : dict):
71 | """
72 | load configs to the args dict
73 | """
74 | # 若文件不存在或不是一个文件,直接报错
75 | if not filePath.is_file():
76 | raise BaseException("Config file doesn't exists. " + str(filePath))
77 | fp = filePath.open('r')
78 | newArgs = json.load(fp)
79 |
80 | # 对除了child_jsons的每一个key尝试覆盖或新建args
81 | for key in newArgs:
82 | if key == "child_jsons":
83 | continue
84 |
85 | # 为避免dict类型的参数被完全覆盖,使用__replaceArgs方法新建或覆盖args
86 | self.__replaceArgs(key,newArgs[key],args)
87 |
88 | # 递归查找子json,子json覆盖父json的参数(按照DFS序)
89 | if "child_jsons" in newArgs:
90 | for item in newArgs["child_jsons"]:
91 | path = Path(item)
92 |
93 | # 若子json路径不是绝对路径,则按照当前json路径寻找相对路径
94 | if not path.is_absolute():
95 | path = filePath.parent / item
96 |
97 | # 载入子json
98 | self.__loadArgs(path,args)
99 |
100 | fp.close()
101 | return
102 |
103 | def __copyAFile(self,filePath : Path, saveDir : Path):
104 | '''
105 | 若filePath为相对路径,则复制到其对应文件夹
106 | 若filePath为绝对路径,则复制到储存包的根
107 | '''
108 | if filePath.is_absolute():
109 | if filePath.is_dir():
110 | shutil.copytree(filePath,saveDir / filePath.stem)
111 | else:
112 | shutil.copy(filePath,saveDir / (filePath.stem + filePath.suffix))
113 | else:
114 | abFilePath = Path(self.args["root_file_path"]) / filePath
115 | if abFilePath.is_dir():
116 | shutil.copytree(abFilePath,saveDir / filePath)
117 | else:
118 | target = saveDir / filePath
119 | target_dir = target.parent
120 | if not target_dir.exists():
121 | target_dir.mkdir(parents = True, exist_ok = True)
122 | shutil.copy(abFilePath,target)
123 |
124 | def __savePackageInformation(self):
125 | packageInfoPath = self.root / "_package.json"
126 | packageInfo = {
127 | "ckpt_slow" : self.ckptSlow,
128 | "ckpt_fast" : self.ckptFast,
129 | "ckpt_consistent" : self.ckptConsistent,
130 | "prefix" : self.prefix
131 | }
132 | with packageInfoPath.open("w") as fp:
133 | json.dump(packageInfo,fp,sort_keys=True, indent=4, separators=(',', ':'))
134 |
135 | def __loadPackageInformation(self):
136 | packageInfoPath = self.root / "_package.json"
137 | packageInfo = {}
138 | with packageInfoPath.open("r") as fp:
139 | packageInfo = json.load(fp)
140 | self.ckptSlow = packageInfo["ckpt_slow"]
141 | self.ckptFast = packageInfo["ckpt_fast"]
142 | self.ckptConsistent = packageInfo["ckpt_consistent"]
143 | self.prefix = packageInfo["prefix"]
144 |
145 | def saveToNewDir(self,overrideSaveName : str = "",copyFiles = True):
146 | """
147 | Make a new save package. If having the override save name, use it. If not, use timestamp to save.
148 | """
149 | saveRoot = Path(self.args["save_root"])
150 | if not saveRoot.is_absolute():
151 | saveRoot = Path(self.args["root_file_path"]) / saveRoot
152 | self.args["save_root"] = str(saveRoot)
153 | nowTime = time.time()
154 | saveName = ""
155 | if overrideSaveName != "":
156 | saveName = overrideSaveName
157 | else:
158 | saveName = time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(nowTime)) + "_" + str(random.randint(100,999)) # avoid conflict. May be changed later
159 | saveDir = saveRoot / saveName
160 | self.root = saveDir
161 |
162 | if saveDir.exists():
163 | shutil.rmtree(saveDir)
164 | saveDir.mkdir(parents=True,exist_ok=True)
165 |
166 | #checkpoints save dir
167 | self.checkpointsDir = saveDir / "Checkpoints"
168 | self.checkpointsDir.mkdir(parents=True,exist_ok=True)
169 | self.ckpts = {}
170 |
171 | # copy python files into dir if copy Files is TRUE
172 | if copyFiles:
173 | self.__copyAFile(Path(self.args["runner_file_path"] if "runner_file_path" in self.args else self.args["model_file_path"]),saveDir) # need to be deprecated
174 | self.__copyAFile(Path(self.args["dataset_file_path"]),saveDir)
175 | self.__copyAFile(Path(self.args["life_cycle_file_path"]),saveDir)
176 | for item in self.args["other_file_paths"]:
177 | self.__copyAFile(Path(item),saveDir)
178 |
179 | # save args
180 | argsPath = saveDir / "args.json"
181 | argsFP = argsPath.open('w')
182 | self.args["root_file_path"] = str(saveDir)
183 | json.dump(self.args,argsFP,sort_keys=True, indent=4, separators=(',', ':'))
184 | argsFP.close()
185 |
186 | # save package info
187 | self.__savePackageInformation()
188 |
189 | def saveACheckpoint(self,stateDict : dict, holdThisCheckpoint : bool = False):
190 | """
191 | Save a new checkpoint and manage the storage.
192 | """
193 | idsNeed2Delete = []
194 |
195 | # save this state dict
196 | saveFile = self.checkpointsDir / (self.prefix + str(self.nowCkptID) + ".ckpt")
197 | saveName = str(saveFile)
198 | self.doSaveOperation(stateDict,saveName)
199 |
200 | # add to ckpts dict
201 | self.ckpts[self.nowCkptID] = saveFile
202 |
203 | # append in fast track
204 | self.ckptFast.append(self.nowCkptID)
205 | if len(self.ckptFast) > self.maxCkptFast:
206 | w2did = self.ckptFast.pop(0)
207 | idsNeed2Delete.append(w2did)
208 |
209 | # append in slow track
210 | if self.nowCkptID % self.slowDilation == 0:
211 | self.ckptSlow.append(self.nowCkptID)
212 | if len(self.ckptSlow) > self.maxCkptSlow:
213 | w2did = self.ckptSlow.pop(0)
214 | if not w2did in idsNeed2Delete:
215 | idsNeed2Delete.append(w2did)
216 |
217 | # append in consistent track
218 | if holdThisCheckpoint:
219 | self.ckptConsistent.append(self.nowCkptID)
220 | if len(self.ckptConsistent) > self.maxCkptConsisitent:
221 | w2did = self.ckptConsistent.pop(0)
222 | if not w2did in idsNeed2Delete:
223 | idsNeed2Delete.append(w2did)
224 |
225 | #delete useless checkpoints on disk
226 | for id in idsNeed2Delete:
227 | if not (id in self.ckptFast or
228 | id in self.ckptSlow or
229 | id in self.ckptConsistent):
230 | path = self.ckpts[id]
231 | path.unlink()
232 | self.ckpts.pop(id)
233 |
234 | self.nowCkptID += 1
235 | # save package info
236 | self.__savePackageInformation()
237 |
238 | def doSaveOperation(self,stateDict,fileName):
239 | """
240 | pytorch version of save ckpt. If no torch in the environment, please override this function.
241 | """
242 | torch.save(stateDict,fileName)
243 |
244 | def doLoadOperation(self,fileName : str,device):
245 | return torch.load(fileName,map_location = device)
246 |
247 | def saveVisualString(self, visualString : str):
248 | """
249 | Make a new file with this visual string which makes the information noticeable.
250 | """
251 | visualFile = self.root / ("_" + visualString)
252 | visualFile.touch()
253 | with open(visualFile,"w") as f:
254 | f.write(visualString)
255 |
256 | def initCkptsFromExist(self):
257 | self.checkpointsDir = self.root / "Checkpoints"
258 | filelist = self.checkpointsDir.iterdir()
259 | for item in filelist:
260 | if item.suffix == ".ckpt":
261 | id = int(item.stem.split("_")[-1])
262 | self.ckpts[id] = item
263 | self.nowCkptID = max(self.nowCkptID,id)
264 | self.nowCkptID += 1
265 |
266 | def initFromAnExistSavePackage(self,packagePath : str):
267 | self.root = Path(packagePath)
268 | configPath = self.root / "args.json"
269 | self.args = {}
270 | self.__loadArgs(configPath,self.args)
271 |
272 | self.maxCkptSlow = self.args["checkpoint_args"]["max_ckpt_in_slow_track"]
273 | self.maxCkptFast = self.args["checkpoint_args"]["max_ckpt_in_fast_track"]
274 | self.maxCkptConsisitent = self.args["checkpoint_args"]["max_ckpt_in_consistent_track"]
275 | self.slowDilation = self.args["checkpoint_args"]["dilation_in_slow_track"]
276 |
277 | # load package info
278 | self.__loadPackageInformation()
279 |
280 | # load ckpt
281 | self.initCkptsFromExist()
282 |
283 | def getStateDict(self,id = -1,device = "cuda:0"):
284 | if id == -1:
285 | id = self.nowCkptID - 1
286 | if not id in self.ckpts:
287 | raise BaseException("The checkpoint doesn't exist.")
288 | return self.doLoadOperation(str(self.ckpts[id]),device)
289 |
290 | def getOutputFile(self,rank = -1):
291 | outputFileName = "_output.txt"
292 | if rank != -1:
293 | outputFileName = "_output_" + str(rank) + ".txt"
294 | outputFilePath = self.root / outputFileName
295 | if outputFilePath.exists():
296 | return outputFilePath.open("a")
297 | else:
298 | return outputFilePath.open("w")
299 |
300 | def setCkptID(self,ckptID : int):
301 | self.nowCkptID = ckptID
302 |
303 | newslow = []
304 | for item in self.ckptSlow:
305 | if item >= self.nowCkptID:
306 | pass
307 | else:
308 | newslow.append(item)
309 | self.ckptSlow = newslow
310 |
311 | newfast = []
312 | for item in self.ckptFast:
313 | if item >= self.nowCkptID:
314 | pass
315 | else:
316 | newfast.append(item)
317 | self.ckptFast = newfast
318 |
319 | newconsistent = []
320 | for item in self.ckptConsistent:
321 | if item >= self.nowCkptID:
322 | pass
323 | else:
324 | newconsistent.append(item)
325 | self.ckptConsistent = newconsistent
326 |
327 | popList = []
328 | for key in self.ckpts:
329 | if key >= self.nowCkptID:
330 | popList.append(key)
331 | for item in popList:
332 | self.ckpts.pop(item)
--------------------------------------------------------------------------------
/DLNest/ShellClient/Communicator.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import argparse
3 | import traceback
4 | import json
5 | from functools import wraps
6 |
7 | def raiseError(self):
8 | raise Exception("command error")
9 |
10 | class Arguments:
11 | def __init__(self,desc : str = ""):
12 | self._parser = argparse.ArgumentParser(description=desc)
13 | self._parser.error = raiseError
14 |
15 | def parser(self):
16 | return self._parser
17 |
18 | class TrainArguments(Arguments):
19 | def __init__(self):
20 | super(TrainArguments, self).__init__(desc="Arguments for DLNest task.")
21 |
22 | self._parser.add_argument("-c",type=str, help="root configuration json file for this task.",required = True)
23 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)")
24 | self._parser.add_argument("-f",type=str, default = "", help="frequently changing configuration json file for this task.(default:None)")
25 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)")
26 | self._parser.add_argument("-ns",action='store_true', help="Set to save to the NOSAVE dir.")
27 | self._parser.add_argument("-mc",action='store_true',help="Set to use multi card.")
28 | self._parser.add_argument("-sd",action='store_true',help="Set to use description as the save dir name.(coverd by ns)")
29 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.")
30 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.")
31 |
32 | class ProjectArguments(Arguments):
33 | def __init__(self):
34 | super(ProjectArguments, self).__init__(desc="Arguments for create a DLNest project.")
35 |
36 | self._parser.add_argument("-d",type=str, help="Path to the directory you want to create the project.", required = True)
37 | self._parser.add_argument("-MNIST",action='store_true', help="Set to new a project with MNIST task.")
38 | self._parser.add_argument("-p", type = str, nargs='+', help = "Set plugins need to be used.")
39 |
40 | class AnalyzeArguments(Arguments):
41 | def __init__(self):
42 | super(AnalyzeArguments, self).__init__(desc="Arguments for an Analyzer")
43 |
44 | self._parser.add_argument("-r",type=str, help = "path to the model record directory.", required = True)
45 | self._parser.add_argument("-s",type=str, default = "", help = "path to the analyze scripts.")
46 | self._parser.add_argument("-c",type=int, default = -1, help = "which epoch you want the model to load.(int)")
47 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)")
48 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.")
49 |
50 | class ContinueArguments(Arguments):
51 | def __init__(self):
52 | super(ContinueArguments, self).__init__(desc="Arguments for an Analyzer")
53 |
54 | self._parser.add_argument("-r",type=str, help = "path to the model record directory.", required = True)
55 | self._parser.add_argument("-c",type=int, default = -1, help = "which epoch you want the model to load.(int)")
56 | self._parser.add_argument("-d",type=str, default = "", help="description for this task.(default: None)")
57 | self._parser.add_argument("-m",type=int, default = -1, help="predicted GPU memory consumption for this task in MB.(default: 90\% of the total memory)")
58 | self._parser.add_argument("-CPU",action='store_true',help="Set to use CPU.")
59 | self._parser.add_argument("-DDP",action='store_true',help="Set to use DDP.")
60 | self._parser.add_argument("-mc",action='store_true',help="Set to use multi card.")
61 |
62 | class DeviceChangeArguments(Arguments):
63 | def __init__(self):
64 | super(DeviceChangeArguments, self).__init__(desc="Arguments for change valid cards.")
65 | self._parser.add_argument("-d",type=int, nargs='+', help='valid devices', required = True)
66 |
67 | class OutputArguments(Arguments):
68 | def __init__(self):
69 | super(OutputArguments, self).__init__(desc="Argumetns for get styled or not styled outputs.")
70 | self._parser.add_argument("-t",type=str, default = "", help="task ID")
71 | self._parser.add_argument("-s",action="store_true",help = "set to get styled")
72 |
73 | class AddPluginsArguents(Arguments):
74 | def __init__(self):
75 | super(AddPluginsArguents, self).__init__(desc="Arguments for add plugins.")
76 |
77 | self._parser.add_argument("-d", type=str, help="Path to the directory you want to create the project.", required = True)
78 | self._parser.add_argument("-p", type = str, nargs='+', help = "Set plugins need to be used.",required = True)
79 | self._parser.add_argument("-F", action='store_true', help="Set to use full config")
80 |
81 | def stableRun(f):
82 | wraps(f)
83 | def doStableRun(*args, **kwargs):
84 | try:
85 | return f(*args, **kwargs)
86 | except Exception as e:
87 | return {"status" : "error", "error" : str(e)}
88 | return doStableRun
89 |
90 |
91 | class Communicator:
92 | def __init__(self, url : str = "127.0.0.1", port : int = 9999):
93 | self.trainArgParser = TrainArguments()
94 | self.continueArgParser = ContinueArguments()
95 | self.projectArgParser = ProjectArguments()
96 | self.analyzeArgParser = AnalyzeArguments()
97 | self.deviceChangeArgParser = DeviceChangeArguments()
98 | self.outputArgParser = OutputArguments()
99 | self.addPluginsArgParser = AddPluginsArguents()
100 | self.url = "http://" + url + ":" + str(port)
101 |
102 | def shortenList(self,commandWordList : list):
103 | newList = []
104 | now = ""
105 | for item in commandWordList:
106 | if item[0] == "-":
107 | if now != "":
108 | newList.append(now)
109 | now = ""
110 | newList.append(item)
111 | continue
112 | elif item[0] == "\"":
113 | if now != "":
114 | newList.append(now)
115 | now = item
116 | else:
117 | if now == "":
118 | now = now + item
119 | else:
120 | now = now + " " + item
121 | if now != "":
122 | newList.append(now)
123 | for i in range(len(newList)):
124 | item = newList[i]
125 | if item[0] == "\"" and item[-1] == "\"":
126 | newList[i] = newList[i][1:-1]
127 | return newList
128 |
129 | @stableRun
130 | def runTrain(self,commandWordList : list):
131 | commandWordList = self.shortenList(commandWordList)
132 | args,otherArgs = self.trainArgParser.parser().parse_known_args(commandWordList[1:])
133 | r = requests.post(self.url + "/run_train",{
134 | "config_path" : args.c,
135 | "freq_path" : args.f,
136 | "description" : args.d,
137 | "memory_consumption" : args.m,
138 | "CPU" : args.CPU,
139 | "DDP" : args.DDP,
140 | "multi_GPU" : args.mc,
141 | "no_save" : args.ns,
142 | "use_description" : args.sd,
143 | })
144 | return json.loads(r.text)
145 |
146 | @stableRun
147 | def newProject(self,commandWordList : list):
148 | commandWordList = self.shortenList(commandWordList)
149 | args,otherArgs = self.projectArgParser.parser().parse_known_args(commandWordList[1:])
150 | r = requests.post(self.url + "/new_proj",{
151 | "target_dir" : args.d,
152 | "MNIST" : args.MNIST,
153 | "plugins" : args.p
154 | })
155 | return json.loads(r.text)
156 |
157 | @stableRun
158 | def runAnalyze(self,commandWordList : list):
159 | commandWordList = self.shortenList(commandWordList)
160 | args,otherArgs = self.analyzeArgParser.parser().parse_known_args(commandWordList[1:])
161 | r = requests.post(self.url + "/analyze",{
162 | "record_path" : args.r,
163 | "script_path" : args.s,
164 | "checkpoint_ID" : args.c,
165 | "CPU" : args.CPU,
166 | "memory_consumption" : args.m
167 | })
168 | return json.loads(r.text)
169 |
170 | @stableRun
171 | def continueTrain(self,commandWordList : list):
172 | commandWordList = self.shortenList(commandWordList)
173 | args,otherArgs = self.continueArgParser.parser().parse_known_args(commandWordList[1:])
174 | r = requests.post(self.url + "/continue_train",{
175 | "record_path" : args.r,
176 | "checkpoint_ID" : args.c,
177 | "memory_consumption" : args.m,
178 | "CPU" : args.CPU,
179 | "DDP" : args.DDP,
180 | "multi_GPU" : args.mc,
181 | "description" : args.d
182 | })
183 | return json.loads(r.text)
184 |
185 | @stableRun
186 | def changeDevices(self,commandWordList : list):
187 | args,otherArgs = self.deviceChangeArgParser.parser().parse_known_args(commandWordList[1:])
188 | r = requests.post(self.url + "/change_devices",{
189 | "new_devices_IDs" : args.d
190 | })
191 | return json.loads(r.text)
192 |
193 | @stableRun
194 | def runExp(self,commandWordList : list):
195 | """
196 | commandWordList : taskID , command
197 | """
198 | r = requests.post(self.url + "/run_exp",{
199 | "task_ID" : commandWordList[1],
200 | "command" : commandWordList[2]
201 | })
202 | return json.loads(r.text)
203 |
204 | @stableRun
205 | def getAnalyzeOutput(self,commandWordList : list):
206 | """
207 | commandWordList : -t task_ID -s
208 | """
209 | args,otherArgs = self.outputArgParser.parser().parse_known_args(commandWordList[1:])
210 | r = requests.get(self.url + "/get_analyze_output",{
211 | "task_ID" : args.t,
212 | "styled" : args.s
213 | })
214 | return json.loads(r.text)
215 |
216 | @stableRun
217 | def getDLNestOutput(self,commandWordList : list):
218 | """
219 | commandWordList : -s
220 | """
221 | args,otherArgs = self.outputArgParser.parser().parse_known_args(commandWordList[1:])
222 | r = requests.get(self.url + "/get_DLNest_output",{
223 | "styled" : args.s
224 | })
225 | return json.loads(r.text)
226 |
227 | @stableRun
228 | def getTasksInformation(self):
229 | r = requests.get(self.url + "/get_task_info",{})
230 | return json.loads(r.text)
231 |
232 | @stableRun
233 | def getDevicesInformation(self):
234 | r = requests.get(self.url + "/get_devices_info",{})
235 | return json.loads(r.text)
236 |
237 | @stableRun
238 | def delATask(self,commandWordList : list):
239 | """
240 | commandWordList: taskID
241 | """
242 | r = requests.post(self.url + "/del_task",{
243 | "task_ID" : commandWordList
244 | })
245 | return json.loads(r.text)
246 |
247 | @stableRun
248 | def clear(self):
249 | r = requests.post(self.url + "/clear",{})
250 | return json.loads(r.text)
251 |
252 | @stableRun
253 | def addPlugins(self, commandWordList : list):
254 | commandWordList = self.shortenList(commandWordList)
255 | args, otherArgs = self.addPluginsArgParser.parser().parse_known_args(commandWordList[1:])
256 | r = requests.post(self.url + "/add_plugins", {
257 | "target_dir" : args.d,
258 | "plugins" : args.p,
259 | "full" : args.F
260 | })
261 | return json.loads(r.text)
262 |
263 | def giveACommand(self, commandWordList : list):
264 | if commandWordList[0] == "run":
265 | return self.runTrain(commandWordList)
266 | elif commandWordList[0] == "continue":
267 | return self.continueTrain(commandWordList)
268 | elif commandWordList[0] == "new":
269 | return self.newProject(commandWordList)
270 | elif commandWordList[0] == "analyze":
271 | return self.runAnalyze(commandWordList)
272 | elif commandWordList[0] == "runExp":
273 | return self.runExp(commandWordList)
274 | elif commandWordList[0] == "del":
275 | return self.delATask(commandWordList[1])
276 | elif commandWordList[0] == "showAN":
277 | return self.getAnalyzeOutput(commandWordList)
278 | elif commandWordList[0] == "showDL":
279 | return self.getDLNestOutput(commandWordList)
280 | elif commandWordList[0] == "showTask":
281 | return self.getTasksInformation()
282 | elif commandWordList[0] == "showDevice":
283 | return self.getDevicesInformation()
284 | elif commandWordList[0] == "changeDevices":
285 | return self.changeDevices(commandWordList)
286 | elif commandWordList[0] == "clear":
287 | return self.clear()
288 | elif commandWordList[0] == "addP":
289 | return self.addPlugins(commandWordList)
290 | elif commandWordList[0] == "exit":
291 | return {"exit" : True}
292 | else:
293 | return {"error" : "Wrong command"}
--------------------------------------------------------------------------------