├── 1 (2).png ├── LICENSE ├── LSTM1.png ├── ML_learn ├── Random_Forest.py └── sk_fit.py ├── README.md ├── nets ├── images_encoder.py ├── model.py └── transformer_text_encoder.py └── train.py /1 (2).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kay-cottage/AI_Medical_System/ed68481f3277915a87f7c71398391a503e4895d1/1 (2).png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LSTM1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kay-cottage/AI_Medical_System/ed68481f3277915a87f7c71398391a503e4895d1/LSTM1.png -------------------------------------------------------------------------------- /ML_learn/Random_Forest.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | 6 | 7 | path = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" 8 | ''' 9 | # data format 10 | 6.2,2.9,4.3,1.3,Iris-versicolor 11 | 5.1,2.5,3.0,1.1,Iris-versicolor 12 | 5.7,2.8,4.1,1.3,Iris-versicolor 13 | 6.3,3.3,6.0,2.5,Iris-virginica 14 | 5.8,2.7,5.1,1.9,Iris-virginica 15 | 7.1,3.0,5.9,2.1,Iris-virginica 16 | 6.3,2.9,5.6,1.8,Iris-virginica 17 | . 18 | . 19 | . 20 | ''' 21 | 22 | headernames = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'Class'] 23 | 24 | 25 | dataset = pd.read_csv(path, names=headernames) 26 | dataset.head() 27 | 28 | 29 | X = dataset.iloc[:, :-1].values 30 | y = dataset.iloc[:, 4].values 31 | 32 | 33 | from sklearn.model_selection import train_test_split 34 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30) 35 | 36 | 37 | from sklearn.ensemble import RandomForestClassifier 38 | classifier = RandomForestClassifier(n_estimators=50) 39 | classifier.fit(X_train, y_train) 40 | 41 | 42 | y_pred = classifier.predict(X_test) 43 | 44 | 45 | from sklearn.metrics import classification_report, confusion_matrix, accuracy_score 46 | result = confusion_matrix(y_test, y_pred) 47 | print("Confusion Matrix:") 48 | print(result) 49 | result1 = classification_report(y_test, y_pred) 50 | print("Classification Report:",) 51 | print (result1) 52 | result2 = accuracy_score(y_test,y_pred) 53 | print("Accuracy:",result2) 54 | 55 | -------------------------------------------------------------------------------- /ML_learn/sk_fit.py: -------------------------------------------------------------------------------- 1 | from sklearn.linear_model import LogisticRegression 2 | from sklearn.ensemble import RandomForestClassifier 3 | #from xgboost import XGBClassifier 4 | import numpy as np 5 | import warnings 6 | warnings.filterwarnings("ignore") 7 | # 这个方法只是解决了表面,没有根治 8 | 9 | # 数据(特征,属性) 10 | x_train = np.array([[1,2,3], 11 | [1,5,4], 12 | [2,2,2], 13 | [4,5,6], 14 | [3,5,4], 15 | [1,7,2]]) 16 | # 数据的标签 17 | y_train = np.array([1, 0, 1, 1, 0, 0]) 18 | 19 | # 测试数据 20 | x_test = np.array([[2,1,2], 21 | [3,2,6], 22 | [2,6,4]]) 23 | 24 | # 导入模型 25 | model = LogisticRegression() 26 | 27 | #model = RandomForestClassifier() 28 | 29 | #model=XGBClassifier() 30 | 31 | model.fit(x_train, y_train) 32 | 33 | # 返回预测标签 34 | print(model.predict(x_test)) 35 | 36 | print('---------------------------------------') 37 | 38 | # 返回预测属于某标签的概率 39 | print(model.predict_proba(x_test)) 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI_Medical_System 2 | AI_Medical_System(idea分享,大创/基金项目:AI全自动疾病诊断流程的一些构思) 3 | 4 | 5 | #-----------------------------------------------# 6 | 7 | 8 | update on 06/08/2022 by GW.kayak 9 | 10 | 上大招就使用LSTM-CLIP时序神经网络结构,可尝试引入深度强化学习!(优点:潜力大,大数据下潜力效果优于机器学习;缺点:模型大,硬件要求高,结构复杂,不易驾驭,需要数据量很大,数据小容易过拟合);原理流程图DEMO如下: 11 | 12 | 13 | ![Imgur](https://github.com/kay-cottage/AI_Medical_System/blob/main/LSTM1.png) 14 | 15 | 16 | * 基于LSTM-CLIP的多模态自主疾病诊疗方法,包含。包括电子病历信息预处理模块、transformer文本编码器模块、图像编码器模块、图像特征提取网络模块、LSTM循环神经网络模块、以及基于强化学习的交互模块。其中:电子病历信息预处理模块:用于采集病人文本病历以及影像学病历信息并对他们做预处理,使其转换为可以供神经网络输入的相关数据形式。 17 | 18 | 19 | * 编码器模块:分为图像编码器(Images encoder)与文本编码器(Text encoder)模块,图像编码器将输入图像编码成一个包含语义信息的高维向量,同理文本编码器将病人的病历文本信息进行特征提取并将其编码成包含病历语义信息的高维向量。 20 | 21 | 22 | * 特征提取网络模块:该模块为额外附加的模块,可以根据实际应用中不同的需求插入到决策网络中。 23 | 24 | 25 | * LSTM循环神经网络模块:该模块是整个决策网络的主干,也是整个系统中与强化学习环境相交互的agent。模块将前面编码器结构所提取的高维特征做时序建模并做出相应的诊疗动作。 26 | 27 | 28 | * 基于强化学习的交互模块:该模块基于价值网络的DDQN算法,提供整个系统强化学习环境,模块接收来自病人反馈的评分并根据这些反馈生成相应的奖赏指导智能体agent决策网络做出相应诊疗行为action。 29 | 30 | 31 | 32 | #-----------------------------------------------# 33 | 34 | 35 | update on 05/08/2022 by GW.kayak 36 | 37 | 38 | DEMO:baseline主干网络流程图如下:(优点:模型小,硬件较低,需要数据量不大,数据小效果好,容易出结果;缺点:天花板低,大数据时效果比不少深度学习) 39 | 40 | ![Imgur](https://github.com/kay-cottage/AI_Medical_System/blob/main/1%20(2).png) 41 | 42 | #-----------------------------------------------# 43 | 44 | 45 | update on 05/04/2022 by GW.kayak 46 | 47 | 类DNS域名解析服务的疾病流程分级解析流程(初筛疾病种类进行分科->对应专科的疾病初筛网络(开出检查项目)->临床决策网络->临床后端评价网络) 48 | 49 | 初筛网络,决策网络,时序网络,强化学习网络 50 | 51 | 多个多模态Clip网络作分类器进程集成,深度强化学习state reward的现实交互模式从临床中学习(clip Finetune可以见我另一主页) 52 | 53 | LSTM时序网络联系前后网络结果最后做出 54 | -------------------------------------------------------------------------------- /nets/images_encoder.py: -------------------------------------------------------------------------------- 1 | # by kayak.gw 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from collections import OrderedDict 7 | #import clip 8 | from PIL import Image 9 | 10 | 11 | # vision_heads,layers,output_dim,heads,input_resolution,width 32 (3, 4, 6, 3) 1024 32 224 64 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1): 16 | super().__init__() 17 | 18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | 22 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu(self.bn1(self.conv1(x))) 46 | out = self.relu(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.input_resolution = input_resolution 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 114 | self.bn3 = nn.BatchNorm2d(width) 115 | self.avgpool = nn.AvgPool2d(2) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | # residual layers 119 | self._inplanes = width # this is a *mutable* variable used during construction 120 | self.layer1 = self._make_layer(width, layers[0]) 121 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 122 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 123 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 124 | 125 | embed_dim = width * 32 # the ResNet feature dimension 126 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 127 | 128 | def _make_layer(self, planes, blocks, stride=1): 129 | layers = [Bottleneck(self._inplanes, planes, stride)] 130 | self._inplanes = planes * Bottleneck.expansion 131 | for _ in range(1, blocks): 132 | layers.append(Bottleneck(self._inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | def stem(x): 138 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 139 | x = self.relu(bn(conv(x))) 140 | x = self.avgpool(x) 141 | return x 142 | 143 | x = x.type(self.conv1.weight.dtype) 144 | x = stem(x) 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | #print(x.shape) 150 | x = self.attnpool(x) 151 | # shape 1024 152 | return x 153 | ''' 154 | device = "cuda" if torch.cuda.is_available() else "cpu" 155 | model=ModifiedResNet(heads=32,layers=(3, 4, 6, 3) ,output_dim=1024) 156 | ''' 157 | 158 | def img_encoder(pretrained=False,**kwargs): 159 | """ 160 | Construct 161 | """ 162 | model=ModifiedResNet(heads=32,layers=(3, 4, 6, 3) ,output_dim=1024) 163 | if pretrained: 164 | print('Loading img model') 165 | model.load_state_dict(torch.load(r'..\model\RN50_Images_Encoder.pth')) 166 | print('Loaded successfully') 167 | return model 168 | 169 | -------------------------------------------------------------------------------- /nets/model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable 2 | 3 | import torch 4 | from torch import Tensor 5 | import torch.nn as nn 6 | 7 | 8 | def channel_shuffle(x: Tensor, groups: int) -> Tensor: 9 | 10 | batch_size, num_channels, height, width = x.size() 11 | channels_per_group = num_channels // groups 12 | 13 | # reshape 14 | # [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width] 15 | x = x.view(batch_size, groups, channels_per_group, height, width) 16 | 17 | x = torch.transpose(x, 1, 2).contiguous() 18 | 19 | # flatten 20 | x = x.view(batch_size, -1, height, width) 21 | 22 | return x 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, input_c: int, output_c: int, stride: int): 27 | super(InvertedResidual, self).__init__() 28 | 29 | if stride not in [1, 2]: 30 | raise ValueError("illegal stride value.") 31 | self.stride = stride 32 | 33 | assert output_c % 2 == 0 34 | branch_features = output_c // 2 35 | # 当stride为1时,input_channel应该是branch_features的两倍 36 | # python中 '<<' 是位运算,可理解为计算×2的快速方法 37 | assert (self.stride != 1) or (input_c == branch_features << 1) 38 | 39 | if self.stride == 2: 40 | self.branch1 = nn.Sequential( 41 | self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1), 42 | nn.BatchNorm2d(input_c), 43 | nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 44 | nn.BatchNorm2d(branch_features), 45 | nn.ReLU(inplace=True) 46 | ) 47 | else: 48 | self.branch1 = nn.Sequential() 49 | 50 | self.branch2 = nn.Sequential( 51 | nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1, 52 | stride=1, padding=0, bias=False), 53 | nn.BatchNorm2d(branch_features), 54 | nn.ReLU(inplace=True), 55 | self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1), 56 | nn.BatchNorm2d(branch_features), 57 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 58 | nn.BatchNorm2d(branch_features), 59 | nn.ReLU(inplace=True) 60 | ) 61 | 62 | @staticmethod 63 | def depthwise_conv(input_c: int, 64 | output_c: int, 65 | kernel_s: int, 66 | stride: int = 1, 67 | padding: int = 0, 68 | bias: bool = False) -> nn.Conv2d: 69 | return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s, 70 | stride=stride, padding=padding, bias=bias, groups=input_c) 71 | 72 | def forward(self, x: Tensor) -> Tensor: 73 | if self.stride == 1: 74 | x1, x2 = x.chunk(2, dim=1) 75 | out = torch.cat((x1, self.branch2(x2)), dim=1) 76 | else: 77 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 78 | 79 | out = channel_shuffle(out, 2) 80 | 81 | return out 82 | 83 | 84 | class ShuffleNetV2(nn.Module): 85 | def __init__(self, 86 | stages_repeats: List[int], 87 | stages_out_channels: List[int], 88 | num_classes: int = 1000, 89 | inverted_residual: Callable[..., nn.Module] = InvertedResidual): 90 | super(ShuffleNetV2, self).__init__() 91 | 92 | if len(stages_repeats) != 3: 93 | raise ValueError("expected stages_repeats as list of 3 positive ints") 94 | if len(stages_out_channels) != 5: 95 | raise ValueError("expected stages_out_channels as list of 5 positive ints") 96 | self._stage_out_channels = stages_out_channels 97 | 98 | # input RGB image 99 | input_channels = 3 100 | output_channels = self._stage_out_channels[0] 101 | 102 | self.conv1 = nn.Sequential( 103 | nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False), 104 | nn.BatchNorm2d(output_channels), 105 | nn.ReLU(inplace=True) 106 | ) 107 | input_channels = output_channels 108 | 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | 111 | # Static annotations for mypy 112 | self.stage2: nn.Sequential 113 | self.stage3: nn.Sequential 114 | self.stage4: nn.Sequential 115 | 116 | stage_names = ["stage{}".format(i) for i in [2, 3, 4]] 117 | for name, repeats, output_channels in zip(stage_names, stages_repeats, 118 | self._stage_out_channels[1:]): 119 | seq = [inverted_residual(input_channels, output_channels, 2)] 120 | for i in range(repeats - 1): 121 | seq.append(inverted_residual(output_channels, output_channels, 1)) 122 | setattr(self, name, nn.Sequential(*seq)) 123 | input_channels = output_channels 124 | 125 | output_channels = self._stage_out_channels[-1] 126 | self.conv5 = nn.Sequential( 127 | nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False), 128 | nn.BatchNorm2d(output_channels), 129 | nn.ReLU(inplace=True) 130 | ) 131 | 132 | self.fc = nn.Linear(output_channels, num_classes) 133 | 134 | def _forward_impl(self, x: Tensor) -> Tensor: 135 | # See note [TorchScript super()] 136 | x = self.conv1(x) 137 | x = self.maxpool(x) 138 | x = self.stage2(x) 139 | x = self.stage3(x) 140 | x = self.stage4(x) 141 | x = self.conv5(x) 142 | x = x.mean([2, 3]) # global pool 143 | x = self.fc(x) 144 | return x 145 | 146 | def forward(self, x: Tensor) -> Tensor: 147 | return self._forward_impl(x) 148 | 149 | 150 | def shufflenet_v2_x1_0(num_classes=1000): 151 | """ 152 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 153 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 154 | `. 155 | weight: https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth 156 | 157 | :param num_classes: 158 | :return: 159 | """ 160 | model = ShuffleNetV2(stages_repeats=[4, 8, 4], 161 | stages_out_channels=[24, 116, 232, 464, 1024], 162 | num_classes=num_classes) 163 | 164 | return model 165 | 166 | 167 | def shufflenet_v2_x0_5(num_classes=1000): 168 | """ 169 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 170 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 171 | `. 172 | weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth 173 | 174 | :param num_classes: 175 | :return: 176 | """ 177 | model = ShuffleNetV2(stages_repeats=[4, 8, 4], 178 | stages_out_channels=[24, 48, 96, 192, 1024], 179 | num_classes=num_classes) 180 | 181 | return model 182 | -------------------------------------------------------------------------------- /nets/transformer_text_encoder.py: -------------------------------------------------------------------------------- 1 | # by kayak.gw 2 | import torch 3 | #import clip 4 | from PIL import Image 5 | import time 6 | #from text_encoder import C 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | from collections import OrderedDict 12 | from typing import Tuple, Union 13 | 14 | 15 | 16 | class LayerNorm(nn.LayerNorm): 17 | """Subclass torch's LayerNorm to handle fp16.""" 18 | 19 | def forward(self, x: torch.Tensor): 20 | orig_type = x.dtype 21 | ret = super().forward(x.type(torch.float32)) 22 | return ret.type(orig_type) 23 | 24 | 25 | class QuickGELU(nn.Module): 26 | def forward(self, x: torch.Tensor): 27 | return x * torch.sigmoid(1.702 * x) 28 | 29 | 30 | class ResidualAttentionBlock(nn.Module): 31 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 32 | super().__init__() 33 | 34 | self.attn = nn.MultiheadAttention(d_model, n_head) 35 | self.ln_1 = LayerNorm(d_model) 36 | self.mlp = nn.Sequential(OrderedDict([ 37 | ("c_fc", nn.Linear(d_model, d_model * 4)), 38 | ("gelu", QuickGELU()), 39 | ("c_proj", nn.Linear(d_model * 4, d_model)) 40 | ])) 41 | self.ln_2 = LayerNorm(d_model) 42 | self.attn_mask = attn_mask 43 | 44 | def attention(self, x: torch.Tensor): 45 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 46 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 47 | 48 | def forward(self, x: torch.Tensor): 49 | x = x + self.attention(self.ln_1(x)) 50 | x = x + self.mlp(self.ln_2(x)) 51 | return x 52 | 53 | 54 | class Transformer(nn.Module): 55 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 56 | super().__init__() 57 | self.width = width 58 | self.layers = layers 59 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 60 | 61 | def forward(self, x: torch.Tensor): 62 | return self.resblocks(x) 63 | 64 | 65 | 66 | class Text_Encoder(nn.Module): 67 | def __init__(self, 68 | embed_dim: int, 69 | # text 70 | context_length: int, 71 | vocab_size: int, 72 | transformer_width: int, 73 | transformer_heads: int, 74 | transformer_layers: int 75 | ): 76 | super().__init__() 77 | 78 | self.context_length = context_length 79 | 80 | 81 | 82 | self.transformer = Transformer( 83 | width=transformer_width, 84 | layers=transformer_layers, 85 | heads=transformer_heads, 86 | attn_mask=self.build_attention_mask() 87 | ) 88 | #512 12 8 tensor( 89 | #print(transformer_width,transformer_layers,transformer_heads,self.build_attention_mask()) 90 | 91 | self.vocab_size = vocab_size 92 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 93 | #print(vocab_size, transformer_width,self.context_length) 94 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 95 | self.ln_final = LayerNorm(transformer_width) 96 | 97 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 98 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 99 | 100 | self.initialize_parameters() 101 | 102 | def initialize_parameters(self): 103 | nn.init.normal_(self.token_embedding.weight, std=0.02) 104 | nn.init.normal_(self.positional_embedding, std=0.01) 105 | 106 | 107 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 108 | attn_std = self.transformer.width ** -0.5 109 | fc_std = (2 * self.transformer.width) ** -0.5 110 | for block in self.transformer.resblocks: 111 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 112 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 113 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 114 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 115 | 116 | if self.text_projection is not None: 117 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 118 | 119 | def build_attention_mask(self): 120 | # lazily create causal attention mask, with full attention between the vision tokens 121 | # pytorch uses additive attention mask; fill with -inf 122 | mask = torch.empty(self.context_length, self.context_length) 123 | mask.fill_(float("-inf")) 124 | mask.triu_(1) # zero out the lower diagonal 125 | return mask 126 | 127 | @property 128 | def dtype(self): 129 | return torch.FloatTensor 130 | 131 | def forward(self, text): 132 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 133 | #x1 torch.Size([1, 77, 512]) 134 | #x2 torch.Size([1, 77, 512]) 135 | #x3 torch.Size([77, 1, 512]) 136 | #print('x1',x.shape) 137 | x = x + self.positional_embedding.type(self.dtype) 138 | #print('x2',x.shape) 139 | x = x.permute(1, 0, 2) # NLD -> LND 140 | x = self.transformer(x) 141 | #print('x3',x.shape) 142 | x = x.permute(1, 0, 2) # LND -> NLD 143 | x = self.ln_final(x).type(self.dtype) 144 | 145 | # x.shape = [batch_size, n_ctx, transformer.width] 146 | # take features from the eot embedding (eot_token is the highest number in each sequence) 147 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 148 | 149 | return x 150 | 151 | def save_model(): 152 | model=Text_Encoder(1024,77,49408,512,8,12) 153 | torch.save(model.state_dict(), "text_encoder.pth") 154 | 155 | def load_model(model_path): 156 | return torch.load(model_path) 157 | 158 | 159 | def text_encoder(pretrained=False,**kwargs): 160 | """ 161 | Construct 162 | """ 163 | model=Text_Encoder(1024,77,49408,512,8,12) 164 | if pretrained: 165 | print('Loading text model') 166 | model.load_state_dict(torch.load(r'..\model\transformer_block_weight.pth')) 167 | print('Loaded successfully') 168 | return model 169 | 170 | ''' 171 | # DEMO 172 | save_model() 173 | model=load_model("text_encoder.pth") 174 | for k,v in model.items(): 175 | print(v.shape) 176 | 177 | #DEMO 178 | from images_encoder import ModifiedResNet 179 | img=torch.rand(1,3,224,224) 180 | model1=ModifiedResNet(heads=32,layers=(3, 4, 6, 3) ,output_dim=1024) 181 | img_f=model1(img) 182 | 183 | model=Text_Encoder(1024,77,49408,512,8,12) 184 | model.load_state_dict(torch.load(r'transformer_block_weight.pth')) 185 | a=torch.rand(1,77).type(torch.LongTensor) 186 | x = model(a) 187 | 188 | print(x.shape) 189 | 190 | 191 | out=torch.add(img_f,x) 192 | fc=nn.Linear(1024,10) 193 | 194 | print(fc(out)) 195 | ''' 196 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # by kayak.gw 2 | import torch 3 | import torch.nn as nn 4 | from torchvision import transforms, datasets, utils 5 | 6 | import numpy as np 7 | import torch.optim as optim 8 | from nets.model import net,Net 9 | 10 | 11 | img=torch.rand(8,3,224,224) 12 | text=torch.rand(1,77).type(torch.LongTensor) 13 | model=net() 14 | a=model(img,text) 15 | print(a) 16 | --------------------------------------------------------------------------------