├── mec ├── __init__.py ├── comms │ ├── __init__.py │ ├── transmit.py │ └── sync_rpc.py ├── configs │ ├── __init__.py │ └── default_config.py ├── scoring │ ├── __init__.py │ └── tester.py ├── utils │ ├── __init__.py │ ├── logs.py │ ├── history.py │ └── monitor.py ├── data_manip │ ├── __init__.py │ ├── criterions.py │ ├── transfroms │ │ ├── batch_transforms.py │ │ ├── label_transforms.py │ │ └── data_transforms.py │ ├── lr_scheduler.py │ ├── data_utils.py │ └── metrics.py ├── training │ ├── __init__.py │ ├── async_trainer.py │ ├── basic_trainer.py │ ├── sync_trainer.py │ └── old_sync_trainer.py └── unit_test │ ├── train_test.py │ ├── dist_test.py │ └── rpc_test.py ├── uninstall.sh ├── install.sh ├── requirements.txt ├── setup.py ├── README.en.md ├── README.md ├── demo_dataset.py ├── .gitignore ├── demo.py ├── demo_start_workers.py └── configs.py /mec/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mec/comms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mec/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mec/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mec/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mec/data_manip/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mec/data_manip/criterions.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mec/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mec/training/async_trainer.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /uninstall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | xargs rm -rf < install.record -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 setup.py install --record install.record -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | tqdm==4.47.0 3 | torch==1.5.1 4 | torchvision==0.6.0a0+35d732a 5 | pandas==1.0.5 6 | Pillow==7.2.0 7 | pyzmq==19.0.1 8 | -------------------------------------------------------------------------------- /mec/data_manip/transfroms/batch_transforms.py: -------------------------------------------------------------------------------- 1 | # batch_transfroms.py 2 | # created: CS 3 | # 针对批次数据的整体变换 4 | # 因批次变换涉及到批次内多个样本的数据和标签的次序 5 | # 因此无法拆分为针对data的和针对label的 6 | 7 | def mix_up(batch_data, batch_indexs): 8 | pass -------------------------------------------------------------------------------- /mec/unit_test/train_test.py: -------------------------------------------------------------------------------- 1 | import mec.training.sync_trainer 2 | 3 | 4 | 5 | 6 | 7 | train_set = None 8 | 9 | 10 | 11 | def main(): 12 | pass 13 | 14 | 15 | 16 | if __name__ == '__main__': 17 | main() -------------------------------------------------------------------------------- /mec/data_manip/transfroms/label_transforms.py: -------------------------------------------------------------------------------- 1 | # label_transfroms.py 2 | # created: CS 3 | # 各种标签变换 4 | # 针对单个样本 5 | 6 | 7 | 8 | # 9 | def to_onehot(indexed_label): 10 | pass 11 | 12 | # 13 | def label_smoothing(onehot_label): 14 | pass 15 | 16 | # 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup(name='mec', 4 | version='1.0.1', 5 | description='Multi Gpu training Library', 6 | author='CS', 7 | author_email='artintel@163.com', 8 | url='https://gitee.com/shinong/cv_dl_multi_gpu_basic_components', 9 | packages=[ 10 | 'mec', 11 | 'mec.comms', 12 | 'mec.data_manip', 13 | 'mec.training', 14 | 'mec.utils', 15 | 'mec.configs' 16 | ], 17 | install_requires=['zmq', 'torch'] 18 | ) -------------------------------------------------------------------------------- /mec/utils/logs.py: -------------------------------------------------------------------------------- 1 | # logs.py 2 | # 创建:cs 3 | # 创建日期:2020.06.30 4 | # 日志功能 5 | 6 | import sys 7 | import time 8 | 9 | class Logger(): 10 | def __init__(self, filepath=None, logfile=sys.stdout, prefix=''): 11 | self.log_file = logfile 12 | if filepath is not None: 13 | try: 14 | self.log_file = open(filepath, 'w') 15 | except FileNotFoundError as e: 16 | print(e) 17 | self.prefix = '[' + prefix + '|{}]' 18 | 19 | def __del__(self): 20 | #self.log_file.close() 21 | #self.log_file = sys.stderr 22 | pass 23 | 24 | def __call__(self, *args, **kwargs): 25 | print( 26 | self.prefix.format( time.strftime("%Y-%m-%d_%H:%M:%S") ), 27 | *args, 28 | **kwargs, 29 | file=self.log_file, 30 | flush=True 31 | ) 32 | 33 | def __getattr__(self, name): 34 | return self.log_file.__getattribute__(name) 35 | -------------------------------------------------------------------------------- /README.en.md: -------------------------------------------------------------------------------- 1 | # 图像深度学习基础部件集 2 | 3 | #### Description 4 | 本项目的主要功能: 5 | 多卡训练代码封装; 6 | 常用工具性代码集成 7 | 8 | #### Software Architecture 9 | Software architecture description 10 | 11 | #### Installation 12 | 13 | 1. xxxx 14 | 2. xxxx 15 | 3. xxxx 16 | 17 | #### Instructions 18 | 19 | 1. xxxx 20 | 2. xxxx 21 | 3. xxxx 22 | 23 | #### Contribution 24 | 25 | 1. Fork the repository 26 | 2. Create Feat_xxx branch 27 | 3. Commit your code 28 | 4. Create Pull Request 29 | 30 | 31 | #### Gitee Feature 32 | 33 | 1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md 34 | 2. Gitee blog [blog.gitee.com](https://blog.gitee.com) 35 | 3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore) 36 | 4. The most valuable open source project [GVP](https://gitee.com/gvp) 37 | 5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help) 38 | 6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) 39 | -------------------------------------------------------------------------------- /mec/configs/default_config.py: -------------------------------------------------------------------------------- 1 | # default configs 2 | 3 | class Options: 4 | def __init__(self): 5 | self.batch_size = 1 6 | self.learning_rate = 1e-3 7 | self.epochs = 1 8 | self.process_num_per_loader = 0 # 每个DataLoader启用的进程数 9 | self.path = 'results/temp' 10 | self.history_filename = 'results/temp/history.json' 11 | self.model_filename = 'results/temp/current_model.pth' 12 | self.best_model_filename = 'results/temp/best_model.pth' 13 | self.excel_filename = 'results/temp/scores.xlsx' 14 | self.control_ip = "127.0.0.1" # manager的IP 15 | self.basic_port = 12500 16 | self.worker_gpu_ids = [0] # worker所使用的gpu编号 [0,1,2,3] 17 | self.worker_ranks = [0] # worker本身编号 [0,1,2,3] 18 | self.sync_worker_num = 1 # 总worker数,单机的情况等于上两者的长度 19 | 20 | conf_g = Options() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 深度学习多卡训练框架 2 | 3 | #### 介绍 4 | 本项目的主要功能: 5 | 1,多卡训练代码封装; 6 | 2,常用工具性代码集成 7 | 8 | #### 软件架构 9 | 软件架构说明 10 | . 11 | ├── comms # 通信 12 | ├── configs # 参数 13 | ├── data_manip # 数据处理 14 | ├── scoring # 结果分析代码 15 | ├── training # 训练代码封装 16 | └── utils # 各种内部工具 17 | 18 | #### 安装教程 19 | 20 | 1, git clone xxx 21 | 2. python setup build 22 | 3. python setup install 23 | 24 | #### 使用说明 25 | 26 | 参见demo.py 27 | 单机可使用trainAndValLocal 28 | 多机可使用startWorkers和trainAndVal 29 | 30 | #### 参与贡献 31 | 32 | 1. Fork 本仓库 33 | 2. 新建 Feat_xxx 分支 34 | 3. 提交代码 35 | 4. 新建 Pull Request 36 | 37 | 38 | #### 码云特技 39 | 40 | 1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md 41 | 2. 码云官方博客 [blog.gitee.com](https://blog.gitee.com) 42 | 3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解码云上的优秀开源项目 43 | 4. [GVP](https://gitee.com/gvp) 全称是码云最有价值开源项目,是码云综合评定出的优秀开源项目 44 | 5. 码云官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) 45 | 6. 码云封面人物是一档用来展示码云会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) 46 | -------------------------------------------------------------------------------- /demo_dataset.py: -------------------------------------------------------------------------------- 1 | # demo dataset 2 | 3 | # 测试数据集 4 | 5 | from PIL import Image 6 | from torchvision import transforms 7 | from torchvision.datasets import CIFAR10 8 | 9 | pre_image_size = (256, 256) 10 | image_size = (224, 224) 11 | 12 | data_transform = transforms.Compose([ 13 | transforms.Resize(pre_image_size), 14 | transforms.RandomHorizontalFlip(), 15 | transforms.RandomAffine(degrees=25, translate=(.2, .2) , 16 | scale=(0.8, 1.2), shear=8, 17 | resample=Image.BILINEAR, fillcolor=0), 18 | transforms.RandomCrop(image_size, padding=2, fill=(0,0,0) ), 19 | transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 20 | transforms.ToTensor(), 21 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 22 | #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 23 | ]) 24 | test_transform = transforms.Compose([ 25 | transforms.Resize(image_size), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 28 | #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 29 | ]) 30 | train_set = CIFAR10('downloaded_models', train=True, transform=data_transform, download=True) 31 | valid_set = CIFAR10('downloaded_models', train=False, transform=test_transform, download=True) -------------------------------------------------------------------------------- /mec/utils/history.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class History: 5 | def __init__(self, filename, logger=print): 6 | self.filename = filename 7 | self.printToLog = logger 8 | self.data = { 9 | 'epochs':0, 10 | 'train_loss':[], 11 | 'train_met':[], 12 | 'val_loss':[], 13 | 'val_met':[], 14 | 'lr':[] 15 | } 16 | 17 | def updateHistory(self, train_loss, train_met, val_loss, val_met, lr): 18 | self.data['epochs'] += 1 19 | self.data['train_loss'].append( train_loss ) 20 | self.data['train_met'].append( train_met ) 21 | self.data['val_loss'].append( val_loss ) 22 | self.data['val_met'].append( val_met ) 23 | self.data['lr'].append( lr ) 24 | 25 | def loadHistory(self): 26 | try: 27 | with open(self.filename) as file: 28 | self.data = json.load(file) 29 | print('history loading successful') 30 | print(self.data) 31 | except FileNotFoundError as e: 32 | self.printToLog('warining: training history record file not found') 33 | 34 | def saveHistory(self): 35 | with open(self.filename, 'w') as file: 36 | json.dump(self.data, file, indent=True, ensure_ascii=False) -------------------------------------------------------------------------------- /mec/unit_test/dist_test.py: -------------------------------------------------------------------------------- 1 | # dist_test.py 2 | # created: CS 3 | # 测试显卡通信组件 4 | 5 | import time 6 | import torch 7 | import torch.multiprocessing as mp 8 | import mec.comms.sync_rpc as rpc 9 | import mec.comms.transmit as trans 10 | import mec.utils.logs as logs 11 | 12 | ip = '192.168.1.99' 13 | port = '9999' 14 | 15 | 16 | 17 | def start_process(rank, gpu_list): 18 | world_size = len(gpu_list) 19 | log = logs.Logger(prefix='process {}|'.format(rank)) 20 | env = trans.DistEnv(rank, world_size, control_ip=ip, dist_port=port, logger=log) 21 | test_group_list = [0,1] 22 | group = env.newGroup(test_group_list) 23 | transmittor = trans.TensorTransmittor([0,1,2,3],logger=log) 24 | 25 | p = torch.nn.Parameter( 26 | torch.randn( 27 | (1,4), 28 | requires_grad=True, 29 | device=torch.device( 30 | 'cuda:{}'.format(gpu_list[rank]) 31 | ) 32 | ) 33 | ) 34 | 35 | l = torch.sum(p*p) 36 | l.backward() 37 | 38 | 39 | log("tensor: ", p) 40 | log("grad: " , p.grad) 41 | time.sleep(1) 42 | 43 | if rank in test_group_list: 44 | transmittor.crossTensors(p, group) 45 | transmittor.crossGrads(p, group) 46 | log("tensor: ", p) 47 | #log("grad: ", p.grad) 48 | 49 | 50 | 51 | def main(): 52 | gpu_list = [0,1,2,3] 53 | rank_list = list(range(len(gpu_list))) 54 | process_pool = [] 55 | for rank in gpu_list: 56 | p = mp.Process(target=start_process, args=(rank, gpu_list)) 57 | process_pool.append( p ) 58 | for p in process_pool: 59 | p.start() 60 | 61 | if __name__ == '__main__': 62 | main() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # vscode environment 118 | .vscode/ -------------------------------------------------------------------------------- /mec/unit_test/rpc_test.py: -------------------------------------------------------------------------------- 1 | # rpc_test.py 2 | # created: CS 3 | # 测试rpc组件 4 | 5 | import time 6 | import numpy as np 7 | import torch.multiprocessing as mp 8 | 9 | import mec.comms.sync_rpc as rpc 10 | 11 | #ip = '192.168.1.99' 12 | ip = '127.0.0.1' 13 | port = 9900 14 | worker_num = 4 15 | 16 | count = 0 17 | def test(): 18 | global count 19 | count+=1 20 | print("recieved ", count, " times") 21 | return 'got' 22 | 23 | def start_worker(rank): 24 | worker = rpc.SyncRpcWorker(ip, port, rank, lambda *x: print('[worker {}]'.format(rank), *x)) 25 | worker.registerMethod(lambda x,y,z: len(x)+len(y)+len(z), 'a.b.c') 26 | worker.registerMethod(lambda x,y,z: x+y+z, 'all.add') 27 | worker.registerMethod(test) 28 | worker.mainLoop() 29 | 30 | def start_controller(worker_num): 31 | controller = rpc.SyncRpcController(ip, port, worker_num, lambda *x: print('[controller]', *x)) 32 | controller.startWorking() 33 | # print( controller.a('a') ) 34 | # print( controller.a.b("123") ) 35 | #print( controller.a.b.c("123", [4, 'abc', 3.875], {1: 5, 666:(254, 'aba')}) ) 36 | print( controller.all.add( 8, 9, 10) ) 37 | #controller.stopWorking() 38 | controller.stopLooping() 39 | 40 | #recieve_count = 0 41 | #for i in range(100000): 42 | # result = controller.test() 43 | # if result == ['got'] * controller.worker_num: 44 | # recieve_count += 1 45 | #print("{} messages recieved".format(recieve_count) ) 46 | 47 | controller.closeSocket() 48 | 49 | 50 | if __name__ == '__main__': 51 | 52 | 53 | 54 | 55 | process_pool = [] 56 | for i in range(worker_num): 57 | wp = mp.Process(target=start_worker, args=(i,)) 58 | process_pool.append(wp) 59 | 60 | #process_pool.append(cp) 61 | cp = mp.Process(target=start_controller, args=(worker_num,) ) 62 | #cp.start() 63 | process_pool.append(cp) 64 | 65 | np.random.shuffle(process_pool) 66 | 67 | for p in process_pool: 68 | p.start() 69 | #time.sleep(np.random.rand() ) 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /mec/training/basic_trainer.py: -------------------------------------------------------------------------------- 1 | # basic_trainer.py 2 | # created: CS 3 | # 基本的训练类 4 | 5 | import torch 6 | 7 | class BasicTrainer(): 8 | """ 9 | Trainer类 10 | 封装训练中的各个步骤 11 | 不负责流程组织 12 | 不负责数据传递 13 | 参数: 14 | model: 模型 15 | optimizer: 优化器 16 | criterion: 损失函数 17 | metrics: 评分器 18 | """ 19 | def __init__(self, model, optimizer, criterion, metrics): 20 | # ----- basic elements ----- 21 | self.model = model # 模型 22 | self.optimizer = optimizer # 优化器 23 | self.criterion = criterion # 损失函数 24 | self.metrics = metrics # 评分函数 25 | # ----- temporary figures ----- 26 | self.loss = 0 27 | self.met = 0 28 | 29 | def initEpoch(self): 30 | self.metrics.init() 31 | self.optimizer.zero_grad() 32 | 33 | def forwardData(self, data): 34 | self.optimizer.zero_grad() 35 | self.result = self.model(data) 36 | 37 | # 增量式前向:不归零梯度 38 | def forwardDataInc(self, data): 39 | self.result=self.model(data) 40 | 41 | # 前向但不记录梯度 42 | @torch.no_grad() 43 | def forwardNoGrad(self, data): 44 | self.result=self.model(data) 45 | 46 | def backwardGrad(self, target): 47 | self.loss = self.criterion(self.result, target) 48 | self.loss.backward() 49 | self.met, _ = self.metrics(self.result, target) 50 | 51 | # 只计算loss,不回传梯度 52 | @torch.no_grad() 53 | def calcScores(self, target): 54 | self.loss = self.criterion(self.result, target) 55 | self.met, _ = self.metrics(self.result, target) 56 | 57 | def setLearningRate(self, lr): 58 | for param_group in self.optimizer.param_groups: 59 | param_group['lr'] = lr 60 | 61 | def updateWeights(self): 62 | self.optimizer.step() 63 | 64 | def getScores(self): 65 | return self.loss.item(), self.met 66 | 67 | def saveModel(self, filename): 68 | torch.save(self.model.state_dict(), filename) 69 | 70 | def loadModel(self, filename, map_location=None): 71 | self.model.load_state_dict(torch.load(filename, map_location=map_location)) -------------------------------------------------------------------------------- /mec/data_manip/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # lr_scheduler.py 2 | # 创建: CS 3 | # 修改: TM 4 | # 保存各种学习率策略 5 | # 所有函数保持函数格式: 6 | # function_name(epoch) -> lr 7 | 8 | import math 9 | import numpy as np 10 | 11 | class CosineLR(): 12 | def __init__(self, warm_epoch=0, lr=0.001, period=30, only_decrease=True): 13 | """学习率余弦更新策略,通过参数可设置warmup的代数、最大学习率、余弦更新周期、是否去除上升段。 14 | Args: 15 | warm_epoch (int): warmup的代数,epoch在此区间内线性上升至最大学习率。Default: 0. 16 | lr (float): 最大学习率,余弦函数最高点的纵坐标。Default: 0.001. 17 | period (int): 余弦更新的周期,指余弦函数从最高点到最低点需要的代数。Default: 30. 18 | only_decrease (bool): 若为``True``则仅保留余弦下降段,若为``False``则使用完整的余弦函数。Default: ``True``. 19 | 20 | Example: 21 | >>> scheduler = CosineLR(warm_epoch=5, lr=0.001, period=30, only_decrease=True) 22 | >>> for epoch in range(180): 23 | >>> lr = scheduler(epoch) 24 | >>> train(...) 25 | """ 26 | self.warm_epoch = warm_epoch 27 | self.lr = lr 28 | self.period = period 29 | self.only_decrease = only_decrease 30 | 31 | def __call__(self, epoch): 32 | """根据输入的epoch返回当前代的学习率""" 33 | if epoch < self.warm_epoch: 34 | return (epoch % self.warm_epoch) / self.warm_epoch * self.lr 35 | elif self.only_decrease is True: 36 | return (1 + math.cos(math.pi * ((epoch - self.warm_epoch) % self.period) / self.period)) * self.lr / 2 37 | else: 38 | return (1 + math.cos(math.pi * (epoch - self.warm_epoch) / self.period)) * self.lr / 2 39 | 40 | class ExponentialLR(): 41 | def __init__(self, warm_epoch=0, lr=0.001, rate=0.9): 42 | """学习率指数下降更新策略,通过参数可设置warmup的代数、最大学习率、指数下降速率。指数下降公式为:lr*(rate^epoch)。 43 | Args: 44 | warm_epoch (int): warmup的代数,epoch在此区间内线性上升至最大学习率。Default: 0. 45 | lr (float): 最大学习率,指数函数下降起点的纵坐标。Default: 0.001. 46 | rate (float): 指数下降速率,即指数函数的底数。Default: 0.9 47 | 48 | Example: 49 | >>> scheduler = ExponentialLR(warm_epoch=5, lr=0.001, rate=0.9) 50 | >>> for epoch in range(180): 51 | >>> lr = scheduler(epoch) 52 | >>> train(...) 53 | """ 54 | self.warm_epoch = warm_epoch 55 | self.lr = lr 56 | self.rate = rate 57 | 58 | def __call__(self, epoch): 59 | """根据输入的epoch返回当前代的学习率""" 60 | if epoch < self.warm_epoch: 61 | return (epoch % self.warm_epoch) / self.warm_epoch * self.lr 62 | else: 63 | return self.lr * np.power(self.rate , (epoch - self.warm_epoch)) 64 | -------------------------------------------------------------------------------- /mec/data_manip/data_utils.py: -------------------------------------------------------------------------------- 1 | # 数据集相关操作 2 | import zmq 3 | import numpy as np 4 | from torch.utils.data.dataset import Dataset 5 | from torch.utils.data.sampler import Sampler 6 | 7 | class IndicedDataset(Dataset): 8 | def __init__(self, dataset, indices): 9 | assert np.max(indices)0 else 0.0 44 | propotion_column.append(precision) 45 | # recall 46 | recall = true_positive_count/real_positive_count if real_positive_count>0 else 0.0 47 | propotion_column.append(recall) 48 | # f1 49 | f1_score = 2*true_positive_count/(real_positive_count+pred_positive_total) if true_positive_count>0 else 0.0 50 | propotion_column.append(f1_score) 51 | for j in range(1, num_classes+1): # 比例 52 | propotion_column.append(counts_column[j]/column_total if column_total>0 else 0.0) 53 | # 正确标签名 54 | target_class_name = idx_to_class_list[i] 55 | # ----- ----- 56 | counts_dict[ target_class_name ] = counts_column 57 | propotion_dict[ target_class_name ] = propotion_column # 总数 58 | 59 | count_frame = pd.DataFrame(data=counts_dict) 60 | propotion_frame = pd.DataFrame(data=propotion_dict) 61 | writer = pd.ExcelWriter(output_file_path, engine='xlsxwriter') 62 | 63 | count_frame.to_excel(writer, sheet_name='判别计数', index=False) 64 | workbook1 = writer.book 65 | worksheet1 = writer.sheets['判别计数'] 66 | worksheet1.set_column(0, 0, 8) 67 | worksheet1.set_column(1, num_classes, 5) 68 | 69 | propotion_frame.to_excel(writer, sheet_name='判别比例', index=False) 70 | workbook2 = writer.book 71 | worksheet2 = writer.sheets['判别比例'] 72 | worksheet2.set_column(0, 0, 8) 73 | worksheet2.set_column(1, num_classes, 5) 74 | 75 | writer.save() 76 | return 77 | 78 | def test_mix_async( 79 | model, 80 | test_set, 81 | train_idx_to_class = None, 82 | test_idx_to_class = None, 83 | test_loader=None): 84 | num_classes = len(idx_to_class_list) 85 | mix_mat_count_tensor = torch.zeros((num_classes, num_classes+1), dtype=torch.long).to(device) 86 | 87 | pass 88 | 89 | def test_mix( 90 | model, 91 | test_set = None, 92 | train_idx_to_class = None, 93 | test_idx_to_class = None, 94 | test_loader=None): 95 | """ 测试model在test_set上的准确率及混淆矩阵 96 | * 要求单标签流程 97 | * test_set跟model的序号排列不同时,要求给出训练集和测试集上的两个idx_to_class 98 | 参数: 99 | model : 测试用的模型 100 | test_set : 测试针对的数据集 101 | test_loader : 测试数据集的读取器,不可与test_set共存 102 | train_idx_to_class : 训练集的类别排列,indexable 103 | test_idx_to_class : 测试集的类别排列 104 | """ 105 | -------------------------------------------------------------------------------- /mec/utils/monitor.py: -------------------------------------------------------------------------------- 1 | # 进度监视条 2 | 3 | import time 4 | from tqdm import tqdm 5 | 6 | def reset_tqdm(pbar): 7 | pbar.n = 0 8 | pbar.last_print_n = 0 9 | pbar.start_t = time.time() 10 | pbar.last_print_t = time.time() 11 | 12 | # 打印类,用于打印训练信息 13 | class Monitor(): 14 | def __init__(self, init_epoch, total_epochs, 15 | train_batch_num, val_batch_num, 16 | metric_name, bar_cols=100, show=True): 17 | # --------------- tast attributes --------------- 18 | self.init_epoch = init_epoch 19 | self.total_epochs = total_epochs 20 | self.train_batch_num = train_batch_num 21 | self.val_batch_num = val_batch_num 22 | self.metric_name = metric_name 23 | # --------------- status records --------------- 24 | self.is_closed = False 25 | self.current_epoch = init_epoch 26 | self.current_train_batch = 0 27 | self.current_val_batch = 0 28 | # --------------- initiation actions --------------- 29 | self.bar_cols = bar_cols 30 | self.bar0 = None 31 | self.bar1 = None 32 | self.bar2 = None 33 | if show: self._initBars() 34 | # --------------- data --------------- 35 | self.avg_loss = 0 36 | self.avg_met = 0 37 | self.avg_val_loss = 0 38 | self.avg_val_met = 9 39 | 40 | 41 | def __del__(self): 42 | if not self.is_closed: 43 | self.close() 44 | 45 | def _initBars(self): 46 | # 监控初始化 47 | self.bar0 = tqdm(range(self.init_epoch, self.init_epoch+self.total_epochs), desc = '', position=0, 48 | bar_format='{desc}│{bar}│{elapsed}s{postfix}', ncols=self.bar_cols) 49 | # for i in range(self.init_epoch): 50 | # self.bar0.update(self.init_epoch) 51 | self.bar0.set_description_str('epoch:{:4d}/{:4d}'.format(self.init_epoch, self.init_epoch+self.total_epochs)) 52 | self.bar0.set_postfix_str(" t_los={1:.3f}, t_{0}={2:.3f}| v_los={3:.3f}, v_{0}={4:.3f}".format( 53 | self.metric_name, 0, 0, 0, 0)) 54 | self.bar1 = tqdm(range(self.train_batch_num), desc="", position=1, 55 | bar_format='{desc}│{bar}│{elapsed}s{postfix}', ncols=self.bar_cols) 56 | self.bar2 = tqdm(range(self.val_batch_num), desc = '', position=2, 57 | bar_format='{desc}│{bar}│{elapsed}s{postfix}', ncols=self.bar_cols, leave=True) 58 | self.bar2.set_description_str('batch:{:4d}/{:4d}'.format(0, self.val_batch_num)) 59 | self.bar2.set_postfix_str("validate─➤ los={1:.3f} avg={2:.3f}| {0}={3:.3f} avg={4:.3f}".format( 60 | self.metric_name, 0, 0, 0, 0)) 61 | 62 | def close(self): 63 | self.bar0.close() 64 | self.bar1.close() 65 | self.bar2.close() 66 | print("\n\n") 67 | self.is_closed = True 68 | 69 | def updateTraining(self, loss, avg_loss, met, avg_met): 70 | self.bar1.update() 71 | self.current_train_batch += 1 72 | self.bar1.set_description_str('batch:{:4d}/{:4d}'.format(self.current_train_batch, self.train_batch_num)) 73 | self.bar1.set_postfix_str( 74 | "training─➤ los={1:.3f} avg={2:.3f}| {0}={3:.3f} avg={4:.3f}".format( 75 | self.metric_name, loss, avg_loss, met, avg_met 76 | ) # format 77 | ) 78 | self.avg_loss = avg_loss 79 | self.avg_met = avg_met 80 | 81 | def updateValidation(self, val_loss, avg_val_loss, val_met, avg_val_met): 82 | self.bar2.update() 83 | self.current_val_batch += 1 84 | self.bar2.set_description_str('v_bth:{:4d}/{:4d}'.format(self.current_val_batch, self.val_batch_num)) 85 | self.bar2.set_postfix_str( 86 | "validate─➤ los={1:.3f} avg={2:.3f}| {0}={3:.3f} avg={4:.3f}".format( 87 | self.metric_name, val_loss, avg_val_loss, val_met, avg_val_met 88 | ) 89 | ) 90 | self.avg_val_loss = avg_val_loss 91 | self.avg_val_met = avg_val_met 92 | 93 | def beginEpoch(self): 94 | self.current_train_batch = 0 95 | self.current_val_batch = 0 96 | reset_tqdm(self.bar1) 97 | reset_tqdm(self.bar2) 98 | pass 99 | 100 | def endEpoch(self, avg_loss=None, avg_val_loss=None, avg_met=None, avg_val_met=None): 101 | avg_loss = self.avg_loss if avg_loss is None else avg_loss 102 | avg_met = self.avg_met if avg_met is None else avg_met 103 | avg_val_loss = self.avg_val_loss if avg_val_loss is None else avg_val_loss 104 | avg_val_met = self.avg_val_met if avg_val_met is None else avg_val_met 105 | self.current_epoch += 1 106 | self.bar0.update() 107 | self.bar0.set_description_str('epoch:{:4d}/{:4d}'.format(self.current_epoch, self.init_epoch+self.total_epochs)) 108 | self.bar0.set_postfix_str( 109 | " t_los={1:.3f}, t_acc={2:.3f}| v_los={3:.3f}, v_{0}={4:.3f}".format( 110 | self.metric_name, avg_loss, avg_met, avg_val_loss, avg_val_met 111 | ) 112 | ) 113 | 114 | def updateEpoch(self, avg_loss=None, avg_val_loss=None, avg_met=None, avg_val_met=None): 115 | avg_loss = self.avg_loss if avg_loss is None else avg_loss 116 | avg_met = self.avg_met if avg_met is None else avg_met 117 | avg_val_loss = self.avg_val_loss if avg_val_loss is None else avg_val_loss 118 | avg_val_met = self.avg_val_met if avg_val_met is None else avg_val_met 119 | self.current_epoch += 1 120 | self.current_train_batch = 0 121 | self.current_val_batch = 0 122 | self.bar0.update() 123 | self.bar0.set_description_str('epoch:{:4d}/{:4d}'.format(self.current_epoch, self.init_epoch+self.total_epochs)) 124 | self.bar0.set_postfix_str( 125 | " t_los={1:.3f}, t_acc={2:.3f}| v_los={3:.3f}, v_{0}={4:.3f}".format( 126 | self.metric_name, avg_loss, avg_met, avg_val_loss, avg_val_met 127 | ) 128 | ) 129 | reset_tqdm(self.bar1) 130 | reset_tqdm(self.bar2) 131 | 132 | -------------------------------------------------------------------------------- /mec/comms/transmit.py: -------------------------------------------------------------------------------- 1 | # data_trasmitting.py 2 | # 创建:陈硕 3 | # 创建日期:2020.06.20 4 | # 文件描述:封装跨卡、跨机模型传递功能 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | 11 | from torch.distributed.distributed_c10d import _get_default_group 12 | 13 | 14 | class DistEnv: 15 | """ 16 | torch.distributed所使用的分布式环境 17 | """ 18 | def __init__(self, rank, world_size, control_ip, dist_port, logger=print): 19 | self.printToLog = logger 20 | self.rank = rank 21 | self.world_size = world_size 22 | self.control_ip = control_ip 23 | self.dist_port = dist_port 24 | self._initTorchDist() 25 | 26 | def _initTorchDist(self): 27 | self.printToLog("dist args:", 'nccl', self.control_ip, 28 | self.rank, self.world_size) 29 | os.environ['MASTER_ADDR'] = self.control_ip 30 | os.environ['MASTER_PORT'] = self.dist_port 31 | dist.init_process_group( 32 | backend='nccl', 33 | rank=self.rank, 34 | world_size=self.world_size) 35 | self.printToLog("torch distributed environment initiated successfully") 36 | #self.worker_group = dist.new_group(list(range(1,self.world_size)) ) 37 | 38 | def newGroup(self, rank_list): 39 | """ 40 | 按照rank_list建立一个新的组 41 | """ 42 | return dist.new_group(rank_list) 43 | 44 | # 用于cube all-reduce的分配函数 45 | def cube_correspond(n, turn): 46 | return (1<1: 68 | batch_target = batch_target.max(dim=-1)[1] 69 | hits = (batch_output.max(dim=-1)[1] == batch_target).sum().item() 70 | #hits = torch.sum(batch_output.max(dim=-1)[1] == batch_target).item() 71 | self.batchScore = hits / currentCount 72 | self.totalHits += hits 73 | self.totalSamples += currentCount 74 | self.epochScore = self.totalHits / self.totalSamples 75 | 76 | # 计算每个batch的分数 77 | def getBatchScore(self): 78 | return self.batchScore 79 | 80 | # 应重载 81 | # 计算每个epoch的分数 82 | def getEpochScore(self): 83 | return self.epochScore 84 | 85 | # 评价标准:accuracy 准确率 86 | # 用于单标签分类、且标签为one hot表示时的情况 87 | # 可适用平滑和混淆 88 | class AccuracyOneHot(MetricBase): 89 | pass 90 | 91 | 92 | # # 评价标准:mean average precision 93 | # # 适用于多标签分类 94 | # # 每个batch单独输入数据 95 | # # 每个epoch先清除缓存 96 | # class MeanAveragePrecision(MetricBase): 97 | # def __init__(self, tag_dimension=1, total_sample_num=1, device=None): #torch.device('cpu') ): 98 | # MetricBase.__init__(self) 99 | # self.tagDimension = tag_dimension 100 | # self.totalSampleNum = total_sample_num 101 | # self.currentCount = 0 102 | # self.batchSize = 1 103 | # self.outputs = torch.zeros( (total_sample_num, tag_dimension) ).to(device) 104 | # self.targets = torch.zeros( (total_sample_num, tag_dimension) ).to(device) 105 | # self.tempPrecisionArray = torch.zeros(total_sample_num).to(device) 106 | # self.tempRecallArray = torch.zeros(total_sample_num).to(device) 107 | 108 | # def name(self): 109 | # return 'map' 110 | 111 | # def init(self): 112 | # self.batchScore = 0.0 113 | # self.epochScore = 0.0 114 | # self.currentCount = 0 115 | # pass 116 | 117 | # def addData(self, batch_output, batch_target): 118 | # assert batch_output.ndimension()==2, 'output must be 2 dimension: batch/feature' 119 | # assert batch_target.ndimension()==2, 'target must be 2 dimension: batch/feature' 120 | # #print(self.outputs.size()) 121 | # #print(batch_output.size()) 122 | # self.batchSize = len(batch_output) 123 | # self.outputs[self.currentCount: self.currentCount+self.batchSize] = torch.sigmoid(batch_output) 124 | # self.targets[self.currentCount: self.currentCount+self.batchSize] = batch_target 125 | # self.currentCount += self.batchSize 126 | 127 | # def getBatchScore(self): 128 | # return 0.0 129 | # #return self.batchScore 130 | 131 | # def getAveragePrecision(self, index): 132 | # outputs, indices = self.outputs[0:self.currentCount, index].sort(descending=True) 133 | # targets = self.targets[0:self.currentCount, index][indices] 134 | # none_zero_indices = outputs!=0 135 | # outputs = outputs[none_zero_indices] 136 | # targets = targets[none_zero_indices] 137 | # total_positive = torch.sum(targets>0) 138 | # # 按threshold从到到低的顺序分别计算precision和recall 139 | # for i in range(self.currentCount): 140 | # threshold = outputs[i] 141 | # true_positive = torch.sum(targets[0:i]>0) 142 | # predicted_positive = torch.sum(outputs>threshold) 143 | # precision = true_positive / (predicted_positive.type(torch.double) + 1e-6) 144 | # recall = true_positive / (total_positive.type(torch.double) + 1e-6) 145 | # # print("\n------ ") 146 | # # print("true pos: {}, pred pos: {}, precision: {}".format(true_positive, predicted_positive, precision) ) 147 | # # print("true pos: {}, total pos: {}, recall: {}".format(true_positive, total_positive, recall) ) 148 | # self.tempPrecisionArray[i] = precision 149 | # self.tempRecallArray[i] = recall 150 | # # print(self.tempRecallArray[0:self.currentCount]) 151 | # # print(self.tempPrecisionArray[0:self.currentCount]) 152 | # # 按recall从高到低排序,方便后面的运算 153 | # recallArray, recallIndices = self.tempRecallArray[0:self.currentCount].sort(descending=True) 154 | # precisionArray = self.tempPrecisionArray[recallIndices] 155 | # # 11-点-AP 156 | # total_precision = 0 157 | # current_precision = 0 158 | # #for index in range(11): 159 | # recall_index = 10 160 | # required_recall = index*0.1 161 | # next_recall = (index-1)*0.1 162 | # maxPrecision = recallArray[0] 163 | # for index, recall in enumerate(recallArray) : # recall 从高到低 164 | # # print("\nindex: ", index) 165 | # while recall < next_recall: # 下一个recall点 166 | # # print("\nrecall; ", recall) 167 | # # print(maxPrecision) 168 | # total_precision += maxPrecision 169 | # recall_index -= 1 170 | # required_recall = recall_index*0.1 171 | # next_recall = (recall_index-1)*0.1 172 | # precision = precisionArray[index] 173 | # if precision > maxPrecision: 174 | # maxPrecision = precision 175 | # total_precision += maxPrecision # recall == 0 时的precision 176 | # # print(total_precision, total_precision/11.0, index) 177 | # return total_precision / 11.0 178 | 179 | # def getEpochScore(self): 180 | # totalAP = 0.0 181 | # for i in range(self.tagDimension): # 对所有输出维度 182 | # totalAP += self.getAveragePrecision(i) 183 | # self.epochScore = totalAP / self.tagDimension 184 | # return self.epochScore.item() 185 | 186 | # # mean F1 score 187 | # # 评价标准:平均F1分数 188 | # # 适用于多标签分类 189 | # class meanF1Score(MetricBase): 190 | # def __init__(self, tag_dimension=1, threshold=0.0, device=torch.device('cuda')): 191 | # MetricBase.__init__(self) 192 | # self.device = device 193 | # self.batchScore = 0.0 194 | # self.epochScore = 0.0 195 | # self.threshold = threshold 196 | # self.tp = torch.zeros(tag_dimension).to(device) 197 | # self.tn = torch.zeros(tag_dimension).to(device) 198 | # self.fp = torch.zeros(tag_dimension).to(device) 199 | # self.fn = torch.zeros(tag_dimension).to(device) 200 | 201 | # # 可重载 202 | # # 定义评价标准的名字 203 | # def getMetricName(self): 204 | # return 'mF1' 205 | 206 | # # 重载 207 | # # 每个batch输入数据 208 | # def addData(self, batch_output, batch_target): 209 | # guess_pos = batch_output > self.threshold # 判断为真 210 | # target_pos = batch_target > 0 # 实际为真 211 | # guess_neg = batch_output < self.threshold # 判断为假 212 | # target_neg = batch_target < 0 # 实际为假 213 | # # 以下为向量,维度为tag dimension 214 | # tp = torch.sum(guess_pos*target_pos, dim=0).type(torch.float) # 猜真实真,为tp 215 | # fp = torch.sum(guess_pos*target_neg, dim=0).type(torch.float) + 1e-6 # 猜真实假,为fp 216 | # fn = torch.sum(guess_neg*target_pos, dim=0).type(torch.float) + 1e-6 # 猜假实真,为fp 217 | # #tn = torch.sum(guess_neg*target_neg, dim=0) # 猜假实假,为tn 218 | # # 计算每个batch的分数 219 | # self.batchScore = tp *2 / (tp*2 + fp + fn) 220 | # # 计算每个epoch的分数 221 | # self.tp += tp 222 | # self.fp += fp 223 | # self.fn += fn 224 | # # print("---- shape: ----\n", self.tp.size(), tp.size() ) 225 | # # print("---- shape: ----\n", self.fp.size(), fp.size() ) 226 | # # print("---- shape: ----\n", self.fn.size(), fn.size() ) 227 | # self.epochScore = self.tp * 2 / (self.tp*2 + self.fp + self.fn) 228 | 229 | # # 重载 230 | # # 每epoch初始重置数据 231 | # def init(self): 232 | # self.tp.zero_() 233 | # self.tn.zero_() 234 | # self.fp.zero_() 235 | # self.fn.zero_() 236 | # self.tp += 1e-6 237 | # self.tn += 1e-6 238 | # self.fp += 1e-6 239 | # self.fn += 1e-6 240 | 241 | # # 重载 242 | # # 计算每个batch的分数 243 | # def getBatchScore(self): 244 | # return torch.mean(self.batchScore).item() 245 | 246 | # # 重载 247 | # # 计算每个epoch的分数 248 | # def getEpochScore(self): 249 | # return torch.mean(self.epochScore).item() 250 | 251 | # def getEpochScore_(self): 252 | # return self.epochScore -------------------------------------------------------------------------------- /mec/data_manip/transfroms/data_transforms.py: -------------------------------------------------------------------------------- 1 | # data_transfroms.py 2 | # created: CS 3 | # 各种数据变换 4 | # 针对单个样本 5 | 6 | ''' 7 | ElasticTransform: 8 | 弹性变换,根据扭曲场的平滑度与强度逐一地移动局部像素点实现模糊效果。 9 | 依赖: albumentations 10 | 11 | 参数: 12 | alpha (float): 13 | sigma (float): Gaussian filter parameter. 14 | alpha_affine (float): The range will be (-alpha_affine, alpha_affine) 15 | interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: 16 | cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. 17 | Default: cv2.INTER_LINEAR. 18 | border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of: 19 | cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. 20 | Default: cv2.BORDER_REFLECT_101 21 | value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. 22 | mask_value (int, float, 23 | list of ints, 24 | list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. 25 | approximate (boolean): Whether to smooth displacement map with fixed kernel size. 26 | Enabling this option gives ~2X speedup on large images. 27 | 输入: 28 | numpy数组 29 | 返回: 30 | 字典,形式为 {'image': array(...)} 31 | 示例: 32 | >>> import albumentations as A 33 | >>> import numpy as np 34 | >>> from PIL import Image 35 | >>> img = Image.open('img.jpg') 36 | >>> np_img = np.asarray(img) 37 | >>> t = A.ElasticTransform() 38 | >>> img = Image.fromarray(t(image = np_img)['image']) 39 | ''' 40 | 41 | ''' 42 | HueSaturationValue: 43 | HSV对比度变换,通过向HSV空间中的每个像素添加或减少V值,修改色调和饱和度实现对比度转换。 44 | 依赖: albumentations 45 | 46 | 参数: 47 | hue_shift_limit ((int, int) or int): range for changing hue. If hue_shift_limit is a single int, the range 48 | will be (-hue_shift_limit, hue_shift_limit). Default: (-20, 20). 49 | sat_shift_limit ((int, int) or int): range for changing saturation. If sat_shift_limit is a single int, 50 | the range will be (-sat_shift_limit, sat_shift_limit). Default: (-30, 30). 51 | val_shift_limit ((int, int) or int): range for changing value. If val_shift_limit is a single int, the range 52 | will be (-val_shift_limit, val_shift_limit). Default: (-20, 20). 53 | p (float): probability of applying the transform. Default: 0.5. 54 | 输入: 55 | numpy数组 56 | 返回: 57 | 字典,形式为 {'image': array(...)} 58 | 示例: 59 | >>> import albumentations as A 60 | >>> import numpy as np 61 | >>> from PIL import Image 62 | >>> img = Image.open('img.jpg') 63 | >>> np_img = np.asarray(img) 64 | >>> t = A.HueSaturationValue() 65 | >>> img = Image.fromarray(t(image = np_img)['image']) 66 | ''' 67 | 68 | ''' 69 | IAASuperpixels: 70 | 超像素法,在最大分辨率处生成图像的若干个超像素,并将其调整到原始大小,再将原始图像中所有超像素区域按一定比例替换为超像素,其他区域不改变。 71 | 依赖: albumentations 72 | 注意: 该方法可能速度较慢。 73 | 参数: 74 | p_replace (float): defines the probability of any superpixel area being replaced by the superpixel, i.e. by 75 | the average pixel color within its area. Default: 0.1. 76 | n_segments (int): target number of superpixels to generate. Default: 100. 77 | p (float): probability of applying the transform. Default: 0.5. 78 | 输入: 79 | numpy数组 80 | 返回: 81 | 字典,形式为 {'image': array(...)} 82 | 示例: 83 | >>> import albumentations as A 84 | >>> import numpy as np 85 | >>> from PIL import Image 86 | >>> img = Image.open('img.jpg') 87 | >>> np_img = np.asarray(img) 88 | >>> t = A.IAASuperpixels() 89 | >>> img = Image.fromarray(t(image = np_img)['image']) 90 | ''' 91 | 92 | ''' 93 | IAAPerspective: 94 | 随机四点透视变换 95 | 依赖: albumentations 96 | 参数: 97 | scale ((float, float): standard deviation of the normal distributions. These are used to sample 98 | the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). 99 | p (float): probability of applying the transform. Default: 0.5. 100 | 输入: 101 | numpy数组 102 | 返回: 103 | 字典,形式为 {'image': array(...)} 104 | 示例: 105 | >>> import albumentations as A 106 | >>> import numpy as np 107 | >>> from PIL import Image 108 | >>> img = Image.open('img.jpg') 109 | >>> np_img = np.asarray(img) 110 | >>> t = A.IAAPerspective() 111 | >>> img = Image.fromarray(t(image = np_img)['image']) 112 | ''' 113 | 114 | ''' 115 | CoarseDropout: 116 | 在面积大小可选定、位置随机的矩形区域上丢失信息实现转换,产生黑色矩形块。 117 | 依赖: albumentations 118 | 参数: 119 | max_holes (int): Maximum number of regions to zero out. 120 | max_height (int): Maximum height of the hole. 121 | max_width (int): Maximum width of the hole. 122 | min_holes (int): Minimum number of regions to zero out. If `None`, 123 | `min_holes` is be set to `max_holes`. Default: `None`. 124 | min_height (int): Minimum height of the hole. Default: None. If `None`, 125 | `min_height` is set to `max_height`. Default: `None`. 126 | min_width (int): Minimum width of the hole. If `None`, `min_height` is 127 | set to `max_width`. Default: `None`. 128 | fill_value (int, float, lisf of int, list of float): value for dropped pixels. 129 | 输入: 130 | numpy数组 131 | 返回: 132 | 字典,形式为 {'image': array(...)} 133 | 示例: 134 | >>> import albumentations as A 135 | >>> import numpy as np 136 | >>> from PIL import Image 137 | >>> img = Image.open('img.jpg') 138 | >>> np_img = np.asarray(img) 139 | >>> t = A.CoarseDropout() 140 | >>> img = Image.fromarray(t(image = np_img)['image']) 141 | ''' 142 | 143 | ''' 144 | EdgeDetect: 145 | 边界检测,检测图像中的所有边缘,将它们标记为黑白图像,再将结果与原始图像叠加。 146 | 依赖: imgaug 147 | 参数: 148 | alpha : number or tuple of number or list of number or imgaug.parameters.StochasticParameter, optional 149 | Blending factor of the edge image. At ``0.0``, only the original 150 | image is visible, at ``1.0`` only the edge image is visible. 151 | 152 | * If a number, exactly that value will always be used. 153 | * If a tuple ``(a, b)``, a random value will be sampled from the 154 | interval ``[a, b]`` per image. 155 | * If a list, a random value will be sampled from that list 156 | per image. 157 | * If a ``StochasticParameter``, a value will be sampled from that 158 | parameter per image. 159 | 160 | seed : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional 161 | See :func:`~imgaug.augmenters.meta.Augmenter.__init__`. 162 | 163 | name : None or str, optional 164 | See :func:`~imgaug.augmenters.meta.Augmenter.__init__`. 165 | 166 | random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional 167 | Old name for parameter `seed`. 168 | Its usage will not yet cause a deprecation warning, 169 | but it is still recommended to use `seed` now. 170 | Outdated since 0.4.0. 171 | 172 | deterministic : bool, optional 173 | Deprecated since 0.4.0. 174 | See method ``to_deterministic()`` for an alternative and for 175 | details about what the "deterministic mode" actually does. 176 | 输入: 177 | numpy数组 178 | 返回: 179 | numpy数组 180 | 示例: 181 | >>> import imgaug.augmenters as iaa 182 | >>> import numpy as np 183 | >>> from PIL import Image 184 | >>> img = Image.open('img.jpg') 185 | >>> np_img = np.asarray(img) 186 | >>> t = iaa.EdgeDetect(alpha=(0.0, 1.0)) 187 | >>> img = Image.fromarray(t(image = np_img)) 188 | ''' 189 | 190 | ''' 191 | RandomAffine: 192 | 仿射变换,对图像进行旋转、水平偏移、裁剪、放缩等操作,保持中心不变。 193 | 依赖: torchvision 194 | 参数: 195 | degrees (sequence or float or int): Range of degrees to select from. 196 | If degrees is a number instead of sequence like (min, max), the range of degrees 197 | will be (-degrees, +degrees). Set to 0 to deactivate rotations. 198 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal 199 | and vertical translations. For example translate=(a, b), then horizontal shift 200 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is 201 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. 202 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is 203 | randomly sampled from the range a <= scale <= b. Will keep original scale by default. 204 | shear (sequence or float or int, optional): Range of degrees to select from. 205 | If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) 206 | will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the 207 | range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, 208 | a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. 209 | Will not apply shear by default 210 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 211 | An optional resampling filter. See `filters`_ for more information. 212 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 213 | fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area 214 | outside the transform in the output image.(Pillow>=5.0.0) 215 | 输入: 216 | PIL图片 217 | 返回: 218 | PIL图片 219 | 示例: 220 | >>> from torchvision import transforms as T 221 | >>> from PIL import Image 222 | >>> img = Image.open('img.jpg') 223 | >>> t = T.RandomAffine(25, translate=(0.2,0.2), scale=(0.8,1.2), shear=8, resample=Image.BILINEAR) 224 | >>> img = t(img) 225 | ''' 226 | 227 | ''' 228 | randomGausNoise: 229 | 高斯模糊。 230 | 实现: 231 | def randomGausNoise(image): 232 | dice = random() 233 | if dice<0.5: 234 | return image.filter(ImageFilter.GaussianBlur(radius=random()*1.7+0.5) ) 235 | else: 236 | return image 237 | 输入: 238 | PIL图片 239 | 返回: 240 | PIL图片 241 | 示例: 242 | >>> from torchvision import transforms as T 243 | >>> from random import random 244 | >>> from PIL import Image, ImageFilter 245 | >>> img = Image.open('img.jpg') 246 | >>> t = T.Lambda(randomGausNoise) 247 | >>> img = t(img) 248 | ''' 249 | 250 | ''' 251 | cropAndPadImage: 252 | 裁切图像并用黑色像素补齐图像至目标大小。 253 | 实现: 254 | target_size = (224, 224) 255 | background_color = (0,0,0) 256 | def cropAndPadImage(img): 257 | w, h = img.size 258 | if w==h: 259 | if w>target_size[0]: 260 | return img.resize(target_size) 261 | else: 262 | return img 263 | if w>h: 264 | x0 = int( (w-h)/4 ) 265 | x1 = w - x0 266 | y0 = 0 267 | y1 = h 268 | padding_length = x1-x0 269 | padding_size = (padding_length, padding_length) 270 | pad_x0 = 0 271 | pad_x1 = padding_length 272 | pad_y0 = int( (w-h)/4 ) 273 | pad_y1 = pad_y0 + h 274 | else : 275 | x0 = 0 276 | x1 = w 277 | y0 = int( (h-w)/4 ) 278 | y1 = h - y0 279 | padding_length = y1-y0 280 | padding_size = (padding_length, padding_length) 281 | pad_x0 = int( (h-w)/4 ) 282 | pad_x1 = pad_x0 + w 283 | pad_y0 = 0 284 | pad_y1 = padding_length 285 | cropped_img = img.crop( (x0,y0, x1,y1) ) 286 | padded_img = Image.new('RGB', padding_size, background_color) 287 | #print(img.size, padding_size, cropped_img.size, (pad_x0, pad_y0, pad_x1, pad_y1) ) 288 | padded_img.paste(cropped_img, (pad_x0, pad_y0, pad_x1, pad_y1) ) 289 | resized_img = padded_img.resize(target_size) 290 | return resized_img 291 | 输入: 292 | PIL图片 293 | 返回: 294 | PIL图片 295 | 示例: 296 | >>> from torchvision import transforms as T 297 | >>> from random import random 298 | >>> from PIL import Image 299 | >>> img = Image.open('img.jpg') 300 | >>> t = T.Lambda(cropAndPadImage) 301 | >>> img = t(img) 302 | ''' -------------------------------------------------------------------------------- /mec/comms/sync_rpc.py: -------------------------------------------------------------------------------- 1 | # process_signalling.py 2 | # 创建:陈硕 3 | # 创建日期:2020.06.20 4 | # 文件描述:进程通信封装,只用于通知,不用于大量传输数据 5 | 6 | 7 | 8 | import zmq 9 | import time 10 | 11 | 12 | # ----------------------------- 同步RPC ----------------------------- 13 | # 多个server对应一个client 14 | 15 | test_signal = 'test' 16 | good_signal = 'good' 17 | check_signal = 'check' 18 | quit_signal = 'quit' 19 | start_signal = 'start' 20 | 21 | 22 | 23 | class SyncRpcBase: 24 | def __init__(self, server_ip, port, logger): 25 | self.printToLog = logger 26 | # 27 | self.context = None 28 | self.publish_addr = "tcp://{}:{}".format(server_ip, str(port) ) 29 | self.publish_socket = None 30 | self.report_addr = "tcp://{}:{}".format(server_ip, str(port+1) ) 31 | self.report_socket = None 32 | self.check_addr = "tcp://{}:{}".format(server_ip, str(port+2) ) 33 | self.check_socket = None 34 | self.logger = logger 35 | self.initSocket() 36 | 37 | def __del__(self): 38 | self.closeSocket() 39 | 40 | 41 | class _Method: 42 | # some magic to bind an RPC method to an RPC server. 43 | # supports "nested" methods (e.g. examples.getStateName) 44 | def __init__(self, name, send, logger=print): 45 | self.__name = name 46 | self.__send = send 47 | self.printToLog = logger 48 | def __getattr__(self, name): 49 | self.printToLog("attribute {:s}".format(name) ) 50 | method = _Method("{:s}.{:s}".format(self.__name, name), self.__send) 51 | self.__setattr__(name, method) 52 | return method 53 | def __call__(self, *args, **kwargs): 54 | return self.__send(self.__name, *args, **kwargs) 55 | 56 | class SyncRpcController(SyncRpcBase): 57 | """ 58 | 同步RPC服务端 59 | """ 60 | def __init__(self, server_ip, port, worker_num, logger=print): 61 | # 62 | self.printToLog = logger 63 | self.printToLog('initiating synchronized rpc controller') 64 | SyncRpcBase.__init__(self, server_ip, port, logger) 65 | self.worker_num = worker_num 66 | self.is_working = False 67 | self.is_looping = False 68 | # check workers 69 | # self check 70 | self.printToLog('waiting for publishing socket') 71 | self._WaitPubSockReady() 72 | self.printToLog('publishing socket ready') 73 | self.printToLog('synchronizing all workers') 74 | self._waitAllWorkersReady() 75 | self.printToLog('confirmed {} workers ready'.format(self.worker_num) ) 76 | 77 | def __del__(self): 78 | self.closeSocket() 79 | 80 | def __getattr__(self, name): 81 | method = _Method(name, self._callMethod, self.printToLog) 82 | self.__setattr__(name, method) 83 | return method 84 | 85 | def initSocket(self): 86 | self.printToLog("initizating socket:") 87 | self.printToLog("publish addr = '{}'".format(self.publish_addr) ) 88 | self.printToLog("report addr = '{}'".format(self.report_addr) ) 89 | self.context = zmq.Context() 90 | # publish socket 91 | self.publish_socket = self.context.socket(zmq.PUB) 92 | self.publish_socket.bind(self.publish_addr) 93 | # report socket 94 | self.report_socket = self.context.socket(zmq.PULL) 95 | self.report_socket.bind(self.report_addr) 96 | # self check socket 97 | self.self_check_sub_socket = self.context.socket(zmq.SUB) 98 | self.self_check_sub_socket.connect(self.publish_addr) 99 | self.self_check_sub_socket.setsockopt(zmq.SUBSCRIBE, b'') 100 | # workers check socket 101 | self.check_socket = self.context.socket(zmq.REQ) 102 | self.check_socket.bind(self.check_addr) 103 | # 104 | self.printToLog("socket initizating complete") 105 | 106 | def closeSocket(self): 107 | #self.printToLog("closing socket ...") 108 | if self.publish_socket != None: 109 | self.publish_socket.unbind(self.publish_addr) 110 | self.publish_socket = None 111 | if self.report_socket != None: 112 | self.report_socket.unbind(self.report_addr) 113 | self.report_socket = None 114 | if self.self_check_sub_socket != None: 115 | self.self_check_sub_socket.disconnect(self.publish_addr) 116 | self.self_check_sub_socket = None 117 | if self.check_socket !=None: 118 | self.check_socket.unbind(self.check_addr) 119 | self.check_socket = None 120 | #self.printToLog("socket closed") 121 | 122 | def _callMethod(self, name, *args, **kwargs): 123 | """ 124 | 调用工作者的方法,并获取返回值 125 | """ 126 | self.printToLog("calling: ", (name, args, kwargs) ) 127 | self._broadcastMessage( (name, args, kwargs) ) 128 | result = self._gatherMessages() 129 | self.printToLog("result: ", result) 130 | return result 131 | 132 | def _broadcastMessage(self, msg): 133 | """ 134 | 将消息广播至所有的工作者 135 | """ 136 | self.printToLog("message sent:", msg) 137 | self.publish_socket.send( repr(msg).encode() ) 138 | 139 | def _recieveSingleMessage(self): 140 | """ 141 | 从工作者收集信息 142 | 一次只收集一个 143 | """ 144 | return eval(self.report_socket.recv().decode()) 145 | 146 | def _gatherMessages(self): 147 | """ 148 | 从所有的工作者汇集消息 149 | 等候直到所有的工作者消息汇总完毕 150 | 返回一个list 151 | """ 152 | result_list = [] 153 | for i in range(self.worker_num): 154 | result_list.append( eval(self.report_socket.recv(0).decode()) ) 155 | self.printToLog(i+1, "results recieved") 156 | return result_list 157 | 158 | def _WaitPubSockReady(self): 159 | while True: 160 | self.publish_socket.send(repr(test_signal).encode()) 161 | try: 162 | result = self.self_check_sub_socket.recv(zmq.NOBLOCK) 163 | if result is not None: 164 | result = eval(result.decode()) 165 | self.printToLog('message: ',result) 166 | if result == test_signal: break 167 | except zmq.ZMQError: 168 | self.printToLog('not ready') 169 | time.sleep(1) 170 | 171 | def _waitAllWorkersReady(self): 172 | self._sendControlSignal(check_signal, sync_check=True) 173 | # if self.is_working: 174 | # self.printToLog('warning! checking workers in working status') 175 | # return 176 | # workers_set = set() 177 | # while True: 178 | # self.printToLog('sending control signal, confirmed worker num: ',len(workers_set)) 179 | # # count workers 180 | # self.check_socket.send(repr(check_signal).encode() ) 181 | # rank = eval(self.check_socket.recv().decode()) 182 | # self.printToLog('worker respond got, rank {}'.format(rank)) 183 | # assert type(rank) is int, 'check respond signal error, should be int' 184 | # assert rank=0, \ 185 | # 'worker respond rank exceeds limit, worker num {}, get {}'.format( 186 | # self.worker_num, rank) 187 | # if rank not in workers_set: 188 | # workers_set.add(rank) 189 | # self.printToLog('counted workers: ', workers_set) 190 | # if len(workers_set)==self.worker_num: # counted all 191 | # break 192 | # else: #rank in workers_set: indicate delay joined worker, count again 193 | # self.printToLog(' [warning] delay joined worker, rank {}, count again'.format(rank) ) 194 | # time.sleep(1) 195 | # workers_set = { rank } 196 | 197 | def _sendControlSignal(self, signal, sync_check=False): 198 | if self.is_working: 199 | self.printToLog('warning! sending control signal in working status') 200 | return 201 | workers_set = set() 202 | while len(workers_set)=0, \ 209 | 'worker respond rank exceeds limit, worker num {}, get {}'.format( 210 | self.worker_num, rank) 211 | if rank not in workers_set: 212 | workers_set.add(rank) 213 | self.printToLog('counted workers: ', workers_set) 214 | if len(workers_set)==self.worker_num: # counted all 215 | break 216 | elif sync_check: # not synchronized: count again 217 | self.printToLog('==== delay joined worker, rank {}, count again'.format(rank) ) 218 | time.sleep(1) 219 | workers_set = { rank } 220 | else: 221 | raise Exception('unhandled asynchronized signal, probably indicates disorder of workers, \ 222 | try start all workers before start controllers') 223 | 224 | 225 | def startWorking(self): 226 | if not self.is_working: 227 | self.printToLog('calling all workers to start') 228 | self._sendControlSignal(start_signal) 229 | self.is_working = True 230 | self.is_looping = True 231 | 232 | def stopWorking(self): 233 | self._callMethod('stopWorking') 234 | self.is_working = False 235 | 236 | def stopLooping(self): 237 | if self.is_working: 238 | self._broadcastMessage(quit_signal) 239 | #self._callMethod('stopLooping') 240 | else: 241 | self._sendControlSignal(quit_signal) 242 | self.is_working = False 243 | self.is_looping = False 244 | 245 | class SyncRpcWorker(SyncRpcBase): 246 | """ 247 | 同步RPC客户端 248 | """ 249 | def __init__(self, server_ip, port, rank, logger= 250 | print): 251 | SyncRpcBase.__init__(self, server_ip, port, logger) 252 | self.printToLog = logger 253 | self.rank = rank 254 | self.function_dict = {} 255 | self.is_working = False 256 | self.is_looping = False 257 | self.registerMethod(self.stopWorking) 258 | self.registerMethod(self.stopLooping) 259 | #self.registerMethod(self.detectRespond) 260 | 261 | def __del__(self): 262 | self.closeSocket() 263 | 264 | def initSocket(self): 265 | self.printToLog("initilzating socket:") 266 | self.printToLog("publish addr = '{}'".format(self.publish_addr) ) 267 | self.printToLog("report addr = '{}'".format(self.report_addr) ) 268 | self.printToLog("check addr = '{}'".format(self.check_addr) ) 269 | self.context = zmq.Context() 270 | # publish socket 271 | self.publish_socket = self.context.socket(zmq.SUB) 272 | self.publish_socket.connect(self.publish_addr) 273 | self.publish_socket.setsockopt(zmq.SUBSCRIBE, b'') 274 | # report socket 275 | self.report_socket = self.context.socket(zmq.PUSH) 276 | self.report_socket.connect(self.report_addr) 277 | # workers check socket 278 | self.check_socket = self.context.socket(zmq.REP) 279 | self.check_socket.connect(self.check_addr) 280 | 281 | def closeSocket(self): 282 | self.printToLog('closing socket, rank {}'.format(self.rank) ) 283 | if self.publish_socket != None: 284 | self.publish_socket.disconnect(self.publish_addr) 285 | self.publish_socket = None 286 | if self.report_socket != None: 287 | self.report_socket.disconnect(self.report_addr) 288 | self.report_socket = None 289 | if self.check_socket !=None: 290 | self.check_socket.disconnect(self.check_addr) 291 | self.check_socket = None 292 | 293 | def recieveBroadcast(self): 294 | return eval(self.publish_socket.recv().decode()) 295 | 296 | def reportMessage(self, msg): 297 | """ 298 | 将消息发送至控制者 299 | """ 300 | return self.report_socket.send( repr(msg).encode() ) 301 | 302 | def registerMethod(self, function, name=None): 303 | if name is None: 304 | name = function.__name__ 305 | self.function_dict[name] = function 306 | 307 | def excecuteMethod(self, func_name, args, kwargs): 308 | if func_name in self.function_dict: 309 | self.printToLog("excecuting function: [func name] {}; [args] {}; [kwargs] {}".format( 310 | func_name, args, kwargs ) 311 | ) 312 | result = self.function_dict[func_name](*args, **kwargs) 313 | else: 314 | self.printToLog("warning: wrong function name. [func name] {}; [args] {}; [kwargs] {}".format( 315 | func_name, args, kwargs 316 | )) 317 | result = None 318 | self.printToLog('result: ', result) 319 | return result 320 | 321 | # def detectRespond(self): 322 | # return good_signal 323 | 324 | def waitControlSignal(self): 325 | self.printToLog('waiting for control signal, rank {}'.format(self.rank) ) 326 | signal = eval(self.check_socket.recv().decode()) 327 | self.printToLog('controller signal recieved: \"{}\"'.format(signal)) 328 | self.printToLog('respond control signal, rank {}'.format(self.rank) ) 329 | #time.sleep(0.2) 330 | self.check_socket.send( repr(self.rank).encode() ) 331 | self.printToLog('respond sent: {}'.format(self.rank) ) 332 | return signal 333 | 334 | def mainLoop(self): 335 | #self.is_working = False 336 | self.is_looping = True 337 | while self.is_looping: 338 | signal = self.waitControlSignal() 339 | if signal == quit_signal: 340 | self.is_looping = False 341 | time.sleep(3) 342 | break 343 | if signal == check_signal: 344 | continue 345 | time.sleep(1) 346 | elif signal==start_signal: 347 | self.printToLog('start working loop') 348 | self.is_working = True 349 | while self.is_working: 350 | self.printToLog("waiting for task message") 351 | msg = self.recieveBroadcast() 352 | self.printToLog("message recieved: \"{}\"".format(msg) ) 353 | if msg==test_signal: continue 354 | elif msg==quit_signal: 355 | self.is_looping = False 356 | break 357 | func_name, func_args, func_kwargs = msg 358 | result = self.excecuteMethod(func_name, func_args, func_kwargs) 359 | self.reportMessage(result) 360 | #time.sleep(3) 361 | self.printToLog('exiting') 362 | #self.closeSocket() 363 | 364 | def stopWorking(self): 365 | self.is_working = False 366 | return good_signal 367 | 368 | def stopLooping(self): 369 | self.is_looping = False 370 | self.is_working = False -------------------------------------------------------------------------------- /mec/training/sync_trainer.py: -------------------------------------------------------------------------------- 1 | # sync_trainer.py 2 | # created: CS 3 | # 多进程多卡同步训练模块封装 4 | 5 | 6 | 7 | import json 8 | import os 9 | import time 10 | import numpy as np 11 | import torch 12 | import torch.multiprocessing as mp 13 | from torch.utils.data import DataLoader, DistributedSampler 14 | 15 | from .basic_trainer import BasicTrainer 16 | from ..comms.sync_rpc import SyncRpcWorker, SyncRpcController 17 | from ..comms.transmit import DistEnv, TensorTransmittor 18 | from ..utils.logs import Logger 19 | from ..utils.monitor import Monitor 20 | from ..utils.history import History 21 | from ..configs.default_config import conf_g 22 | 23 | class WorkerSync(): 24 | """ 25 | 同步训练worker类 26 | 每次接受manager广播来分配batch 27 | forward和backward结束后 28 | 用dist.reduce上传grad 29 | 用dist.broadcast回传weights 30 | grad在manager上延迟更新 31 | 32 | 参数: 33 | model : 数据模型 34 | dataset_dict : 数据集表,key为name,value为单个dataset 35 | device : 使用的设备 36 | rank : 进程编号,在组内应唯一 37 | sync_worker_num : worker数量 38 | controll_ip : 仅IP地址,如"127.0.0.1" 39 | control_port : 发布使用的端口 40 | report_port : 返回分数使用的端口 41 | dist_port : torch.dist使用的端口 42 | """ 43 | def __init__(self, 44 | model, optimizer, criterion, metrics, 45 | dataset_dict, batch_transform_dict, 46 | batch_size, process_num_per_loader, 47 | rank, gpu_id, sync_worker_num, control_ip='127.0.0.1', 48 | port=12500, logger=print): 49 | self.printToLog = logger 50 | rpc_port = port 51 | dist_port = str(port+10) 52 | self.printToLog("initiating worker {} ...".format(rank)) 53 | self.dataset_dict = dataset_dict 54 | self.batch_transform_dict = batch_transform_dict 55 | self.sync_worker_num = sync_worker_num 56 | self.printToLog("worker num :", sync_worker_num) 57 | self.world_size = sync_worker_num 58 | self.printToLog('initiating data loader ...') 59 | # for dataset_name in dataset_dict: 60 | # print( 61 | # "\ndataset_name:", dataset_name, dataset_dict[dataset_name], 62 | # "\nbatch_size:", int(batch_size/self.world_size), 63 | # "\nsampler:", DistributedSampler(dataset_dict[dataset_name], sync_worker_num, rank), 64 | # "\nnum_workers:", process_num_per_loader, 65 | # ) 66 | self.dataloader_dict = { 67 | dataset_name: DataLoader( 68 | dataset_dict[dataset_name], 69 | batch_size = int(batch_size/self.world_size), 70 | sampler = DistributedSampler(dataset_dict[dataset_name], sync_worker_num, rank), 71 | num_workers = process_num_per_loader, 72 | pin_memory = True 73 | ) 74 | for dataset_name in dataset_dict 75 | } 76 | self.printToLog("initiating device") 77 | self.device = torch.device('cuda:{}'.format(gpu_id)) 78 | self.printToLog("device:", self.device) 79 | self.rank = rank 80 | self.printToLog("rank:", rank) 81 | self.printToLog("initiating model") 82 | self.model = model #.to(self.device) 83 | self.model.to(self.device) 84 | self.printToLog("initiating trainer") 85 | self.trainer = BasicTrainer(model, optimizer, criterion, metrics) 86 | self.printToLog("initiating torch.distributed environment") 87 | self.env = DistEnv(rank, self.world_size, control_ip, dist_port, self.printToLog) 88 | self.default_group = self.env.newGroup(range(self.world_size)) 89 | self.printToLog("initiating transmittor") 90 | self.transmittor = TensorTransmittor(list(range(self.world_size)) , logger=self.printToLog) 91 | self.printToLog("initialing rpc proxy") 92 | self.rpc_server = SyncRpcWorker(control_ip, rpc_port, rank, self.printToLog) 93 | self.rpc_server.registerMethod(self.averagingGrads) 94 | self.rpc_server.registerMethod(self.averagingWeights) 95 | self.rpc_server.registerMethod(self.broadCastModelWeights) 96 | self.rpc_server.registerMethod(self.gatherAveragedModelWeights) 97 | # ------------ training methods ------------ 98 | self.rpc_server.registerMethod(self.initTrainEpoch) 99 | self.rpc_server.registerMethod(self.batchTrainNoUpdate) 100 | self.rpc_server.registerMethod(self.updateWeights) 101 | # ------------ validation methods ------------ 102 | self.rpc_server.registerMethod(self.initTestEpoch) 103 | self.rpc_server.registerMethod(self.batchTest) 104 | # ------------ saving methods ------------ 105 | self.rpc_server.registerMethod(self.saveModelWeights) 106 | self.rpc_server.registerMethod(self.loadModelWeights) 107 | # init rpc proxy last, after all preparations are ready 108 | self.printToLog('workers ready') 109 | 110 | def mainLoop(self): 111 | self.rpc_server.mainLoop() 112 | 113 | # def _returnScore(self, flag, batch, sample_num, loss, met): 114 | # respond = { 115 | # 'flag' : flag, 116 | # 'batch' : batch, 117 | # 'samples': sample_num, 118 | # 'loss' : loss, 119 | # 'met' : met 120 | # } 121 | # self.printToLog(repr(respond)[1:-2]) 122 | # self.rpc_server.reportMessage(respond) 123 | 124 | # data communicating -------------------------- 125 | def averagingWeights(self, style='full'): 126 | """ 127 | 各个进程的模型参数取平均 128 | """ 129 | self.printToLog("averaging weights") 130 | self.transmittor.crossTensors(self.trainer.model, style=style) 131 | 132 | def averagingGrads(self, style='full'): 133 | """ 134 | 各个进程的模型梯度取平均 135 | """ 136 | self.printToLog("averaging grads") 137 | self.transmittor.crossGrads(self.trainer.model, style=style) 138 | 139 | def gatherAveragedModelWeights(self, rank, group=None): 140 | """ 141 | 将所有进程的权重集中求平均 142 | 结果保存至一个进程 143 | 用于初始化时同步 144 | """ 145 | self.transmittor.meanGatherTensors(self.trainer.model, rank, group) 146 | 147 | def broadCastModelWeights(self, rank, group=None): 148 | """ 149 | 从一个进程向其他进程广播模型权重 150 | """ 151 | self.transmittor.broadcastTensors(self.trainer.model, rank, group) 152 | 153 | # training methods ---------------------------- 154 | def initTrainEpoch(self, dataset_name, epoch, lr): 155 | self.printToLog("initializing training epoch {}".format(epoch)) 156 | self.printToLog("learning rate: {}".format(lr)) 157 | self.trainer.initEpoch() 158 | self.trainer.setLearningRate(lr) 159 | self.trainer.model.train() 160 | self.train_batch_index = 0 161 | self.printToLog("initializing train loader iter") 162 | self.printToLog('dataloader length: ', len(self.dataloader_dict[dataset_name]) ) 163 | self.printToLog('dataloader dict: ', self.dataloader_dict) 164 | train_loader = self.dataloader_dict[dataset_name] 165 | self.printToLog('dataloader: ', train_loader) 166 | try: 167 | self.train_iter = iter(train_loader) 168 | except Exception as e: 169 | self.printToLog(e) 170 | raise e 171 | self.printToLog('dataloader iter', self.train_iter) 172 | if dataset_name in self.batch_transform_dict: 173 | self.train_batch_transform = self.batch_transform_dict[dataset_name] 174 | self.printToLog("setting up batch transforms") 175 | else: 176 | self.train_batch_transform = {lambda x: x} 177 | self.printToLog("no batch transforms") 178 | self.printToLog("epoch {}, begin training".format(epoch)) 179 | 180 | def batchTrainNoUpdate(self): 181 | self.printToLog("train batch {}".format(self.train_batch_index) ) 182 | data, target = next(self.train_iter) 183 | batch_sample_num = len(target) 184 | self.printToLog("getting data") 185 | data, target = data.to(self.device), target.to(self.device) 186 | if self.train_batch_transform is not None: 187 | self.printToLog(self.train_batch_transform.__name__) 188 | data, target = self.train_batch_transform(data, target) 189 | self.printToLog("forwarding") 190 | self.trainer.forwardData(data) 191 | self.printToLog("backwarding") 192 | self.trainer.backwardGrad(target) 193 | self.train_batch_index += 1 194 | loss, met = self.trainer.getScores() 195 | return batch_sample_num, loss, met 196 | 197 | def updateWeights(self): 198 | self.printToLog("updating weights") 199 | self.trainer.updateWeights() 200 | 201 | # validation methods ---------------------------------- 202 | def initTestEpoch(self, dataset_name, epoch): 203 | self.printToLog("initizating validation epoch {}".format(epoch)) 204 | valid_loader = self.dataloader_dict[dataset_name] 205 | valid_loader.sampler.set_epoch(epoch) 206 | self.valid_iter = iter(valid_loader) 207 | self.trainer.model.eval() 208 | self.valid_batch_index = 0 209 | if dataset_name in self.batch_transform_dict: 210 | batch_transform = self.batch_transform_dict[dataset_name] 211 | self.printToLog("setting up batch transforms") 212 | else: 213 | batch_transform = {lambda x: x} 214 | self.printToLog("no batch transforms") 215 | self.printToLog("epoch {}, begin validation".format(epoch)) 216 | 217 | # for batch, (data, target) in enumerate(valid_loader): 218 | # data, target = data.to(self.device), target.to(self.device) 219 | # data, target = batch_transform(data, target) 220 | # self.trainer.forwardNoGrad(data) 221 | # self.trainer.calcScores(target) 222 | # self._returnScore('valid_batch', batch, len(data), *self.trainer.getScores()) 223 | 224 | def batchTest(self): 225 | self.printToLog("validation batch {}".format(self.valid_batch_index)) 226 | data, target = next(self.valid_iter) 227 | batch_sample_num = len(target) 228 | self.printToLog("getting data") 229 | data, target = data.to(self.device), target.to(self.device) 230 | if self.train_batch_transform is not None: 231 | data, target = self.train_batch_transform(data, target) 232 | self.printToLog("forwarding") 233 | self.trainer.forwardData(data) 234 | self.printToLog("backwarding") 235 | self.trainer.backwardGrad(target) 236 | self.valid_batch_index += 1 237 | loss, met = self.trainer.getScores() 238 | return batch_sample_num, loss, met 239 | 240 | # ------------ saving methods ------------ 241 | def saveModelWeights(self, filename, rank=0, rank_list=[-1]): 242 | """ 243 | 指定位置保存当前模型权重 244 | """ 245 | if self.rank == rank or self.rank in rank_list: 246 | self.trainer.saveModel(filename) 247 | 248 | def loadModelWeights(self, filename, rank=0, rank_list=[-1]): 249 | """ 250 | 从指定位置读取模型权重 251 | """ 252 | self.printToLog("rank", self.rank, ", loading model, ") 253 | self.printToLog("filename={}, rank={}".format(filename, rank) ) 254 | if self.rank==rank or self.rank in rank_list: 255 | if not os.path.exists(filename): 256 | self.printToLog('warning: no model weight file found') 257 | return 258 | try: 259 | self.trainer.loadModel(filename, map_location=self.device) 260 | except Exception as e: 261 | self.printToLog(e) 262 | raise e 263 | 264 | 265 | class ControllerSync(): 266 | """ 267 | 268 | """ 269 | def __init__(self, 270 | train_set_len, valid_set_len, batch_size, 271 | control_ip, port, sync_worker_num, 272 | logger=print 273 | ): 274 | rpc_port = port 275 | dist_port = str(port+10) 276 | # 日志 277 | self.printToLog = logger 278 | # 控制接口 279 | self.rpcWorkers = SyncRpcController(control_ip, port, sync_worker_num, self.printToLog) 280 | #self.rpcWorkers.startWorking() 281 | # 数据 282 | self.train_set_len = train_set_len 283 | self.valid_set_len = valid_set_len 284 | self.train_loader_len = int(np.ceil(train_set_len/batch_size) ) 285 | self.valid_loader_len = int(np.ceil(valid_set_len/batch_size) ) 286 | # 训练 287 | #self.lr_scheduler = lr_scheduler 288 | #self.metric_name = metric_name 289 | #self.monitor = Monitor(init_epoch, total_epochs, self.train_loader_len, self.valid_loader_len, metric_name) 290 | # 评价 291 | self.best_epoch_loss = 1e6 292 | self.best_epoch_met = 0. 293 | 294 | def __del__(self): 295 | #self.rpcWorkers.stopLooping() 296 | pass 297 | 298 | # def averagingGrads(self, style='full'): 299 | # self.rpcWorkers.averagingGrads(style='full') 300 | 301 | # def averagingWeights(self, style='full'): 302 | # self.rpcWorkers.averagingWeights(style='full') 303 | 304 | # def broadCastModelWeights(self, rank, group=None): 305 | # self.rpcWorkers.broadCastModelWeights(rank, group) 306 | 307 | # def gatherAveragedModelWeights(self, rank, group=None): 308 | # self.rpcWorkers.gatherAveragedModelWeights(rank, group=None) 309 | 310 | # def initTrainEpoch(self, dataset_name, epoch, lr): 311 | # #self.printToLog("initiating training epoch {}, lr={}".format(epoch, lr)) 312 | # self.rpcWorkers.initTrainEpoch(dataset_name, epoch, lr) 313 | 314 | # def batchTrainNoUpdate(self, batch_index): 315 | # #self.printToLog("training batch {}".format(batch_index)) 316 | # return self.rpcWorkers.batchTrainNoUpdate() 317 | 318 | # def updateWeights(self): 319 | # self.rpcWorkers.updateWeights() 320 | 321 | # def initTestEpoch(self, dataset_name, epoch): 322 | # #self.printToLog("initiating validation epoch {}".format(epoch)) 323 | # self.rpcWorkers.initTestEpoch(dataset_name, epoch) 324 | 325 | # def batchTest(self): 326 | # #self.printToLog("validation batch {}".format(batch_index)) 327 | # return self.rpcWorkers.batchTest() 328 | 329 | # def saveModelWeights(self, filename, rank=1, rank_list=[-1]): 330 | # self.rpcWorkers.saveModelWeights(filename, rank=1, rank_list=[-1]) 331 | 332 | # def loadModelWeights(self, filename, rank=1, rank_list=[-1]): 333 | # self.rpcWorkers.loadModelWeights(filename, rank=1, rank_list=[-1]) 334 | 335 | def startWorking(self): 336 | self.rpcWorkers.startWorking() 337 | 338 | def stopWorking(self): 339 | self.rpcWorkers.stopWorking() 340 | 341 | def stopLooping(self): 342 | self.rpcWorkers.stopLooping() 343 | 344 | def trainEpoch(self, epoch, lr, dataset_name='train', style='full', monitor=None): 345 | #lr = self.lr_scheduler(epoch) 346 | self.printToLog("initiating training epoch {}; lr={}".format(epoch, lr) ) 347 | self.rpcWorkers.initTrainEpoch(dataset_name, epoch, lr) 348 | total_sample_num = 0 349 | total_loss = 0 350 | total_met = 0 351 | total_avg_loss = 0 352 | total_avg_met = 0 353 | for batch_index in range(self.train_loader_len): 354 | self.printToLog("training batch {}".format(batch_index)) 355 | result_list = self.rpcWorkers.batchTrainNoUpdate() 356 | #self.averagingGrads() 357 | self.rpcWorkers.updateWeights() 358 | self.rpcWorkers.averagingWeights(style=style) 359 | batch_sample_num = 0 360 | batch_total_loss = 0 361 | batch_total_met = 0 362 | for sample_num, loss, met in result_list: 363 | self.printToLog('single_result: loss={}, score={}'.format(loss, met) ) 364 | batch_sample_num += sample_num 365 | batch_total_loss += loss * sample_num 366 | batch_total_met += met * sample_num 367 | total_sample_num += batch_sample_num 368 | total_loss += batch_total_loss 369 | total_met += batch_total_met 370 | batch_loss = batch_total_loss / batch_sample_num 371 | batch_met = batch_total_met / batch_sample_num 372 | total_avg_loss = total_loss/total_sample_num 373 | total_avg_met = total_met /total_sample_num 374 | if monitor!=None: 375 | monitor.updateTraining(batch_loss, total_avg_loss, batch_met, total_avg_met) 376 | return total_avg_loss, total_avg_met 377 | 378 | def testEpoch(self, epoch, dataset_name='valid', monitor=None): 379 | self.printToLog("initiating test epoch {}".format(epoch)) 380 | self.rpcWorkers.initTestEpoch(dataset_name, epoch) 381 | total_val_sample_num = 0 382 | total_val_loss = 0 383 | total_val_met = 0 384 | total_avg_val_loss = 0 385 | total_avg_val_met = 0 386 | for val_batch_index in range(self.valid_loader_len): 387 | self.printToLog("test batch {}".format(val_batch_index)) 388 | result_list = self.rpcWorkers.batchTest() 389 | batch_val_sample_num = 0 390 | batch_val_total_loss = 0 391 | batch_val_total_met = 0 392 | for val_sample_num, val_loss, val_met in result_list: 393 | self.printToLog('single_result: loss={}, score={}'.format(val_loss, val_met) ) 394 | batch_val_sample_num += val_sample_num 395 | batch_val_total_loss += val_loss * val_sample_num 396 | batch_val_total_met += val_met * val_sample_num 397 | total_val_sample_num += batch_val_sample_num 398 | total_val_loss += batch_val_total_loss 399 | total_val_met += batch_val_total_met 400 | batch_val_loss = batch_val_total_loss / batch_val_sample_num 401 | batch_val_met = batch_val_total_met / batch_val_sample_num 402 | total_avg_val_loss = total_val_loss/total_val_sample_num 403 | total_avg_val_met = total_val_met /total_val_sample_num 404 | if monitor!=None: 405 | monitor.updateValidation(batch_val_loss, total_avg_val_loss, batch_val_met, total_avg_val_met) 406 | if total_val_met>self.best_epoch_met or self.best_epoch_met==None: 407 | self.is_best_epoch = True 408 | self.best_epoch_loss = total_val_loss 409 | self.best_epoch_met = total_val_met 410 | return total_avg_val_loss, total_avg_val_met 411 | 412 | # def endEpoch(self): 413 | # self.monitor.updateEpoch() 414 | 415 | def saveModelWeights(self, filename, rank=0): 416 | self.printToLog('saving model') 417 | self.rpcWorkers.saveModelWeights(filename, rank) 418 | 419 | def loadModelWeights(self, filename, rank=0): 420 | self.printToLog('loading model') 421 | self.rpcWorkers.loadModelWeights(filename, rank) 422 | self.rpcWorkers.broadCastModelWeights(rank) 423 | 424 | # ========================== # ========================== 425 | def startWorkerProcess( 426 | model, optimizer, criterion, metrics, 427 | dataset_dict, batch_transform_dict, 428 | batch_size, process_num_per_loader, 429 | rank, gpu_id, 430 | sync_worker_num, 431 | control_ip, port 432 | ): 433 | logger = Logger( 434 | filepath ='logs/worker_{}.log'.format(rank), 435 | prefix ='worker_{}|gpu_{}'.format(rank, gpu_id) 436 | ) 437 | torch.cuda.set_device( gpu_id ) 438 | os.environ['CUDA_VISIBLE_DEVICE']='{}'.format(gpu_id) 439 | #time.sleep(3) 440 | worker = WorkerSync( 441 | model, optimizer, criterion, metrics, 442 | dataset_dict, batch_transform_dict, 443 | batch_size, process_num_per_loader, rank, gpu_id, 444 | sync_worker_num, control_ip, port, logger) 445 | worker.mainLoop() 446 | 447 | def startWorkers( 448 | model, optimizer, criterion, metrics, 449 | train_set, valid_set, 450 | batch_size = conf_g.batch_size, 451 | sync_worker_num = conf_g.sync_worker_num, 452 | process_num_per_loader = conf_g.process_num_per_loader, 453 | rank_list = conf_g.worker_ranks, 454 | gpu_id_list = conf_g.worker_gpu_ids, 455 | control_ip = conf_g.control_ip, 456 | port = conf_g.basic_port, 457 | train_batch_transform = None, 458 | valid_batch_transform = None 459 | ): 460 | assert len(rank_list)==len(gpu_id_list), 'rank_list has different length from gpu_id_list' 461 | assert min(rank_list)>=0, 'rank must be greater than 0' 462 | assert max(rank_list)0: 714 | self.best_loss = min(self.history['train_loss']) 715 | self.best_met = max(self.history['train_met']) 716 | self.best_val_loss = min(self.history['val_loss']) 717 | self.best_val_met = max(self.history['val_met']) 718 | # 通讯相关 719 | self.context = None 720 | self.task_addr = "tcp://" + manager_ip + ":" + task_port 721 | self.task_socket = None 722 | self.score_addr = "tcp://" + manager_ip + ":" + score_port 723 | self.score_socket = None 724 | # 分布式相关 725 | os.environ['MASTER_ADDR'] = manager_ip 726 | os.environ['MASTER_PORT'] = dist_port 727 | self.dist_addr = "tcp://" + manager_ip + ":" + dist_port 728 | self.rank = 0 729 | self.world_size = sync_worker_num + 1 730 | # 日志文件 731 | self.logfile = open("logs/manager.log", 'w') 732 | 733 | def __del__(self): 734 | self.closeSocket() 735 | self.logfile.close() 736 | 737 | def printToLog(self, *content): 738 | print("[manager|{}]".format(time.strftime("%y-%m-%d_%H:%M:%S") ), 739 | *content, file=self.logfile, flush=True) 740 | #print("[manager]", *content) 741 | 742 | def initSocket(self): 743 | self.printToLog("initizating socket:") 744 | self.printToLog("task addr = '{}'".format(self.task_addr) ) 745 | self.printToLog("score addr = '{}'".format(self.score_addr) ) 746 | self.context = zmq.Context() 747 | self.task_socket = self.context.socket(zmq.PUB) 748 | self.task_socket.bind(self.task_addr) 749 | self.score_socket = self.context.socket(zmq.PULL) 750 | self.score_socket.bind(self.score_addr) 751 | 752 | def closeSocket(self): 753 | self.printToLog("closing socket") 754 | if self.task_socket != None: 755 | self.task_socket.unbind(self.task_addr) 756 | self.task_socket = None 757 | if self.score_socket != None: 758 | self.score_socket.unbind(self.score_addr) 759 | self.score_socket = None 760 | 761 | def initTorchDist(self): 762 | self.printToLog("dist args:", 'nccl', self.dist_addr, 763 | self.rank, self.world_size) 764 | dist.init_process_group('nccl', 765 | rank=self.rank, world_size=self.world_size) 766 | 767 | def sendMessage(self, msg): 768 | return self.task_socket.send(repr(msg).encode() ) 769 | 770 | def recvMessage(self): 771 | return eval(self.score_socket.recv().decode() ) 772 | 773 | # 返回值:是否是最好结果 774 | def updateHistory(self, epoch, lr, loss, met, val_loss, val_met): 775 | self.history['epochs'] += 1 776 | self.history['lr'] .append(lr) 777 | self.history['train_loss'].append(loss) 778 | self.history['train_met'] .append(met) 779 | self.history['val_loss'] .append(val_loss) 780 | self.history['val_met'] .append(val_met) 781 | self.printToLog("saving history:") 782 | self.printToLog("filename = {}".format(self.history_filename)) 783 | self.printToLog("content:") 784 | self.printToLog("epoch {}, lr {}".format(epoch, lr) ) 785 | self.printToLog("lr {}, t_los {:3f}, t_met {:3f}, v_los {:3f}, v_met {:3f}".format( 786 | lr, loss, met, val_loss, val_met) 787 | ) 788 | with open(self.history_filename, 'w', encoding='utf8') as file: 789 | json.dump(self.history, file, ensure_ascii=False, indent=2) 790 | if val_met > self.best_val_met: 791 | self.best_loss = loss 792 | self.best_met = met 793 | self.best_val_loss = val_loss 794 | self.best_val_met = val_met 795 | self.printToLog("new best score at epoch {:d}:".format(epoch)) 796 | self.printToLog("loss {:.3f}, met {:.3f}, v_los {:.3f}, v_met {:.3f}".format( 797 | loss, met, val_loss, val_met) 798 | ) 799 | return True 800 | return False 801 | 802 | 803 | # manager 进程 804 | def sync_manager_process_no_gpu( 805 | metrics_name, train_set_len, valid_set_len, batch_size, 806 | init_epoch, total_epochs, lr_scheduler, sync_worker_num, 807 | manager_ip="127.0.0.1", result_path="results/temp", 808 | history=None, sync_flag='cross' 809 | ): 810 | world_size = sync_worker_num+1 811 | manager_no_gpu = ManagerSyncNoGpu( 812 | train_set_len, valid_set_len, 813 | batch_size, sync_worker_num, 814 | history=history, result_path=result_path, 815 | manager_ip=manager_ip) 816 | manager_no_gpu.initSocket() 817 | manager_no_gpu.printToLog("init monitor") 818 | monitor = Monitor(init_epoch, total_epochs, 819 | manager_no_gpu.train_loader_len, 820 | manager_no_gpu.valid_loader_len, 821 | metrics_name) 822 | manager_no_gpu.initTorchDist() 823 | message = { 824 | 'flag' : 'init', 825 | 'sync_flag' : 'cross' 826 | } 827 | manager_no_gpu.sendMessage(message) 828 | manager_no_gpu.printToLog("message sent:") 829 | manager_no_gpu.printToLog(message) 830 | for epoch in range(init_epoch, init_epoch+total_epochs): 831 | manager_no_gpu.printToLog( 832 | ( 833 | "===== train epoch {}; " 834 | "train loader len : {} =====" 835 | ).format(epoch, manager_no_gpu.train_loader_len) 836 | ) 837 | message = { 838 | 'flag' : 'train_epoch', 839 | 'epoch': epoch+1, 840 | 'lr' : lr_scheduler(epoch+1) 841 | } 842 | manager_no_gpu.sendMessage(message) 843 | manager_no_gpu.printToLog("message sent:") 844 | manager_no_gpu.printToLog(message) 845 | # training 846 | total_samples = 0 847 | total_loss = 0 848 | total_met = 0 849 | for batch in range(manager_no_gpu.train_loader_len): 850 | manager_no_gpu.printToLog("workers corssing grads ...") 851 | sample_num = 0 852 | loss = 0 853 | met = 0 854 | for worker_index in range(sync_worker_num): 855 | respond = manager_no_gpu.recvMessage() 856 | sample_num += respond['samples'] 857 | loss += respond['loss']/sync_worker_num 858 | total_loss += respond['loss'] * respond['samples'] 859 | met += respond['met']/sync_worker_num 860 | total_met += respond['met'] * respond['samples'] 861 | total_samples += sample_num 862 | avg_loss = total_loss / total_samples 863 | avg_met = total_met / total_samples 864 | manager_no_gpu.printToLog( 865 | ("batch {:d}; " 866 | "sample num : {}; " 867 | "loss : {:.3f}; " 868 | "{} : {:.3f}").format( 869 | batch, 870 | sample_num, 871 | loss, 872 | metrics_name, 873 | met 874 | ) 875 | ) 876 | monitor.updateTraining(loss, avg_loss, met, avg_met) 877 | # validation 878 | manager_no_gpu.printToLog( 879 | ( 880 | "===== valid epoch {}; " 881 | "train loader len : {} =====" 882 | ).format(epoch, manager_no_gpu.train_loader_len) 883 | ) 884 | total_val_samples = 0 885 | total_val_loss = 0 886 | total_val_met = 0 887 | message = { 888 | 'flag' : 'valid_epoch', 889 | 'epoch': epoch 890 | } 891 | manager_no_gpu.sendMessage(message) 892 | for val_batch in range(manager_no_gpu.valid_loader_len): 893 | val_sample_num = 0 894 | val_loss = 0 895 | val_met = 0 896 | for i in range(sync_worker_num): 897 | respond = manager_no_gpu.recvMessage() 898 | val_sample_num += respond['valid_samples'] 899 | val_loss += respond['valid_loss'] / sync_worker_num 900 | total_val_loss += respond['valid_loss'] * respond['valid_samples'] 901 | val_met += respond['valid_met']/sync_worker_num 902 | total_val_met += respond['valid_met'] * respond['valid_samples'] 903 | total_val_samples += val_sample_num 904 | avg_val_loss = total_val_loss / total_val_samples 905 | avg_val_met = total_val_met / total_val_samples 906 | monitor.updateValidation(val_loss, avg_val_loss, val_met, avg_val_met) 907 | manager_no_gpu.printToLog( 908 | ("v batch {:d}; " 909 | "v sample num : {}; " 910 | "v loss : {:.3f}; " 911 | "v {} : {:.3f}").format( 912 | val_batch, 913 | val_sample_num, 914 | val_loss, 915 | metrics_name, 916 | val_met 917 | ) 918 | ) 919 | monitor.updateEpoch(avg_loss, avg_val_loss, avg_met, avg_val_met) 920 | is_best_epoch = manager_no_gpu.updateHistory( 921 | epoch, lr_scheduler(epoch+1), avg_loss, avg_met, avg_val_loss, avg_val_met 922 | ) 923 | # 发送信号保存训练结果 924 | # 只有在本地无GPU manager模型中才需要让worker保存模型 925 | # 因此模型文件名只保存在manager中,需要一并发送给worker 926 | message = { 927 | 'flag' : 'save_model', 928 | 'ranks_to_save' : [1], # 指定哪几个进程来保存模型 929 | 'model_filename' : manager_no_gpu.model_filename, 930 | 'is_best' : False 931 | } 932 | if is_best_epoch: 933 | message['best_model_fname'] = manager_no_gpu.best_model_fname 934 | message['is_best'] = True 935 | manager_no_gpu.sendMessage(message) 936 | # epoch 937 | # epoch for loop end 938 | manager_no_gpu.printToLog("sending quit") 939 | message = { 940 | 'flag': 'quit' 941 | } 942 | manager_no_gpu.sendMessage(message) 943 | monitor.close() 944 | 945 | --------------------------------------------------------------------------------