├── .gitignore ├── README.md ├── __init__.py ├── classes.py ├── example.py ├── pandas.jpg ├── prepare.py ├── profile.py ├── surgery.py ├── surgery2.py ├── walk.py └── zebra.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | model_weight 2 | __pycache__ 3 | parameters 4 | .vscode -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorchtool 2 | DNN_Partition辅助工具,用于对pytorch模型进行简单的性能分析 3 | 4 | #### 使用步骤 5 | 6 | 1. 对一个初始化后的模型(已加载权重文件)使用`save_model`函数分别将各层权重保存到**“./model_weight/”** 7 | 2. 对一个未初始化的模型,使用`with Profile(model) as p`,然后执行模型推理。模型分析被保存在`self.information`中,使用`p.printCsv`可以以csv文件格式输出。 8 | 9 | **添加了示例(example.py classes.py)** *已验证* 10 | 11 | ## 下载模型权重文件 12 | ```shell 13 | # AlexNet 14 | mkdir alexnet 15 | cd ./alexnet 16 | wget https://download.pytorch.org/models/alexnet-owt-7be5be79.pth 17 | ``` 18 | 19 | ```shell 20 | # inceptionV3 21 | mkdir inception_v3 22 | cd ./inception_v3 23 | wget https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth 24 | ``` 25 | 26 | ## 性能分析 27 | 28 | ### 1. 保存各层权重文件 29 | 30 | `doPrepare=True`,执行`python example.py` 31 | 32 | ### 2. 获取各层性能参数 33 | 34 | `doPrepare=False `, `doProf=True`,执行`python example.py` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from pytorchtool.profile import Profile 2 | from pytorchtool.prepare import save_model 3 | from pytorchtool.surgery import Surgery 4 | from pytorchtool.surgery2 import Surgery2 5 | from pytorchtool.walk import walk_modules 6 | 7 | name = "pytorchtool" 8 | 9 | __all__ = ["Profile", "save_model", "Surgery", "Surgery2", "walk_modules"] 10 | __version__ = "0.2.0" 11 | 12 | -------------------------------------------------------------------------------- /classes.py: -------------------------------------------------------------------------------- 1 | class_names = '''丁鲷 2 | 金鱼 3 | 大白鲨 4 | 虎鲨 5 | 锤头鲨 6 | 电鳐 7 | 黄貂鱼 8 | 公鸡 9 | 母鸡 10 | 鸵鸟 11 | 燕雀 12 | 金翅雀 13 | 家朱雀 14 | 灯芯草雀 15 | 靛蓝雀 16 | 蓝鹀 17 | 夜莺 18 | 松鸦 19 | 喜鹊 20 | 山雀 21 | 河鸟 22 | 鸢(猛禽) 23 | 秃头鹰 24 | 秃鹫 25 | 大灰猫头鹰 26 | 欧洲火蝾螈 27 | 普通蝾螈 28 | 水蜥 29 | 斑点蝾螈 30 | 蝾螈 31 | 牛蛙 32 | 树蛙 33 | 尾蛙 34 | 红海龟 35 | 皮革龟 36 | 泥龟 37 | 淡水龟 38 | 箱龟 39 | 带状壁虎 40 | 普通鬣蜥 41 | 美国变色龙 42 | 鞭尾蜥蜴 43 | 飞龙科蜥蜴 44 | 褶边蜥蜴 45 | 鳄鱼蜥蜴 46 | 毒蜥 47 | 绿蜥蜴 48 | 非洲变色龙 49 | 科莫多蜥蜴 50 | 非洲鳄 51 | 美国鳄鱼 52 | 三角龙 53 | 雷蛇 54 | 环蛇 55 | 希腊蛇 56 | 绿蛇 57 | 国王蛇 58 | 袜带蛇 59 | 水蛇 60 | 藤蛇 61 | 夜蛇 62 | 大蟒蛇 63 | 岩石蟒蛇 64 | 印度眼镜蛇 65 | 绿曼巴 66 | 海蛇 67 | 角腹蛇 68 | 菱纹响尾蛇 69 | 角响尾蛇 70 | 三叶虫 71 | 盲蜘蛛 72 | 蝎子 73 | 黑金花园蜘蛛 74 | 谷仓蜘蛛 75 | 花园蜘蛛 76 | 黑寡妇蜘蛛 77 | 狼蛛 78 | 狼蜘蛛 79 | 壁虱 80 | 蜈蚣 81 | 黑松鸡 82 | 松鸡 83 | 披肩鸡 84 | 草原鸡 85 | 孔雀 86 | 鹌鹑 87 | 鹧鸪 88 | 非洲灰鹦鹉 89 | 金刚鹦鹉 90 | 硫冠鹦鹉 91 | 短尾鹦鹉 92 | 褐翅鸦鹃 93 | 蜜蜂 94 | 犀鸟 95 | 蜂鸟 96 | 鹟? 97 | 犀鸟 98 | 野鸭 99 | 红胸秋沙鸭 100 | 鹅 101 | 黑天鹅 102 | 大象 103 | 针鼹鼠 104 | 鸭嘴兽 105 | 沙袋鼠 106 | 考拉 107 | 袋熊 108 | 水母 109 | 海葵 110 | 脑珊瑚 111 | 扁形虫扁虫 112 | 线虫 113 | 海螺 114 | 蜗牛 115 | 鼻涕虫 116 | 海参 117 | 石鳖 118 | 鹦鹉螺 119 | 珍宝蟹 120 | 石蟹 121 | 招潮蟹 122 | 帝王蟹 123 | 美国龙虾 124 | 大螯虾 125 | 小龙虾 126 | 寄居蟹 127 | 等足目动物(明虾和螃蟹近亲) 128 | 白鹳 129 | 黑鹳 130 | 鹭 131 | 火烈鸟 132 | 小蓝鹭 133 | 美国鹭 134 | 麻鸦 135 | 鹤 136 | 秧鹤 137 | 欧洲水鸡 138 | 沼泽泥母鸡 139 | 鸨 140 | 红翻石鹬 141 | 红背鹬 142 | 红脚鹬 143 | 半蹼鹬 144 | 蛎鹬 145 | 鹈鹕 146 | 国王企鹅 147 | 信天翁 148 | 灰鲸 149 | 杀人鲸 150 | 海牛 151 | 海狮 152 | 奇瓦瓦 153 | 日本猎犬 154 | 马尔济斯犬 155 | 狮子狗 156 | 西施犬 157 | 布莱尼姆猎犬 158 | 巴比狗 159 | 玩具犬 160 | 罗得西亚长背猎狗 161 | 阿富汗猎犬 162 | 猎犬 163 | 比格犬 164 | 侦探犬 165 | 蓝色快狗 166 | 黑褐猎浣熊犬 167 | 沃克猎犬 168 | 英国猎狐犬 169 | 美洲赤狗 170 | 俄罗斯猎狼犬 171 | 爱尔兰猎狼犬 172 | 意大利灰狗 173 | 惠比特犬 174 | 依比沙猎犬 175 | 挪威猎犬 176 | 奥达猎犬 177 | 沙克犬 178 | 苏格兰猎鹿犬 179 | 威玛猎犬 180 | 斯塔福德郡牛头梗 181 | 美国斯塔福德郡梗 182 | 贝德灵顿梗 183 | 边境梗 184 | 凯丽蓝梗 185 | 爱尔兰梗 186 | 诺福克梗 187 | 诺维奇梗 188 | 约克郡梗 189 | 刚毛猎狐梗 190 | 莱克兰梗 191 | 锡利哈姆梗 192 | 艾尔谷犬 193 | 凯恩梗 194 | 澳大利亚梗 195 | 丹迪丁蒙梗 196 | 波士顿梗 197 | 迷你雪纳瑞犬 198 | 巨型雪纳瑞犬 199 | 标准雪纳瑞犬 200 | 苏格兰梗 201 | 西藏梗 202 | 丝毛梗 203 | 软毛麦色梗 204 | 西高地白梗 205 | 拉萨阿普索犬 206 | 平毛寻回犬 207 | 卷毛寻回犬 208 | 金毛猎犬 209 | 拉布拉多猎犬 210 | 乞沙比克猎犬 211 | 德国短毛猎犬 212 | 维兹拉犬 213 | 英国谍犬 214 | 爱尔兰雪达犬 215 | 戈登雪达犬 216 | 布列塔尼犬猎犬 217 | 黄毛 218 | 英国史宾格犬 219 | 威尔士史宾格犬 220 | 可卡犬 221 | 萨塞克斯猎犬 222 | 爱尔兰水猎犬 223 | 哥威斯犬 224 | 舒柏奇犬 225 | 比利时牧羊犬 226 | 马里努阿犬 227 | 伯瑞犬 228 | 凯尔皮犬 229 | 匈牙利牧羊犬 230 | 老英国牧羊犬 231 | 喜乐蒂牧羊犬 232 | 牧羊犬 233 | 边境牧羊犬 234 | 法兰德斯牧牛狗 235 | 罗特韦尔犬 236 | 德国牧羊犬 237 | 多伯曼犬 238 | 迷你杜宾犬 239 | 大瑞士山地犬 240 | 伯恩山犬 241 | Appenzeller狗 242 | EntleBucher狗 243 | 拳师狗 244 | 斗牛獒 245 | 藏獒 246 | 法国斗牛犬 247 | 大丹犬 248 | 圣伯纳德狗 249 | 爱斯基摩犬 250 | 雪橇犬 251 | 哈士奇 252 | 达尔马提亚 253 | 狮毛狗 254 | 巴辛吉狗 255 | 哈巴狗 256 | 莱昂贝格狗 257 | 纽芬兰岛狗 258 | 大白熊犬 259 | 萨摩耶犬 260 | 博美犬 261 | 松狮 262 | 荷兰卷尾狮毛狗 263 | 布鲁塞尔格林芬犬 264 | 彭布洛克威尔士科基犬 265 | 威尔士柯基犬 266 | 玩具贵宾犬 267 | 迷你贵宾犬 268 | 标准贵宾犬 269 | 墨西哥无毛犬 270 | 灰狼 271 | 白狼 272 | 红太狼 273 | 狼 274 | 澳洲野狗 275 | 豺 276 | 非洲猎犬 277 | 鬣狗 278 | 红狐狸 279 | 沙狐 280 | 北极狐狸 281 | 灰狐狸 282 | 虎斑猫 283 | 山猫 284 | 波斯猫 285 | 暹罗暹罗猫 286 | 埃及猫 287 | 美洲狮 288 | 猞猁 289 | 豹子 290 | 雪豹 291 | 美洲虎 292 | 狮子 293 | 老虎 294 | 猎豹 295 | 棕熊 296 | 美洲黑熊 297 | 冰熊 298 | 懒熊 299 | 猫鼬 300 | 猫鼬 301 | 虎甲虫 302 | 瓢虫 303 | 土鳖虫 304 | 天牛 305 | 龟甲虫 306 | 粪甲虫 307 | 犀牛甲虫 308 | 象甲 309 | 苍蝇 310 | 蜜蜂 311 | 蚂蚁 312 | 蚱蜢 313 | 蟋蟀 314 | 竹节虫 315 | 蟑螂 316 | 螳螂 317 | 蝉 318 | 叶蝉 319 | 草蜻蛉 320 | 蜻蜓 321 | 豆娘 322 | 优红蛱蝶 323 | 小环蝴蝶 324 | 君主蝴蝶 325 | 菜粉蝶 326 | 白蝴蝶 327 | 灰蝶 328 | 海星 329 | 海胆 330 | 海参 331 | 野兔 332 | 兔 333 | 安哥拉兔 334 | 仓鼠 335 | 刺猬 336 | 黑松鼠 337 | 土拨鼠 338 | 海狸 339 | 豚鼠 340 | 栗色马 341 | 斑马 342 | 猪 343 | 野猪 344 | 疣猪 345 | 河马 346 | 牛 347 | 水牛 348 | 野牛 349 | 公羊 350 | 大角羊 351 | 山羊 352 | 狷羚 353 | 黑斑羚 354 | 瞪羚 355 | 阿拉伯单峰骆驼 356 | 骆驼 357 | 黄鼠狼 358 | 水貂 359 | 臭猫 360 | 黑足鼬 361 | 水獭 362 | 臭鼬 363 | 獾 364 | 犰狳 365 | 树懒 366 | 猩猩 367 | 大猩猩 368 | 黑猩猩 369 | 长臂猿 370 | 合趾猿长臂猿 371 | 长尾猴 372 | 赤猴 373 | 狒狒 374 | 恒河猴 375 | 白头叶猴 376 | 疣猴 377 | 长鼻猴 378 | 狨(美洲产小型长尾猴) 379 | 卷尾猴 380 | 吼猴 381 | 伶猴 382 | 蜘蛛猴 383 | 松鼠猴 384 | 马达加斯加环尾狐猴 385 | 大狐猴 386 | 印度大象 387 | 非洲象 388 | 小熊猫 389 | 大熊猫 390 | 杖鱼 391 | 鳗鱼 392 | 银鲑 393 | 三色刺蝶鱼 394 | 海葵鱼 395 | 鲟鱼 396 | 雀鳝 397 | 狮子鱼 398 | 河豚 399 | 算盘 400 | 长袍 401 | 学位袍 402 | 手风琴 403 | 原声吉他 404 | 航空母舰 405 | 客机 406 | 飞艇 407 | 祭坛 408 | 救护车 409 | 水陆两用车 410 | 模拟时钟 411 | 蜂房 412 | 围裙 413 | 垃圾桶 414 | 攻击步枪 415 | 背包 416 | 面包店 417 | 平衡木 418 | 热气球 419 | 圆珠笔 420 | 创可贴 421 | 班卓琴 422 | 栏杆 423 | 杠铃 424 | 理发师的椅子 425 | 理发店 426 | 牲口棚 427 | 晴雨表 428 | 圆筒 429 | 园地小车 430 | 棒球 431 | 篮球 432 | 婴儿床 433 | 巴松管 434 | 游泳帽 435 | 沐浴毛巾 436 | 浴缸 437 | 沙滩车 438 | 灯塔 439 | 高脚杯 440 | 熊皮高帽 441 | 啤酒瓶 442 | 啤酒杯 443 | 钟塔 444 | (小儿用的)围嘴 445 | 串联自行车 446 | 比基尼 447 | 装订册 448 | 双筒望远镜 449 | 鸟舍 450 | 船库 451 | 雪橇 452 | 饰扣式领带 453 | 阔边女帽 454 | 书橱 455 | 书店 456 | 瓶盖 457 | 弓箭 458 | 蝴蝶结领结 459 | 铜制牌位 460 | 奶罩 461 | 防波堤 462 | 铠甲 463 | 扫帚 464 | 桶 465 | 扣环 466 | 防弹背心 467 | 动车 468 | 肉铺 469 | 出租车 470 | 大锅 471 | 蜡烛 472 | 大炮 473 | 独木舟 474 | 开瓶器 475 | 开衫 476 | 车镜 477 | 旋转木马 478 | 木匠的工具包 479 | 纸箱 480 | 车轮 481 | 取款机 482 | 盒式录音带 483 | 卡带播放器 484 | 城堡 485 | 双体船 486 | CD播放器 487 | 大提琴 488 | 移动电话 489 | 铁链 490 | 围栏 491 | 链甲 492 | 电锯 493 | 箱子 494 | 衣柜 495 | 编钟 496 | 中国橱柜 497 | 圣诞袜 498 | 教堂 499 | 电影院 500 | 切肉刀 501 | 悬崖屋 502 | 斗篷 503 | 木屐 504 | 鸡尾酒调酒器 505 | 咖啡杯 506 | 咖啡壶 507 | 螺旋结构(楼梯) 508 | 组合锁 509 | 电脑键盘 510 | 糖果 511 | 集装箱船 512 | 敞篷车 513 | 开瓶器 514 | 短号 515 | 牛仔靴 516 | 牛仔帽 517 | 摇篮 518 | 起重机 519 | 头盔 520 | 板条箱 521 | 小儿床 522 | 砂锅 523 | 槌球 524 | 拐杖 525 | 胸甲 526 | 大坝 527 | 书桌 528 | 台式电脑 529 | 有线电话 530 | 尿布湿 531 | 数字时钟 532 | 数字手表 533 | 餐桌板 534 | 抹布 535 | 洗碗机 536 | 盘式制动器 537 | 码头 538 | 狗拉雪橇 539 | 圆顶 540 | 门垫 541 | 钻井平台 542 | 鼓 543 | 鼓槌 544 | 哑铃 545 | 荷兰烤箱 546 | 电风扇 547 | 电吉他 548 | 电力机车 549 | 电视 550 | 信封 551 | 浓缩咖啡机 552 | 扑面粉 553 | 女用长围巾 554 | 文件 555 | 消防船 556 | 消防车 557 | 火炉栏 558 | 旗杆 559 | 长笛 560 | 折叠椅 561 | 橄榄球头盔 562 | 叉车 563 | 喷泉 564 | 钢笔 565 | 有四根帷柱的床 566 | 运货车厢 567 | 圆号 568 | 煎锅 569 | 裘皮大衣 570 | 垃圾车 571 | 防毒面具 572 | 汽油泵 573 | 高脚杯 574 | 卡丁车 575 | 高尔夫球 576 | 高尔夫球车 577 | 狭长小船 578 | 锣 579 | 礼服 580 | 钢琴 581 | 温室 582 | 散热器格栅 583 | 杂货店 584 | 断头台 585 | 小发夹 586 | 头发喷雾 587 | 半履带装甲车 588 | 锤子 589 | 大篮子 590 | 手摇鼓风机 591 | 手提电脑 592 | 手帕 593 | 硬盘 594 | 口琴 595 | 竖琴 596 | 收割机 597 | 斧头 598 | 手枪皮套 599 | 家庭影院 600 | 蜂窝 601 | 钩爪 602 | 衬裙 603 | 单杠 604 | 马车 605 | 沙漏 606 | iPod 607 | 熨斗 608 | 南瓜灯笼 609 | 牛仔裤 610 | 吉普车 611 | 运动衫 612 | 拼图 613 | 人力车 614 | 操纵杆 615 | 和服 616 | 护膝 617 | 蝴蝶结 618 | 大褂 619 | 长柄勺 620 | 灯罩 621 | 笔记本电脑 622 | 割草机 623 | 镜头盖 624 | 开信刀 625 | 图书馆 626 | 救生艇 627 | 点火器 628 | 豪华轿车 629 | 远洋班轮 630 | 唇膏 631 | 平底便鞋 632 | 洗剂 633 | 扬声器 634 | 放大镜 635 | 锯木厂 636 | 磁罗盘 637 | 邮袋 638 | 信箱 639 | 女游泳衣 640 | 有肩带浴衣 641 | 窨井盖 642 | 沙球(一种打击乐器) 643 | 马林巴木琴 644 | 面膜 645 | 火柴 646 | 花柱 647 | 迷宫 648 | 量杯 649 | 药箱 650 | 巨石 651 | 麦克风 652 | 微波炉 653 | 军装 654 | 奶桶 655 | 迷你巴士 656 | 迷你裙 657 | 面包车 658 | 导弹 659 | 连指手套 660 | 搅拌钵 661 | 活动房屋(由汽车拖拉的) 662 | T型发动机小汽车 663 | 调制解调器 664 | 修道院 665 | 显示器 666 | 电瓶车 667 | 砂浆 668 | 学士 669 | 清真寺 670 | 蚊帐 671 | 摩托车 672 | 山地自行车 673 | 登山帐 674 | 鼠标 675 | 捕鼠器 676 | 搬家车 677 | 口套 678 | 钉子 679 | 颈托 680 | 项链 681 | 乳头(瓶) 682 | 笔记本 683 | 方尖碑 684 | 双簧管 685 | 陶笛 686 | 里程表 687 | 滤油器 688 | 风琴 689 | 示波器 690 | 罩裙 691 | 牛车 692 | 氧气面罩 693 | 包装 694 | 船桨 695 | 明轮 696 | 挂锁 697 | 画笔 698 | 睡衣 699 | 宫殿 700 | 排箫 701 | 纸巾 702 | 降落伞 703 | 双杠 704 | 公园长椅 705 | 停车收费表 706 | 客车 707 | 露台 708 | 付费电话 709 | 基座 710 | 铅笔盒 711 | 卷笔刀 712 | 香水(瓶) 713 | 培养皿 714 | 复印机 715 | 拨弦片 716 | 尖顶头盔 717 | 栅栏 718 | 皮卡 719 | 桥墩 720 | 存钱罐 721 | 药瓶 722 | 枕头 723 | 乒乓球 724 | 风车 725 | 海盗船 726 | 水罐 727 | 木工刨 728 | 天文馆 729 | 塑料袋 730 | 板架 731 | 犁型铲雪机 732 | 手压皮碗泵 733 | 宝丽来相机 734 | 电线杆 735 | 警车 736 | 雨披 737 | 台球桌 738 | 充气饮料瓶 739 | 花盆 740 | 陶工旋盘 741 | 电钻 742 | 祈祷垫 743 | 打印机 744 | 监狱 745 | 炮弹 746 | 投影仪 747 | 冰球 748 | 沙包 749 | 钱包 750 | 羽管笔 751 | 被子 752 | 赛车 753 | 球拍 754 | 散热器 755 | 收音机 756 | 射电望远镜 757 | 雨桶 758 | 休闲车 759 | 卷轴 760 | 反射式照相机 761 | 冰箱 762 | 遥控器 763 | 餐厅 764 | 左轮手枪 765 | 步枪 766 | 摇椅 767 | 电转烤肉架 768 | 橡皮 769 | 橄榄球 770 | 直尺 771 | 跑步鞋 772 | 保险柜 773 | 安全别针 774 | 盐瓶(调味用) 775 | 凉鞋 776 | 纱笼 777 | 萨克斯管 778 | 剑鞘 779 | 秤 780 | 校车 781 | 帆船 782 | 记分牌 783 | 屏幕 784 | 螺丝 785 | 螺丝刀 786 | 安全带 787 | 缝纫机 788 | 盾牌 789 | 皮鞋店 790 | 障子 791 | 购物篮 792 | 购物车 793 | 铁锹 794 | 浴帽 795 | 浴帘 796 | 滑雪板 797 | 滑雪面罩 798 | 睡袋 799 | 滑尺 800 | 滑动门 801 | 角子老虎机 802 | 潜水通气管 803 | 雪橇 804 | 扫雪机 805 | 皂液器 806 | 足球 807 | 袜子 808 | 碟式太阳能 809 | 宽边帽 810 | 汤碗 811 | 空格键 812 | 空间加热器 813 | 航天飞机 814 | 铲(搅拌或涂敷用的) 815 | 快艇 816 | 蜘蛛网 817 | 纺锤 818 | 跑车 819 | 聚光灯 820 | 舞台 821 | 蒸汽机车 822 | 钢拱桥 823 | 钢滚筒 824 | 听诊器 825 | 女用披肩 826 | 石头墙 827 | 秒表 828 | 火炉 829 | 过滤器 830 | 有轨电车 831 | 担架 832 | 沙发床 833 | 佛塔 834 | 潜艇 835 | 套装 836 | 日晷 837 | 太阳镜 838 | 太阳镜 839 | 防晒霜 840 | 悬索桥 841 | 拖把 842 | 运动衫 843 | 游泳裤 844 | 秋千 845 | 开关 846 | 注射器 847 | 台灯 848 | 坦克 849 | 磁带播放器 850 | 茶壶 851 | 泰迪 852 | 电视 853 | 网球 854 | 茅草 855 | 幕布 856 | 顶针 857 | 脱粒机 858 | 宝座 859 | 瓦屋顶 860 | 烤面包机 861 | 烟草店 862 | 马桶 863 | 火炬 864 | 图腾柱 865 | 拖车 866 | 玩具店 867 | 拖拉机 868 | 拖车 869 | 托盘 870 | 风衣 871 | 三轮车 872 | 三体船 873 | 三脚架 874 | 凯旋门 875 | 无轨电车 876 | 长号 877 | 浴盆 878 | 旋转式栅门 879 | 打字机键盘 880 | 伞 881 | 独轮车 882 | 直立式钢琴 883 | 真空吸尘器 884 | 花瓶 885 | 拱顶 886 | 天鹅绒 887 | 自动售货机 888 | 祭服 889 | 高架桥 890 | 小提琴 891 | 排球 892 | 松饼机 893 | 挂钟 894 | 钱包 895 | 衣柜 896 | 军用飞机 897 | 洗脸盆 898 | 洗衣机 899 | 水瓶 900 | 水壶 901 | 水塔 902 | 威士忌壶 903 | 哨子 904 | 假发 905 | 纱窗 906 | 百叶窗 907 | 温莎领带 908 | 葡萄酒瓶 909 | 飞机翅膀 910 | 炒菜锅 911 | 木制的勺子 912 | 毛织品 913 | 栅栏 914 | 沉船 915 | 双桅船 916 | 蒙古包 917 | 网站 918 | 漫画 919 | 纵横字谜 920 | 路标 921 | 交通信号灯 922 | 防尘罩 923 | 菜单 924 | 盘子 925 | 鳄梨酱 926 | 清汤 927 | 罐焖土豆烧肉 928 | 蛋糕 929 | 冰淇淋 930 | 雪糕 931 | 法式面包 932 | 百吉饼 933 | 椒盐脆饼 934 | 芝士汉堡 935 | 热狗 936 | 土豆泥 937 | 结球甘蓝 938 | 西兰花 939 | 菜花 940 | 绿皮密生西葫芦 941 | 西葫芦 942 | 小青南瓜 943 | 南瓜 944 | 黄瓜 945 | 朝鲜蓟 946 | 甜椒 947 | 刺棘蓟 948 | 蘑菇 949 | 绿苹果 950 | 草莓 951 | 橘子 952 | 柠檬 953 | 无花果 954 | 菠萝 955 | 香蕉 956 | 菠萝蜜 957 | 蛋奶冻苹果 958 | 石榴 959 | 干草 960 | 烤面条加干酪沙司 961 | 巧克力酱 962 | 面团 963 | 瑞士肉包 964 | 披萨 965 | 馅饼 966 | 卷饼 967 | 红葡萄酒 968 | 意大利浓咖啡 969 | 杯子 970 | 蛋酒 971 | 高山 972 | 泡泡 973 | 悬崖 974 | 珊瑚礁 975 | 间歇泉 976 | 湖边 977 | 海角 978 | 沙洲 979 | 海滨 980 | 峡谷 981 | 火山 982 | 棒球 983 | 新郎 984 | 潜水员 985 | 油菜 986 | 雏菊 987 | 杓兰 988 | 玉米 989 | 橡子 990 | 玫瑰果 991 | 七叶树果实 992 | 珊瑚菌 993 | 木耳 994 | 鹿花菌 995 | 鬼笔菌 996 | 地星 997 | 多叶奇果菌 998 | 牛肝菌 999 | 玉米穗 1000 | 卫生纸'''.split("\n") -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('..') 4 | import time 5 | import torch 6 | import pytorchtool 7 | import numpy as np 8 | 9 | from classes import class_names 10 | 11 | from PIL import Image 12 | from torchvision import models, transforms 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 15 | 16 | def process_img(path_img): 17 | # hard code 18 | norm_mean = [0.485, 0.456, 0.406] 19 | norm_std = [0.229, 0.224, 0.225] 20 | inference_transform = transforms.Compose([ 21 | transforms.Resize(256), 22 | transforms.CenterCrop((224, 224)), 23 | transforms.ToTensor(), 24 | transforms.Normalize(norm_mean, norm_std), 25 | ]) 26 | 27 | # tensor 28 | img_tensor = inference_transform(Image.open(path_img).convert('RGB')) 29 | img_tensor.unsqueeze_(0) # chw --> bchw 30 | 31 | return img_tensor 32 | 33 | class model: 34 | def __init__(self, model_name, use_gpu=False): 35 | self.model_name = model_name 36 | self.x = process_img('./pandas.jpg') 37 | self.use_gpu = use_gpu 38 | 39 | if self.model_name in 'inception': 40 | self.model_name = 'inception' 41 | self.path = "./model_weight/inception_v3/inception_v3_google-1a9a5a14.pth" 42 | 43 | model = models.Inception3(aux_logits=False, transform_input=False, 44 | init_weights=False) 45 | model.eval() 46 | self.model = model 47 | self.depth = 2 48 | elif self.model_name in 'alexnet': 49 | self.model_name = 'alexnet' 50 | self.path = "./model_weight/alexnet/alexnet-owt-4df8aa71.pth" 51 | 52 | model = models.alexnet(False) 53 | model.eval() 54 | self.model = model 55 | self.depth = -1 56 | elif self.model_name in 'resnet': 57 | self.model_name = 'resnet' 58 | self.path = './model_weight/resnet/resnet18-f37072fd.pth' 59 | model = models.resnet18(False) 60 | model.eval() 61 | self.model = model 62 | self.depth = 2 63 | else: 64 | print("Wrong model name") 65 | 66 | if self.use_gpu: 67 | self.model = self.model.to(0) 68 | # self.x = self.x.cuda() 69 | self.x = self.x.to(0) 70 | 71 | def load_weight(self): 72 | state_dict_read = torch.load(self.path) 73 | 74 | self.model.load_state_dict(state_dict_read, strict=False) 75 | 76 | def get_model(self): 77 | return self.model 78 | 79 | def get_input(self): 80 | return self.x 81 | 82 | def save_layers(self, depth=-1): 83 | pytorchtool.save_model(self.model, depth=depth) 84 | 85 | def inference(self): 86 | with torch.no_grad(): 87 | outputs = self.model(self.x) 88 | print("result: " + class_names[torch.argmax(outputs, 1)[0]]) 89 | 90 | def prof(self, depth=-1): 91 | with pytorchtool.Profile(self.model, use_cuda=self.use_gpu, 92 | depth=depth) as prof: 93 | outputs = self.model(self.x) 94 | print("result: " + class_names[torch.argmax(outputs, 1)[0]]) 95 | 96 | if not os.path.exists("./parameters/" + self.model_name): 97 | os.makedirs("./parameters/" + self.model_name) 98 | if self.use_gpu: 99 | prof.printCsv("./parameters/" + self.model_name + "/gpuPart.csv") 100 | else: 101 | prof.printCsv("./parameters/" + self.model_name + "/cpuPart.csv") 102 | 103 | 104 | if __name__ == "__main__": 105 | torch.randn(4).to(0) 106 | 107 | name = "res" 108 | start_init = time.time() 109 | m = model(name, use_gpu=True) 110 | print("模型结构初始化时间: ", time.time() - start_init) 111 | start_load = time.time() 112 | m.load_weight() 113 | print("模型参数加载时间: ", time.time() - start_load) 114 | 115 | m.inference() 116 | 117 | doPrepare = False 118 | doProf = True 119 | doInference = False 120 | doPartition = False 121 | doPartition2 = False 122 | 123 | if doPrepare: 124 | m.save_layers(depth=m.depth) 125 | elif doProf: 126 | m.prof(depth=m.depth) 127 | m.prof(depth=m.depth) 128 | elif doInference: 129 | start = time.time() 130 | m.inference() 131 | print("推理时间", time.time() - start) 132 | elif doPartition: 133 | ''' 134 | 使用Alexnet进行了切分测试 135 | ''' 136 | cModel = pytorchtool.Surgery(m.model, 0, depth=m.depth) 137 | cModel.setLayerState({"input": 1, "features.0": 2, "features.1": 2, "features.2": 2, "features.3": 2, 138 | "features.4": 2, "features.5": 2, "features.6": 2, "features.7": 2, 139 | "features.8": 2, "features.9": 2, "features.10": 2, "features.11": 2, 140 | "features.12": 2, "avgpool": 2, "classifier.0": 2, "classifier.1": 2, 141 | "classifier.2": 2, "classifier.3": 2, "classifier.4": 2, "classifier.5": 2, 142 | "classifier.6": 2, 'flatten': 2}) 143 | cModel.clearMiddleResult() 144 | cModel(m.x) 145 | cModel.recover() # 恢复m的forward函数,避免sModel对同一个模型嵌套修改 146 | print(cModel.getMiddleResult()) 147 | 148 | sModel = pytorchtool.Surgery(m.model, 2, depth=m.depth) 149 | sModel.setLayerState({"input": 1, "features.0": 2, "features.1": 2, "features.2": 2, "features.3": 2, 150 | "features.4": 2, "features.5": 2, "features.6": 2, "features.7": 2, 151 | "features.8": 2, "features.9": 2, "features.10": 2, "features.11": 2, 152 | "features.12": 2, "avgpool": 2, "classifier.0": 2, "classifier.1": 2, 153 | "classifier.2": 2, "classifier.3": 2, "classifier.4": 2, "classifier.5": 2, 154 | "classifier.6": 2, 'flatten': 2}) 155 | ''' 156 | 这里使用随机生成的相同size的数据代替原始输入数据, 157 | 实际使用时若将计算全部卸载到了服务端,则需要传入原始数据 158 | ''' 159 | sModel.setMiddleResult(cModel.getMiddleResult()) 160 | outputs = sModel(torch.rand(224, 224).unsqueeze_(0)) 161 | print("result: " + class_names[torch.argmax(outputs, 1)[0]]) 162 | elif doPartition2: 163 | ''' 164 | 使用Alexnet进行了切分测试 165 | ''' 166 | cModel = pytorchtool.Surgery(m.model, 0, depth=m.depth) 167 | cModel.setLayerState({"input": 1, "features.0": 2, "features.1": 2, "features.2": 2, "features.3": 2, 168 | "features.4": 2, "features.5": 2, "features.6": 2, "features.7": 2, 169 | "features.8": 2, "features.9": 2, "features.10": 2, "features.11": 2, 170 | "features.12": 2, "avgpool": 2, "classifier.0": 2, "classifier.1": 2, 171 | "classifier.2": 2, "classifier.3": 2, "classifier.4": 2, "classifier.5": 2, 172 | "classifier.6": 2, 'flatten': 2}) 173 | cModel.clearMiddleResult() 174 | cModel(m.x) 175 | cModel.recover() # 恢复m的forward函数,避免sModel对同一个模型嵌套修改 176 | print(cModel.getMiddleResult()) 177 | 178 | sModel = pytorchtool.Surgery2('alex', './parameters/alexnet/dag') 179 | for k, v in {"input": 1, "features.0": 2, "features.1": 2, "features.2": 2, "features.3": 2, 180 | "features.4": 2, "features.5": 2, "features.6": 2, "features.7": 2, 181 | "features.8": 2, "features.9": 2, "features.10": 2, "features.11": 2, 182 | "features.12": 2, "avgpool": 2, "classifier.0": 2, "classifier.1": 2, 183 | "classifier.2": 2, "classifier.3": 2, "classifier.4": 2, "classifier.5": 2, 184 | "classifier.6": 2, 'flatten': 2}.items(): 185 | if v == 2: 186 | sModel.loadLayer(k) 187 | 188 | start = time.time() 189 | outputs = sModel.inferencePart(cModel.getMiddleResult()) 190 | print("服务端时间", time.time() - start) 191 | print("result: " + class_names[torch.argmax(outputs, 1)[0]]) 192 | -------------------------------------------------------------------------------- /pandas.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FENP/pytorchtool/73c293fd592a810c5a4c19e8110c2b4b46d9195e/pandas.jpg -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import torch 4 | 5 | path = "./model_weight" 6 | dir_name = "default" 7 | 8 | def save_weight_by_layer(module, name="", depth=-1): 9 | """根据depth遍历pytorch模块,保存模块的权重文件""" 10 | 11 | child_list = list(module.named_children()) 12 | ''' 13 | 遍历到叶子结点或depth指定的深度时保存权重; 14 | 否则继续向下遍历 15 | ''' 16 | if depth == 0 or len(child_list) == 0: 17 | torch.save(module.state_dict(), os.path.join(path, dir_name, name + ".pth")) 18 | torch.save(module, os.path.join(path, dir_name, name + ".pkl")) 19 | else: 20 | for child in child_list: 21 | save_weight_by_layer(child[1], child[0] if name=="" else name + "." + child[0], depth - 1) 22 | 23 | def save_model(model, depth=-1): 24 | '''分别保存模型各层权重到'$path/模型名/'目录下 25 | 参数: 26 | model: pytorch模型(已加载权重) 27 | depth: 模型层嵌套深度(-1表示全部展开) 28 | ''' 29 | global dir_name 30 | dir_name = model.__class__.__name__ 31 | 32 | if not os.path.exists(os.path.join(path, dir_name)): 33 | os.makedirs(os.path.join(path, dir_name)) 34 | 35 | save_weight_by_layer(model, name="", depth=depth) -------------------------------------------------------------------------------- /profile.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import sys 3 | import time 4 | import torch 5 | import pandas as pd 6 | import functools 7 | from collections import defaultdict 8 | 9 | from .walk import walk_modules 10 | 11 | class Profile(object): 12 | """PyTorch模型的逐层分析器,可以获取模型各层初始化、执行时间和输出数据大小""" 13 | 14 | def __init__(self, model, enabled=True, use_cuda=False, depth=-1): 15 | """ 16 | 参数: 17 | model:pytorch模型 18 | enabled:是否(True/False)启用分析 19 | use_cuda:是否(True/False)使用GPU 20 | depth:模型层嵌套深度 21 | """ 22 | self._model = model 23 | self.enabled = enabled 24 | self.use_cuda = use_cuda 25 | self.depth = depth 26 | 27 | 28 | self.entered = False 29 | self.exited = False 30 | self.traces = () 31 | self.information = defaultdict(list) 32 | 33 | # 输入层信息(默认输入数据size为3*224*224,float32) 34 | self.information["input"].extend([0, 0, 0.15, 0.01]) 35 | 36 | def __enter__(self): 37 | if not self.enabled: 38 | return self 39 | if self.entered: 40 | raise RuntimeError("pytorchtool profiler is not reentrant") 41 | self.entered = True 42 | self._forwards = {} # 存储初始forwards 43 | 44 | # 逐层初始化分析 45 | self.traces = tuple(map(self._load_weight, walk_modules(self._model, depth=self.depth))) 46 | # 逐层修改forwards 47 | tuple(map(self._hook_trace, self.traces)) 48 | return self 49 | 50 | def __exit__(self, exc_type, exc_val, exc_tb): 51 | if not self.enabled: 52 | return 53 | # 逐层恢复初始forwards 54 | tuple(map(self._remove_hook_trace, self.traces)) 55 | del self._forwards # remove unnecessary forwards 56 | self.exited = True 57 | 58 | def __str__(self): 59 | return str(pd.DataFrame.from_dict(self.information, orient='index', 60 | columns=['Parameters Loading Time(ms)', 'Model Loading Time(ms)', 'Data Size(MB)','Execute Time(ms)'])) 61 | 62 | def __call__(self, *args, **kwargs): 63 | return self._model(*args, **kwargs) 64 | 65 | def _load_weight(self, trace): 66 | (name, module) = trace 67 | 68 | start = time.time() 69 | module.load_state_dict(torch.load("./model_weight/" + 70 | self._model.__class__.__name__ + "/" + name + ".pth"), strict=False) 71 | loadingTime = (time.time() - start) * 1000 72 | self.information[name].append(loadingTime) 73 | 74 | return trace 75 | 76 | def _hook_trace(self, trace): 77 | (name, module) = trace 78 | ''' 79 | start = time.time() 80 | m = torch.load("./model_weight/" + 81 | self._model.__class__.__name__ + "/" + name + ".pkl") 82 | loadingTime = (time.time() - start) * 1000 83 | self.information[name].append(loadingTime) 84 | ''' 85 | _forward = module.forward 86 | self._forwards[name] = _forward 87 | 88 | @functools.wraps(_forward) 89 | def wrap_forward(*args, **kwargs): 90 | print(name) 91 | # 执行时间 92 | if self.use_cuda: 93 | start = torch.cuda.Event(enable_timing=True) 94 | end = torch.cuda.Event(enable_timing=True) 95 | 96 | start.record() 97 | output = _forward(*args, **kwargs) 98 | end.record() 99 | 100 | # 等待执行完成 101 | torch.cuda.synchronize() 102 | 103 | exec_time = start.elapsed_time(end) 104 | else: 105 | start = time.time() 106 | output = _forward(*args, **kwargs) 107 | # 转换为ms 108 | exec_time = (time.time() - start) * 1000 109 | 110 | # 输出数据大小(MB) 111 | data_size = sys.getsizeof(output.storage()) / 1024 / 1024 112 | 113 | self.information[name].append(data_size) 114 | self.information[name].append(exec_time) 115 | return output 116 | 117 | module.forward = wrap_forward 118 | return trace 119 | 120 | def _remove_hook_trace(self, trace): 121 | [name, module] = trace 122 | module.forward = self._forwards[name] 123 | 124 | def printCsv(self, filePath='./parameters/default.csv'): 125 | """将模型分析结果写入csv文件 126 | 参数: 127 | filePath:csv文件路径及文件名 128 | """ 129 | df = pd.DataFrame.from_dict(self.information, orient='index', 130 | columns=['Parameters Loading Time(ms)', 'Model Loading Time(ms)', 'Data Size(MB)','Execute Time(ms)']) 131 | df.to_csv(filePath) -------------------------------------------------------------------------------- /surgery.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import sys 3 | import torch 4 | import logging 5 | import functools 6 | from collections import defaultdict 7 | 8 | from .walk import walk_modules 9 | 10 | class Surgery(object): 11 | """对PyTorch模型进行处理,方便进行模型切分""" 12 | 13 | def __init__(self, model, mode, is_debug=False, depth=-1): 14 | """ 15 | 参数: 16 | model:初始化后的DNN模型 17 | mode:切分模式 0:客户端模式(执行中间层并存储输出)、2:客户端模式(根据中间层输出继续执行,得到最终结果) 18 | """ 19 | self._model = model 20 | self._mode = mode 21 | self._is_debug = is_debug 22 | self._depth = depth 23 | 24 | self._layerState = None 25 | 26 | # 存储初始forwards 27 | self._forwards = {} 28 | 29 | # 存储中间层输出 30 | self._middleResult = {} 31 | 32 | # 逐层修改forwards 33 | self.traces = tuple(map(self._hook_trace, walk_modules(self._model, depth=self._depth))) 34 | 35 | def recover(self): 36 | # 逐层恢复初始forwards 37 | tuple(map(self._remove_hook_trace, self.traces)) 38 | del self._forwards # remove unnecessary forwards 39 | 40 | def getMiddleResult(self): 41 | return self._middleResult 42 | 43 | def setMiddleResult(self, middleResult): 44 | if self._mode == 2: 45 | self._middleResult = middleResult 46 | 47 | def clearMiddleResult(self): 48 | if self._mode == 0: 49 | self._middleResult.clear() 50 | 51 | def setLayerState(self, layerState): 52 | self._layerState = layerState 53 | 54 | def __call__(self, *args, **kwargs): 55 | # 针对客户端或服务端完成全部计算的清空做特殊处理 56 | if self._layerState['input'] == 1: 57 | if self._mode == 0: 58 | if self._is_debug: 59 | logging.info("客户端传输原始输入") 60 | self._middleResult['input'] = args[0] 61 | return torch.rand(1,1000) 62 | elif self._mode == 2: 63 | if self._is_debug: 64 | logging.info("服务端接收原始输入") 65 | return self._model((self._middleResult['input'])) 66 | 67 | return self._model(*args, **kwargs) 68 | 69 | def _hook_trace(self, trace): 70 | (name, module) = trace 71 | _forward = module.forward 72 | self._forwards[name] = _forward 73 | 74 | @functools.wraps(_forward) 75 | def wrap_forward(*args, **kwargs): 76 | if self._layerState[name] != self._mode: 77 | # 非中间输出层直接返回原始数据 78 | if self._layerState[name] != 1: 79 | if self._is_debug: 80 | logging.debug("skip %s", name) 81 | return args[0] 82 | # 服务端模式获取层输出并返回 83 | elif self._mode == 2: 84 | if self._is_debug: 85 | logging.debug("middle %s", name) 86 | return self._middleResult[name] 87 | if self._is_debug: 88 | logging.debug("execute %s", name) 89 | output = _forward(*args, **kwargs) 90 | 91 | # 客户端模型下需要存储中间层输出 92 | if self._mode == 0 and self._layerState[name] == 1: 93 | if self._is_debug: 94 | logging.debug("save %s", name) 95 | self._middleResult[name] = output 96 | return output 97 | 98 | module.forward = wrap_forward 99 | return trace 100 | 101 | def _remove_hook_trace(self, trace): 102 | (name, module) = trace 103 | module.forward = self._forwards[name] -------------------------------------------------------------------------------- /surgery2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import logging 4 | from collections import defaultdict 5 | 6 | class Surgery2(object): 7 | """分层加载模型的处理模式""" 8 | def __init__(self, name, dag_path, is_debug = False): 9 | if name in 'inception': 10 | self.model_name = 'Inception3' 11 | 12 | elif name in 'alexnet': 13 | self.model_name = 'AlexNet' 14 | 15 | elif name in 'resnet': 16 | self.model_name = 'ResNet' 17 | 18 | else: 19 | raise RuntimeError("Wrong model name") 20 | 21 | self._is_debug = is_debug 22 | 23 | self._edge = defaultdict(list) 24 | self._output = defaultdict(None) 25 | self._layerModule = defaultdict(None) 26 | # 读取dag文件 27 | for line in open(dag_path, 'r'): 28 | line = line.strip('\n') 29 | nameList = line.split(' ') 30 | name = nameList[0] 31 | self._output[name] = None 32 | self._layerModule[name] = None 33 | 34 | if len(nameList) == 1: 35 | self._endlayerName = name 36 | for nextLayerName in nameList[1:]: 37 | self._edge[nextLayerName].append(name) 38 | 39 | def loadLayer(self, layerName): 40 | if(self._layerModule[layerName] is None): 41 | if self._is_debug: 42 | logging.debug("初始化层: %s", layerName) 43 | # 加载该层 44 | self._layerModule[layerName] = torch.load("../pytorchtool/model_weight/" + 45 | self.model_name + "/" + layerName + ".pkl") 46 | 47 | def inferencePart(self, middleResult): 48 | # 清空输出字典 49 | for k in self._output.keys(): 50 | self._output[k] = None 51 | # 中间输出赋值 52 | for k, v in middleResult.items(): 53 | self._output[k] = v 54 | # 获取最终结果 55 | return self._inferenceLayer(self._endlayerName) 56 | 57 | def _inferenceLayer(self, layerName): 58 | if self._output[layerName] is not None: 59 | return self._output[layerName] 60 | 61 | inputList = [] 62 | for lastLayerName in self._edge[layerName]: 63 | inputList.append(self._inferenceLayer(lastLayerName)) 64 | 65 | if len(inputList) == 1: 66 | layerInput = inputList[0] 67 | else: 68 | layerInput = inputList 69 | if self._is_debug: 70 | logging.debug("execute %s", layerName) 71 | self._output[layerName] = self._layerModule[layerName](layerInput) 72 | return self._output[layerName] 73 | -------------------------------------------------------------------------------- /walk.py: -------------------------------------------------------------------------------- 1 | def walk_modules(module, name="", depth=-1): 2 | """生成器。根据depth遍历pytorch模块,生成Trace元组""" 3 | 4 | child_list = list(module.named_children()) 5 | ''' 6 | 遍历到叶子结点或depth指定的深度时返回当前模块元组; 7 | 否则继续向下遍历 8 | ''' 9 | if depth == 0 or len(child_list) == 0: 10 | yield (name, module) 11 | else: 12 | for child in child_list: 13 | yield from walk_modules(child[1], child[0] if name=="" else name + "." + child[0], depth - 1) -------------------------------------------------------------------------------- /zebra.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FENP/pytorchtool/73c293fd592a810c5a4c19e8110c2b4b46d9195e/zebra.jpg --------------------------------------------------------------------------------