├── .DS_Store ├── .gitignore ├── .vscode └── settings.json ├── README.md ├── SETR.png ├── SETR ├── __pycache__ │ ├── transformer_model.cpython-37.pyc │ └── transformer_seg.cpython-37.pyc ├── transformer_model.py └── transformer_seg.py ├── main.py ├── task_minst.py └── tast_car_seg.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/SETR-pytorch/d4183440a6362795c9b0f2913ae805837470725a/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | checkpoints -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/Users/xingzhaohu/.local/share/virtualenvs/ml-5foBrNl9/bin/python" 3 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## SETR - Pytorch 5 | 6 | Since the original paper (Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.) has no official code,I implemented SETR-Progressive UPsampling(SETR-PUP) using pytorch. 7 | 8 | Original paper: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers. 9 | 10 | ## Vit 11 | The Vit model is also implemented, and you can use it for image classification. 12 | 13 | ## Usage SETR 14 | 15 | ```python 16 | from SETR.transformer_seg import SETRModel 17 | import torch 18 | 19 | if __name__ == "__main__": 20 | net = SETRModel(patch_size=(32, 32), 21 | in_channels=3, 22 | out_channels=1, 23 | hidden_size=1024, 24 | num_hidden_layers=8, 25 | num_attention_heads=16, 26 | decode_features=[512, 256, 128, 64]) 27 | t1 = torch.rand(1, 3, 256, 256) 28 | print("input: " + str(t1.shape)) 29 | 30 | # print(net) 31 | print("output: " + str(net(t1).shape)) 32 | 33 | ``` 34 | If the output size is (1, 1, 256, 256), the code runs successfully. 35 | 36 | ## Usage Vit 37 | ```python 38 | from SETR.transformer_seg import Vit 39 | import torch 40 | 41 | if __name__ == "__main__": 42 | model = Vit(patch_size=(7, 7), 43 | in_channels=1, 44 | out_class=10, 45 | hidden_size=1024, 46 | num_hidden_layers=1, 47 | num_attention_heads=16) 48 | print(model) 49 | t1 = torch.rand(1, 1, 28, 28) 50 | print("input: " + str(t1.shape)) 51 | 52 | print("output: " + str(model(t1).shape)) 53 | ``` 54 | The output shape is (1, 10). 55 | 56 | ## current examples 57 | 1. task_mnist: The simplest example, using the Vit model to classify the minst dataset. 58 | 2. task_car_seg: The example is sample segmentation task. data download: https://www.kaggle.com/c/carvana-image-masking-challenge/data 59 | 60 | ## more 61 | More examples will be updated later. -------------------------------------------------------------------------------- /SETR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/SETR-pytorch/d4183440a6362795c9b0f2913ae805837470725a/SETR.png -------------------------------------------------------------------------------- /SETR/__pycache__/transformer_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/SETR-pytorch/d4183440a6362795c9b0f2913ae805837470725a/SETR/__pycache__/transformer_model.cpython-37.pyc -------------------------------------------------------------------------------- /SETR/__pycache__/transformer_seg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/SETR-pytorch/d4183440a6362795c9b0f2913ae805837470725a/SETR/__pycache__/transformer_seg.cpython-37.pyc -------------------------------------------------------------------------------- /SETR/transformer_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import CrossEntropyLoss, MSELoss 8 | from einops import rearrange 9 | 10 | def swish(x): 11 | return x * torch.sigmoid(x) 12 | 13 | def gelu(x): 14 | """ 15 | """ 16 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 17 | 18 | def mish(x): 19 | return x * torch.tanh(nn.functional.softplus(x)) 20 | 21 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "mish": mish} 22 | 23 | class TransConfig(object): 24 | 25 | def __init__( 26 | self, 27 | patch_size, 28 | in_channels, 29 | out_channels, 30 | sample_rate=4, 31 | hidden_size=768, 32 | num_hidden_layers=8, 33 | num_attention_heads=6, 34 | intermediate_size=1024, 35 | hidden_act="gelu", 36 | hidden_dropout_prob=0.1, 37 | attention_probs_dropout_prob=0.1, 38 | max_position_embeddings=512, 39 | initializer_range=0.02, 40 | layer_norm_eps=1e-12, 41 | ): 42 | self.sample_rate = sample_rate 43 | self.patch_size = patch_size 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.hidden_size = hidden_size 47 | self.num_hidden_layers = num_hidden_layers 48 | self.num_attention_heads = num_attention_heads 49 | self.hidden_act = hidden_act 50 | self.intermediate_size = intermediate_size 51 | self.hidden_dropout_prob = hidden_dropout_prob 52 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 53 | self.max_position_embeddings = max_position_embeddings 54 | self.initializer_range = initializer_range 55 | self.layer_norm_eps = layer_norm_eps 56 | 57 | 58 | class TransLayerNorm(nn.Module): 59 | def __init__(self, hidden_size, eps=1e-12): 60 | """Construct a layernorm module in the TF style (epsilon inside the square root). 61 | """ 62 | super(TransLayerNorm, self).__init__() 63 | 64 | self.gamma = nn.Parameter(torch.ones(hidden_size)) 65 | self.beta = nn.Parameter(torch.zeros(hidden_size)) 66 | self.variance_epsilon = eps 67 | 68 | 69 | def forward(self, x): 70 | 71 | u = x.mean(-1, keepdim=True) 72 | s = (x - u).pow(2).mean(-1, keepdim=True) 73 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 74 | return self.gamma * x + self.beta 75 | 76 | class TransEmbeddings(nn.Module): 77 | """Construct the embeddings from word, position and token_type embeddings. 78 | """ 79 | 80 | def __init__(self, config): 81 | super().__init__() 82 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 83 | 84 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 85 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 86 | 87 | def forward(self, input_ids): 88 | input_shape = input_ids.size() 89 | 90 | seq_length = input_shape[1] 91 | device = input_ids.device 92 | 93 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device) 94 | position_ids = position_ids.unsqueeze(0).expand(input_shape[:2]) 95 | 96 | position_embeddings = self.position_embeddings(position_ids) 97 | 98 | embeddings = input_ids + position_embeddings 99 | embeddings = self.LayerNorm(embeddings) 100 | embeddings = self.dropout(embeddings) 101 | return embeddings 102 | 103 | class TransSelfAttention(nn.Module): 104 | def __init__(self, config: TransConfig): 105 | super().__init__() 106 | if config.hidden_size % config.num_attention_heads != 0: 107 | raise ValueError( 108 | "The hidden size (%d) is not a multiple of the number of attention " 109 | "heads (%d)" % (config.hidden_size, config.num_attention_heads) 110 | ) 111 | 112 | self.num_attention_heads = config.num_attention_heads 113 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 114 | self.all_head_size = self.num_attention_heads * self.attention_head_size 115 | 116 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 117 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 118 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 119 | 120 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 121 | 122 | def transpose_for_scores(self, x): 123 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 124 | x = x.view(*new_x_shape) 125 | 126 | ## 最后xshape (batch_size, num_attention_heads, seq_len, head_size) 127 | return x.permute(0, 2, 1, 3) 128 | 129 | def forward( 130 | self, 131 | hidden_states 132 | ): 133 | mixed_query_layer = self.query(hidden_states) 134 | mixed_key_layer = self.key(hidden_states) 135 | mixed_value_layer = self.value(hidden_states) 136 | 137 | query_layer = self.transpose_for_scores(mixed_query_layer) 138 | key_layer = self.transpose_for_scores(mixed_key_layer) 139 | value_layer = self.transpose_for_scores(mixed_value_layer) 140 | 141 | # Take the dot product between "query" and "key" to get the raw attention scores. 142 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 143 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 144 | 145 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 146 | attention_scores = attention_scores 147 | 148 | # Normalize the attention scores to probabilities. 149 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 150 | 151 | # This is actually dropping out entire tokens to attend to, which might 152 | # seem a bit unusual, but is taken from the original Transformer paper. 153 | attention_probs = self.dropout(attention_probs) 154 | 155 | # 注意力加权 156 | context_layer = torch.matmul(attention_probs, value_layer) 157 | # 把加权后的V reshape, 得到[batch_size, length, embedding_dimension] 158 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 159 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 160 | context_layer = context_layer.view(*new_context_layer_shape) 161 | 162 | return context_layer 163 | 164 | 165 | class TransSelfOutput(nn.Module): 166 | def __init__(self, config): 167 | super().__init__() 168 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 169 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 170 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 171 | 172 | def forward(self, hidden_states, input_tensor): 173 | hidden_states = self.dense(hidden_states) 174 | hidden_states = self.dropout(hidden_states) 175 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 176 | return hidden_states 177 | 178 | 179 | class TransAttention(nn.Module): 180 | def __init__(self, config): 181 | super().__init__() 182 | self.self = TransSelfAttention(config) 183 | self.output = TransSelfOutput(config) 184 | 185 | def forward( 186 | self, 187 | hidden_states, 188 | ): 189 | self_outputs = self.self(hidden_states) 190 | attention_output = self.output(self_outputs, hidden_states) 191 | 192 | return attention_output 193 | 194 | 195 | class TransIntermediate(nn.Module): 196 | def __init__(self, config): 197 | super().__init__() 198 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 199 | self.intermediate_act_fn = ACT2FN[config.hidden_act] ## relu 200 | 201 | def forward(self, hidden_states): 202 | hidden_states = self.dense(hidden_states) 203 | hidden_states = self.intermediate_act_fn(hidden_states) 204 | return hidden_states 205 | 206 | class TransOutput(nn.Module): 207 | def __init__(self, config): 208 | super().__init__() 209 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 210 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 211 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 212 | 213 | def forward(self, hidden_states, input_tensor): 214 | hidden_states = self.dense(hidden_states) 215 | hidden_states = self.dropout(hidden_states) 216 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 217 | return hidden_states 218 | 219 | 220 | class TransLayer(nn.Module): 221 | def __init__(self, config): 222 | super().__init__() 223 | self.attention = TransAttention(config) 224 | self.intermediate = TransIntermediate(config) 225 | self.output = TransOutput(config) 226 | 227 | def forward( 228 | self, 229 | hidden_states 230 | ): 231 | attention_output = self.attention(hidden_states) 232 | intermediate_output = self.intermediate(attention_output) 233 | layer_output = self.output(intermediate_output, attention_output) 234 | return layer_output 235 | 236 | 237 | class TransEncoder(nn.Module): 238 | def __init__(self, config): 239 | super().__init__() 240 | self.layer = nn.ModuleList([TransLayer(config) for _ in range(config.num_hidden_layers)]) 241 | 242 | def forward( 243 | self, 244 | hidden_states, 245 | output_all_encoded_layers=True, 246 | ): 247 | all_encoder_layers = [] 248 | 249 | for i, layer_module in enumerate(self.layer): 250 | 251 | layer_output = layer_module(hidden_states) 252 | hidden_states = layer_output 253 | if output_all_encoded_layers: 254 | all_encoder_layers.append(hidden_states) 255 | if not output_all_encoded_layers: 256 | all_encoder_layers.append(hidden_states) 257 | 258 | return all_encoder_layers 259 | 260 | class InputDense2d(nn.Module): 261 | def __init__(self, config): 262 | super(InputDense2d, self).__init__() 263 | self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.in_channels, config.hidden_size) 264 | self.transform_act_fn = ACT2FN[config.hidden_act] 265 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 266 | 267 | def forward(self, hidden_states): 268 | hidden_states = self.dense(hidden_states) 269 | hidden_states = self.transform_act_fn(hidden_states) 270 | hidden_states = self.LayerNorm(hidden_states) 271 | return hidden_states 272 | 273 | class InputDense3d(nn.Module): 274 | def __init__(self, config): 275 | super(InputDense3d, self).__init__() 276 | self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.patch_size[2] * config.in_channels, config.hidden_size) 277 | self.transform_act_fn = ACT2FN[config.hidden_act] 278 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 279 | 280 | def forward(self, hidden_states): 281 | hidden_states = self.dense(hidden_states) 282 | hidden_states = self.transform_act_fn(hidden_states) 283 | hidden_states = self.LayerNorm(hidden_states) 284 | return hidden_states 285 | 286 | class TransModel2d(nn.Module): 287 | 288 | def __init__(self, config): 289 | super(TransModel2d, self).__init__() 290 | self.config = config 291 | self.dense = InputDense2d(config) 292 | self.embeddings = TransEmbeddings(config) 293 | self.encoder = TransEncoder(config) 294 | 295 | def forward( 296 | self, 297 | input_ids, 298 | output_all_encoded_layers=True, 299 | 300 | ): 301 | dense_out = self.dense(input_ids) 302 | embedding_output = self.embeddings( 303 | input_ids=dense_out 304 | ) 305 | encoder_layers = self.encoder( 306 | embedding_output, 307 | output_all_encoded_layers=output_all_encoded_layers, 308 | ) 309 | sequence_output = encoder_layers[-1] 310 | 311 | if not output_all_encoded_layers: 312 | # 如果不用输出所有encoder层 313 | encoder_layers = encoder_layers[-1] 314 | return encoder_layers 315 | 316 | 317 | class TransModel3d(nn.Module): 318 | 319 | def __init__(self, config): 320 | super(TransModel3d, self).__init__() 321 | self.config = config 322 | self.dense = InputDense3d(config) 323 | self.embeddings = TransEmbeddings(config) 324 | self.encoder = TransEncoder(config) 325 | 326 | def forward( 327 | self, 328 | input_ids, 329 | output_all_encoded_layers=True, 330 | 331 | ): 332 | dense_out = self.dense(input_ids) 333 | embedding_output = self.embeddings( 334 | input_ids=dense_out 335 | ) 336 | encoder_layers = self.encoder( 337 | embedding_output, 338 | output_all_encoded_layers=output_all_encoded_layers, 339 | ) 340 | sequence_output = encoder_layers[-1] 341 | 342 | if not output_all_encoded_layers: 343 | # 如果不用输出所有encoder层 344 | encoder_layers = encoder_layers[-1] 345 | return encoder_layers -------------------------------------------------------------------------------- /SETR/transformer_seg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import numpy as np 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import CrossEntropyLoss, MSELoss 9 | from einops import rearrange 10 | from SETR.transformer_model import TransModel2d, TransConfig 11 | import math 12 | 13 | class Encoder2D(nn.Module): 14 | def __init__(self, config: TransConfig, is_segmentation=True): 15 | super().__init__() 16 | self.config = config 17 | self.out_channels = config.out_channels 18 | self.bert_model = TransModel2d(config) 19 | sample_rate = config.sample_rate 20 | sample_v = int(math.pow(2, sample_rate)) 21 | assert config.patch_size[0] * config.patch_size[1] * config.hidden_size % (sample_v**2) == 0, "不能除尽" 22 | self.final_dense = nn.Linear(config.hidden_size, config.patch_size[0] * config.patch_size[1] * config.hidden_size // (sample_v**2)) 23 | self.patch_size = config.patch_size 24 | self.hh = self.patch_size[0] // sample_v 25 | self.ww = self.patch_size[1] // sample_v 26 | 27 | self.is_segmentation = is_segmentation 28 | def forward(self, x): 29 | ## x:(b, c, w, h) 30 | b, c, h, w = x.shape 31 | assert self.config.in_channels == c, "in_channels != 输入图像channel" 32 | p1 = self.patch_size[0] 33 | p2 = self.patch_size[1] 34 | 35 | if h % p1 != 0: 36 | print("请重新输入img size 参数 必须整除") 37 | os._exit(0) 38 | if w % p2 != 0: 39 | print("请重新输入img size 参数 必须整除") 40 | os._exit(0) 41 | hh = h // p1 42 | ww = w // p2 43 | 44 | x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p1, p2 = p2) 45 | 46 | encode_x = self.bert_model(x)[-1] # 取出来最后一层 47 | if not self.is_segmentation: 48 | return encode_x 49 | 50 | x = self.final_dense(encode_x) 51 | x = rearrange(x, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", p1 = self.hh, p2 = self.ww, h = hh, w = ww, c = self.config.hidden_size) 52 | return encode_x, x 53 | 54 | 55 | class PreTrainModel(nn.Module): 56 | def __init__(self, patch_size, 57 | in_channels, 58 | out_class, 59 | hidden_size=1024, 60 | num_hidden_layers=8, 61 | num_attention_heads=16, 62 | decode_features=[512, 256, 128, 64]): 63 | super().__init__() 64 | config = TransConfig(patch_size=patch_size, 65 | in_channels=in_channels, 66 | out_channels=0, 67 | hidden_size=hidden_size, 68 | num_hidden_layers=num_hidden_layers, 69 | num_attention_heads=num_attention_heads) 70 | self.encoder_2d = Encoder2D(config, is_segmentation=False) 71 | self.cls = nn.Linear(hidden_size, out_class) 72 | 73 | def forward(self, x): 74 | encode_img = self.encoder_2d(x) 75 | encode_pool = encode_img.mean(dim=1) 76 | out = self.cls(encode_pool) 77 | return out 78 | 79 | class Vit(nn.Module): 80 | def __init__(self, patch_size, 81 | in_channels, 82 | out_class, 83 | hidden_size=1024, 84 | num_hidden_layers=8, 85 | num_attention_heads=16, 86 | sample_rate=4, 87 | ): 88 | super().__init__() 89 | config = TransConfig(patch_size=patch_size, 90 | in_channels=in_channels, 91 | out_channels=0, 92 | sample_rate=sample_rate, 93 | hidden_size=hidden_size, 94 | num_hidden_layers=num_hidden_layers, 95 | num_attention_heads=num_attention_heads) 96 | self.encoder_2d = Encoder2D(config, is_segmentation=False) 97 | self.cls = nn.Linear(hidden_size, out_class) 98 | 99 | def forward(self, x): 100 | encode_img = self.encoder_2d(x) 101 | 102 | encode_pool = encode_img.mean(dim=1) 103 | out = self.cls(encode_pool) 104 | return out 105 | 106 | class Decoder2D(nn.Module): 107 | def __init__(self, in_channels, out_channels, features=[512, 256, 128, 64]): 108 | super().__init__() 109 | self.decoder_1 = nn.Sequential( 110 | nn.Conv2d(in_channels, features[0], 3, padding=1), 111 | nn.BatchNorm2d(features[0]), 112 | nn.ReLU(inplace=True), 113 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 114 | ) 115 | self.decoder_2 = nn.Sequential( 116 | nn.Conv2d(features[0], features[1], 3, padding=1), 117 | nn.BatchNorm2d(features[1]), 118 | nn.ReLU(inplace=True), 119 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 120 | ) 121 | self.decoder_3 = nn.Sequential( 122 | nn.Conv2d(features[1], features[2], 3, padding=1), 123 | nn.BatchNorm2d(features[2]), 124 | nn.ReLU(inplace=True), 125 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 126 | ) 127 | self.decoder_4 = nn.Sequential( 128 | nn.Conv2d(features[2], features[3], 3, padding=1), 129 | nn.BatchNorm2d(features[3]), 130 | nn.ReLU(inplace=True), 131 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 132 | ) 133 | 134 | self.final_out = nn.Conv2d(features[-1], out_channels, 3, padding=1) 135 | 136 | def forward(self, x): 137 | x = self.decoder_1(x) 138 | x = self.decoder_2(x) 139 | x = self.decoder_3(x) 140 | x = self.decoder_4(x) 141 | x = self.final_out(x) 142 | return x 143 | 144 | class SETRModel(nn.Module): 145 | def __init__(self, patch_size=(32, 32), 146 | in_channels=3, 147 | out_channels=1, 148 | hidden_size=1024, 149 | num_hidden_layers=8, 150 | num_attention_heads=16, 151 | decode_features=[512, 256, 128, 64], 152 | sample_rate=4,): 153 | super().__init__() 154 | config = TransConfig(patch_size=patch_size, 155 | in_channels=in_channels, 156 | out_channels=out_channels, 157 | sample_rate=sample_rate, 158 | hidden_size=hidden_size, 159 | num_hidden_layers=num_hidden_layers, 160 | num_attention_heads=num_attention_heads) 161 | self.encoder_2d = Encoder2D(config) 162 | self.decoder_2d = Decoder2D(in_channels=config.hidden_size, out_channels=config.out_channels, features=decode_features) 163 | 164 | def forward(self, x): 165 | _, final_x = self.encoder_2d(x) 166 | x = self.decoder_2d(final_x) 167 | return x 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from SETR.transformer_seg import SETRModel, Vit 2 | import torch 3 | 4 | if __name__ == "__main__": 5 | net = SETRModel(patch_size=(32, 32), 6 | in_channels=3, 7 | out_channels=1, 8 | hidden_size=1024, 9 | sample_rate=5, 10 | num_hidden_layers=1, 11 | num_attention_heads=16, 12 | decode_features=[512, 256, 128, 64]) 13 | t1 = torch.rand(1, 3, 512, 512) 14 | print("input: " + str(t1.shape)) 15 | 16 | print("output: " + str(net(t1).shape)) 17 | 18 | 19 | model = Vit(patch_size=(32, 32), 20 | in_channels=1, 21 | out_class=10, 22 | sample_rate=4, 23 | hidden_size=1024, 24 | num_hidden_layers=1, 25 | num_attention_heads=16) 26 | 27 | t1 = torch.rand(1, 1, 512, 512) 28 | print("input: " + str(t1.shape)) 29 | 30 | print("output: " + str(model(t1).shape)) 31 | 32 | -------------------------------------------------------------------------------- /task_minst.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from SETR.transformer_seg import Vit 3 | import torchvision 4 | import torch 5 | import torch.nn as nn 6 | from torchvision import datasets, transforms 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | print("device is " + str(device)) 12 | 13 | def compute_acc(model, test_dataloader): 14 | with torch.no_grad(): 15 | right_num = 0 16 | total_num = 0 17 | for in_data, label in tqdm(test_dataloader, total=len(test_dataloader)): 18 | in_data = in_data.to(device) 19 | label = label.to(device) 20 | total_num += len(in_data) 21 | out = model(in_data) 22 | pred = out.argmax(dim=-1) 23 | for i, each_pred in enumerate(pred): 24 | if int(each_pred) == int(label[i]): 25 | right_num += 1 26 | 27 | return (right_num / total_num) 28 | 29 | if __name__ == "__main__": 30 | 31 | model = Vit(patch_size=(7, 7), 32 | in_channels=1, 33 | out_class=10, 34 | hidden_size=1024, 35 | num_hidden_layers=1, 36 | num_attention_heads=16, 37 | ) 38 | print(model) 39 | model.to(device) 40 | 41 | transform = transforms.Compose([transforms.ToTensor(), 42 | transforms.Normalize(mean=[0.5],std=[0.5])]) 43 | data_train = datasets.MNIST(root = "./data/", 44 | transform=transform, 45 | train = True, 46 | download = True) 47 | 48 | data_test = datasets.MNIST(root="./data/", 49 | transform = transform, 50 | train = False) 51 | 52 | data_loader_train = torch.utils.data.DataLoader(dataset=data_train, 53 | batch_size = 64, 54 | shuffle = True) 55 | 56 | data_loader_test = torch.utils.data.DataLoader(dataset=data_test, 57 | batch_size = 32, 58 | shuffle = True) 59 | 60 | optimizer = torch.optim.Adam(model.parameters()) 61 | loss_func = nn.CrossEntropyLoss() 62 | report_loss = 0 63 | step = 0 64 | best_acc = 0.0 65 | 66 | for in_data, label in tqdm(data_loader_train, total=len(data_loader_train)): 67 | batch_size = len(in_data) 68 | in_data = in_data.to(device) 69 | label = label.to(device) 70 | optimizer.zero_grad() 71 | step += 1 72 | out = model(in_data) 73 | loss = loss_func(out, label) 74 | loss.backward() 75 | optimizer.step() 76 | report_loss += loss.item() 77 | if step % 10 == 0: 78 | print("report_loss is : " + str(report_loss)) 79 | report_loss = 0 80 | acc = compute_acc(model, data_loader_test) 81 | if acc > best_acc: 82 | best_acc = acc 83 | torch.save(model.state_dict(), "./checkpoints/mnist_model.pkl") 84 | 85 | print("acc is " + str(acc) + ", best acc is " + str(best_acc)) 86 | -------------------------------------------------------------------------------- /tast_car_seg.py: -------------------------------------------------------------------------------- 1 | # data_url : https://www.kaggle.com/c/carvana-image-masking-challenge/data 2 | import torch 3 | import numpy as np 4 | from SETR.transformer_seg import SETRModel 5 | from PIL import Image 6 | import glob 7 | import torch.nn as nn 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader, Dataset 10 | from tqdm import tqdm 11 | 12 | img_url = sorted(glob.glob("./segmentation_car/imgs/*")) 13 | mask_url = sorted(glob.glob("./segmentation_car/masks/*")) 14 | # print(img_url) 15 | train_size = int(len(img_url) * 0.8) 16 | train_img_url = img_url[:train_size] 17 | train_mask_url = mask_url[:train_size] 18 | val_img_url = img_url[train_size:] 19 | val_mask_url = mask_url[train_size:] 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | print("device is " + str(device)) 23 | epoches = 100 24 | out_channels = 1 25 | 26 | def build_model(): 27 | model = SETRModel(patch_size=(16, 16), 28 | in_channels=3, 29 | out_channels=1, 30 | hidden_size=1024, 31 | num_hidden_layers=6, 32 | num_attention_heads=16, 33 | decode_features=[512, 256, 128, 64]) 34 | return model 35 | 36 | class CarDataset(Dataset): 37 | def __init__(self, img_url, mask_url): 38 | super(CarDataset, self).__init__() 39 | self.img_url = img_url 40 | self.mask_url = mask_url 41 | 42 | def __getitem__(self, idx): 43 | img = Image.open(self.img_url[idx]) 44 | img = img.resize((256, 256)) 45 | img_array = np.array(img, dtype=np.float32) / 255 46 | mask = Image.open(self.mask_url[idx]) 47 | mask = mask.resize((256, 256)) 48 | mask = np.array(mask, dtype=np.float32) 49 | img_array = img_array.transpose(2, 0, 1) 50 | 51 | return torch.tensor(img_array.copy()), torch.tensor(mask.copy()) 52 | 53 | def __len__(self): 54 | return len(self.img_url) 55 | 56 | def compute_dice(input, target): 57 | eps = 0.0001 58 | # input 是经过了sigmoid 之后的输出。 59 | input = (input > 0.5).float() 60 | target = (target > 0.5).float() 61 | 62 | # inter = torch.dot(input.view(-1), target.view(-1)) + eps 63 | inter = torch.sum(target.view(-1) * input.view(-1)) + eps 64 | 65 | # print(self.inter) 66 | union = torch.sum(input) + torch.sum(target) + eps 67 | 68 | t = (2 * inter.float()) / union.float() 69 | return t 70 | 71 | def predict(): 72 | model = build_model() 73 | model.load_state_dict(torch.load("./checkpoints/SETR_car.pkl", map_location="cpu")) 74 | print(model) 75 | 76 | import matplotlib.pyplot as plt 77 | val_dataset = CarDataset(val_img_url, val_mask_url) 78 | val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) 79 | with torch.no_grad(): 80 | for img, mask in val_loader: 81 | pred = torch.sigmoid(model(img)) 82 | pred = (pred > 0.5).int() 83 | plt.subplot(1, 3, 1) 84 | print(img.shape) 85 | img = img.permute(0, 2, 3, 1) 86 | plt.imshow(img[0]) 87 | plt.subplot(1, 3, 2) 88 | plt.imshow(pred[0].squeeze(0), cmap="gray") 89 | plt.subplot(1, 3, 3) 90 | plt.imshow(mask[0], cmap="gray") 91 | plt.show() 92 | 93 | if __name__ == "__main__": 94 | 95 | model = build_model() 96 | model.to(device) 97 | 98 | train_dataset = CarDataset(train_img_url, train_mask_url) 99 | train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True) 100 | 101 | val_dataset = CarDataset(val_img_url, val_mask_url) 102 | val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) 103 | 104 | loss_func = nn.BCEWithLogitsLoss() 105 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5) 106 | 107 | step = 0 108 | report_loss = 0.0 109 | for epoch in range(epoches): 110 | print("epoch is " + str(epoch)) 111 | 112 | for img, mask in tqdm(train_loader, total=len(train_loader)): 113 | optimizer.zero_grad() 114 | step += 1 115 | img = img.to(device) 116 | mask = mask.to(device) 117 | 118 | pred_img = model(img) ## pred_img (batch, len, channel, W, H) 119 | if out_channels == 1: 120 | pred_img = pred_img.squeeze(1) # 去掉通道维度 121 | 122 | loss = loss_func(pred_img, mask) 123 | report_loss += loss.item() 124 | loss.backward() 125 | optimizer.step() 126 | 127 | if step % 1000 == 0: 128 | dice = 0.0 129 | n = 0 130 | model.eval() 131 | with torch.no_grad(): 132 | print("report_loss is " + str(report_loss)) 133 | report_loss = 0.0 134 | for val_img, val_mask in tqdm(val_loader, total=len(val_loader)): 135 | n += 1 136 | val_img = val_img.to(device) 137 | val_mask = val_mask.to(device) 138 | pred_img = torch.sigmoid(model(val_img)) 139 | if out_channels == 1: 140 | pred_img = pred_img.squeeze(1) 141 | cur_dice = compute_dice(pred_img, val_mask) 142 | dice += cur_dice 143 | dice = dice / n 144 | print("mean dice is " + str(dice)) 145 | torch.save(model.state_dict(), "./checkpoints/SETR_car.pkl") 146 | model.train() 147 | --------------------------------------------------------------------------------