├── .gitignore ├── README.en.md ├── README.md ├── configs.py ├── demo.py ├── demo_dataset.py ├── demo_start_workers.py ├── install.sh ├── mec ├── __init__.py ├── comms │ ├── __init__.py │ ├── sync_rpc.py │ └── transmit.py ├── configs │ ├── __init__.py │ └── default_config.py ├── data_manip │ ├── __init__.py │ ├── criterions.py │ ├── data_utils.py │ ├── lr_scheduler.py │ ├── metrics.py │ └── transfroms │ │ ├── batch_transforms.py │ │ ├── data_transforms.py │ │ └── label_transforms.py ├── scoring │ ├── __init__.py │ └── tester.py ├── training │ ├── __init__.py │ ├── async_trainer.py │ ├── basic_trainer.py │ ├── old_sync_trainer.py │ └── sync_trainer.py ├── unit_test │ ├── dist_test.py │ ├── rpc_test.py │ └── train_test.py └── utils │ ├── __init__.py │ ├── history.py │ ├── logs.py │ └── monitor.py ├── requirements.txt ├── setup.py └── uninstall.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # vscode environment 118 | .vscode/ -------------------------------------------------------------------------------- /README.en.md: -------------------------------------------------------------------------------- 1 | # 图像深度学习基础部件集 2 | 3 | #### Description 4 | 本项目的主要功能: 5 | 多卡训练代码封装; 6 | 常用工具性代码集成 7 | 8 | #### Software Architecture 9 | Software architecture description 10 | 11 | #### Installation 12 | 13 | 1. xxxx 14 | 2. xxxx 15 | 3. xxxx 16 | 17 | #### Instructions 18 | 19 | 1. xxxx 20 | 2. xxxx 21 | 3. xxxx 22 | 23 | #### Contribution 24 | 25 | 1. Fork the repository 26 | 2. Create Feat_xxx branch 27 | 3. Commit your code 28 | 4. Create Pull Request 29 | 30 | 31 | #### Gitee Feature 32 | 33 | 1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md 34 | 2. Gitee blog [blog.gitee.com](https://blog.gitee.com) 35 | 3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore) 36 | 4. The most valuable open source project [GVP](https://gitee.com/gvp) 37 | 5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help) 38 | 6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 深度学习多卡训练框架 2 | 3 | #### 介绍 4 | 本项目的主要功能: 5 | 1,多卡训练代码封装; 6 | 2,常用工具性代码集成 7 | 8 | #### 软件架构 9 | 软件架构说明 10 | . 11 | ├── comms # 通信 12 | ├── configs # 参数 13 | ├── data_manip # 数据处理 14 | ├── scoring # 结果分析代码 15 | ├── training # 训练代码封装 16 | └── utils # 各种内部工具 17 | 18 | #### 安装教程 19 | 20 | 1, git clone xxx 21 | 2. python setup build 22 | 3. python setup install 23 | 24 | #### 使用说明 25 | 26 | 参见demo.py 27 | 单机可使用trainAndValLocal 28 | 多机可使用startWorkers和trainAndVal 29 | 30 | #### 参与贡献 31 | 32 | 1. Fork 本仓库 33 | 2. 新建 Feat_xxx 分支 34 | 3. 提交代码 35 | 4. 新建 Pull Request 36 | 37 | 38 | #### 码云特技 39 | 40 | 1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md 41 | 2. 码云官方博客 [blog.gitee.com](https://blog.gitee.com) 42 | 3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解码云上的优秀开源项目 43 | 4. [GVP](https://gitee.com/gvp) 全称是码云最有价值开源项目,是码云综合评定出的优秀开源项目 44 | 5. 码云官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) 45 | 6. 码云封面人物是一档用来展示码云会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) 46 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | from mec.configs.default_config import conf_g 2 | import os 3 | import argparse 4 | 5 | 6 | train = False 7 | test = False 8 | score = False 9 | prod = False 10 | mix = False 11 | deploy = False 12 | continue_training = False 13 | batch_size = 1 14 | learning_rate = 1e-3 15 | epochs = 1 16 | process_num_per_loader = 8 # 每个DataLoader启用的进程数 17 | path = 'results/temp' 18 | history_filename = 'history.json' 19 | model_filename = 'current_model.pth' 20 | best_model_filename = 'best_model.pth' 21 | excel_filename = 'scores.xls' 22 | control_ip = "127.0.0.1" # control IP 23 | basic_port = 12500 # 基本端口,会占用其后几个连续端口 24 | worker_gpu_ids = [0,1,2,3] # worker所使用的gpu编号 [0,1,2,3] 25 | worker_ranks = [0,1,2,3] # worker本身编号 [0,1,2,3] 26 | sync_worker_num = 4 # 总worker数,单机的情况等于上两者的长度 27 | batch_size = 256*sync_worker_num 28 | 29 | # 多机运行时需指定本地使用哪个网卡,否则可能因网络连接速度太慢拖累训练速度 30 | # 单机训练时不需要此参数,默认指定本地地址127.0.0.1 31 | # os.environ['NCCL_SOCKET_IFNAME'] = 'eno2' 32 | # os.environ['NCCL_SOCKET_IFNAME'] = 'eno1np0' 33 | 34 | def parse_configs(): 35 | global train 36 | global test 37 | global score 38 | global prod 39 | global mix 40 | global deploy 41 | global continue_training 42 | global batch_size 43 | global learning_rate 44 | global epochs 45 | global process_num_per_loader 46 | global path 47 | global history_filename 48 | global model_filename 49 | global best_model_filename 50 | global excel_filename 51 | 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('-train', '--train', action='store_true', 54 | help='train model') 55 | parser.add_argument('-test', '--test', action='store_true', 56 | help='evaluate model on test set') 57 | parser.add_argument('-c', '--continue_training', action='store_true', 58 | help='continue training from last point') 59 | parser.add_argument('-score', '--score', action='store_true', 60 | help='calc precision, recall and F1, then write to an excel file') 61 | parser.add_argument('-prod', '--prod', action='store_true', 62 | help='test production per image') 63 | parser.add_argument('-mix', '--mix', action='store_true', 64 | help='output image mix matrix as xlsx file') 65 | parser.add_argument('-d', '--deploy', action='store_true', 66 | help='generate index to wiki_idx json file') 67 | parser.add_argument('-lr', '--learning_rate', type=float, 68 | help='designating statis training rate') 69 | parser.add_argument('-e', '--epochs', type=int, 70 | help='how many epochs to train in this run') 71 | parser.add_argument('-p', '--path', type=str, 72 | help='path to store results') 73 | parser.add_argument('-manager', '--start_manager', action='store_true', 74 | help='train model') 75 | parser.add_argument('-workers', '--start_workers', action='store_true', 76 | help='train model') 77 | 78 | args = parser.parse_args() 79 | 80 | if args.train: 81 | print("training") 82 | global train 83 | train=True 84 | if args.test: 85 | test=True 86 | if args.score: 87 | score=True 88 | if args.prod: 89 | prod=True 90 | if args.mix: 91 | mix=True 92 | if args.deploy: 93 | deploy=True 94 | if args.continue_training: 95 | continue_training=True 96 | if args.learning_rate: 97 | learning_rate = args.learning_rate 98 | if args.epochs: 99 | epochs = args.epochs 100 | if args.path: 101 | path = args.path 102 | history_filename = os.path.join(path, history_filename) 103 | model_filename = os.path.join(path, model_filename) 104 | best_model_filename = os.path.join(path, best_model_filename) 105 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | import json 12 | import argparse 13 | 14 | from PIL import Image 15 | from torch.utils.data import DataLoader, Dataset 16 | from mec.data_manip.metrics import Accuracy 17 | from mec.training.sync_trainer import startWorkers, trainAndVal, trainAndValLocal 18 | 19 | # 演示数据 20 | from demo_dataset import train_set, valid_set 21 | 22 | # 预训练公开模型 23 | from torchvision.models.resnet import resnet50, resnet18 24 | 25 | # 运行参数 26 | from configs import parse_configs 27 | parse_configs() 28 | from configs import * 29 | 30 | #print( [(k,eval(k)) for k in dir()] ) 31 | 32 | 33 | # 多机运行时需指定本地使用哪个网卡,否则可能因网络连接速度太慢拖累训练速度 34 | # 单机训练时不需要此参数,默认指定本地地址127.0.0.1 35 | # os.environ['NCCL_SOCKET_IFNAME'] = 'eno2' 36 | # os.environ['NCCL_SOCKET_IFNAME'] = 'eno1np0' 37 | 38 | 39 | 40 | # ------------------------------------------------------------------------- 41 | 42 | 43 | def main(): 44 | # model 45 | class_to_idx = train_set.class_to_idx 46 | idx_to_class = {class_to_idx[x]: x for x in class_to_idx} 47 | num_classes = len(class_to_idx) 48 | print("classes: ", num_classes) 49 | print(idx_to_class) 50 | 51 | model = resnet50(pretrained=True) 52 | model.fc = nn.Linear(2048, num_classes) 53 | 54 | opt = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.01, nesterov=True) 55 | criterion = torch.nn.CrossEntropyLoss() 56 | metrics = Accuracy() 57 | print(metrics) 58 | lr_scheduler=lambda epoch: learning_rate 59 | 60 | if train: 61 | trainAndValLocal( 62 | model, opt, criterion, metrics, 63 | train_set, valid_set, 64 | batch_size, lr_scheduler, epochs, 65 | process_num_per_loader = process_num_per_loader, 66 | rank_list = worker_ranks, 67 | gpu_id_list = worker_gpu_ids, 68 | control_ip = control_ip, 69 | port = basic_port, 70 | continue_training = continue_training 71 | ) 72 | # startWorkers( 73 | # model, opt, criterion, metrics, 74 | # train_set, valid_set, 75 | # batch_size, sync_worker_num, process_num_per_loader, 76 | # worker_ranks, worker_gpu_ids, 77 | # control_ip=control_ip 78 | # ) 79 | # trainAndVal( 80 | # train_set, valid_set, metrics, 81 | # batch_size, lr_scheduler, 82 | # sync_worker_num=sync_worker_num, 83 | # control_ip=control_ip 84 | # ) 85 | 86 | 87 | 88 | if __name__ == "__main__": 89 | main() -------------------------------------------------------------------------------- /demo_dataset.py: -------------------------------------------------------------------------------- 1 | # demo dataset 2 | 3 | # 测试数据集 4 | 5 | from PIL import Image 6 | from torchvision import transforms 7 | from torchvision.datasets import CIFAR10 8 | 9 | pre_image_size = (256, 256) 10 | image_size = (224, 224) 11 | 12 | data_transform = transforms.Compose([ 13 | transforms.Resize(pre_image_size), 14 | transforms.RandomHorizontalFlip(), 15 | transforms.RandomAffine(degrees=25, translate=(.2, .2) , 16 | scale=(0.8, 1.2), shear=8, 17 | resample=Image.BILINEAR, fillcolor=0), 18 | transforms.RandomCrop(image_size, padding=2, fill=(0,0,0) ), 19 | transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 20 | transforms.ToTensor(), 21 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 22 | #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 23 | ]) 24 | test_transform = transforms.Compose([ 25 | transforms.Resize(image_size), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 28 | #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 29 | ]) 30 | train_set = CIFAR10('downloaded_models', train=True, transform=data_transform, download=True) 31 | valid_set = CIFAR10('downloaded_models', train=False, transform=test_transform, download=True) -------------------------------------------------------------------------------- /demo_start_workers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | import json 12 | import argparse 13 | 14 | from PIL import Image 15 | from torch.utils.data import DataLoader, Dataset 16 | from torchvision import transforms 17 | from mec.data_manip.metrics import Accuracy 18 | from mec.training.sync_trainer import startWorkers, trainAndVal 19 | 20 | # 测试数据集 21 | from torchvision.datasets import CIFAR10 22 | pre_image_size = (34, 34) 23 | image_size = (32, 32) 24 | data_transform = transforms.Compose([ 25 | transforms.Resize(pre_image_size), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.RandomAffine(degrees=25, translate=(.2, .2) , 28 | scale=(0.8, 1.2), shear=8, 29 | resample=Image.BILINEAR, fillcolor=0), 30 | transforms.RandomCrop(image_size, padding=2, fill=(0,0,0) ), 31 | transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 32 | transforms.ToTensor(), 33 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 34 | #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | ]) 36 | test_transform = transforms.Compose([ 37 | transforms.Resize(image_size), 38 | transforms.ToTensor(), 39 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 40 | #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 41 | ]) 42 | train_set = CIFAR10('downloaded_models', train=True, transform=data_transform, download=True) 43 | valid_set = CIFAR10('downloaded_models', train=False, transform=test_transform, download=True) 44 | 45 | # 预训练公开模型 46 | from torchvision.models.resnet import resnet50, resnet18 47 | 48 | from mec.configs.arguments import * 49 | 50 | #print( [(k,eval(k)) for k in dir()] ) 51 | 52 | # 多机运行时需指定本地使用哪个网卡,否则可能因网络连接速度太慢拖累训练速度 53 | # 单机训练时不需要此参数,默认指定本地地址127.0.0.1 54 | os.environ['NCCL_SOCKET_IFNAME'] = 'eno1np0' 55 | 56 | process_num_per_loader = 8 # 每个DataLoader启用的进程数 57 | worker_gpu_ids = [0,1,2,3] # worker所使用的gpu编号 58 | worker_ranks = [4,5,6,7] # worker编号 59 | sync_worker_num = 8 # 总worker数,单机的情况等于上两者的长度 60 | batch_size = 256*sync_worker_num 61 | control_ip = "192.168.1.99" # manager的IP 62 | 63 | 64 | 65 | # ------------------------------------------------------------------------- 66 | 67 | 68 | def main(): 69 | # model 70 | class_to_idx = train_set.class_to_idx 71 | idx_to_class = {class_to_idx[x]: x for x in class_to_idx} 72 | num_classes = len(class_to_idx) 73 | print("classes: ", num_classes) 74 | print(idx_to_class) 75 | 76 | model = resnet50(pretrained=True) 77 | model.fc = nn.Linear(2048, num_classes) 78 | 79 | opt = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.01, nesterov=True) 80 | criterion = torch.nn.CrossEntropyLoss() 81 | metrics = Accuracy() 82 | lr_scheduler=lambda epoch: learning_rate 83 | 84 | startWorkers( 85 | model, opt, criterion, metrics, 86 | train_set, valid_set, 87 | batch_size, sync_worker_num, process_num_per_loader, 88 | worker_ranks, worker_gpu_ids, 89 | control_ip=control_ip 90 | ) 91 | 92 | 93 | 94 | if __name__ == "__main__": 95 | main() -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 setup.py install --record install.record -------------------------------------------------------------------------------- /mec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artintel2017/torch-multi-gpu/94c9171035b33e8f9cb9e816f1d3e5f3c39e5018/mec/__init__.py -------------------------------------------------------------------------------- /mec/comms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artintel2017/torch-multi-gpu/94c9171035b33e8f9cb9e816f1d3e5f3c39e5018/mec/comms/__init__.py -------------------------------------------------------------------------------- /mec/comms/sync_rpc.py: -------------------------------------------------------------------------------- 1 | # process_signalling.py 2 | # 创建:陈硕 3 | # 创建日期:2020.06.20 4 | # 文件描述:进程通信封装,只用于通知,不用于大量传输数据 5 | 6 | 7 | 8 | import zmq 9 | import time 10 | 11 | 12 | # ----------------------------- 同步RPC ----------------------------- 13 | # 多个server对应一个client 14 | 15 | test_signal = 'test' 16 | good_signal = 'good' 17 | check_signal = 'check' 18 | quit_signal = 'quit' 19 | start_signal = 'start' 20 | 21 | 22 | 23 | class SyncRpcBase: 24 | def __init__(self, server_ip, port, logger): 25 | self.printToLog = logger 26 | # 27 | self.context = None 28 | self.publish_addr = "tcp://{}:{}".format(server_ip, str(port) ) 29 | self.publish_socket = None 30 | self.report_addr = "tcp://{}:{}".format(server_ip, str(port+1) ) 31 | self.report_socket = None 32 | self.check_addr = "tcp://{}:{}".format(server_ip, str(port+2) ) 33 | self.check_socket = None 34 | self.logger = logger 35 | self.initSocket() 36 | 37 | def __del__(self): 38 | self.closeSocket() 39 | 40 | 41 | class _Method: 42 | # some magic to bind an RPC method to an RPC server. 43 | # supports "nested" methods (e.g. examples.getStateName) 44 | def __init__(self, name, send, logger=print): 45 | self.__name = name 46 | self.__send = send 47 | self.printToLog = logger 48 | def __getattr__(self, name): 49 | self.printToLog("attribute {:s}".format(name) ) 50 | method = _Method("{:s}.{:s}".format(self.__name, name), self.__send) 51 | self.__setattr__(name, method) 52 | return method 53 | def __call__(self, *args, **kwargs): 54 | return self.__send(self.__name, *args, **kwargs) 55 | 56 | class SyncRpcController(SyncRpcBase): 57 | """ 58 | 同步RPC服务端 59 | """ 60 | def __init__(self, server_ip, port, worker_num, logger=print): 61 | # 62 | self.printToLog = logger 63 | self.printToLog('initiating synchronized rpc controller') 64 | SyncRpcBase.__init__(self, server_ip, port, logger) 65 | self.worker_num = worker_num 66 | self.is_working = False 67 | self.is_looping = False 68 | # check workers 69 | # self check 70 | self.printToLog('waiting for publishing socket') 71 | self._WaitPubSockReady() 72 | self.printToLog('publishing socket ready') 73 | self.printToLog('synchronizing all workers') 74 | self._waitAllWorkersReady() 75 | self.printToLog('confirmed {} workers ready'.format(self.worker_num) ) 76 | 77 | def __del__(self): 78 | self.closeSocket() 79 | 80 | def __getattr__(self, name): 81 | method = _Method(name, self._callMethod, self.printToLog) 82 | self.__setattr__(name, method) 83 | return method 84 | 85 | def initSocket(self): 86 | self.printToLog("initizating socket:") 87 | self.printToLog("publish addr = '{}'".format(self.publish_addr) ) 88 | self.printToLog("report addr = '{}'".format(self.report_addr) ) 89 | self.context = zmq.Context() 90 | # publish socket 91 | self.publish_socket = self.context.socket(zmq.PUB) 92 | self.publish_socket.bind(self.publish_addr) 93 | # report socket 94 | self.report_socket = self.context.socket(zmq.PULL) 95 | self.report_socket.bind(self.report_addr) 96 | # self check socket 97 | self.self_check_sub_socket = self.context.socket(zmq.SUB) 98 | self.self_check_sub_socket.connect(self.publish_addr) 99 | self.self_check_sub_socket.setsockopt(zmq.SUBSCRIBE, b'') 100 | # workers check socket 101 | self.check_socket = self.context.socket(zmq.REQ) 102 | self.check_socket.bind(self.check_addr) 103 | # 104 | self.printToLog("socket initizating complete") 105 | 106 | def closeSocket(self): 107 | #self.printToLog("closing socket ...") 108 | if self.publish_socket != None: 109 | self.publish_socket.unbind(self.publish_addr) 110 | self.publish_socket = None 111 | if self.report_socket != None: 112 | self.report_socket.unbind(self.report_addr) 113 | self.report_socket = None 114 | if self.self_check_sub_socket != None: 115 | self.self_check_sub_socket.disconnect(self.publish_addr) 116 | self.self_check_sub_socket = None 117 | if self.check_socket !=None: 118 | self.check_socket.unbind(self.check_addr) 119 | self.check_socket = None 120 | #self.printToLog("socket closed") 121 | 122 | def _callMethod(self, name, *args, **kwargs): 123 | """ 124 | 调用工作者的方法,并获取返回值 125 | """ 126 | self.printToLog("calling: ", (name, args, kwargs) ) 127 | self._broadcastMessage( (name, args, kwargs) ) 128 | result = self._gatherMessages() 129 | self.printToLog("result: ", result) 130 | return result 131 | 132 | def _broadcastMessage(self, msg): 133 | """ 134 | 将消息广播至所有的工作者 135 | """ 136 | self.printToLog("message sent:", msg) 137 | self.publish_socket.send( repr(msg).encode() ) 138 | 139 | def _recieveSingleMessage(self): 140 | """ 141 | 从工作者收集信息 142 | 一次只收集一个 143 | """ 144 | return eval(self.report_socket.recv().decode()) 145 | 146 | def _gatherMessages(self): 147 | """ 148 | 从所有的工作者汇集消息 149 | 等候直到所有的工作者消息汇总完毕 150 | 返回一个list 151 | """ 152 | result_list = [] 153 | for i in range(self.worker_num): 154 | result_list.append( eval(self.report_socket.recv(0).decode()) ) 155 | self.printToLog(i+1, "results recieved") 156 | return result_list 157 | 158 | def _WaitPubSockReady(self): 159 | while True: 160 | self.publish_socket.send(repr(test_signal).encode()) 161 | try: 162 | result = self.self_check_sub_socket.recv(zmq.NOBLOCK) 163 | if result is not None: 164 | result = eval(result.decode()) 165 | self.printToLog('message: ',result) 166 | if result == test_signal: break 167 | except zmq.ZMQError: 168 | self.printToLog('not ready') 169 | time.sleep(1) 170 | 171 | def _waitAllWorkersReady(self): 172 | self._sendControlSignal(check_signal, sync_check=True) 173 | # if self.is_working: 174 | # self.printToLog('warning! checking workers in working status') 175 | # return 176 | # workers_set = set() 177 | # while True: 178 | # self.printToLog('sending control signal, confirmed worker num: ',len(workers_set)) 179 | # # count workers 180 | # self.check_socket.send(repr(check_signal).encode() ) 181 | # rank = eval(self.check_socket.recv().decode()) 182 | # self.printToLog('worker respond got, rank {}'.format(rank)) 183 | # assert type(rank) is int, 'check respond signal error, should be int' 184 | # assert rank=0, \ 185 | # 'worker respond rank exceeds limit, worker num {}, get {}'.format( 186 | # self.worker_num, rank) 187 | # if rank not in workers_set: 188 | # workers_set.add(rank) 189 | # self.printToLog('counted workers: ', workers_set) 190 | # if len(workers_set)==self.worker_num: # counted all 191 | # break 192 | # else: #rank in workers_set: indicate delay joined worker, count again 193 | # self.printToLog(' [warning] delay joined worker, rank {}, count again'.format(rank) ) 194 | # time.sleep(1) 195 | # workers_set = { rank } 196 | 197 | def _sendControlSignal(self, signal, sync_check=False): 198 | if self.is_working: 199 | self.printToLog('warning! sending control signal in working status') 200 | return 201 | workers_set = set() 202 | while len(workers_set)=0, \ 209 | 'worker respond rank exceeds limit, worker num {}, get {}'.format( 210 | self.worker_num, rank) 211 | if rank not in workers_set: 212 | workers_set.add(rank) 213 | self.printToLog('counted workers: ', workers_set) 214 | if len(workers_set)==self.worker_num: # counted all 215 | break 216 | elif sync_check: # not synchronized: count again 217 | self.printToLog('==== delay joined worker, rank {}, count again'.format(rank) ) 218 | time.sleep(1) 219 | workers_set = { rank } 220 | else: 221 | raise Exception('unhandled asynchronized signal, probably indicates disorder of workers, \ 222 | try start all workers before start controllers') 223 | 224 | 225 | def startWorking(self): 226 | if not self.is_working: 227 | self.printToLog('calling all workers to start') 228 | self._sendControlSignal(start_signal) 229 | self.is_working = True 230 | self.is_looping = True 231 | 232 | def stopWorking(self): 233 | self._callMethod('stopWorking') 234 | self.is_working = False 235 | 236 | def stopLooping(self): 237 | if self.is_working: 238 | self._broadcastMessage(quit_signal) 239 | #self._callMethod('stopLooping') 240 | else: 241 | self._sendControlSignal(quit_signal) 242 | self.is_working = False 243 | self.is_looping = False 244 | 245 | class SyncRpcWorker(SyncRpcBase): 246 | """ 247 | 同步RPC客户端 248 | """ 249 | def __init__(self, server_ip, port, rank, logger= 250 | print): 251 | SyncRpcBase.__init__(self, server_ip, port, logger) 252 | self.printToLog = logger 253 | self.rank = rank 254 | self.function_dict = {} 255 | self.is_working = False 256 | self.is_looping = False 257 | self.registerMethod(self.stopWorking) 258 | self.registerMethod(self.stopLooping) 259 | #self.registerMethod(self.detectRespond) 260 | 261 | def __del__(self): 262 | self.closeSocket() 263 | 264 | def initSocket(self): 265 | self.printToLog("initilzating socket:") 266 | self.printToLog("publish addr = '{}'".format(self.publish_addr) ) 267 | self.printToLog("report addr = '{}'".format(self.report_addr) ) 268 | self.printToLog("check addr = '{}'".format(self.check_addr) ) 269 | self.context = zmq.Context() 270 | # publish socket 271 | self.publish_socket = self.context.socket(zmq.SUB) 272 | self.publish_socket.connect(self.publish_addr) 273 | self.publish_socket.setsockopt(zmq.SUBSCRIBE, b'') 274 | # report socket 275 | self.report_socket = self.context.socket(zmq.PUSH) 276 | self.report_socket.connect(self.report_addr) 277 | # workers check socket 278 | self.check_socket = self.context.socket(zmq.REP) 279 | self.check_socket.connect(self.check_addr) 280 | 281 | def closeSocket(self): 282 | self.printToLog('closing socket, rank {}'.format(self.rank) ) 283 | if self.publish_socket != None: 284 | self.publish_socket.disconnect(self.publish_addr) 285 | self.publish_socket = None 286 | if self.report_socket != None: 287 | self.report_socket.disconnect(self.report_addr) 288 | self.report_socket = None 289 | if self.check_socket !=None: 290 | self.check_socket.disconnect(self.check_addr) 291 | self.check_socket = None 292 | 293 | def recieveBroadcast(self): 294 | return eval(self.publish_socket.recv().decode()) 295 | 296 | def reportMessage(self, msg): 297 | """ 298 | 将消息发送至控制者 299 | """ 300 | return self.report_socket.send( repr(msg).encode() ) 301 | 302 | def registerMethod(self, function, name=None): 303 | if name is None: 304 | name = function.__name__ 305 | self.function_dict[name] = function 306 | 307 | def excecuteMethod(self, func_name, args, kwargs): 308 | if func_name in self.function_dict: 309 | self.printToLog("excecuting function: [func name] {}; [args] {}; [kwargs] {}".format( 310 | func_name, args, kwargs ) 311 | ) 312 | result = self.function_dict[func_name](*args, **kwargs) 313 | else: 314 | self.printToLog("warning: wrong function name. [func name] {}; [args] {}; [kwargs] {}".format( 315 | func_name, args, kwargs 316 | )) 317 | result = None 318 | self.printToLog('result: ', result) 319 | return result 320 | 321 | # def detectRespond(self): 322 | # return good_signal 323 | 324 | def waitControlSignal(self): 325 | self.printToLog('waiting for control signal, rank {}'.format(self.rank) ) 326 | signal = eval(self.check_socket.recv().decode()) 327 | self.printToLog('controller signal recieved: \"{}\"'.format(signal)) 328 | self.printToLog('respond control signal, rank {}'.format(self.rank) ) 329 | #time.sleep(0.2) 330 | self.check_socket.send( repr(self.rank).encode() ) 331 | self.printToLog('respond sent: {}'.format(self.rank) ) 332 | return signal 333 | 334 | def mainLoop(self): 335 | #self.is_working = False 336 | self.is_looping = True 337 | while self.is_looping: 338 | signal = self.waitControlSignal() 339 | if signal == quit_signal: 340 | self.is_looping = False 341 | time.sleep(3) 342 | break 343 | if signal == check_signal: 344 | continue 345 | time.sleep(1) 346 | elif signal==start_signal: 347 | self.printToLog('start working loop') 348 | self.is_working = True 349 | while self.is_working: 350 | self.printToLog("waiting for task message") 351 | msg = self.recieveBroadcast() 352 | self.printToLog("message recieved: \"{}\"".format(msg) ) 353 | if msg==test_signal: continue 354 | elif msg==quit_signal: 355 | self.is_looping = False 356 | break 357 | func_name, func_args, func_kwargs = msg 358 | result = self.excecuteMethod(func_name, func_args, func_kwargs) 359 | self.reportMessage(result) 360 | #time.sleep(3) 361 | self.printToLog('exiting') 362 | #self.closeSocket() 363 | 364 | def stopWorking(self): 365 | self.is_working = False 366 | return good_signal 367 | 368 | def stopLooping(self): 369 | self.is_looping = False 370 | self.is_working = False -------------------------------------------------------------------------------- /mec/comms/transmit.py: -------------------------------------------------------------------------------- 1 | # data_trasmitting.py 2 | # 创建:陈硕 3 | # 创建日期:2020.06.20 4 | # 文件描述:封装跨卡、跨机模型传递功能 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | 11 | from torch.distributed.distributed_c10d import _get_default_group 12 | 13 | 14 | class DistEnv: 15 | """ 16 | torch.distributed所使用的分布式环境 17 | """ 18 | def __init__(self, rank, world_size, control_ip, dist_port, logger=print): 19 | self.printToLog = logger 20 | self.rank = rank 21 | self.world_size = world_size 22 | self.control_ip = control_ip 23 | self.dist_port = dist_port 24 | self._initTorchDist() 25 | 26 | def _initTorchDist(self): 27 | self.printToLog("dist args:", 'nccl', self.control_ip, 28 | self.rank, self.world_size) 29 | os.environ['MASTER_ADDR'] = self.control_ip 30 | os.environ['MASTER_PORT'] = self.dist_port 31 | dist.init_process_group( 32 | backend='nccl', 33 | rank=self.rank, 34 | world_size=self.world_size) 35 | self.printToLog("torch distributed environment initiated successfully") 36 | #self.worker_group = dist.new_group(list(range(1,self.world_size)) ) 37 | 38 | def newGroup(self, rank_list): 39 | """ 40 | 按照rank_list建立一个新的组 41 | """ 42 | return dist.new_group(rank_list) 43 | 44 | # 用于cube all-reduce的分配函数 45 | def cube_correspond(n, turn): 46 | return (1< lr 7 | 8 | import math 9 | import numpy as np 10 | 11 | class CosineLR(): 12 | def __init__(self, warm_epoch=0, lr=0.001, period=30, only_decrease=True): 13 | """学习率余弦更新策略,通过参数可设置warmup的代数、最大学习率、余弦更新周期、是否去除上升段。 14 | Args: 15 | warm_epoch (int): warmup的代数,epoch在此区间内线性上升至最大学习率。Default: 0. 16 | lr (float): 最大学习率,余弦函数最高点的纵坐标。Default: 0.001. 17 | period (int): 余弦更新的周期,指余弦函数从最高点到最低点需要的代数。Default: 30. 18 | only_decrease (bool): 若为``True``则仅保留余弦下降段,若为``False``则使用完整的余弦函数。Default: ``True``. 19 | 20 | Example: 21 | >>> scheduler = CosineLR(warm_epoch=5, lr=0.001, period=30, only_decrease=True) 22 | >>> for epoch in range(180): 23 | >>> lr = scheduler(epoch) 24 | >>> train(...) 25 | """ 26 | self.warm_epoch = warm_epoch 27 | self.lr = lr 28 | self.period = period 29 | self.only_decrease = only_decrease 30 | 31 | def __call__(self, epoch): 32 | """根据输入的epoch返回当前代的学习率""" 33 | if epoch < self.warm_epoch: 34 | return (epoch % self.warm_epoch) / self.warm_epoch * self.lr 35 | elif self.only_decrease is True: 36 | return (1 + math.cos(math.pi * ((epoch - self.warm_epoch) % self.period) / self.period)) * self.lr / 2 37 | else: 38 | return (1 + math.cos(math.pi * (epoch - self.warm_epoch) / self.period)) * self.lr / 2 39 | 40 | class ExponentialLR(): 41 | def __init__(self, warm_epoch=0, lr=0.001, rate=0.9): 42 | """学习率指数下降更新策略,通过参数可设置warmup的代数、最大学习率、指数下降速率。指数下降公式为:lr*(rate^epoch)。 43 | Args: 44 | warm_epoch (int): warmup的代数,epoch在此区间内线性上升至最大学习率。Default: 0. 45 | lr (float): 最大学习率,指数函数下降起点的纵坐标。Default: 0.001. 46 | rate (float): 指数下降速率,即指数函数的底数。Default: 0.9 47 | 48 | Example: 49 | >>> scheduler = ExponentialLR(warm_epoch=5, lr=0.001, rate=0.9) 50 | >>> for epoch in range(180): 51 | >>> lr = scheduler(epoch) 52 | >>> train(...) 53 | """ 54 | self.warm_epoch = warm_epoch 55 | self.lr = lr 56 | self.rate = rate 57 | 58 | def __call__(self, epoch): 59 | """根据输入的epoch返回当前代的学习率""" 60 | if epoch < self.warm_epoch: 61 | return (epoch % self.warm_epoch) / self.warm_epoch * self.lr 62 | else: 63 | return self.lr * np.power(self.rate , (epoch - self.warm_epoch)) 64 | -------------------------------------------------------------------------------- /mec/data_manip/metrics.py: -------------------------------------------------------------------------------- 1 | #import torch 2 | 3 | 4 | 5 | # metrics base “评价标准” 类 6 | # 定义了基本的评价标准接口 7 | # 所有的评价标准必须继承此类 8 | 9 | class MetricBase(): 10 | def __init__(self): 11 | pass 12 | 13 | def __call__(self, batch_output, batch_target): 14 | self.addData(batch_output, batch_target) 15 | return self.getBatchScore(), self.getEpochScore() 16 | 17 | def __str__(self): 18 | return self.name() 19 | 20 | # 可重载 21 | # 定义评价标准的名字 22 | def name(self): 23 | return 'met' 24 | 25 | # 应重载 26 | # 每个batch输入数据 27 | def addData(self, batch_output, batch_target): 28 | pass 29 | 30 | # 应重载 31 | # 每epoch初始重置数据 32 | def init(self): 33 | pass 34 | 35 | # 应重载 36 | # 计算每个batch的分数 37 | def getBatchScore(self): 38 | pass 39 | 40 | # 应重载 41 | # 计算每个epoch的分数 42 | def getEpochScore(self): 43 | pass 44 | 45 | # 评价标准:accuracy 准确率 46 | # 用于单标签分类 47 | class Accuracy(MetricBase): 48 | def __init__(self): 49 | MetricBase.__init__(self) 50 | self.totalSamples = 0 51 | self.totalHits = 0 52 | self.batchScore = 0. 53 | self.epochScore = 0. 54 | 55 | def name(self): 56 | return 'acc' 57 | 58 | def init(self): 59 | self.totalSamples = 0 60 | self.totalHits = 0 61 | pass 62 | 63 | # 应重载 64 | # 每个batch输入数据 65 | def addData(self, batch_output, batch_target): 66 | currentCount = len(batch_output) 67 | if batch_target.ndimension()>1: 68 | batch_target = batch_target.max(dim=-1)[1] 69 | hits = (batch_output.max(dim=-1)[1] == batch_target).sum().item() 70 | #hits = torch.sum(batch_output.max(dim=-1)[1] == batch_target).item() 71 | self.batchScore = hits / currentCount 72 | self.totalHits += hits 73 | self.totalSamples += currentCount 74 | self.epochScore = self.totalHits / self.totalSamples 75 | 76 | # 计算每个batch的分数 77 | def getBatchScore(self): 78 | return self.batchScore 79 | 80 | # 应重载 81 | # 计算每个epoch的分数 82 | def getEpochScore(self): 83 | return self.epochScore 84 | 85 | # 评价标准:accuracy 准确率 86 | # 用于单标签分类、且标签为one hot表示时的情况 87 | # 可适用平滑和混淆 88 | class AccuracyOneHot(MetricBase): 89 | pass 90 | 91 | 92 | # # 评价标准:mean average precision 93 | # # 适用于多标签分类 94 | # # 每个batch单独输入数据 95 | # # 每个epoch先清除缓存 96 | # class MeanAveragePrecision(MetricBase): 97 | # def __init__(self, tag_dimension=1, total_sample_num=1, device=None): #torch.device('cpu') ): 98 | # MetricBase.__init__(self) 99 | # self.tagDimension = tag_dimension 100 | # self.totalSampleNum = total_sample_num 101 | # self.currentCount = 0 102 | # self.batchSize = 1 103 | # self.outputs = torch.zeros( (total_sample_num, tag_dimension) ).to(device) 104 | # self.targets = torch.zeros( (total_sample_num, tag_dimension) ).to(device) 105 | # self.tempPrecisionArray = torch.zeros(total_sample_num).to(device) 106 | # self.tempRecallArray = torch.zeros(total_sample_num).to(device) 107 | 108 | # def name(self): 109 | # return 'map' 110 | 111 | # def init(self): 112 | # self.batchScore = 0.0 113 | # self.epochScore = 0.0 114 | # self.currentCount = 0 115 | # pass 116 | 117 | # def addData(self, batch_output, batch_target): 118 | # assert batch_output.ndimension()==2, 'output must be 2 dimension: batch/feature' 119 | # assert batch_target.ndimension()==2, 'target must be 2 dimension: batch/feature' 120 | # #print(self.outputs.size()) 121 | # #print(batch_output.size()) 122 | # self.batchSize = len(batch_output) 123 | # self.outputs[self.currentCount: self.currentCount+self.batchSize] = torch.sigmoid(batch_output) 124 | # self.targets[self.currentCount: self.currentCount+self.batchSize] = batch_target 125 | # self.currentCount += self.batchSize 126 | 127 | # def getBatchScore(self): 128 | # return 0.0 129 | # #return self.batchScore 130 | 131 | # def getAveragePrecision(self, index): 132 | # outputs, indices = self.outputs[0:self.currentCount, index].sort(descending=True) 133 | # targets = self.targets[0:self.currentCount, index][indices] 134 | # none_zero_indices = outputs!=0 135 | # outputs = outputs[none_zero_indices] 136 | # targets = targets[none_zero_indices] 137 | # total_positive = torch.sum(targets>0) 138 | # # 按threshold从到到低的顺序分别计算precision和recall 139 | # for i in range(self.currentCount): 140 | # threshold = outputs[i] 141 | # true_positive = torch.sum(targets[0:i]>0) 142 | # predicted_positive = torch.sum(outputs>threshold) 143 | # precision = true_positive / (predicted_positive.type(torch.double) + 1e-6) 144 | # recall = true_positive / (total_positive.type(torch.double) + 1e-6) 145 | # # print("\n------ ") 146 | # # print("true pos: {}, pred pos: {}, precision: {}".format(true_positive, predicted_positive, precision) ) 147 | # # print("true pos: {}, total pos: {}, recall: {}".format(true_positive, total_positive, recall) ) 148 | # self.tempPrecisionArray[i] = precision 149 | # self.tempRecallArray[i] = recall 150 | # # print(self.tempRecallArray[0:self.currentCount]) 151 | # # print(self.tempPrecisionArray[0:self.currentCount]) 152 | # # 按recall从高到低排序,方便后面的运算 153 | # recallArray, recallIndices = self.tempRecallArray[0:self.currentCount].sort(descending=True) 154 | # precisionArray = self.tempPrecisionArray[recallIndices] 155 | # # 11-点-AP 156 | # total_precision = 0 157 | # current_precision = 0 158 | # #for index in range(11): 159 | # recall_index = 10 160 | # required_recall = index*0.1 161 | # next_recall = (index-1)*0.1 162 | # maxPrecision = recallArray[0] 163 | # for index, recall in enumerate(recallArray) : # recall 从高到低 164 | # # print("\nindex: ", index) 165 | # while recall < next_recall: # 下一个recall点 166 | # # print("\nrecall; ", recall) 167 | # # print(maxPrecision) 168 | # total_precision += maxPrecision 169 | # recall_index -= 1 170 | # required_recall = recall_index*0.1 171 | # next_recall = (recall_index-1)*0.1 172 | # precision = precisionArray[index] 173 | # if precision > maxPrecision: 174 | # maxPrecision = precision 175 | # total_precision += maxPrecision # recall == 0 时的precision 176 | # # print(total_precision, total_precision/11.0, index) 177 | # return total_precision / 11.0 178 | 179 | # def getEpochScore(self): 180 | # totalAP = 0.0 181 | # for i in range(self.tagDimension): # 对所有输出维度 182 | # totalAP += self.getAveragePrecision(i) 183 | # self.epochScore = totalAP / self.tagDimension 184 | # return self.epochScore.item() 185 | 186 | # # mean F1 score 187 | # # 评价标准:平均F1分数 188 | # # 适用于多标签分类 189 | # class meanF1Score(MetricBase): 190 | # def __init__(self, tag_dimension=1, threshold=0.0, device=torch.device('cuda')): 191 | # MetricBase.__init__(self) 192 | # self.device = device 193 | # self.batchScore = 0.0 194 | # self.epochScore = 0.0 195 | # self.threshold = threshold 196 | # self.tp = torch.zeros(tag_dimension).to(device) 197 | # self.tn = torch.zeros(tag_dimension).to(device) 198 | # self.fp = torch.zeros(tag_dimension).to(device) 199 | # self.fn = torch.zeros(tag_dimension).to(device) 200 | 201 | # # 可重载 202 | # # 定义评价标准的名字 203 | # def getMetricName(self): 204 | # return 'mF1' 205 | 206 | # # 重载 207 | # # 每个batch输入数据 208 | # def addData(self, batch_output, batch_target): 209 | # guess_pos = batch_output > self.threshold # 判断为真 210 | # target_pos = batch_target > 0 # 实际为真 211 | # guess_neg = batch_output < self.threshold # 判断为假 212 | # target_neg = batch_target < 0 # 实际为假 213 | # # 以下为向量,维度为tag dimension 214 | # tp = torch.sum(guess_pos*target_pos, dim=0).type(torch.float) # 猜真实真,为tp 215 | # fp = torch.sum(guess_pos*target_neg, dim=0).type(torch.float) + 1e-6 # 猜真实假,为fp 216 | # fn = torch.sum(guess_neg*target_pos, dim=0).type(torch.float) + 1e-6 # 猜假实真,为fp 217 | # #tn = torch.sum(guess_neg*target_neg, dim=0) # 猜假实假,为tn 218 | # # 计算每个batch的分数 219 | # self.batchScore = tp *2 / (tp*2 + fp + fn) 220 | # # 计算每个epoch的分数 221 | # self.tp += tp 222 | # self.fp += fp 223 | # self.fn += fn 224 | # # print("---- shape: ----\n", self.tp.size(), tp.size() ) 225 | # # print("---- shape: ----\n", self.fp.size(), fp.size() ) 226 | # # print("---- shape: ----\n", self.fn.size(), fn.size() ) 227 | # self.epochScore = self.tp * 2 / (self.tp*2 + self.fp + self.fn) 228 | 229 | # # 重载 230 | # # 每epoch初始重置数据 231 | # def init(self): 232 | # self.tp.zero_() 233 | # self.tn.zero_() 234 | # self.fp.zero_() 235 | # self.fn.zero_() 236 | # self.tp += 1e-6 237 | # self.tn += 1e-6 238 | # self.fp += 1e-6 239 | # self.fn += 1e-6 240 | 241 | # # 重载 242 | # # 计算每个batch的分数 243 | # def getBatchScore(self): 244 | # return torch.mean(self.batchScore).item() 245 | 246 | # # 重载 247 | # # 计算每个epoch的分数 248 | # def getEpochScore(self): 249 | # return torch.mean(self.epochScore).item() 250 | 251 | # def getEpochScore_(self): 252 | # return self.epochScore -------------------------------------------------------------------------------- /mec/data_manip/transfroms/batch_transforms.py: -------------------------------------------------------------------------------- 1 | # batch_transfroms.py 2 | # created: CS 3 | # 针对批次数据的整体变换 4 | # 因批次变换涉及到批次内多个样本的数据和标签的次序 5 | # 因此无法拆分为针对data的和针对label的 6 | 7 | def mix_up(batch_data, batch_indexs): 8 | pass -------------------------------------------------------------------------------- /mec/data_manip/transfroms/data_transforms.py: -------------------------------------------------------------------------------- 1 | # data_transfroms.py 2 | # created: CS 3 | # 各种数据变换 4 | # 针对单个样本 5 | 6 | ''' 7 | ElasticTransform: 8 | 弹性变换,根据扭曲场的平滑度与强度逐一地移动局部像素点实现模糊效果。 9 | 依赖: albumentations 10 | 11 | 参数: 12 | alpha (float): 13 | sigma (float): Gaussian filter parameter. 14 | alpha_affine (float): The range will be (-alpha_affine, alpha_affine) 15 | interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: 16 | cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. 17 | Default: cv2.INTER_LINEAR. 18 | border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of: 19 | cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. 20 | Default: cv2.BORDER_REFLECT_101 21 | value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. 22 | mask_value (int, float, 23 | list of ints, 24 | list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. 25 | approximate (boolean): Whether to smooth displacement map with fixed kernel size. 26 | Enabling this option gives ~2X speedup on large images. 27 | 输入: 28 | numpy数组 29 | 返回: 30 | 字典,形式为 {'image': array(...)} 31 | 示例: 32 | >>> import albumentations as A 33 | >>> import numpy as np 34 | >>> from PIL import Image 35 | >>> img = Image.open('img.jpg') 36 | >>> np_img = np.asarray(img) 37 | >>> t = A.ElasticTransform() 38 | >>> img = Image.fromarray(t(image = np_img)['image']) 39 | ''' 40 | 41 | ''' 42 | HueSaturationValue: 43 | HSV对比度变换,通过向HSV空间中的每个像素添加或减少V值,修改色调和饱和度实现对比度转换。 44 | 依赖: albumentations 45 | 46 | 参数: 47 | hue_shift_limit ((int, int) or int): range for changing hue. If hue_shift_limit is a single int, the range 48 | will be (-hue_shift_limit, hue_shift_limit). Default: (-20, 20). 49 | sat_shift_limit ((int, int) or int): range for changing saturation. If sat_shift_limit is a single int, 50 | the range will be (-sat_shift_limit, sat_shift_limit). Default: (-30, 30). 51 | val_shift_limit ((int, int) or int): range for changing value. If val_shift_limit is a single int, the range 52 | will be (-val_shift_limit, val_shift_limit). Default: (-20, 20). 53 | p (float): probability of applying the transform. Default: 0.5. 54 | 输入: 55 | numpy数组 56 | 返回: 57 | 字典,形式为 {'image': array(...)} 58 | 示例: 59 | >>> import albumentations as A 60 | >>> import numpy as np 61 | >>> from PIL import Image 62 | >>> img = Image.open('img.jpg') 63 | >>> np_img = np.asarray(img) 64 | >>> t = A.HueSaturationValue() 65 | >>> img = Image.fromarray(t(image = np_img)['image']) 66 | ''' 67 | 68 | ''' 69 | IAASuperpixels: 70 | 超像素法,在最大分辨率处生成图像的若干个超像素,并将其调整到原始大小,再将原始图像中所有超像素区域按一定比例替换为超像素,其他区域不改变。 71 | 依赖: albumentations 72 | 注意: 该方法可能速度较慢。 73 | 参数: 74 | p_replace (float): defines the probability of any superpixel area being replaced by the superpixel, i.e. by 75 | the average pixel color within its area. Default: 0.1. 76 | n_segments (int): target number of superpixels to generate. Default: 100. 77 | p (float): probability of applying the transform. Default: 0.5. 78 | 输入: 79 | numpy数组 80 | 返回: 81 | 字典,形式为 {'image': array(...)} 82 | 示例: 83 | >>> import albumentations as A 84 | >>> import numpy as np 85 | >>> from PIL import Image 86 | >>> img = Image.open('img.jpg') 87 | >>> np_img = np.asarray(img) 88 | >>> t = A.IAASuperpixels() 89 | >>> img = Image.fromarray(t(image = np_img)['image']) 90 | ''' 91 | 92 | ''' 93 | IAAPerspective: 94 | 随机四点透视变换 95 | 依赖: albumentations 96 | 参数: 97 | scale ((float, float): standard deviation of the normal distributions. These are used to sample 98 | the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). 99 | p (float): probability of applying the transform. Default: 0.5. 100 | 输入: 101 | numpy数组 102 | 返回: 103 | 字典,形式为 {'image': array(...)} 104 | 示例: 105 | >>> import albumentations as A 106 | >>> import numpy as np 107 | >>> from PIL import Image 108 | >>> img = Image.open('img.jpg') 109 | >>> np_img = np.asarray(img) 110 | >>> t = A.IAAPerspective() 111 | >>> img = Image.fromarray(t(image = np_img)['image']) 112 | ''' 113 | 114 | ''' 115 | CoarseDropout: 116 | 在面积大小可选定、位置随机的矩形区域上丢失信息实现转换,产生黑色矩形块。 117 | 依赖: albumentations 118 | 参数: 119 | max_holes (int): Maximum number of regions to zero out. 120 | max_height (int): Maximum height of the hole. 121 | max_width (int): Maximum width of the hole. 122 | min_holes (int): Minimum number of regions to zero out. If `None`, 123 | `min_holes` is be set to `max_holes`. Default: `None`. 124 | min_height (int): Minimum height of the hole. Default: None. If `None`, 125 | `min_height` is set to `max_height`. Default: `None`. 126 | min_width (int): Minimum width of the hole. If `None`, `min_height` is 127 | set to `max_width`. Default: `None`. 128 | fill_value (int, float, lisf of int, list of float): value for dropped pixels. 129 | 输入: 130 | numpy数组 131 | 返回: 132 | 字典,形式为 {'image': array(...)} 133 | 示例: 134 | >>> import albumentations as A 135 | >>> import numpy as np 136 | >>> from PIL import Image 137 | >>> img = Image.open('img.jpg') 138 | >>> np_img = np.asarray(img) 139 | >>> t = A.CoarseDropout() 140 | >>> img = Image.fromarray(t(image = np_img)['image']) 141 | ''' 142 | 143 | ''' 144 | EdgeDetect: 145 | 边界检测,检测图像中的所有边缘,将它们标记为黑白图像,再将结果与原始图像叠加。 146 | 依赖: imgaug 147 | 参数: 148 | alpha : number or tuple of number or list of number or imgaug.parameters.StochasticParameter, optional 149 | Blending factor of the edge image. At ``0.0``, only the original 150 | image is visible, at ``1.0`` only the edge image is visible. 151 | 152 | * If a number, exactly that value will always be used. 153 | * If a tuple ``(a, b)``, a random value will be sampled from the 154 | interval ``[a, b]`` per image. 155 | * If a list, a random value will be sampled from that list 156 | per image. 157 | * If a ``StochasticParameter``, a value will be sampled from that 158 | parameter per image. 159 | 160 | seed : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional 161 | See :func:`~imgaug.augmenters.meta.Augmenter.__init__`. 162 | 163 | name : None or str, optional 164 | See :func:`~imgaug.augmenters.meta.Augmenter.__init__`. 165 | 166 | random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional 167 | Old name for parameter `seed`. 168 | Its usage will not yet cause a deprecation warning, 169 | but it is still recommended to use `seed` now. 170 | Outdated since 0.4.0. 171 | 172 | deterministic : bool, optional 173 | Deprecated since 0.4.0. 174 | See method ``to_deterministic()`` for an alternative and for 175 | details about what the "deterministic mode" actually does. 176 | 输入: 177 | numpy数组 178 | 返回: 179 | numpy数组 180 | 示例: 181 | >>> import imgaug.augmenters as iaa 182 | >>> import numpy as np 183 | >>> from PIL import Image 184 | >>> img = Image.open('img.jpg') 185 | >>> np_img = np.asarray(img) 186 | >>> t = iaa.EdgeDetect(alpha=(0.0, 1.0)) 187 | >>> img = Image.fromarray(t(image = np_img)) 188 | ''' 189 | 190 | ''' 191 | RandomAffine: 192 | 仿射变换,对图像进行旋转、水平偏移、裁剪、放缩等操作,保持中心不变。 193 | 依赖: torchvision 194 | 参数: 195 | degrees (sequence or float or int): Range of degrees to select from. 196 | If degrees is a number instead of sequence like (min, max), the range of degrees 197 | will be (-degrees, +degrees). Set to 0 to deactivate rotations. 198 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal 199 | and vertical translations. For example translate=(a, b), then horizontal shift 200 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is 201 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. 202 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is 203 | randomly sampled from the range a <= scale <= b. Will keep original scale by default. 204 | shear (sequence or float or int, optional): Range of degrees to select from. 205 | If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) 206 | will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the 207 | range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, 208 | a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. 209 | Will not apply shear by default 210 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 211 | An optional resampling filter. See `filters`_ for more information. 212 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 213 | fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area 214 | outside the transform in the output image.(Pillow>=5.0.0) 215 | 输入: 216 | PIL图片 217 | 返回: 218 | PIL图片 219 | 示例: 220 | >>> from torchvision import transforms as T 221 | >>> from PIL import Image 222 | >>> img = Image.open('img.jpg') 223 | >>> t = T.RandomAffine(25, translate=(0.2,0.2), scale=(0.8,1.2), shear=8, resample=Image.BILINEAR) 224 | >>> img = t(img) 225 | ''' 226 | 227 | ''' 228 | randomGausNoise: 229 | 高斯模糊。 230 | 实现: 231 | def randomGausNoise(image): 232 | dice = random() 233 | if dice<0.5: 234 | return image.filter(ImageFilter.GaussianBlur(radius=random()*1.7+0.5) ) 235 | else: 236 | return image 237 | 输入: 238 | PIL图片 239 | 返回: 240 | PIL图片 241 | 示例: 242 | >>> from torchvision import transforms as T 243 | >>> from random import random 244 | >>> from PIL import Image, ImageFilter 245 | >>> img = Image.open('img.jpg') 246 | >>> t = T.Lambda(randomGausNoise) 247 | >>> img = t(img) 248 | ''' 249 | 250 | ''' 251 | cropAndPadImage: 252 | 裁切图像并用黑色像素补齐图像至目标大小。 253 | 实现: 254 | target_size = (224, 224) 255 | background_color = (0,0,0) 256 | def cropAndPadImage(img): 257 | w, h = img.size 258 | if w==h: 259 | if w>target_size[0]: 260 | return img.resize(target_size) 261 | else: 262 | return img 263 | if w>h: 264 | x0 = int( (w-h)/4 ) 265 | x1 = w - x0 266 | y0 = 0 267 | y1 = h 268 | padding_length = x1-x0 269 | padding_size = (padding_length, padding_length) 270 | pad_x0 = 0 271 | pad_x1 = padding_length 272 | pad_y0 = int( (w-h)/4 ) 273 | pad_y1 = pad_y0 + h 274 | else : 275 | x0 = 0 276 | x1 = w 277 | y0 = int( (h-w)/4 ) 278 | y1 = h - y0 279 | padding_length = y1-y0 280 | padding_size = (padding_length, padding_length) 281 | pad_x0 = int( (h-w)/4 ) 282 | pad_x1 = pad_x0 + w 283 | pad_y0 = 0 284 | pad_y1 = padding_length 285 | cropped_img = img.crop( (x0,y0, x1,y1) ) 286 | padded_img = Image.new('RGB', padding_size, background_color) 287 | #print(img.size, padding_size, cropped_img.size, (pad_x0, pad_y0, pad_x1, pad_y1) ) 288 | padded_img.paste(cropped_img, (pad_x0, pad_y0, pad_x1, pad_y1) ) 289 | resized_img = padded_img.resize(target_size) 290 | return resized_img 291 | 输入: 292 | PIL图片 293 | 返回: 294 | PIL图片 295 | 示例: 296 | >>> from torchvision import transforms as T 297 | >>> from random import random 298 | >>> from PIL import Image 299 | >>> img = Image.open('img.jpg') 300 | >>> t = T.Lambda(cropAndPadImage) 301 | >>> img = t(img) 302 | ''' -------------------------------------------------------------------------------- /mec/data_manip/transfroms/label_transforms.py: -------------------------------------------------------------------------------- 1 | # label_transfroms.py 2 | # created: CS 3 | # 各种标签变换 4 | # 针对单个样本 5 | 6 | 7 | 8 | # 9 | def to_onehot(indexed_label): 10 | pass 11 | 12 | # 13 | def label_smoothing(onehot_label): 14 | pass 15 | 16 | # 17 | -------------------------------------------------------------------------------- /mec/scoring/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artintel2017/torch-multi-gpu/94c9171035b33e8f9cb9e816f1d3e5f3c39e5018/mec/scoring/__init__.py -------------------------------------------------------------------------------- /mec/scoring/tester.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import time 6 | from tqdm import tqdm 7 | from PIL import Image 8 | 9 | # ----- ----- ----- ----- 混淆矩阵测试 ----- ----- ----- ----- 10 | 11 | def test_mix_sync( 12 | model, 13 | test_loader, 14 | idx_to_class_list, # 类别列表,有序排列 15 | output_file_path, 16 | device=torch.device('cpu')): 17 | model = model.to(device) 18 | model.eval() 19 | num_classes = len(idx_to_class_list) 20 | mix_mat_count_tensor = torch.zeros((num_classes, num_classes+1), dtype=torch.long).to(device) 21 | with torch.no_grad(): 22 | for test_data_batch, test_target_index_batch in tqdm(test_loader, ncols=80): 23 | test_data_batch = test_data_batch.to(device) 24 | test_output_batch = model(test_data_batch) 25 | test_output_index_batch = test_output_batch.max(dim=-1)[1] 26 | for test_target_index, test_output_index in zip(test_target_index_batch, test_output_index_batch): 27 | mix_mat_count_tensor[test_target_index][test_output_index] += 1 # 计数一次判断 28 | mix_mat_count_tensor[test_target_index][num_classes] += 1 # 总数技术 29 | mix_mat_count_list = mix_mat_count_tensor.tolist() 30 | counts_dict = { 'name': ['总计'] + idx_to_class_list } 31 | propotion_dict = { 'name': ['总计', '精确率', '召回率', 'F1-score'] + idx_to_class_list} # propotion dict 32 | for i in range(num_classes): 33 | #print(mix_mat_count_list) 34 | counts_column = mix_mat_count_list[i] 35 | counts_column = [counts_column[num_classes]] + counts_column[0:num_classes] 36 | column_total = counts_column[0] 37 | propotion_column = [column_total] 38 | row_total = torch.sum(mix_mat_count_tensor[:,i]).item() 39 | true_positive_count = counts_column[i+1] 40 | real_positive_count = column_total 41 | pred_positive_total = row_total 42 | # precision 43 | precision = true_positive_count/pred_positive_total if pred_positive_total>0 else 0.0 44 | propotion_column.append(precision) 45 | # recall 46 | recall = true_positive_count/real_positive_count if real_positive_count>0 else 0.0 47 | propotion_column.append(recall) 48 | # f1 49 | f1_score = 2*true_positive_count/(real_positive_count+pred_positive_total) if true_positive_count>0 else 0.0 50 | propotion_column.append(f1_score) 51 | for j in range(1, num_classes+1): # 比例 52 | propotion_column.append(counts_column[j]/column_total if column_total>0 else 0.0) 53 | # 正确标签名 54 | target_class_name = idx_to_class_list[i] 55 | # ----- ----- 56 | counts_dict[ target_class_name ] = counts_column 57 | propotion_dict[ target_class_name ] = propotion_column # 总数 58 | 59 | count_frame = pd.DataFrame(data=counts_dict) 60 | propotion_frame = pd.DataFrame(data=propotion_dict) 61 | writer = pd.ExcelWriter(output_file_path, engine='xlsxwriter') 62 | 63 | count_frame.to_excel(writer, sheet_name='判别计数', index=False) 64 | workbook1 = writer.book 65 | worksheet1 = writer.sheets['判别计数'] 66 | worksheet1.set_column(0, 0, 8) 67 | worksheet1.set_column(1, num_classes, 5) 68 | 69 | propotion_frame.to_excel(writer, sheet_name='判别比例', index=False) 70 | workbook2 = writer.book 71 | worksheet2 = writer.sheets['判别比例'] 72 | worksheet2.set_column(0, 0, 8) 73 | worksheet2.set_column(1, num_classes, 5) 74 | 75 | writer.save() 76 | return 77 | 78 | def test_mix_async( 79 | model, 80 | test_set, 81 | train_idx_to_class = None, 82 | test_idx_to_class = None, 83 | test_loader=None): 84 | num_classes = len(idx_to_class_list) 85 | mix_mat_count_tensor = torch.zeros((num_classes, num_classes+1), dtype=torch.long).to(device) 86 | 87 | pass 88 | 89 | def test_mix( 90 | model, 91 | test_set = None, 92 | train_idx_to_class = None, 93 | test_idx_to_class = None, 94 | test_loader=None): 95 | """ 测试model在test_set上的准确率及混淆矩阵 96 | * 要求单标签流程 97 | * test_set跟model的序号排列不同时,要求给出训练集和测试集上的两个idx_to_class 98 | 参数: 99 | model : 测试用的模型 100 | test_set : 测试针对的数据集 101 | test_loader : 测试数据集的读取器,不可与test_set共存 102 | train_idx_to_class : 训练集的类别排列,indexable 103 | test_idx_to_class : 测试集的类别排列 104 | """ 105 | -------------------------------------------------------------------------------- /mec/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artintel2017/torch-multi-gpu/94c9171035b33e8f9cb9e816f1d3e5f3c39e5018/mec/training/__init__.py -------------------------------------------------------------------------------- /mec/training/async_trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artintel2017/torch-multi-gpu/94c9171035b33e8f9cb9e816f1d3e5f3c39e5018/mec/training/async_trainer.py -------------------------------------------------------------------------------- /mec/training/basic_trainer.py: -------------------------------------------------------------------------------- 1 | # basic_trainer.py 2 | # created: CS 3 | # 基本的训练类 4 | 5 | import torch 6 | 7 | class BasicTrainer(): 8 | """ 9 | Trainer类 10 | 封装训练中的各个步骤 11 | 不负责流程组织 12 | 不负责数据传递 13 | 参数: 14 | model: 模型 15 | optimizer: 优化器 16 | criterion: 损失函数 17 | metrics: 评分器 18 | """ 19 | def __init__(self, model, optimizer, criterion, metrics): 20 | # ----- basic elements ----- 21 | self.model = model # 模型 22 | self.optimizer = optimizer # 优化器 23 | self.criterion = criterion # 损失函数 24 | self.metrics = metrics # 评分函数 25 | # ----- temporary figures ----- 26 | self.loss = 0 27 | self.met = 0 28 | 29 | def initEpoch(self): 30 | self.metrics.init() 31 | self.optimizer.zero_grad() 32 | 33 | def forwardData(self, data): 34 | self.optimizer.zero_grad() 35 | self.result = self.model(data) 36 | 37 | # 增量式前向:不归零梯度 38 | def forwardDataInc(self, data): 39 | self.result=self.model(data) 40 | 41 | # 前向但不记录梯度 42 | @torch.no_grad() 43 | def forwardNoGrad(self, data): 44 | self.result=self.model(data) 45 | 46 | def backwardGrad(self, target): 47 | self.loss = self.criterion(self.result, target) 48 | self.loss.backward() 49 | self.met, _ = self.metrics(self.result, target) 50 | 51 | # 只计算loss,不回传梯度 52 | @torch.no_grad() 53 | def calcScores(self, target): 54 | self.loss = self.criterion(self.result, target) 55 | self.met, _ = self.metrics(self.result, target) 56 | 57 | def setLearningRate(self, lr): 58 | for param_group in self.optimizer.param_groups: 59 | param_group['lr'] = lr 60 | 61 | def updateWeights(self): 62 | self.optimizer.step() 63 | 64 | def getScores(self): 65 | return self.loss.item(), self.met 66 | 67 | def saveModel(self, filename): 68 | torch.save(self.model.state_dict(), filename) 69 | 70 | def loadModel(self, filename, map_location=None): 71 | self.model.load_state_dict(torch.load(filename, map_location=map_location)) -------------------------------------------------------------------------------- /mec/training/old_sync_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import zmq 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.distributed import DistributedSampler 12 | 13 | from .monitor import Monitor 14 | 15 | 16 | from torchvision.models.resnet import resnet50 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | 20 | 21 | #from dataset_single import get_data_loader 22 | 23 | # --------------------------------- traning utils ------------------------------- 24 | # ------------------------------------------------------------------------- 25 | 26 | 27 | # Trainer类 28 | # 封装训练中的各个步骤 29 | # 不负责流程组织 30 | # 不负责数据传递 31 | class Trainer(): 32 | def __init__(self, model, optimizer, criterion, metrics): 33 | # ----- basic elements ----- 34 | self.model = model # 模型 35 | self.optimizer = optimizer # 优化器 36 | self.criterion = criterion # 损失函数 37 | self.metrics = metrics # 评分函数 38 | # ----- temporary figures ----- 39 | self.loss = 0 40 | self.met = 0 41 | 42 | def zero_grad(self): 43 | self.optimizer.zero_grad() 44 | 45 | def forwardData(self, data): 46 | self.optimizer.zero_grad() 47 | self.result = self.model(data) 48 | 49 | # 增量式前向:不归零梯度 50 | def forwardDataInc(self, data): 51 | self.result=self.model(data) 52 | 53 | # 前向但不记录梯度 54 | @torch.no_grad() 55 | def forwardNoGrad(self, data): 56 | self.result=self.model(data) 57 | 58 | def backwardGrad(self, target): 59 | self.loss = self.criterion(self.result, target) 60 | self.loss.backward() 61 | self.met, _ = self.metrics(self.result, target) 62 | 63 | # 只计算loss,不回传梯度 64 | @torch.no_grad() 65 | def calcScores(self, target): 66 | self.loss = self.criterion(self.result, target) 67 | self.met, _ = self.metrics(self.result, target) 68 | 69 | def setLearningRate(self, lr): 70 | for param_group in self.optimizer.param_groups: 71 | param_group['lr'] = lr 72 | 73 | def updateWeights(self): 74 | self.optimizer.step() 75 | 76 | def getScores(self): 77 | return self.loss.item(), self.met 78 | 79 | # Workers 类 80 | # 负责数据传递 81 | # 负责基本任务承接 82 | # class WorkerBase(): 83 | # pass 84 | 85 | # 86 | class WorkerSync(): 87 | r""" 88 | 同步训练worker类 89 | 每次接受manager广播来分配batch 90 | forward和backward结束后 91 | 用dist.reduce上传grad 92 | 用dist.broadcast回传weights 93 | grad在manager上延迟更新 94 | 95 | 参数: 96 | model : 数据模型 97 | dataset : 数据集,应可索引 98 | device : 使用的设备 99 | rank : 进程编号,在组内应唯一 100 | sync_worker_num : worker数量 101 | manager_ip : 仅IP地址,如"127.0.0.1" 102 | task_port : 发布使用的端口 103 | score_port : 返回分数使用的端口 104 | dist_port : torch.dist使用的端口 105 | """ 106 | def __init__(self, trainer, train_loader, valid_loader, 107 | device, rank, sync_worker_num, num_workers_per_loader, manager_ip, 108 | train_batch_transform=None, valid_batch_transform=None, 109 | dist_port="8100", task_port="8101", score_port="8102", 110 | manager_with_gpu=True): 111 | # 训练相关 112 | self.trainer = trainer 113 | self.device = device 114 | self.train_loader = train_loader 115 | self.valid_loader = valid_loader 116 | self.train_batch_transform = train_batch_transform 117 | self.valid_batch_transform = valid_batch_transform 118 | self.train_iter = None 119 | self.valid_iter = None 120 | # 分布式相关 121 | os.environ['MASTER_ADDR'] = manager_ip 122 | os.environ['MASTER_PORT'] = dist_port 123 | self.dist_addr = "tcp://" + manager_ip + ":" + dist_port 124 | self.rank = rank 125 | self.sync_worker_num = sync_worker_num 126 | self.manager_with_gpu = manager_with_gpu 127 | self.world_size = sync_worker_num + 1 128 | # 通讯相关 129 | self.context = None 130 | self.task_addr = "tcp://" + manager_ip + ":" + task_port 131 | self.task_socket = None 132 | self.score_addr = "tcp://" + manager_ip + ":" + score_port 133 | self.score_socket = None 134 | # 日志文件 135 | self.logfile = open("logs/worker_{:d}.log".format(rank), 'w') 136 | # try: 137 | # self.logfile = open("logs/worker_{:d}.log".format(rank), 'a') 138 | # except FileNotFoundError: 139 | # self.logfile = open("logs/worker_{:d}.log".format(rank), 'w') 140 | 141 | def __del__(self): 142 | self.closeSocket() 143 | self.logfile.close() 144 | 145 | def printToLog(self, *content): 146 | print("[worker_{:d}|{}]".format(self.rank, time.strftime("%y-%m-%d_%H:%M:%S")), 147 | *content, file=self.logfile, flush=True) 148 | #print("[manager]", *content) 149 | 150 | def initSocket(self): 151 | self.printToLog("initilzating socket:") 152 | self.printToLog("task addr = '{}'".format(self.task_addr) ) 153 | self.printToLog("score addr = '{}'".format(self.score_addr) ) 154 | self.context = zmq.Context() 155 | self.task_socket = self.context.socket(zmq.SUB) 156 | self.task_socket.connect(self.task_addr) 157 | self.task_socket.setsockopt(zmq.SUBSCRIBE, b'') 158 | self.score_socket = self.context.socket(zmq.PUSH) 159 | self.score_socket.connect(self.score_addr) 160 | 161 | def closeSocket(self): 162 | if self.task_socket != None: 163 | self.task_socket.disconnect(self.task_addr) 164 | self.task_socket = None 165 | if self.score_socket != None: 166 | self.score_socket.disconnect(self.score_addr) 167 | self.score_socket = None 168 | 169 | def recvMessage(self): 170 | return eval(self.task_socket.recv().decode() ) 171 | 172 | def sendMessage(self, msg): 173 | return self.score_socket.send(repr(msg).encode() ) 174 | 175 | def initTorchDist(self): 176 | self.printToLog("dist args:", 'nccl', self.dist_addr, 177 | self.rank, self.world_size) 178 | dist.init_process_group('nccl', 179 | rank=self.rank, world_size=self.world_size) 180 | # dist.init_process_group('nccl', init_method=self.dist_addr, 181 | # rank=self.rank, world_size=self.world_size) 182 | self.worker_group = dist.new_group(list(range(1,self.world_size)) ) 183 | 184 | # def setBatch(self, batch_index_list): 185 | # self.dataloader.batch_sampler.setBatch(batch_index_list) 186 | 187 | # 开始一个epoch 188 | def initTrainEpoch(self, epoch, lr): 189 | self.printToLog("setting training epoch") 190 | self.train_loader.sampler.set_epoch(epoch) 191 | self.printToLog("setting training iter") 192 | # if self.train_iter != None: 193 | # while True: 194 | # try : 195 | # next(self.train_iter) 196 | # except StopIteration: 197 | # break 198 | self.train_iter = iter(self.train_loader) 199 | self.printToLog("setting train iter complete") 200 | self.trainer.setLearningRate(lr) 201 | self.trainer.model.train() 202 | 203 | def initValidEpoch(self, epoch): 204 | self.printToLog("setting validation epoch") 205 | self.valid_loader.sampler.set_epoch(epoch) 206 | self.printToLog("setting validation iter") 207 | if self.valid_iter != None: 208 | while True: 209 | try : 210 | next(self.valid_iter) 211 | except StopIteration: 212 | break 213 | self.valid_iter = iter(self.valid_loader) 214 | self.trainer.model.eval() 215 | 216 | # 只计算grad,不包括更新weights 217 | def batchTrainNoUpdate(self): 218 | self.printToLog("batch train") 219 | data, target = next(self.train_iter) 220 | self.batch_sample_num = len(target) 221 | self.printToLog("getting data") 222 | data, target = data.to(self.device), target.to(self.device) 223 | if self.train_batch_transform is not None: 224 | data, target = self.train_batch_transform(data, target) 225 | self.printToLog("forwarding") 226 | self.trainer.forwardData(data) 227 | self.printToLog("backwarding") 228 | self.trainer.backwardGrad(target) 229 | self.printToLog("batch train complete") 230 | 231 | def updateWeights(self): 232 | self.trainer.updateWeights() 233 | 234 | # no grad 235 | def batchValidate(self): 236 | data, target = next(self.valid_iter) 237 | self.batch_sample_num = len(target) 238 | data, target = data.to(self.device), target.to(self.device) 239 | self.trainer.forwardNoGrad(data) 240 | self.trainer.calcScores(target) 241 | 242 | def getScores(self): 243 | return self.trainer.getScores() 244 | 245 | def saveState(self, filename): 246 | return torch.save(self.trainer.model.state_dict(), filename) 247 | 248 | # 交换梯度 249 | # 视情况需传入group决定交换梯度的对象 250 | def crossGrads(self, async_op=False): 251 | for p_group in self.trainer.optimizer.param_groups: 252 | for param in p_group['params']: 253 | #print(param.size()) 254 | dist.all_reduce(param.grad, group=self.worker_group, op=dist.ReduceOp.SUM, async_op=async_op) 255 | if async_op: dist.barrier(group=self.worker_group) 256 | for p_group in self.trainer.optimizer.param_groups: 257 | for param in p_group['params']: 258 | param.grad /= self.sync_worker_num 259 | 260 | 261 | # 上传梯度 262 | # 默认rank0为管理进程 263 | def uploadGrads(self, async_op=False): 264 | for p_group in self.trainer.optimizer.param_groups: 265 | for param in p_group['params']: 266 | dist.reduce(param.grad, 0, op=dist.ReduceOp.SUM, async_op=async_op) 267 | if async_op: dist.barrier() 268 | 269 | # 下载权值 270 | # 默认rank0为管理进程 271 | def downloadWeights(self, async_op=False): 272 | for p_group in self.trainer.optimizer.param_groups: 273 | for param in p_group['params']: 274 | dist.broadcast(param, 0, async_op=async_op) 275 | if async_op: dist.barrier() 276 | 277 | # 上传和交换权值 278 | # 默认rank0为管理进程 279 | def exchangeGradsAndWeights(self, async_op=False): 280 | for p_group in self.trainer.optimizer.param_groups: 281 | for param in p_group['params']: 282 | dist.reduce(param.grad, 0, op=dist.ReduceOp.SUM, async_op=async_op) 283 | dist.broadcast(param.grad, 0, async_op=async_op) 284 | if async_op: dist.barrier() 285 | 286 | #def gen_optimizer(model): 287 | # return ... 288 | 289 | def sync_worker_process(model, optimizer, criterion, metrics, train_set, valid_set, 290 | batch_size, device, rank, sync_worker_num, num_workers_per_loader, 291 | train_batch_transform, valid_batch_transform, 292 | manager_ip, sync_flag='cross'): 293 | #print("worker process ", rank) 294 | world_size = sync_worker_num+1 295 | ### 296 | torch.cuda.set_device(device) 297 | # 如果不设置这一条,在DataLoader里设置pin_memory=True以后, 298 | # 每个worker进程里的DataLoader都会占用0号GPU 299 | 300 | model = model.to(device) 301 | #opt = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.01, nesterov=True) 302 | #trainer = Trainer(model, opt, criterion, metrics) 303 | 304 | trainer = Trainer(model, optimizer, criterion, metrics) 305 | batch_size_per_worker = int(batch_size/sync_worker_num) 306 | train_loader = DataLoader( 307 | train_set, 308 | batch_size = batch_size_per_worker, 309 | sampler = DistributedSampler( 310 | train_set, 311 | num_replicas=sync_worker_num, 312 | rank=rank%sync_worker_num, 313 | #shuffle=True 314 | ), 315 | num_workers = num_workers_per_loader, 316 | pin_memory=True 317 | ) 318 | valid_loader = DataLoader( 319 | valid_set, 320 | batch_size = batch_size_per_worker, 321 | sampler = DistributedSampler( 322 | valid_set, 323 | num_replicas=sync_worker_num, 324 | rank=rank%sync_worker_num 325 | #shuffle=True 326 | ), 327 | num_workers = num_workers_per_loader, 328 | pin_memory=True 329 | ) 330 | #train_loader, valid_loader, _ = get_data_loader(batch_size=64, val_split=0.2, num_workers=num_workers_per_loader) 331 | worker = WorkerSync(trainer, train_loader, valid_loader, 332 | device, rank, sync_worker_num, 333 | num_workers_per_loader, manager_ip, 334 | train_batch_transform=train_batch_transform) 335 | worker.printToLog("train loader len {}".format(len(train_loader)) ) 336 | worker.initSocket() 337 | worker.initTorchDist() 338 | #worker.sendMessage({'flag':'ready'}) 339 | worker.printToLog("beginning loop") 340 | while True: 341 | msg = worker.recvMessage() 342 | worker.printToLog("message recieved:") 343 | worker.printToLog(msg) 344 | if msg['flag'] == 'init': 345 | sync_flag = msg['sync_flag'] 346 | if msg['flag'] == 'quit': 347 | break 348 | elif msg['flag'] == 'train_epoch': 349 | epoch = msg['epoch'] 350 | lr = msg['lr'] 351 | worker.printToLog(("===== train epoch {}; " 352 | "train loader len : {} ===== " 353 | ).format(epoch, len(train_loader)) 354 | ) 355 | worker.initTrainEpoch(epoch, lr) 356 | for train_batch_index in range(len(worker.train_loader) ): 357 | worker.batchTrainNoUpdate() 358 | if sync_flag=='cross': 359 | worker.printToLog("cross grads") 360 | worker.crossGrads(async_op=True) 361 | worker.printToLog("updating weights") 362 | worker.updateWeights() 363 | else : 364 | worker.printToLog("exchanging") 365 | #worker.exchangeGradsAndWeights(async_op=True) 366 | worker.printToLog("uploading grads") 367 | worker.uploadGrads() 368 | worker.printToLog("downloading weights") 369 | worker.downloadWeights() 370 | #print("getting scores") 371 | sample_num = worker.batch_sample_num 372 | loss, met = worker.getScores() 373 | respond = { 374 | 'flag' : 'train_score', 375 | 'samples': sample_num, 376 | 'loss' : loss, 377 | 'met' : met 378 | } 379 | worker.sendMessage(respond) 380 | worker.printToLog( 381 | ("batch {:d}; " 382 | "sample num : {}; " 383 | "loss : {:.4f}; " 384 | "{} : {:.4f}").format( 385 | train_batch_index, sample_num, loss, metrics.getMetricName(), met 386 | ) 387 | ) 388 | elif msg['flag'] == 'valid_epoch': 389 | #worker.printToLog("valid batch") 390 | worker.initValidEpoch(epoch) 391 | epoch = msg['epoch'] 392 | for valid_batch_index in range(len(worker.valid_loader) ): 393 | worker.batchValidate() 394 | loss, met = worker.getScores() 395 | sample_num = worker.batch_sample_num 396 | respond = { 397 | 'flag' : 'valid_score', 398 | 'valid_samples': sample_num, 399 | 'valid_loss' : loss, 400 | 'valid_met' : met 401 | } 402 | worker.sendMessage(respond) 403 | worker.printToLog( 404 | ("v batch {:d}; " 405 | "v samples {}; " 406 | "v loss: {:.3f}; " 407 | "v {}: {:.3f}").format( 408 | valid_batch_index, sample_num, loss, metrics.getMetricName(), met 409 | ) 410 | ) 411 | elif msg['flag'] == 'save_model': 412 | if rank in msg['ranks_to_save']: # 413 | worker.printToLog('---- saving model ----') 414 | model_filename = msg['model_filename'] 415 | worker.saveState(model_filename) 416 | if msg['is_best'] == True: 417 | worker.printToLog('---- saving best model ----') 418 | best_model_fname = msg['best_model_fname'] 419 | worker.saveState(best_model_fname) 420 | else: 421 | worker.printToLog("--- unknown message types ---") 422 | worker.printToLog(msg) 423 | continue 424 | worker.closeSocket() 425 | 426 | 427 | # Manager 类 428 | # 负责流程组织 429 | # 用消息通信控制worker进程 430 | # torch.distributed无法同时使用多个后端,为效率选择了nccl后,无法传递cpu数据 431 | class ManagerSync(): 432 | 433 | def __init__(self, trainer, train_set_len, valid_set_len, batch_size, rank, sync_worker_num, 434 | sync_flag='cross', history_filename="history.json", 435 | result_path="results/temp", model_filename ="current_model.pth", 436 | best_model_fname="best_model.pth", manager_ip="127.0.0.1", 437 | dist_port="8100", task_port="8101", score_port="8102"): 438 | # 基本数据 439 | self.trainer = trainer 440 | self.sync_worker_num = sync_worker_num # 训练进程的个数,区别于dataloader中的num_workers 441 | # 数据相关 442 | self.train_loader_len = int(np.ceil(train_set_len/batch_size/sync_worker_num) ) 443 | self.valid_loader_len = int(np.ceil(valid_set_len/batch_size/sync_worker_num) ) 444 | # 结果保存 445 | self.model_filename = result_path + '/' + model_filename 446 | self.best_model_fname = result_path + '/' + best_model_fname 447 | self.history_filename = result_path + '/' + history_filename 448 | # 分布式相关 449 | os.environ['MASTER_ADDR'] = manager_ip 450 | os.environ['MASTER_PORT'] = dist_port 451 | self.dist_addr = "tcp://" + manager_ip + ":" + dist_port 452 | self.rank = rank 453 | self.world_size = sync_worker_num + 1 454 | # 通讯相关 455 | self.context = None 456 | self.task_addr = "tcp://" + manager_ip + ":" + task_port 457 | self.task_socket = None 458 | self.score_addr = "tcp://" + manager_ip + ":" + score_port 459 | self.score_socket = None 460 | # 日志文件 461 | self.logfile = open("logs/manager.log", 'w') 462 | # try: 463 | # self.logfile = open("logs/manager.log", 'a') 464 | # except FileNotFoundError: 465 | # self.logfile = open("logs/manager.log", 'w') 466 | 467 | def __del__(self): 468 | self.closeSocket() 469 | self.logfile.close() 470 | 471 | def printToLog(self, *content): 472 | print("[manager|{}]".format(time.strftime("%y-%m-%d_%H:%M:%S") ), 473 | *content, file=self.logfile, flush=True) 474 | #print("[manager]", *content) 475 | 476 | def initSocket(self): 477 | self.printToLog("initizating socket:") 478 | self.printToLog("task addr = '{}'".format(self.task_addr) ) 479 | self.printToLog("score addr = '{}'".format(self.score_addr) ) 480 | self.context = zmq.Context() 481 | self.task_socket = self.context.socket(zmq.PUB) 482 | self.task_socket.bind(self.task_addr) 483 | self.score_socket = self.context.socket(zmq.PULL) 484 | self.score_socket.bind(self.score_addr) 485 | 486 | def closeSocket(self): 487 | self.printToLog("closing socket") 488 | if self.task_socket != None: 489 | self.task_socket.unbind(self.task_addr) 490 | self.task_socket = None 491 | if self.score_socket != None: 492 | self.score_socket.unbind(self.score_addr) 493 | self.score_socket = None 494 | 495 | def sendMessage(self, msg): 496 | return self.task_socket.send(repr(msg).encode() ) 497 | 498 | def recvMessage(self): 499 | return eval(self.score_socket.recv().decode() ) 500 | 501 | def initTorchDist(self): 502 | self.printToLog("dist args:", 'nccl', self.dist_addr, 503 | self.rank, self.world_size) 504 | dist.init_process_group('nccl', 505 | rank=self.rank, world_size=self.world_size) 506 | # dist.init_process_group('nccl', init_method=self.dist_addr, 507 | # rank=self.rank, world_size=self.world_size) 508 | 509 | # 上传梯度 510 | # 默认rank0为管理进程 511 | def gatherGrads(self, async_op=False): 512 | for p_group in self.trainer.optimizer.param_groups: 513 | for param in p_group['params']: 514 | dist.reduce(param.grad, 0, op=dist.ReduceOp.SUM, async_op=async_op) 515 | if async_op: dist.barrier() 516 | 517 | # 下载权值 518 | # 默认rank0为管理进程 519 | def broadcastWeights(self, async_op=False): 520 | for p_group in self.trainer.optimizer.param_groups: 521 | for param in p_group['params']: 522 | dist.broadcast(param, 0, async_op=async_op) 523 | if async_op: dist.barrier() 524 | 525 | def updateWeights(self): 526 | self.trainer.updateWeights() 527 | 528 | # 上传和交换权值 529 | # 默认rank0为管理进程 530 | def exchangeGradsAndWeights(self, async_op=False): 531 | for p_group in self.trainer.optimizer.param_groups: 532 | for param in p_group['params']: 533 | dist.reduce(param.grad, 0, op=dist.ReduceOp.SUM, async_op=async_op) 534 | #dist.broadcast(param.data, 0, async_op=async_op) 535 | dist.broadcast(param.grad, 0, async_op=async_op) 536 | if async_op: dist.barrier() 537 | 538 | # manager 进程 539 | def sync_manager_process( 540 | model, optimizer, criterion, metrics, 541 | train_set, valid_set, batch_size, 542 | init_epoch, total_epochs, 543 | device, rank, sync_worker_num, 544 | manager_ip="127.0.0.1", sync_flag='cross', 545 | result_path="results/temp" 546 | ): 547 | train_set_len = len(train_set) 548 | valid_set_len = len(valid_set) 549 | world_size = sync_worker_num+1 550 | trainer = Trainer(model.to(device), optimizer, None, metrics) 551 | manager = ManagerSync(trainer, train_set_len, valid_set_len, 552 | batch_size, rank, sync_worker_num, 553 | manager_ip=manager_ip) 554 | manager.initSocket() 555 | manager.initTorchDist() 556 | manager.printToLog("getting data") 557 | tempdata = torch.randn(3,224,224) 558 | manager.printToLog("moving data") 559 | tempdata = tempdata.to(device).unsqueeze(0) 560 | manager.printToLog("calc result") 561 | tempresult = trainer.model(tempdata) 562 | manager.printToLog("calc loss") 563 | loss = torch.mean(tempresult) 564 | manager.printToLog("backward") 565 | loss.backward() 566 | manager.printToLog("zero_grad") 567 | trainer.zero_grad() 568 | manager.printToLog("init monitor") 569 | monitor = Monitor(init_epoch, total_epochs, 570 | manager.train_loader_len, manager.valid_loader_len, 571 | metrics.getMetricName()) 572 | message = { 573 | 'flag' : 'init', 574 | 'sync_flag' : sync_flag 575 | } 576 | manager.sendMessage(message) 577 | for epoch in range(init_epoch+1, init_epoch+total_epochs+1): 578 | manager.printToLog( 579 | ( 580 | "===== train epoch {}; " 581 | "train loader len : {} =====" 582 | ).format(epoch, manager.train_loader_len) 583 | ) 584 | message = { 585 | 'flag' : 'train_epoch', 586 | 'epoch': epoch 587 | } 588 | manager.sendMessage(message) 589 | # training 590 | total_samples = 0 591 | total_loss = 0 592 | total_met = 0 593 | for batch in range(manager.train_loader_len): 594 | if sync_flag == "cross": 595 | manager.printToLog("workers corssing grads ...") 596 | pass 597 | else : 598 | manager.printToLog("-- exchanging --" ) 599 | # manager.exchangeGradsAndWeights(async_op=True) 600 | manager.printToLog("gathering grads") 601 | manager.gatherGrads() 602 | manager.printToLog("broadcasting weights") 603 | manager.broadcastWeights() 604 | manager.printToLog("updating weights") 605 | manager.updateWeights() 606 | trainer.zero_grad() 607 | sample_num = 0 608 | loss = 0 609 | met = 0 610 | for worker_index in range(sync_worker_num): 611 | respond = manager.recvMessage() 612 | sample_num += respond['samples'] 613 | loss += respond['loss']/sync_worker_num 614 | total_loss += respond['loss'] * respond['samples'] 615 | met += respond['met']/sync_worker_num 616 | total_met += respond['met'] * respond['samples'] 617 | total_samples += sample_num 618 | avg_loss = total_loss / total_samples 619 | avg_met = total_met / total_samples 620 | manager.printToLog( 621 | ("batch {:d}; " 622 | "sample num : {}; " 623 | "loss : {:.3f}; " 624 | "{} : {:.3f}").format( 625 | batch, 626 | sample_num, 627 | loss, 628 | metrics.getMetricName(), 629 | met 630 | ) 631 | ) 632 | monitor.updateTraining(loss, avg_loss, met, avg_met) 633 | # validation 634 | manager.printToLog( 635 | ( 636 | "===== valid epoch {}; " 637 | "train loader len : {} =====" 638 | ).format(epoch, manager.train_loader_len) 639 | ) 640 | total_val_samples = 0 641 | total_val_loss = 0 642 | total_val_met = 0 643 | message = { 644 | 'flag' : 'valid_epoch', 645 | 'epoch': epoch 646 | } 647 | manager.sendMessage(message) 648 | for val_batch in range(manager.valid_loader_len): 649 | val_sample_num = 0 650 | val_loss = 0 651 | val_met = 0 652 | for i in range(sync_worker_num): 653 | respond = manager.recvMessage() 654 | val_sample_num += respond['valid_samples'] 655 | val_loss += respond['valid_loss'] / sync_worker_num 656 | total_val_loss += respond['valid_loss'] * respond['valid_samples'] 657 | val_met += respond['valid_met']/sync_worker_num 658 | total_val_met += respond['valid_met'] * respond['valid_samples'] 659 | total_val_samples += val_sample_num 660 | avg_val_loss = total_val_loss / total_val_samples 661 | avg_val_met = total_val_met / total_val_samples 662 | monitor.updateValidation(val_loss, avg_val_loss, val_met, avg_val_met) 663 | manager.printToLog( 664 | ("val batch {:d}; " 665 | "val sample num : {}; " 666 | "val loss : {:.3f}; " 667 | "val {} : {:.3f}").format( 668 | val_batch, 669 | val_sample_num, 670 | val_loss, 671 | metrics.getMetricName(), 672 | val_met 673 | ) 674 | ) 675 | monitor.updateEpoch(avg_loss, avg_val_loss, avg_met, avg_val_met) 676 | # epoch 677 | # epoch for loop end 678 | manager.printToLog("sending quit") 679 | message = { 680 | 'flag': 'quit' 681 | } 682 | manager.sendMessage(message) 683 | monitor.close() 684 | 685 | # ------------------------------ no gpu manager-------------------------------- 686 | 687 | # 无GPU的Manager类 688 | # 负责流程组织 689 | # 负责模型存取 690 | class ManagerSyncNoGpu(): 691 | 692 | def __init__(self, train_set_len, valid_set_len, batch_size, sync_worker_num, 693 | history=None, result_path="results/temp", model_filename ="current_model.pth", 694 | best_model_fname="best_model.pth", history_filename="history.json", 695 | manager_ip="127.0.0.1", dist_port="8100", task_port="8101", score_port="8102"): 696 | # 基本数据 697 | self.sync_worker_num = sync_worker_num # 训练进程的个数,区别于dataloader中的num_workers 698 | # 数据相关 699 | self.train_loader_len = int(np.ceil(train_set_len/batch_size) ) 700 | self.valid_loader_len = int(np.ceil(valid_set_len/batch_size) ) 701 | # 结果保存 702 | self.model_filename = result_path + '/' + model_filename 703 | self.best_model_fname = result_path + '/' + best_model_fname 704 | self.history_filename = result_path + '/' + history_filename 705 | if history == None: 706 | self.history = {'epochs':0, 'train_loss':[], 'train_met':[], 'val_loss':[], 'val_met':[], 'lr':[]} 707 | self.best_loss = 0 708 | self.best_met = 0 709 | self.best_val_loss = 1e6 710 | self.best_val_met = 0 711 | else: 712 | self.history = history 713 | if self.history['epochs']>0: 714 | self.best_loss = min(self.history['train_loss']) 715 | self.best_met = max(self.history['train_met']) 716 | self.best_val_loss = min(self.history['val_loss']) 717 | self.best_val_met = max(self.history['val_met']) 718 | # 通讯相关 719 | self.context = None 720 | self.task_addr = "tcp://" + manager_ip + ":" + task_port 721 | self.task_socket = None 722 | self.score_addr = "tcp://" + manager_ip + ":" + score_port 723 | self.score_socket = None 724 | # 分布式相关 725 | os.environ['MASTER_ADDR'] = manager_ip 726 | os.environ['MASTER_PORT'] = dist_port 727 | self.dist_addr = "tcp://" + manager_ip + ":" + dist_port 728 | self.rank = 0 729 | self.world_size = sync_worker_num + 1 730 | # 日志文件 731 | self.logfile = open("logs/manager.log", 'w') 732 | 733 | def __del__(self): 734 | self.closeSocket() 735 | self.logfile.close() 736 | 737 | def printToLog(self, *content): 738 | print("[manager|{}]".format(time.strftime("%y-%m-%d_%H:%M:%S") ), 739 | *content, file=self.logfile, flush=True) 740 | #print("[manager]", *content) 741 | 742 | def initSocket(self): 743 | self.printToLog("initizating socket:") 744 | self.printToLog("task addr = '{}'".format(self.task_addr) ) 745 | self.printToLog("score addr = '{}'".format(self.score_addr) ) 746 | self.context = zmq.Context() 747 | self.task_socket = self.context.socket(zmq.PUB) 748 | self.task_socket.bind(self.task_addr) 749 | self.score_socket = self.context.socket(zmq.PULL) 750 | self.score_socket.bind(self.score_addr) 751 | 752 | def closeSocket(self): 753 | self.printToLog("closing socket") 754 | if self.task_socket != None: 755 | self.task_socket.unbind(self.task_addr) 756 | self.task_socket = None 757 | if self.score_socket != None: 758 | self.score_socket.unbind(self.score_addr) 759 | self.score_socket = None 760 | 761 | def initTorchDist(self): 762 | self.printToLog("dist args:", 'nccl', self.dist_addr, 763 | self.rank, self.world_size) 764 | dist.init_process_group('nccl', 765 | rank=self.rank, world_size=self.world_size) 766 | 767 | def sendMessage(self, msg): 768 | return self.task_socket.send(repr(msg).encode() ) 769 | 770 | def recvMessage(self): 771 | return eval(self.score_socket.recv().decode() ) 772 | 773 | # 返回值:是否是最好结果 774 | def updateHistory(self, epoch, lr, loss, met, val_loss, val_met): 775 | self.history['epochs'] += 1 776 | self.history['lr'] .append(lr) 777 | self.history['train_loss'].append(loss) 778 | self.history['train_met'] .append(met) 779 | self.history['val_loss'] .append(val_loss) 780 | self.history['val_met'] .append(val_met) 781 | self.printToLog("saving history:") 782 | self.printToLog("filename = {}".format(self.history_filename)) 783 | self.printToLog("content:") 784 | self.printToLog("epoch {}, lr {}".format(epoch, lr) ) 785 | self.printToLog("lr {}, t_los {:3f}, t_met {:3f}, v_los {:3f}, v_met {:3f}".format( 786 | lr, loss, met, val_loss, val_met) 787 | ) 788 | with open(self.history_filename, 'w', encoding='utf8') as file: 789 | json.dump(self.history, file, ensure_ascii=False, indent=2) 790 | if val_met > self.best_val_met: 791 | self.best_loss = loss 792 | self.best_met = met 793 | self.best_val_loss = val_loss 794 | self.best_val_met = val_met 795 | self.printToLog("new best score at epoch {:d}:".format(epoch)) 796 | self.printToLog("loss {:.3f}, met {:.3f}, v_los {:.3f}, v_met {:.3f}".format( 797 | loss, met, val_loss, val_met) 798 | ) 799 | return True 800 | return False 801 | 802 | 803 | # manager 进程 804 | def sync_manager_process_no_gpu( 805 | metrics_name, train_set_len, valid_set_len, batch_size, 806 | init_epoch, total_epochs, lr_scheduler, sync_worker_num, 807 | manager_ip="127.0.0.1", result_path="results/temp", 808 | history=None, sync_flag='cross' 809 | ): 810 | world_size = sync_worker_num+1 811 | manager_no_gpu = ManagerSyncNoGpu( 812 | train_set_len, valid_set_len, 813 | batch_size, sync_worker_num, 814 | history=history, result_path=result_path, 815 | manager_ip=manager_ip) 816 | manager_no_gpu.initSocket() 817 | manager_no_gpu.printToLog("init monitor") 818 | monitor = Monitor(init_epoch, total_epochs, 819 | manager_no_gpu.train_loader_len, 820 | manager_no_gpu.valid_loader_len, 821 | metrics_name) 822 | manager_no_gpu.initTorchDist() 823 | message = { 824 | 'flag' : 'init', 825 | 'sync_flag' : 'cross' 826 | } 827 | manager_no_gpu.sendMessage(message) 828 | manager_no_gpu.printToLog("message sent:") 829 | manager_no_gpu.printToLog(message) 830 | for epoch in range(init_epoch, init_epoch+total_epochs): 831 | manager_no_gpu.printToLog( 832 | ( 833 | "===== train epoch {}; " 834 | "train loader len : {} =====" 835 | ).format(epoch, manager_no_gpu.train_loader_len) 836 | ) 837 | message = { 838 | 'flag' : 'train_epoch', 839 | 'epoch': epoch+1, 840 | 'lr' : lr_scheduler(epoch+1) 841 | } 842 | manager_no_gpu.sendMessage(message) 843 | manager_no_gpu.printToLog("message sent:") 844 | manager_no_gpu.printToLog(message) 845 | # training 846 | total_samples = 0 847 | total_loss = 0 848 | total_met = 0 849 | for batch in range(manager_no_gpu.train_loader_len): 850 | manager_no_gpu.printToLog("workers corssing grads ...") 851 | sample_num = 0 852 | loss = 0 853 | met = 0 854 | for worker_index in range(sync_worker_num): 855 | respond = manager_no_gpu.recvMessage() 856 | sample_num += respond['samples'] 857 | loss += respond['loss']/sync_worker_num 858 | total_loss += respond['loss'] * respond['samples'] 859 | met += respond['met']/sync_worker_num 860 | total_met += respond['met'] * respond['samples'] 861 | total_samples += sample_num 862 | avg_loss = total_loss / total_samples 863 | avg_met = total_met / total_samples 864 | manager_no_gpu.printToLog( 865 | ("batch {:d}; " 866 | "sample num : {}; " 867 | "loss : {:.3f}; " 868 | "{} : {:.3f}").format( 869 | batch, 870 | sample_num, 871 | loss, 872 | metrics_name, 873 | met 874 | ) 875 | ) 876 | monitor.updateTraining(loss, avg_loss, met, avg_met) 877 | # validation 878 | manager_no_gpu.printToLog( 879 | ( 880 | "===== valid epoch {}; " 881 | "train loader len : {} =====" 882 | ).format(epoch, manager_no_gpu.train_loader_len) 883 | ) 884 | total_val_samples = 0 885 | total_val_loss = 0 886 | total_val_met = 0 887 | message = { 888 | 'flag' : 'valid_epoch', 889 | 'epoch': epoch 890 | } 891 | manager_no_gpu.sendMessage(message) 892 | for val_batch in range(manager_no_gpu.valid_loader_len): 893 | val_sample_num = 0 894 | val_loss = 0 895 | val_met = 0 896 | for i in range(sync_worker_num): 897 | respond = manager_no_gpu.recvMessage() 898 | val_sample_num += respond['valid_samples'] 899 | val_loss += respond['valid_loss'] / sync_worker_num 900 | total_val_loss += respond['valid_loss'] * respond['valid_samples'] 901 | val_met += respond['valid_met']/sync_worker_num 902 | total_val_met += respond['valid_met'] * respond['valid_samples'] 903 | total_val_samples += val_sample_num 904 | avg_val_loss = total_val_loss / total_val_samples 905 | avg_val_met = total_val_met / total_val_samples 906 | monitor.updateValidation(val_loss, avg_val_loss, val_met, avg_val_met) 907 | manager_no_gpu.printToLog( 908 | ("v batch {:d}; " 909 | "v sample num : {}; " 910 | "v loss : {:.3f}; " 911 | "v {} : {:.3f}").format( 912 | val_batch, 913 | val_sample_num, 914 | val_loss, 915 | metrics_name, 916 | val_met 917 | ) 918 | ) 919 | monitor.updateEpoch(avg_loss, avg_val_loss, avg_met, avg_val_met) 920 | is_best_epoch = manager_no_gpu.updateHistory( 921 | epoch, lr_scheduler(epoch+1), avg_loss, avg_met, avg_val_loss, avg_val_met 922 | ) 923 | # 发送信号保存训练结果 924 | # 只有在本地无GPU manager模型中才需要让worker保存模型 925 | # 因此模型文件名只保存在manager中,需要一并发送给worker 926 | message = { 927 | 'flag' : 'save_model', 928 | 'ranks_to_save' : [1], # 指定哪几个进程来保存模型 929 | 'model_filename' : manager_no_gpu.model_filename, 930 | 'is_best' : False 931 | } 932 | if is_best_epoch: 933 | message['best_model_fname'] = manager_no_gpu.best_model_fname 934 | message['is_best'] = True 935 | manager_no_gpu.sendMessage(message) 936 | # epoch 937 | # epoch for loop end 938 | manager_no_gpu.printToLog("sending quit") 939 | message = { 940 | 'flag': 'quit' 941 | } 942 | manager_no_gpu.sendMessage(message) 943 | monitor.close() 944 | 945 | -------------------------------------------------------------------------------- /mec/training/sync_trainer.py: -------------------------------------------------------------------------------- 1 | # sync_trainer.py 2 | # created: CS 3 | # 多进程多卡同步训练模块封装 4 | 5 | 6 | 7 | import json 8 | import os 9 | import time 10 | import numpy as np 11 | import torch 12 | import torch.multiprocessing as mp 13 | from torch.utils.data import DataLoader, DistributedSampler 14 | 15 | from .basic_trainer import BasicTrainer 16 | from ..comms.sync_rpc import SyncRpcWorker, SyncRpcController 17 | from ..comms.transmit import DistEnv, TensorTransmittor 18 | from ..utils.logs import Logger 19 | from ..utils.monitor import Monitor 20 | from ..utils.history import History 21 | from ..configs.default_config import conf_g 22 | 23 | class WorkerSync(): 24 | """ 25 | 同步训练worker类 26 | 每次接受manager广播来分配batch 27 | forward和backward结束后 28 | 用dist.reduce上传grad 29 | 用dist.broadcast回传weights 30 | grad在manager上延迟更新 31 | 32 | 参数: 33 | model : 数据模型 34 | dataset_dict : 数据集表,key为name,value为单个dataset 35 | device : 使用的设备 36 | rank : 进程编号,在组内应唯一 37 | sync_worker_num : worker数量 38 | controll_ip : 仅IP地址,如"127.0.0.1" 39 | control_port : 发布使用的端口 40 | report_port : 返回分数使用的端口 41 | dist_port : torch.dist使用的端口 42 | """ 43 | def __init__(self, 44 | model, optimizer, criterion, metrics, 45 | dataset_dict, batch_transform_dict, 46 | batch_size, process_num_per_loader, 47 | rank, gpu_id, sync_worker_num, control_ip='127.0.0.1', 48 | port=12500, logger=print): 49 | self.printToLog = logger 50 | rpc_port = port 51 | dist_port = str(port+10) 52 | self.printToLog("initiating worker {} ...".format(rank)) 53 | self.dataset_dict = dataset_dict 54 | self.batch_transform_dict = batch_transform_dict 55 | self.sync_worker_num = sync_worker_num 56 | self.printToLog("worker num :", sync_worker_num) 57 | self.world_size = sync_worker_num 58 | self.printToLog('initiating data loader ...') 59 | # for dataset_name in dataset_dict: 60 | # print( 61 | # "\ndataset_name:", dataset_name, dataset_dict[dataset_name], 62 | # "\nbatch_size:", int(batch_size/self.world_size), 63 | # "\nsampler:", DistributedSampler(dataset_dict[dataset_name], sync_worker_num, rank), 64 | # "\nnum_workers:", process_num_per_loader, 65 | # ) 66 | self.dataloader_dict = { 67 | dataset_name: DataLoader( 68 | dataset_dict[dataset_name], 69 | batch_size = int(batch_size/self.world_size), 70 | sampler = DistributedSampler(dataset_dict[dataset_name], sync_worker_num, rank), 71 | num_workers = process_num_per_loader, 72 | pin_memory = True 73 | ) 74 | for dataset_name in dataset_dict 75 | } 76 | self.printToLog("initiating device") 77 | self.device = torch.device('cuda:{}'.format(gpu_id)) 78 | self.printToLog("device:", self.device) 79 | self.rank = rank 80 | self.printToLog("rank:", rank) 81 | self.printToLog("initiating model") 82 | self.model = model #.to(self.device) 83 | self.model.to(self.device) 84 | self.printToLog("initiating trainer") 85 | self.trainer = BasicTrainer(model, optimizer, criterion, metrics) 86 | self.printToLog("initiating torch.distributed environment") 87 | self.env = DistEnv(rank, self.world_size, control_ip, dist_port, self.printToLog) 88 | self.default_group = self.env.newGroup(range(self.world_size)) 89 | self.printToLog("initiating transmittor") 90 | self.transmittor = TensorTransmittor(list(range(self.world_size)) , logger=self.printToLog) 91 | self.printToLog("initialing rpc proxy") 92 | self.rpc_server = SyncRpcWorker(control_ip, rpc_port, rank, self.printToLog) 93 | self.rpc_server.registerMethod(self.averagingGrads) 94 | self.rpc_server.registerMethod(self.averagingWeights) 95 | self.rpc_server.registerMethod(self.broadCastModelWeights) 96 | self.rpc_server.registerMethod(self.gatherAveragedModelWeights) 97 | # ------------ training methods ------------ 98 | self.rpc_server.registerMethod(self.initTrainEpoch) 99 | self.rpc_server.registerMethod(self.batchTrainNoUpdate) 100 | self.rpc_server.registerMethod(self.updateWeights) 101 | # ------------ validation methods ------------ 102 | self.rpc_server.registerMethod(self.initTestEpoch) 103 | self.rpc_server.registerMethod(self.batchTest) 104 | # ------------ saving methods ------------ 105 | self.rpc_server.registerMethod(self.saveModelWeights) 106 | self.rpc_server.registerMethod(self.loadModelWeights) 107 | # init rpc proxy last, after all preparations are ready 108 | self.printToLog('workers ready') 109 | 110 | def mainLoop(self): 111 | self.rpc_server.mainLoop() 112 | 113 | # def _returnScore(self, flag, batch, sample_num, loss, met): 114 | # respond = { 115 | # 'flag' : flag, 116 | # 'batch' : batch, 117 | # 'samples': sample_num, 118 | # 'loss' : loss, 119 | # 'met' : met 120 | # } 121 | # self.printToLog(repr(respond)[1:-2]) 122 | # self.rpc_server.reportMessage(respond) 123 | 124 | # data communicating -------------------------- 125 | def averagingWeights(self, style='full'): 126 | """ 127 | 各个进程的模型参数取平均 128 | """ 129 | self.printToLog("averaging weights") 130 | self.transmittor.crossTensors(self.trainer.model, style=style) 131 | 132 | def averagingGrads(self, style='full'): 133 | """ 134 | 各个进程的模型梯度取平均 135 | """ 136 | self.printToLog("averaging grads") 137 | self.transmittor.crossGrads(self.trainer.model, style=style) 138 | 139 | def gatherAveragedModelWeights(self, rank, group=None): 140 | """ 141 | 将所有进程的权重集中求平均 142 | 结果保存至一个进程 143 | 用于初始化时同步 144 | """ 145 | self.transmittor.meanGatherTensors(self.trainer.model, rank, group) 146 | 147 | def broadCastModelWeights(self, rank, group=None): 148 | """ 149 | 从一个进程向其他进程广播模型权重 150 | """ 151 | self.transmittor.broadcastTensors(self.trainer.model, rank, group) 152 | 153 | # training methods ---------------------------- 154 | def initTrainEpoch(self, dataset_name, epoch, lr): 155 | self.printToLog("initializing training epoch {}".format(epoch)) 156 | self.printToLog("learning rate: {}".format(lr)) 157 | self.trainer.initEpoch() 158 | self.trainer.setLearningRate(lr) 159 | self.trainer.model.train() 160 | self.train_batch_index = 0 161 | self.printToLog("initializing train loader iter") 162 | self.printToLog('dataloader length: ', len(self.dataloader_dict[dataset_name]) ) 163 | self.printToLog('dataloader dict: ', self.dataloader_dict) 164 | train_loader = self.dataloader_dict[dataset_name] 165 | self.printToLog('dataloader: ', train_loader) 166 | try: 167 | self.train_iter = iter(train_loader) 168 | except Exception as e: 169 | self.printToLog(e) 170 | raise e 171 | self.printToLog('dataloader iter', self.train_iter) 172 | if dataset_name in self.batch_transform_dict: 173 | self.train_batch_transform = self.batch_transform_dict[dataset_name] 174 | self.printToLog("setting up batch transforms") 175 | else: 176 | self.train_batch_transform = {lambda x: x} 177 | self.printToLog("no batch transforms") 178 | self.printToLog("epoch {}, begin training".format(epoch)) 179 | 180 | def batchTrainNoUpdate(self): 181 | self.printToLog("train batch {}".format(self.train_batch_index) ) 182 | data, target = next(self.train_iter) 183 | batch_sample_num = len(target) 184 | self.printToLog("getting data") 185 | data, target = data.to(self.device), target.to(self.device) 186 | if self.train_batch_transform is not None: 187 | self.printToLog(self.train_batch_transform.__name__) 188 | data, target = self.train_batch_transform(data, target) 189 | self.printToLog("forwarding") 190 | self.trainer.forwardData(data) 191 | self.printToLog("backwarding") 192 | self.trainer.backwardGrad(target) 193 | self.train_batch_index += 1 194 | loss, met = self.trainer.getScores() 195 | return batch_sample_num, loss, met 196 | 197 | def updateWeights(self): 198 | self.printToLog("updating weights") 199 | self.trainer.updateWeights() 200 | 201 | # validation methods ---------------------------------- 202 | def initTestEpoch(self, dataset_name, epoch): 203 | self.printToLog("initizating validation epoch {}".format(epoch)) 204 | valid_loader = self.dataloader_dict[dataset_name] 205 | valid_loader.sampler.set_epoch(epoch) 206 | self.valid_iter = iter(valid_loader) 207 | self.trainer.model.eval() 208 | self.valid_batch_index = 0 209 | if dataset_name in self.batch_transform_dict: 210 | batch_transform = self.batch_transform_dict[dataset_name] 211 | self.printToLog("setting up batch transforms") 212 | else: 213 | batch_transform = {lambda x: x} 214 | self.printToLog("no batch transforms") 215 | self.printToLog("epoch {}, begin validation".format(epoch)) 216 | 217 | # for batch, (data, target) in enumerate(valid_loader): 218 | # data, target = data.to(self.device), target.to(self.device) 219 | # data, target = batch_transform(data, target) 220 | # self.trainer.forwardNoGrad(data) 221 | # self.trainer.calcScores(target) 222 | # self._returnScore('valid_batch', batch, len(data), *self.trainer.getScores()) 223 | 224 | def batchTest(self): 225 | self.printToLog("validation batch {}".format(self.valid_batch_index)) 226 | data, target = next(self.valid_iter) 227 | batch_sample_num = len(target) 228 | self.printToLog("getting data") 229 | data, target = data.to(self.device), target.to(self.device) 230 | if self.train_batch_transform is not None: 231 | data, target = self.train_batch_transform(data, target) 232 | self.printToLog("forwarding") 233 | self.trainer.forwardData(data) 234 | self.printToLog("backwarding") 235 | self.trainer.backwardGrad(target) 236 | self.valid_batch_index += 1 237 | loss, met = self.trainer.getScores() 238 | return batch_sample_num, loss, met 239 | 240 | # ------------ saving methods ------------ 241 | def saveModelWeights(self, filename, rank=0, rank_list=[-1]): 242 | """ 243 | 指定位置保存当前模型权重 244 | """ 245 | if self.rank == rank or self.rank in rank_list: 246 | self.trainer.saveModel(filename) 247 | 248 | def loadModelWeights(self, filename, rank=0, rank_list=[-1]): 249 | """ 250 | 从指定位置读取模型权重 251 | """ 252 | self.printToLog("rank", self.rank, ", loading model, ") 253 | self.printToLog("filename={}, rank={}".format(filename, rank) ) 254 | if self.rank==rank or self.rank in rank_list: 255 | if not os.path.exists(filename): 256 | self.printToLog('warning: no model weight file found') 257 | return 258 | try: 259 | self.trainer.loadModel(filename, map_location=self.device) 260 | except Exception as e: 261 | self.printToLog(e) 262 | raise e 263 | 264 | 265 | class ControllerSync(): 266 | """ 267 | 268 | """ 269 | def __init__(self, 270 | train_set_len, valid_set_len, batch_size, 271 | control_ip, port, sync_worker_num, 272 | logger=print 273 | ): 274 | rpc_port = port 275 | dist_port = str(port+10) 276 | # 日志 277 | self.printToLog = logger 278 | # 控制接口 279 | self.rpcWorkers = SyncRpcController(control_ip, port, sync_worker_num, self.printToLog) 280 | #self.rpcWorkers.startWorking() 281 | # 数据 282 | self.train_set_len = train_set_len 283 | self.valid_set_len = valid_set_len 284 | self.train_loader_len = int(np.ceil(train_set_len/batch_size) ) 285 | self.valid_loader_len = int(np.ceil(valid_set_len/batch_size) ) 286 | # 训练 287 | #self.lr_scheduler = lr_scheduler 288 | #self.metric_name = metric_name 289 | #self.monitor = Monitor(init_epoch, total_epochs, self.train_loader_len, self.valid_loader_len, metric_name) 290 | # 评价 291 | self.best_epoch_loss = 1e6 292 | self.best_epoch_met = 0. 293 | 294 | def __del__(self): 295 | #self.rpcWorkers.stopLooping() 296 | pass 297 | 298 | # def averagingGrads(self, style='full'): 299 | # self.rpcWorkers.averagingGrads(style='full') 300 | 301 | # def averagingWeights(self, style='full'): 302 | # self.rpcWorkers.averagingWeights(style='full') 303 | 304 | # def broadCastModelWeights(self, rank, group=None): 305 | # self.rpcWorkers.broadCastModelWeights(rank, group) 306 | 307 | # def gatherAveragedModelWeights(self, rank, group=None): 308 | # self.rpcWorkers.gatherAveragedModelWeights(rank, group=None) 309 | 310 | # def initTrainEpoch(self, dataset_name, epoch, lr): 311 | # #self.printToLog("initiating training epoch {}, lr={}".format(epoch, lr)) 312 | # self.rpcWorkers.initTrainEpoch(dataset_name, epoch, lr) 313 | 314 | # def batchTrainNoUpdate(self, batch_index): 315 | # #self.printToLog("training batch {}".format(batch_index)) 316 | # return self.rpcWorkers.batchTrainNoUpdate() 317 | 318 | # def updateWeights(self): 319 | # self.rpcWorkers.updateWeights() 320 | 321 | # def initTestEpoch(self, dataset_name, epoch): 322 | # #self.printToLog("initiating validation epoch {}".format(epoch)) 323 | # self.rpcWorkers.initTestEpoch(dataset_name, epoch) 324 | 325 | # def batchTest(self): 326 | # #self.printToLog("validation batch {}".format(batch_index)) 327 | # return self.rpcWorkers.batchTest() 328 | 329 | # def saveModelWeights(self, filename, rank=1, rank_list=[-1]): 330 | # self.rpcWorkers.saveModelWeights(filename, rank=1, rank_list=[-1]) 331 | 332 | # def loadModelWeights(self, filename, rank=1, rank_list=[-1]): 333 | # self.rpcWorkers.loadModelWeights(filename, rank=1, rank_list=[-1]) 334 | 335 | def startWorking(self): 336 | self.rpcWorkers.startWorking() 337 | 338 | def stopWorking(self): 339 | self.rpcWorkers.stopWorking() 340 | 341 | def stopLooping(self): 342 | self.rpcWorkers.stopLooping() 343 | 344 | def trainEpoch(self, epoch, lr, dataset_name='train', style='full', monitor=None): 345 | #lr = self.lr_scheduler(epoch) 346 | self.printToLog("initiating training epoch {}; lr={}".format(epoch, lr) ) 347 | self.rpcWorkers.initTrainEpoch(dataset_name, epoch, lr) 348 | total_sample_num = 0 349 | total_loss = 0 350 | total_met = 0 351 | total_avg_loss = 0 352 | total_avg_met = 0 353 | for batch_index in range(self.train_loader_len): 354 | self.printToLog("training batch {}".format(batch_index)) 355 | result_list = self.rpcWorkers.batchTrainNoUpdate() 356 | #self.averagingGrads() 357 | self.rpcWorkers.updateWeights() 358 | self.rpcWorkers.averagingWeights(style=style) 359 | batch_sample_num = 0 360 | batch_total_loss = 0 361 | batch_total_met = 0 362 | for sample_num, loss, met in result_list: 363 | self.printToLog('single_result: loss={}, score={}'.format(loss, met) ) 364 | batch_sample_num += sample_num 365 | batch_total_loss += loss * sample_num 366 | batch_total_met += met * sample_num 367 | total_sample_num += batch_sample_num 368 | total_loss += batch_total_loss 369 | total_met += batch_total_met 370 | batch_loss = batch_total_loss / batch_sample_num 371 | batch_met = batch_total_met / batch_sample_num 372 | total_avg_loss = total_loss/total_sample_num 373 | total_avg_met = total_met /total_sample_num 374 | if monitor!=None: 375 | monitor.updateTraining(batch_loss, total_avg_loss, batch_met, total_avg_met) 376 | return total_avg_loss, total_avg_met 377 | 378 | def testEpoch(self, epoch, dataset_name='valid', monitor=None): 379 | self.printToLog("initiating test epoch {}".format(epoch)) 380 | self.rpcWorkers.initTestEpoch(dataset_name, epoch) 381 | total_val_sample_num = 0 382 | total_val_loss = 0 383 | total_val_met = 0 384 | total_avg_val_loss = 0 385 | total_avg_val_met = 0 386 | for val_batch_index in range(self.valid_loader_len): 387 | self.printToLog("test batch {}".format(val_batch_index)) 388 | result_list = self.rpcWorkers.batchTest() 389 | batch_val_sample_num = 0 390 | batch_val_total_loss = 0 391 | batch_val_total_met = 0 392 | for val_sample_num, val_loss, val_met in result_list: 393 | self.printToLog('single_result: loss={}, score={}'.format(val_loss, val_met) ) 394 | batch_val_sample_num += val_sample_num 395 | batch_val_total_loss += val_loss * val_sample_num 396 | batch_val_total_met += val_met * val_sample_num 397 | total_val_sample_num += batch_val_sample_num 398 | total_val_loss += batch_val_total_loss 399 | total_val_met += batch_val_total_met 400 | batch_val_loss = batch_val_total_loss / batch_val_sample_num 401 | batch_val_met = batch_val_total_met / batch_val_sample_num 402 | total_avg_val_loss = total_val_loss/total_val_sample_num 403 | total_avg_val_met = total_val_met /total_val_sample_num 404 | if monitor!=None: 405 | monitor.updateValidation(batch_val_loss, total_avg_val_loss, batch_val_met, total_avg_val_met) 406 | if total_val_met>self.best_epoch_met or self.best_epoch_met==None: 407 | self.is_best_epoch = True 408 | self.best_epoch_loss = total_val_loss 409 | self.best_epoch_met = total_val_met 410 | return total_avg_val_loss, total_avg_val_met 411 | 412 | # def endEpoch(self): 413 | # self.monitor.updateEpoch() 414 | 415 | def saveModelWeights(self, filename, rank=0): 416 | self.printToLog('saving model') 417 | self.rpcWorkers.saveModelWeights(filename, rank) 418 | 419 | def loadModelWeights(self, filename, rank=0): 420 | self.printToLog('loading model') 421 | self.rpcWorkers.loadModelWeights(filename, rank) 422 | self.rpcWorkers.broadCastModelWeights(rank) 423 | 424 | # ========================== # ========================== 425 | def startWorkerProcess( 426 | model, optimizer, criterion, metrics, 427 | dataset_dict, batch_transform_dict, 428 | batch_size, process_num_per_loader, 429 | rank, gpu_id, 430 | sync_worker_num, 431 | control_ip, port 432 | ): 433 | logger = Logger( 434 | filepath ='logs/worker_{}.log'.format(rank), 435 | prefix ='worker_{}|gpu_{}'.format(rank, gpu_id) 436 | ) 437 | torch.cuda.set_device( gpu_id ) 438 | os.environ['CUDA_VISIBLE_DEVICE']='{}'.format(gpu_id) 439 | #time.sleep(3) 440 | worker = WorkerSync( 441 | model, optimizer, criterion, metrics, 442 | dataset_dict, batch_transform_dict, 443 | batch_size, process_num_per_loader, rank, gpu_id, 444 | sync_worker_num, control_ip, port, logger) 445 | worker.mainLoop() 446 | 447 | def startWorkers( 448 | model, optimizer, criterion, metrics, 449 | train_set, valid_set, 450 | batch_size = conf_g.batch_size, 451 | sync_worker_num = conf_g.sync_worker_num, 452 | process_num_per_loader = conf_g.process_num_per_loader, 453 | rank_list = conf_g.worker_ranks, 454 | gpu_id_list = conf_g.worker_gpu_ids, 455 | control_ip = conf_g.control_ip, 456 | port = conf_g.basic_port, 457 | train_batch_transform = None, 458 | valid_batch_transform = None 459 | ): 460 | assert len(rank_list)==len(gpu_id_list), 'rank_list has different length from gpu_id_list' 461 | assert min(rank_list)>=0, 'rank must be greater than 0' 462 | assert max(rank_list)