├── u_2_net ├── __init__.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── u2net.cpython-37.pyc │ │ └── __init__.cpython-37.pyc │ └── u2net.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── data_loader.cpython-37.pyc │ └── my_u2net_test.cpython-37.pyc ├── my_u2net_test.py └── data_loader.py ├── helloword.py ├── bj.png ├── .gitignore ├── img ├── 2in.jpg ├── meinv.jpg ├── meinv_id.png ├── meinv_alpha.png ├── meinv_id_grid.png ├── meinv_alpha_resize.png ├── meinv_id_landmarks.png └── meinv_trimap_resize.png ├── m_web ├── bj.png ├── cs.jpg ├── web.py └── upload.py ├── to_background ├── __pycache__ │ ├── to_background.cpython-37.pyc │ └── to_standard_trimap.cpython-37.pyc ├── to_standard_trimap.py └── to_background.py ├── m_dlib ├── ai_crop.py └── face_marks.py ├── main.py ├── README.md.orig └── README.md /u_2_net/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helloword.py: -------------------------------------------------------------------------------- 1 | 2 | print("hello word !") 3 | -------------------------------------------------------------------------------- /bj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/bj.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### IntelliJ IDEA ### 2 | .idea 3 | *.iws 4 | *.ip 5 | -------------------------------------------------------------------------------- /img/2in.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/img/2in.jpg -------------------------------------------------------------------------------- /img/meinv.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/img/meinv.jpg -------------------------------------------------------------------------------- /m_web/bj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/m_web/bj.png -------------------------------------------------------------------------------- /m_web/cs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/m_web/cs.jpg -------------------------------------------------------------------------------- /img/meinv_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/img/meinv_id.png -------------------------------------------------------------------------------- /img/meinv_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/img/meinv_alpha.png -------------------------------------------------------------------------------- /u_2_net/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .u2net import U2NET 2 | from .u2net import U2NETP 3 | -------------------------------------------------------------------------------- /img/meinv_id_grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/img/meinv_id_grid.png -------------------------------------------------------------------------------- /img/meinv_alpha_resize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/img/meinv_alpha_resize.png -------------------------------------------------------------------------------- /img/meinv_id_landmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/img/meinv_id_landmarks.png -------------------------------------------------------------------------------- /img/meinv_trimap_resize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/img/meinv_trimap_resize.png -------------------------------------------------------------------------------- /u_2_net/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/u_2_net/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /u_2_net/__pycache__/data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/u_2_net/__pycache__/data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /u_2_net/model/__pycache__/u2net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/u_2_net/model/__pycache__/u2net.cpython-37.pyc -------------------------------------------------------------------------------- /u_2_net/__pycache__/my_u2net_test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/u_2_net/__pycache__/my_u2net_test.cpython-37.pyc -------------------------------------------------------------------------------- /u_2_net/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/u_2_net/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /to_background/__pycache__/to_background.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/to_background/__pycache__/to_background.cpython-37.pyc -------------------------------------------------------------------------------- /to_background/__pycache__/to_standard_trimap.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itainf/aiphoto/HEAD/to_background/__pycache__/to_standard_trimap.cpython-37.pyc -------------------------------------------------------------------------------- /m_dlib/ai_crop.py: -------------------------------------------------------------------------------- 1 | import m_dlib.face_marks as fmarks 2 | from PIL import Image 3 | 4 | 5 | def crop_photo(path, target): 6 | path = path 7 | shape, d = fmarks.predictor_face(path) 8 | 9 | WIDTH_2IN = 413/2 10 | HEIGHT_2IN = 626/2 11 | 12 | # 人像中心点 13 | X_CENTRE = d.left()+(d.right()-d.left()) / 2 14 | Y_CENTER = d.top()+(d.bottom()-d.top()) / 2 15 | 16 | im = Image.open(path) 17 | im = im.crop((X_CENTRE-WIDTH_2IN, Y_CENTER-HEIGHT_2IN, X_CENTRE+WIDTH_2IN, Y_CENTER+HEIGHT_2IN)) 18 | im.save(target) 19 | 20 | 21 | # 通过识别人脸关键点,裁剪图像 22 | # crop_photo("..//img//meinv_id.png","..//img//2in.jpg") 23 | -------------------------------------------------------------------------------- /m_web/web.py: -------------------------------------------------------------------------------- 1 | import tornado.ioloop 2 | import tornado.web 3 | import m_web.upload as upload 4 | import os 5 | 6 | 7 | class MainHandler(tornado.web.RequestHandler): 8 | def get(self): 9 | self.write("Hello, world") 10 | 11 | 12 | def make_app(): 13 | return tornado.web.Application([ 14 | (r"/", MainHandler), 15 | (r"/eam/fileLocal/upload", upload.UploadHandler), 16 | (r"/eam/fileLocal/static", tornado.web.StaticFileHandler, {"path": "/static"}) 17 | ], 18 | static_path=os.path.dirname(os.path.dirname(__file__))+"/static" 19 | ) 20 | 21 | 22 | if __name__ == "__main__": 23 | current_path = os.path.dirname(__file__) 24 | app = make_app() 25 | app.listen(8013) 26 | tornado.ioloop.IOLoop.current().start() 27 | -------------------------------------------------------------------------------- /to_background/to_standard_trimap.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | 4 | def to_standard_trimap(alpha, trimap): 5 | # Alpha图生成 trimap 6 | print(alpha) 7 | image = Image.open(alpha) 8 | print(image) 9 | # image = image.convert("P") 10 | # image_file.save('meinv_resize_trimap.png') 11 | sp = image.size 12 | width = sp[0] 13 | height = sp[1] 14 | 15 | for yh in range(height): 16 | for xw in range(width): 17 | dot = (xw, yh) 18 | color_d_arr = image.getpixel(dot) 19 | color_d=color_d_arr[0] 20 | 21 | if 0 < color_d <= 60: 22 | image.putpixel(dot, (0,0,0)) 23 | if 60 < color_d <= 200: 24 | image.putpixel(dot, (128,128,128)) 25 | if 200 < color_d <= 255: 26 | image.putpixel(dot, (255,255,255)) 27 | 28 | image.save(trimap) 29 | 30 | 31 | 32 | 33 | 34 | # to_standard_trimap("..\\img\\trimap\\meinv_resize_trimap.png", "meinv_resize_bz_trimap.png") 35 | # 36 | # 37 | # image = Image.open("meinv_resize_bz_trimap.png") 38 | # data = image.getdata() 39 | # np.savetxt("data4.txt", data,fmt='%d',delimiter=',') 40 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | from u_2_net import my_u2net_test 3 | from to_background import to_background 4 | from to_background import to_standard_trimap 5 | from m_dlib import ai_crop 6 | 7 | import numpy as np 8 | from PIL import Image 9 | if __name__ == "__main__": 10 | org_img = "..\\aiphoto\\img\\meinv.jpg" 11 | alpha_img = "..\\aiphoto\\img\\meinv_alpha.png" 12 | alpha_resize_img = "..\\aiphoto\\img\\meinv_alpha_resize.png" 13 | # # 14 | # 通过u_2_net 获取 alpha 15 | my_u2net_test.test_seg_trimap(org_img, alpha_img, alpha_resize_img) 16 | # 17 | # # 通过alpha 获取 trimap 18 | trimap = "..\\aiphoto\\img\\meinv_trimap_resize.png" 19 | to_standard_trimap.to_standard_trimap(alpha_resize_img, trimap) 20 | # 21 | # 证件照添加蓝底纯色背景 22 | id_image = "..\\aiphoto\\img\\meinv_id.png" 23 | to_background.to_background(org_img, trimap, id_image, "blue") 24 | #id_image = "..\\aiphoto\\img\\meinv_id_grid.png" 25 | #to_background.to_background_grid(org_img, trimap, id_image) 26 | # image = Image.open(id_image) 27 | # data = image.getdata() 28 | # np.savetxt("data6.txt", data,fmt='%d',delimiter=',') 29 | 30 | # 20200719 31 | # 通过识别人脸关键点,裁剪图像 32 | ai_crop.crop_photo("..\\aiphoto\\img\\meinv_id.png", "..\\aiphoto\\img\\2in.jpg") 33 | -------------------------------------------------------------------------------- /m_dlib/face_marks.py: -------------------------------------------------------------------------------- 1 | import dlib 2 | import cv2 3 | 4 | 5 | def predictor_face(path): 6 | """ 7 | :param path: 原始图片路径 8 | :return shape 关键点 9 | """ 10 | # 模型路径 --》修改成自己的路径 11 | predictor_path = "E:\\python\\az\\dlib-19.19.0.tar\\shape_predictor_68_face_landmarks.dat" 12 | 13 | # 获取一个探测器 14 | detector = dlib.get_frontal_face_detector() 15 | # 获取一个预测模型 16 | predictor = dlib.shape_predictor(predictor_path) 17 | # 加载图片 18 | img = dlib.load_rgb_image(path) 19 | 20 | # 探测到的人脸 21 | dets = detector(img, 1) 22 | print("Number of faces detected: {}".format(len(dets))) 23 | # 取一个人脸 24 | d = dets[0] 25 | 26 | print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(0, d.left(), d.top(), d.right(), d.bottom())) 27 | 28 | # Get the landmarks/parts for the face in box d. 29 | shape = predictor(img, d) 30 | 31 | print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1))) 32 | return shape, d 33 | 34 | # 显示人脸关键点 35 | def test_landmarks(path,target_path): 36 | path = path 37 | img = cv2.imread(path) 38 | font=cv2.FONT_HERSHEY_SIMPLEX 39 | i = 0 40 | shape, d = predictor_face(path) 41 | for pt in shape.parts(): 42 | i = i+1 43 | pt_pos = (pt.x, pt.y) 44 | cv2.putText(img, str(i), pt_pos, font, 0.3, (0, 255, 0)) 45 | 46 | cv2.imshow("image", img) 47 | cv2.imwrite(target_path,img) 48 | cv2.waitKey(0) 49 | cv2.destroyAllWindows() 50 | 51 | 52 | # 打印人脸特征点 53 | # test_landmarks("..//img//meinv_id.png","..//img//meinv_id_landmarks.png") 54 | -------------------------------------------------------------------------------- /README.md.orig: -------------------------------------------------------------------------------- 1 | [知乎专栏](https://www.zhihu.com/people/itlf/columns) 2 | 3 | # aiphoto 4 | 最近要去办事情,很多地方都需要证件照,最近刚好在看AI,人脸识别,图形识别相关的知识,就打算利用这些技术开发一个证件照功能 5 | 6 | <<<<<<< HEAD 7 | # 效果图 8 | ======= 9 | 蓝底原始图: 10 | 11 | ![image-20200715000133867](https://raw.githubusercontent.com/wiki/itainf/aiphoto/裁剪照片.assets/image-20200715000133867.png) 12 | 13 | 关键特征描述: 14 | 15 | ![image-20200715000232798](https://raw.githubusercontent.com/wiki/itainf/aiphoto/裁剪照片.assets/image-20200715000232798.png) 16 | 17 | 裁剪后的2寸照片: 18 | 19 | ![image-20200715000251662](https://raw.githubusercontent.com/wiki/itainf/aiphoto/裁剪照片.assets/image-20200715000251662.png) 20 | 21 | >>>>>>> 734dfc205f51cb644b299387aea0ffc40ddd6a54 22 | 23 | # 文档 24 | 25 | 通过文档可以快速上手和了解项目。 26 | 27 | 1.[python环境搭建](https://github.com/itainf/aiphoto/wiki/python%E7%8E%AF%E5%A2%83%E6%90%AD%E5%BB%BA) 28 | 29 | 2.[卷积神经网络模型人像分割](https://github.com/itainf/aiphoto/wiki/%E5%8D%B7%E7%A7%AF%E6%A8%A1%E5%9E%8B%E4%BA%BA%E5%83%8F%E5%88%86%E5%89%B2) 30 | 31 | 3.[利用PyMatting替换背景颜色](https://github.com/itainf/aiphoto/wiki/%E5%88%A9%E7%94%A8PyMatting%E7%B2%BE%E7%BB%86%E5%8C%96%E6%8A%A0%E5%9B%BE) 32 | 33 | 34 | ### 更新记录 35 | 36 | 2020年7月4日更新 37 | 38 | 本次更新版本: v20200704 39 | 40 | 本次更新了,通过卷积神经网络模型分割人像 41 | 42 | 文档: [卷积模型人像分割](https://github.com/itainf/aiphoto/wiki/%E5%8D%B7%E7%A7%AF%E6%A8%A1%E5%9E%8B%E4%BA%BA%E5%83%8F%E5%88%86%E5%89%B2) 43 | 44 | 45 | 46 | 2020年7月9日更新 47 | 48 | 本次更新版本: v20200709 49 | 50 | 本次更新了,通过PyMatting框架,利用trimap分割前景人像,将背景替换成证件照背景 51 | 52 | 文档:[利用PyMatting替换成证件照背景](https://github.com/itainf/aiphoto/wiki/%E5%88%A9%E7%94%A8PyMatting%E7%B2%BE%E7%BB%86%E5%8C%96%E6%8A%A0%E5%9B%BE) 53 | 54 | 55 | 2020年7月15日更新 56 | 57 | 本次更新版本: v20200715 58 | 59 | 本次更新了,通过dlib框架,裁剪成标准尺寸照片 60 | 61 | 文档:[裁剪照片](https://github.com/itainf/aiphoto/wiki/%E8%A3%81%E5%89%AA%E7%85%A7%E7%89%87) 62 | 63 | 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [知乎专栏](https://www.zhihu.com/people/itlf/columns) 2 | 3 | # aiphoto 4 | 最近要去办事情,很多地方都需要证件照,最近刚好在看AI,人脸识别,图形识别相关的知识,就打算利用这些技术开发一个证件照功能 5 | 6 | 蓝底原始图: 7 | 8 | ![image-20200715000133867](https://raw.githubusercontent.com/wiki/itainf/aiphoto/裁剪照片.assets/image-20200715000133867.png) 9 | 10 | 关键特征描述: 11 | 12 | ![image-20200715000232798](https://raw.githubusercontent.com/wiki/itainf/aiphoto/裁剪照片.assets/image-20200715000232798.png) 13 | 14 | 裁剪后的2寸照片: 15 | 16 | ![image-20200715000251662](https://raw.githubusercontent.com/wiki/itainf/aiphoto/裁剪照片.assets/image-20200715000251662.png) 17 | 18 | 19 | # 文档 20 | 21 | 通过文档可以快速上手和了解项目。 22 | 23 | 1.[python环境搭建](https://github.com/itainf/aiphoto/wiki/python%E7%8E%AF%E5%A2%83%E6%90%AD%E5%BB%BA) 24 | 25 | 2.[卷积神经网络模型人像分割](https://github.com/itainf/aiphoto/wiki/%E5%8D%B7%E7%A7%AF%E6%A8%A1%E5%9E%8B%E4%BA%BA%E5%83%8F%E5%88%86%E5%89%B2) 26 | 27 | 3.[利用PyMatting替换背景颜色](https://github.com/itainf/aiphoto/wiki/%E5%88%A9%E7%94%A8PyMatting%E7%B2%BE%E7%BB%86%E5%8C%96%E6%8A%A0%E5%9B%BE) 28 | 29 | 30 | ### 更新记录 31 | 32 | 2020年7月4日更新 33 | 34 | 本次更新版本: v20200704 35 | 36 | 本次更新了,通过卷积神经网络模型分割人像 37 | 38 | 文档: [卷积模型人像分割](https://github.com/itainf/aiphoto/wiki/%E5%8D%B7%E7%A7%AF%E6%A8%A1%E5%9E%8B%E4%BA%BA%E5%83%8F%E5%88%86%E5%89%B2) 39 | 40 | 41 | 42 | 2020年7月9日更新 43 | 44 | 本次更新版本: v20200709 45 | 46 | 本次更新了,通过PyMatting框架,利用trimap分割前景人像,将背景替换成证件照背景 47 | 48 | 文档:[利用PyMatting替换成证件照背景](https://github.com/itainf/aiphoto/wiki/%E5%88%A9%E7%94%A8PyMatting%E7%B2%BE%E7%BB%86%E5%8C%96%E6%8A%A0%E5%9B%BE) 49 | 50 | 51 | 2020年7月15日更新 52 | 53 | 本次更新版本: v20200715 54 | 55 | 本次更新了,通过dlib框架,裁剪成标准尺寸照片 56 | 57 | 文档:[裁剪照片](https://github.com/itainf/aiphoto/wiki/%E8%A3%81%E5%89%AA%E7%85%A7%E7%89%87) 58 | 59 | 60 | 2020年7月26日更新 61 | 62 | 本次更新版本: v20200726 63 | 64 | 文档: [小程序拍照上传生成证件照](https://github.com/itainf/aiphoto/wiki/%E5%B0%8F%E7%A8%8B%E5%BA%8F%E6%8B%8D%E7%85%A7%E4%B8%8A%E4%BC%A0%E7%94%9F%E6%88%90%E8%AF%81%E4%BB%B6%E7%85%A7) 65 | -------------------------------------------------------------------------------- /to_background/to_background.py: -------------------------------------------------------------------------------- 1 | from pymatting import * 2 | from PIL import Image 3 | 4 | colour_dict = { 5 | "white": (255, 255, 255), 6 | "red": (255, 0, 0), 7 | "blue": (67, 142, 219) 8 | } 9 | 10 | 11 | def to_background(org, resize_trimap, id_image, colour): 12 | """ 13 | org:原始图片 14 | resize_trimap:trimap 15 | id_image:新图片 16 | colour: 背景颜色 17 | """ 18 | scale = 1.0 19 | image = load_image(org, "RGB", scale, "box") 20 | trimap = load_image(resize_trimap, "GRAY", scale, "nearest") 21 | im = Image.open(org) 22 | # estimate alpha from image and trimap 23 | alpha = estimate_alpha_cf(image, trimap) 24 | 25 | new_background = Image.new('RGB', im.size, colour_dict[colour]) 26 | new_background.save("bj.png") 27 | # load new background 28 | new_background = load_image("bj.png", "RGB", scale, "box") 29 | 30 | 31 | # estimate foreground from image and alpha 32 | foreground, background = estimate_foreground_ml(image, alpha, return_background=True) 33 | 34 | # blend foreground with background and alpha 35 | new_image = blend(foreground, new_background, alpha) 36 | save_image(id_image, new_image) 37 | 38 | 39 | def to_background_grid(org, resize_trimap, id_image): 40 | """ 41 | org:原始图片 42 | resize_trimap:trimap 43 | id_image:新图片 44 | colour: 背景颜色 45 | """ 46 | scale = 1.0 47 | image = load_image(org, "RGB", scale, "box") 48 | trimap = load_image(resize_trimap, "GRAY", scale, "nearest") 49 | im = Image.open(org) 50 | # estimate alpha from image and trimap 51 | alpha = estimate_alpha_cf(image, trimap) 52 | 53 | # estimate foreground from image and alpha 54 | foreground, background = estimate_foreground_ml(image, alpha, return_background=True) 55 | images = [image] 56 | for k,v in colour_dict.items(): 57 | new_background = Image.new('RGB', im.size, v) 58 | new_background.save("bj.png") 59 | new_background = load_image("bj.png", "RGB", scale, "box") 60 | new_image = blend(foreground, new_background, alpha) 61 | images.append(new_image) 62 | 63 | grid = make_grid(images) 64 | save_image(id_image, grid) 65 | 66 | -------------------------------------------------------------------------------- /m_web/upload.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tornado.web 3 | import shortuuid 4 | from u_2_net import my_u2net_test 5 | from to_background import to_background 6 | from to_background import to_standard_trimap 7 | from m_dlib import ai_crop 8 | 9 | # import PILImageMy as mypil 10 | 11 | 12 | class UploadHandler(tornado.web.RequestHandler): 13 | 14 | def post(self, *args, **kwargs): 15 | 16 | filename=shortuuid.uuid() 17 | print(os.path.dirname(__file__)) 18 | parent_path = os.path.dirname(os.path.dirname(__file__)) 19 | filePath ="" 20 | # 查看上传文件的完整格式,files以字典形式返回 21 | print(self.request.files) 22 | # {'file1': 23 | # [{'filename': '新建文本文档.txt', 'body': b'61 60 -83\r\n-445 64 -259', 'content_type': 'text/plain'}], 24 | # 'file2': 25 | filesDict = self.request.files 26 | for inputname in filesDict: 27 | # 第一层循环取出最外层信息,即input标签传回的name值 28 | # 用过filename键值对对应,取出对应的上传文件的真实属性 29 | http_file = filesDict[inputname] 30 | for fileObj in http_file: 31 | 32 | # 第二层循环取出完整的对象 33 | # 取得当前路径下的 upfiles 文件夹+上fileObj.filename属性(即真实文件名) 34 | filePath = os.path.join(parent_path, "static", filename+".jpg") 35 | print(filePath) 36 | with open(filePath, 'wb') as f: 37 | f.write(fileObj.body) 38 | 39 | 40 | 41 | org_img = filePath 42 | 43 | id_image = os.path.join(parent_path, "static", filename+"_meinv_id.png") 44 | # 20200719 45 | # 通过识别人脸关键点,裁剪图像 46 | ai_crop.crop_photo(org_img,id_image ) 47 | 48 | 49 | print(org_img) 50 | alpha_img = os.path.join(parent_path, "static", filename+"_alpha.png") 51 | print(alpha_img) 52 | alpha_resize_img = os.path.join(parent_path, "static", filename+"_alpha_resize.png") 53 | print(alpha_resize_img) 54 | # 55 | # 通过u_2_net 获取 alpha 56 | my_u2net_test.test_seg_trimap(id_image, alpha_img, alpha_resize_img) 57 | # 58 | # # 通过alpha 获取 trimap 59 | trimap = os.path.join(parent_path, "static", filename+"_trimap_resize.png") 60 | to_standard_trimap.to_standard_trimap(alpha_resize_img, trimap) 61 | print(trimap) 62 | 63 | id_image_org = os.path.join(parent_path, "static", filename+"_meinv_id_2in.png") 64 | 65 | # 66 | # 证件照添加蓝底纯色背景//"..\\aiphoto\\img\\meinv_trimap_resize.png" 67 | # to_standard_trimap.to_standard_trimap(alpha_resize_img, trimap) 68 | to_background.to_background(id_image, trimap, id_image_org, "blue") 69 | print(id_image_org) 70 | self.write( "static/"+filename+"_meinv_id_2in.png") 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /u_2_net/my_u2net_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torchvision import transforms#, utils 4 | # import torch.optim as optim 5 | import numpy as np 6 | from u_2_net.data_loader import RescaleT 7 | from u_2_net.data_loader import ToTensorLab 8 | from u_2_net.model import U2NET # full size version 173.6 MB 9 | from PIL import Image 10 | import os 11 | 12 | # normalize the predicted SOD probability map 13 | def normPRED(d): 14 | ma = torch.max(d) 15 | mi = torch.min(d) 16 | dn = (d-mi)/(ma-mi) 17 | return dn 18 | 19 | 20 | def preprocess(image): 21 | label_3 = np.zeros(image.shape) 22 | label = np.zeros(label_3.shape[0:2]) 23 | 24 | if (3 == len(label_3.shape)): 25 | label = label_3[:, :, 0] 26 | elif (2 == len(label_3.shape)): 27 | label = label_3 28 | if (3 == len(image.shape) and 2 == len(label.shape)): 29 | label = label[:, :, np.newaxis] 30 | elif (2 == len(image.shape) and 2 == len(label.shape)): 31 | image = image[:, :, np.newaxis] 32 | label = label[:, :, np.newaxis] 33 | 34 | transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) 35 | sample = transform({ 36 | 'imidx': np.array([0]), 37 | 'image': image, 38 | 'label': label 39 | }) 40 | 41 | return sample 42 | 43 | 44 | def pre_net(): 45 | # 采用n2net 模型数据 46 | model_name = 'u2net' 47 | path = os.path.dirname(__file__) 48 | print(path) 49 | model_dir = path+'/saved_models/'+ model_name + '/' + model_name + '.pth' 50 | print(model_dir) 51 | print("...load U2NET---173.6 MB") 52 | net = U2NET(3,1) 53 | # 指定cpu 54 | net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) 55 | if torch.cuda.is_available(): 56 | net.cuda() 57 | net.eval() 58 | return net 59 | 60 | 61 | def pre_test_data(img): 62 | torch.cuda.empty_cache() 63 | sample = preprocess(img) 64 | inputs_test = sample['image'].unsqueeze(0) 65 | inputs_test = inputs_test.type(torch.FloatTensor) 66 | if torch.cuda.is_available(): 67 | inputs_test = Variable(inputs_test.cuda()) 68 | else: 69 | inputs_test = Variable(inputs_test) 70 | return inputs_test 71 | 72 | 73 | def get_im(pred): 74 | predict = pred 75 | predict = predict.squeeze() 76 | predict_np = predict.cpu().data.numpy() 77 | im = Image.fromarray(predict_np*255).convert('RGB') 78 | return im 79 | 80 | 81 | def test_seg_trimap(org,alpha,alpha_resize): 82 | # 将原始图片转换成 Alpha图 83 | # org:原始图片 84 | # org_trimap: 85 | # resize_trimap: 调整尺寸的trimap 86 | image = Image.open(org) 87 | print(image) 88 | img = np.array(image) 89 | net = pre_net() 90 | inputs_test = pre_test_data(img) 91 | d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) 92 | # normalization 93 | pred = d1[:, 0, :, :] 94 | pred = normPRED(pred) 95 | # 将数据转换成图片 96 | im = get_im(pred) 97 | im.save(alpha) 98 | sp = image.size 99 | # 根据原始图片调整尺寸 100 | imo = im.resize((sp[0], sp[1]), resample=Image.BILINEAR) 101 | imo.save(alpha_resize) 102 | 103 | 104 | # if __name__ == "__main__": 105 | # test_seg_trimap("..\\img\\meinv.jpg","..\\img\\trimap\\meinv_alpha.png","..\\img\\trimap\\meinv_alpha_resize.png") 106 | # #pil_wait_blue() 107 | -------------------------------------------------------------------------------- /u_2_net/data_loader.py: -------------------------------------------------------------------------------- 1 | # data loader 2 | from __future__ import print_function, division 3 | import glob 4 | import torch 5 | from skimage import io, transform, color 6 | import numpy as np 7 | import random 8 | import math 9 | import matplotlib.pyplot as plt 10 | from torch.utils.data import Dataset, DataLoader 11 | from torchvision import transforms, utils 12 | from PIL import Image 13 | #==========================dataset load========================== 14 | class RescaleT(object): 15 | 16 | def __init__(self,output_size): 17 | assert isinstance(output_size,(int,tuple)) 18 | self.output_size = output_size 19 | 20 | def __call__(self,sample): 21 | imidx, image, label = sample['imidx'], sample['image'],sample['label'] 22 | 23 | h, w = image.shape[:2] 24 | 25 | if isinstance(self.output_size,int): 26 | if h > w: 27 | new_h, new_w = self.output_size*h/w,self.output_size 28 | else: 29 | new_h, new_w = self.output_size,self.output_size*w/h 30 | else: 31 | new_h, new_w = self.output_size 32 | 33 | new_h, new_w = int(new_h), int(new_w) 34 | 35 | # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] 36 | # img = transform.resize(image,(new_h,new_w),mode='constant') 37 | # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) 38 | 39 | img = transform.resize(image,(self.output_size,self.output_size),mode='constant') 40 | lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True) 41 | 42 | return {'imidx':imidx, 'image':img,'label':lbl} 43 | 44 | class Rescale(object): 45 | 46 | def __init__(self,output_size): 47 | assert isinstance(output_size,(int,tuple)) 48 | self.output_size = output_size 49 | 50 | def __call__(self,sample): 51 | imidx, image, label = sample['imidx'], sample['image'],sample['label'] 52 | 53 | h, w = image.shape[:2] 54 | 55 | if isinstance(self.output_size,int): 56 | if h > w: 57 | new_h, new_w = self.output_size*h/w,self.output_size 58 | else: 59 | new_h, new_w = self.output_size,self.output_size*w/h 60 | else: 61 | new_h, new_w = self.output_size 62 | 63 | new_h, new_w = int(new_h), int(new_w) 64 | 65 | # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] 66 | img = transform.resize(image,(new_h,new_w),mode='constant') 67 | lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) 68 | 69 | return {'imidx':imidx, 'image':img,'label':lbl} 70 | 71 | class RandomCrop(object): 72 | 73 | def __init__(self,output_size): 74 | assert isinstance(output_size, (int, tuple)) 75 | if isinstance(output_size, int): 76 | self.output_size = (output_size, output_size) 77 | else: 78 | assert len(output_size) == 2 79 | self.output_size = output_size 80 | def __call__(self,sample): 81 | imidx, image, label = sample['imidx'], sample['image'], sample['label'] 82 | 83 | h, w = image.shape[:2] 84 | new_h, new_w = self.output_size 85 | 86 | top = np.random.randint(0, h - new_h) 87 | left = np.random.randint(0, w - new_w) 88 | 89 | image = image[top: top + new_h, left: left + new_w] 90 | label = label[top: top + new_h, left: left + new_w] 91 | 92 | return {'imidx':imidx,'image':image, 'label':label} 93 | 94 | class ToTensor(object): 95 | """Convert ndarrays in sample to Tensors.""" 96 | 97 | def __call__(self, sample): 98 | 99 | imidx, image, label = sample['imidx'], sample['image'], sample['label'] 100 | 101 | tmpImg = np.zeros((image.shape[0],image.shape[1],3)) 102 | tmpLbl = np.zeros(label.shape) 103 | 104 | image = image/np.max(image) 105 | if(np.max(label)<1e-6): 106 | label = label 107 | else: 108 | label = label/np.max(label) 109 | 110 | if image.shape[2]==1: 111 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 112 | tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 113 | tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 114 | else: 115 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 116 | tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 117 | tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 118 | 119 | tmpLbl[:,:,0] = label[:,:,0] 120 | 121 | # change the r,g,b to b,r,g from [0,255] to [0,1] 122 | #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) 123 | tmpImg = tmpImg.transpose((2, 0, 1)) 124 | tmpLbl = label.transpose((2, 0, 1)) 125 | 126 | return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} 127 | 128 | class ToTensorLab(object): 129 | """Convert ndarrays in sample to Tensors.""" 130 | def __init__(self,flag=0): 131 | self.flag = flag 132 | 133 | def __call__(self, sample): 134 | 135 | imidx, image, label =sample['imidx'], sample['image'], sample['label'] 136 | 137 | tmpLbl = np.zeros(label.shape) 138 | 139 | if(np.max(label)<1e-6): 140 | label = label 141 | else: 142 | label = label/np.max(label) 143 | 144 | # change the color space 145 | if self.flag == 2: # with rgb and Lab colors 146 | tmpImg = np.zeros((image.shape[0],image.shape[1],6)) 147 | tmpImgt = np.zeros((image.shape[0],image.shape[1],3)) 148 | if image.shape[2]==1: 149 | tmpImgt[:,:,0] = image[:,:,0] 150 | tmpImgt[:,:,1] = image[:,:,0] 151 | tmpImgt[:,:,2] = image[:,:,0] 152 | else: 153 | tmpImgt = image 154 | tmpImgtl = color.rgb2lab(tmpImgt) 155 | 156 | # nomalize image to range [0,1] 157 | tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0])) 158 | tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1])) 159 | tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2])) 160 | tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0])) 161 | tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1])) 162 | tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2])) 163 | 164 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) 165 | 166 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) 167 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) 168 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) 169 | tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3]) 170 | tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4]) 171 | tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5]) 172 | 173 | elif self.flag == 1: #with Lab color 174 | tmpImg = np.zeros((image.shape[0],image.shape[1],3)) 175 | 176 | if image.shape[2]==1: 177 | tmpImg[:,:,0] = image[:,:,0] 178 | tmpImg[:,:,1] = image[:,:,0] 179 | tmpImg[:,:,2] = image[:,:,0] 180 | else: 181 | tmpImg = image 182 | 183 | tmpImg = color.rgb2lab(tmpImg) 184 | 185 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) 186 | 187 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0])) 188 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1])) 189 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2])) 190 | 191 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) 192 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) 193 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) 194 | 195 | else: # with rgb color 196 | tmpImg = np.zeros((image.shape[0],image.shape[1],3)) 197 | image = image/np.max(image) 198 | if image.shape[2]==1: 199 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 200 | tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 201 | tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 202 | else: 203 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 204 | tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 205 | tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 206 | 207 | tmpLbl[:,:,0] = label[:,:,0] 208 | 209 | # change the r,g,b to b,r,g from [0,255] to [0,1] 210 | #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) 211 | tmpImg = tmpImg.transpose((2, 0, 1)) 212 | tmpLbl = label.transpose((2, 0, 1)) 213 | 214 | return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} 215 | 216 | class SalObjDataset(Dataset): 217 | def __init__(self,img_name_list,lbl_name_list,transform=None): 218 | # self.root_dir = root_dir 219 | # self.image_name_list = glob.glob(image_dir+'*.png') 220 | # self.label_name_list = glob.glob(label_dir+'*.png') 221 | self.image_name_list = img_name_list 222 | self.label_name_list = lbl_name_list 223 | self.transform = transform 224 | 225 | def __len__(self): 226 | return len(self.image_name_list) 227 | 228 | def __getitem__(self,idx): 229 | 230 | # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx]) 231 | # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx]) 232 | 233 | image = io.imread(self.image_name_list[idx]) 234 | imname = self.image_name_list[idx] 235 | imidx = np.array([idx]) 236 | 237 | if(0==len(self.label_name_list)): 238 | label_3 = np.zeros(image.shape) 239 | else: 240 | label_3 = io.imread(self.label_name_list[idx]) 241 | 242 | label = np.zeros(label_3.shape[0:2]) 243 | if(3==len(label_3.shape)): 244 | label = label_3[:,:,0] 245 | elif(2==len(label_3.shape)): 246 | label = label_3 247 | 248 | if(3==len(image.shape) and 2==len(label.shape)): 249 | label = label[:,:,np.newaxis] 250 | elif(2==len(image.shape) and 2==len(label.shape)): 251 | image = image[:,:,np.newaxis] 252 | label = label[:,:,np.newaxis] 253 | 254 | sample = {'imidx':imidx, 'image':image, 'label':label} 255 | 256 | if self.transform: 257 | sample = self.transform(sample) 258 | 259 | return sample 260 | -------------------------------------------------------------------------------- /u_2_net/model/u2net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | import torch.nn.functional as F 5 | 6 | class REBNCONV(nn.Module): 7 | def __init__(self,in_ch=3,out_ch=3,dirate=1): 8 | super(REBNCONV,self).__init__() 9 | 10 | self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate) 11 | self.bn_s1 = nn.BatchNorm2d(out_ch) 12 | self.relu_s1 = nn.ReLU(inplace=True) 13 | 14 | def forward(self,x): 15 | 16 | hx = x 17 | xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) 18 | 19 | return xout 20 | 21 | ## upsample tensor 'src' to have the same spatial size with tensor 'tar' 22 | def _upsample_like(src,tar): 23 | 24 | src = F.upsample(src,size=tar.shape[2:],mode='bilinear') 25 | 26 | return src 27 | 28 | 29 | ### RSU-7 ### 30 | class RSU7(nn.Module):#UNet07DRES(nn.Module): 31 | 32 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 33 | super(RSU7,self).__init__() 34 | 35 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 36 | 37 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 38 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 39 | 40 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) 41 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 42 | 43 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) 44 | self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 45 | 46 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) 47 | self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 48 | 49 | self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) 50 | self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 51 | 52 | self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) 53 | 54 | self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) 55 | 56 | self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 57 | self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 58 | self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 59 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 60 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 61 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 62 | 63 | def forward(self,x): 64 | 65 | hx = x 66 | hxin = self.rebnconvin(hx) 67 | 68 | hx1 = self.rebnconv1(hxin) 69 | hx = self.pool1(hx1) 70 | 71 | hx2 = self.rebnconv2(hx) 72 | hx = self.pool2(hx2) 73 | 74 | hx3 = self.rebnconv3(hx) 75 | hx = self.pool3(hx3) 76 | 77 | hx4 = self.rebnconv4(hx) 78 | hx = self.pool4(hx4) 79 | 80 | hx5 = self.rebnconv5(hx) 81 | hx = self.pool5(hx5) 82 | 83 | hx6 = self.rebnconv6(hx) 84 | 85 | hx7 = self.rebnconv7(hx6) 86 | 87 | hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) 88 | hx6dup = _upsample_like(hx6d,hx5) 89 | 90 | hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) 91 | hx5dup = _upsample_like(hx5d,hx4) 92 | 93 | hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) 94 | hx4dup = _upsample_like(hx4d,hx3) 95 | 96 | hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) 97 | hx3dup = _upsample_like(hx3d,hx2) 98 | 99 | hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) 100 | hx2dup = _upsample_like(hx2d,hx1) 101 | 102 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) 103 | 104 | return hx1d + hxin 105 | 106 | ### RSU-6 ### 107 | class RSU6(nn.Module):#UNet06DRES(nn.Module): 108 | 109 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 110 | super(RSU6,self).__init__() 111 | 112 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 113 | 114 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 115 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 116 | 117 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) 118 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 119 | 120 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) 121 | self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 122 | 123 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) 124 | self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 125 | 126 | self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) 127 | 128 | self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) 129 | 130 | self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 131 | self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 132 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 133 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 134 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 135 | 136 | def forward(self,x): 137 | 138 | hx = x 139 | 140 | hxin = self.rebnconvin(hx) 141 | 142 | hx1 = self.rebnconv1(hxin) 143 | hx = self.pool1(hx1) 144 | 145 | hx2 = self.rebnconv2(hx) 146 | hx = self.pool2(hx2) 147 | 148 | hx3 = self.rebnconv3(hx) 149 | hx = self.pool3(hx3) 150 | 151 | hx4 = self.rebnconv4(hx) 152 | hx = self.pool4(hx4) 153 | 154 | hx5 = self.rebnconv5(hx) 155 | 156 | hx6 = self.rebnconv6(hx5) 157 | 158 | 159 | hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) 160 | hx5dup = _upsample_like(hx5d,hx4) 161 | 162 | hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) 163 | hx4dup = _upsample_like(hx4d,hx3) 164 | 165 | hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) 166 | hx3dup = _upsample_like(hx3d,hx2) 167 | 168 | hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) 169 | hx2dup = _upsample_like(hx2d,hx1) 170 | 171 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) 172 | 173 | return hx1d + hxin 174 | 175 | ### RSU-5 ### 176 | class RSU5(nn.Module):#UNet05DRES(nn.Module): 177 | 178 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 179 | super(RSU5,self).__init__() 180 | 181 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 182 | 183 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 184 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 185 | 186 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) 187 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 188 | 189 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) 190 | self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 191 | 192 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) 193 | 194 | self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) 195 | 196 | self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 197 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 198 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 199 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 200 | 201 | def forward(self,x): 202 | 203 | hx = x 204 | 205 | hxin = self.rebnconvin(hx) 206 | 207 | hx1 = self.rebnconv1(hxin) 208 | hx = self.pool1(hx1) 209 | 210 | hx2 = self.rebnconv2(hx) 211 | hx = self.pool2(hx2) 212 | 213 | hx3 = self.rebnconv3(hx) 214 | hx = self.pool3(hx3) 215 | 216 | hx4 = self.rebnconv4(hx) 217 | 218 | hx5 = self.rebnconv5(hx4) 219 | 220 | hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) 221 | hx4dup = _upsample_like(hx4d,hx3) 222 | 223 | hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) 224 | hx3dup = _upsample_like(hx3d,hx2) 225 | 226 | hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) 227 | hx2dup = _upsample_like(hx2d,hx1) 228 | 229 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) 230 | 231 | return hx1d + hxin 232 | 233 | ### RSU-4 ### 234 | class RSU4(nn.Module):#UNet04DRES(nn.Module): 235 | 236 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 237 | super(RSU4,self).__init__() 238 | 239 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 240 | 241 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 242 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 243 | 244 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) 245 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 246 | 247 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) 248 | 249 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) 250 | 251 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 252 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 253 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 254 | 255 | def forward(self,x): 256 | 257 | hx = x 258 | 259 | hxin = self.rebnconvin(hx) 260 | 261 | hx1 = self.rebnconv1(hxin) 262 | hx = self.pool1(hx1) 263 | 264 | hx2 = self.rebnconv2(hx) 265 | hx = self.pool2(hx2) 266 | 267 | hx3 = self.rebnconv3(hx) 268 | 269 | hx4 = self.rebnconv4(hx3) 270 | 271 | hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) 272 | hx3dup = _upsample_like(hx3d,hx2) 273 | 274 | hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) 275 | hx2dup = _upsample_like(hx2d,hx1) 276 | 277 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) 278 | 279 | return hx1d + hxin 280 | 281 | ### RSU-4F ### 282 | class RSU4F(nn.Module):#UNet04FRES(nn.Module): 283 | 284 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 285 | super(RSU4F,self).__init__() 286 | 287 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 288 | 289 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 290 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) 291 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) 292 | 293 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) 294 | 295 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) 296 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) 297 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 298 | 299 | def forward(self,x): 300 | 301 | hx = x 302 | 303 | hxin = self.rebnconvin(hx) 304 | 305 | hx1 = self.rebnconv1(hxin) 306 | hx2 = self.rebnconv2(hx1) 307 | hx3 = self.rebnconv3(hx2) 308 | 309 | hx4 = self.rebnconv4(hx3) 310 | 311 | hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) 312 | hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) 313 | hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) 314 | 315 | return hx1d + hxin 316 | 317 | 318 | ##### U^2-Net #### 319 | class U2NET(nn.Module): 320 | 321 | def __init__(self,in_ch=3,out_ch=1): 322 | super(U2NET,self).__init__() 323 | 324 | self.stage1 = RSU7(in_ch,32,64) 325 | self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 326 | 327 | self.stage2 = RSU6(64,32,128) 328 | self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 329 | 330 | self.stage3 = RSU5(128,64,256) 331 | self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 332 | 333 | self.stage4 = RSU4(256,128,512) 334 | self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 335 | 336 | self.stage5 = RSU4F(512,256,512) 337 | self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 338 | 339 | self.stage6 = RSU4F(512,256,512) 340 | 341 | # decoder 342 | self.stage5d = RSU4F(1024,256,512) 343 | self.stage4d = RSU4(1024,128,256) 344 | self.stage3d = RSU5(512,64,128) 345 | self.stage2d = RSU6(256,32,64) 346 | self.stage1d = RSU7(128,16,64) 347 | 348 | self.side1 = nn.Conv2d(64,out_ch,3,padding=1) 349 | self.side2 = nn.Conv2d(64,out_ch,3,padding=1) 350 | self.side3 = nn.Conv2d(128,out_ch,3,padding=1) 351 | self.side4 = nn.Conv2d(256,out_ch,3,padding=1) 352 | self.side5 = nn.Conv2d(512,out_ch,3,padding=1) 353 | self.side6 = nn.Conv2d(512,out_ch,3,padding=1) 354 | 355 | self.outconv = nn.Conv2d(6,out_ch,1) 356 | 357 | def forward(self,x): 358 | 359 | hx = x 360 | 361 | #stage 1 362 | hx1 = self.stage1(hx) 363 | hx = self.pool12(hx1) 364 | 365 | #stage 2 366 | hx2 = self.stage2(hx) 367 | hx = self.pool23(hx2) 368 | 369 | #stage 3 370 | hx3 = self.stage3(hx) 371 | hx = self.pool34(hx3) 372 | 373 | #stage 4 374 | hx4 = self.stage4(hx) 375 | hx = self.pool45(hx4) 376 | 377 | #stage 5 378 | hx5 = self.stage5(hx) 379 | hx = self.pool56(hx5) 380 | 381 | #stage 6 382 | hx6 = self.stage6(hx) 383 | hx6up = _upsample_like(hx6,hx5) 384 | 385 | #-------------------- decoder -------------------- 386 | hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) 387 | hx5dup = _upsample_like(hx5d,hx4) 388 | 389 | hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) 390 | hx4dup = _upsample_like(hx4d,hx3) 391 | 392 | hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) 393 | hx3dup = _upsample_like(hx3d,hx2) 394 | 395 | hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) 396 | hx2dup = _upsample_like(hx2d,hx1) 397 | 398 | hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) 399 | 400 | 401 | #side output 402 | d1 = self.side1(hx1d) 403 | 404 | d2 = self.side2(hx2d) 405 | d2 = _upsample_like(d2,d1) 406 | 407 | d3 = self.side3(hx3d) 408 | d3 = _upsample_like(d3,d1) 409 | 410 | d4 = self.side4(hx4d) 411 | d4 = _upsample_like(d4,d1) 412 | 413 | d5 = self.side5(hx5d) 414 | d5 = _upsample_like(d5,d1) 415 | 416 | d6 = self.side6(hx6) 417 | d6 = _upsample_like(d6,d1) 418 | 419 | d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) 420 | 421 | return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6) 422 | 423 | ### U^2-Net small ### 424 | class U2NETP(nn.Module): 425 | 426 | def __init__(self,in_ch=3,out_ch=1): 427 | super(U2NETP,self).__init__() 428 | 429 | self.stage1 = RSU7(in_ch,16,64) 430 | self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 431 | 432 | self.stage2 = RSU6(64,16,64) 433 | self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 434 | 435 | self.stage3 = RSU5(64,16,64) 436 | self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 437 | 438 | self.stage4 = RSU4(64,16,64) 439 | self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 440 | 441 | self.stage5 = RSU4F(64,16,64) 442 | self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 443 | 444 | self.stage6 = RSU4F(64,16,64) 445 | 446 | # decoder 447 | self.stage5d = RSU4F(128,16,64) 448 | self.stage4d = RSU4(128,16,64) 449 | self.stage3d = RSU5(128,16,64) 450 | self.stage2d = RSU6(128,16,64) 451 | self.stage1d = RSU7(128,16,64) 452 | 453 | self.side1 = nn.Conv2d(64,out_ch,3,padding=1) 454 | self.side2 = nn.Conv2d(64,out_ch,3,padding=1) 455 | self.side3 = nn.Conv2d(64,out_ch,3,padding=1) 456 | self.side4 = nn.Conv2d(64,out_ch,3,padding=1) 457 | self.side5 = nn.Conv2d(64,out_ch,3,padding=1) 458 | self.side6 = nn.Conv2d(64,out_ch,3,padding=1) 459 | 460 | self.outconv = nn.Conv2d(6,out_ch,1) 461 | 462 | def forward(self,x): 463 | 464 | hx = x 465 | 466 | #stage 1 467 | hx1 = self.stage1(hx) 468 | hx = self.pool12(hx1) 469 | 470 | #stage 2 471 | hx2 = self.stage2(hx) 472 | hx = self.pool23(hx2) 473 | 474 | #stage 3 475 | hx3 = self.stage3(hx) 476 | hx = self.pool34(hx3) 477 | 478 | #stage 4 479 | hx4 = self.stage4(hx) 480 | hx = self.pool45(hx4) 481 | 482 | #stage 5 483 | hx5 = self.stage5(hx) 484 | hx = self.pool56(hx5) 485 | 486 | #stage 6 487 | hx6 = self.stage6(hx) 488 | hx6up = _upsample_like(hx6,hx5) 489 | 490 | #decoder 491 | hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) 492 | hx5dup = _upsample_like(hx5d,hx4) 493 | 494 | hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) 495 | hx4dup = _upsample_like(hx4d,hx3) 496 | 497 | hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) 498 | hx3dup = _upsample_like(hx3d,hx2) 499 | 500 | hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) 501 | hx2dup = _upsample_like(hx2d,hx1) 502 | 503 | hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) 504 | 505 | 506 | #side output 507 | d1 = self.side1(hx1d) 508 | 509 | d2 = self.side2(hx2d) 510 | d2 = _upsample_like(d2,d1) 511 | 512 | d3 = self.side3(hx3d) 513 | d3 = _upsample_like(d3,d1) 514 | 515 | d4 = self.side4(hx4d) 516 | d4 = _upsample_like(d4,d1) 517 | 518 | d5 = self.side5(hx5d) 519 | d5 = _upsample_like(d5,d1) 520 | 521 | d6 = self.side6(hx6) 522 | d6 = _upsample_like(d6,d1) 523 | 524 | d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) 525 | 526 | return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6) 527 | --------------------------------------------------------------------------------