├── .DS_Store
├── .gitignore
├── .vscode
└── settings.json
├── README.md
├── SETR.png
├── SETR
├── __pycache__
│ ├── transformer_model.cpython-37.pyc
│ └── transformer_seg.cpython-37.pyc
├── transformer_model.py
└── transformer_seg.py
├── main.py
├── task_minst.py
└── tast_car_seg.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/SETR-pytorch/d4183440a6362795c9b0f2913ae805837470725a/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | checkpoints
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.pythonPath": "/Users/xingzhaohu/.local/share/virtualenvs/ml-5foBrNl9/bin/python"
3 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ## SETR - Pytorch
5 |
6 | Since the original paper (Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.) has no official code,I implemented SETR-Progressive UPsampling(SETR-PUP) using pytorch.
7 |
8 | Original paper: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.
9 |
10 | ## Vit
11 | The Vit model is also implemented, and you can use it for image classification.
12 |
13 | ## Usage SETR
14 |
15 | ```python
16 | from SETR.transformer_seg import SETRModel
17 | import torch
18 |
19 | if __name__ == "__main__":
20 | net = SETRModel(patch_size=(32, 32),
21 | in_channels=3,
22 | out_channels=1,
23 | hidden_size=1024,
24 | num_hidden_layers=8,
25 | num_attention_heads=16,
26 | decode_features=[512, 256, 128, 64])
27 | t1 = torch.rand(1, 3, 256, 256)
28 | print("input: " + str(t1.shape))
29 |
30 | # print(net)
31 | print("output: " + str(net(t1).shape))
32 |
33 | ```
34 | If the output size is (1, 1, 256, 256), the code runs successfully.
35 |
36 | ## Usage Vit
37 | ```python
38 | from SETR.transformer_seg import Vit
39 | import torch
40 |
41 | if __name__ == "__main__":
42 | model = Vit(patch_size=(7, 7),
43 | in_channels=1,
44 | out_class=10,
45 | hidden_size=1024,
46 | num_hidden_layers=1,
47 | num_attention_heads=16)
48 | print(model)
49 | t1 = torch.rand(1, 1, 28, 28)
50 | print("input: " + str(t1.shape))
51 |
52 | print("output: " + str(model(t1).shape))
53 | ```
54 | The output shape is (1, 10).
55 |
56 | ## current examples
57 | 1. task_mnist: The simplest example, using the Vit model to classify the minst dataset.
58 | 2. task_car_seg: The example is sample segmentation task. data download: https://www.kaggle.com/c/carvana-image-masking-challenge/data
59 |
60 | ## more
61 | More examples will be updated later.
--------------------------------------------------------------------------------
/SETR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/SETR-pytorch/d4183440a6362795c9b0f2913ae805837470725a/SETR.png
--------------------------------------------------------------------------------
/SETR/__pycache__/transformer_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/SETR-pytorch/d4183440a6362795c9b0f2913ae805837470725a/SETR/__pycache__/transformer_model.cpython-37.pyc
--------------------------------------------------------------------------------
/SETR/__pycache__/transformer_seg.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/SETR-pytorch/d4183440a6362795c9b0f2913ae805837470725a/SETR/__pycache__/transformer_seg.cpython-37.pyc
--------------------------------------------------------------------------------
/SETR/transformer_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import os
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn import CrossEntropyLoss, MSELoss
8 | from einops import rearrange
9 |
10 | def swish(x):
11 | return x * torch.sigmoid(x)
12 |
13 | def gelu(x):
14 | """
15 | """
16 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
17 |
18 | def mish(x):
19 | return x * torch.tanh(nn.functional.softplus(x))
20 |
21 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "mish": mish}
22 |
23 | class TransConfig(object):
24 |
25 | def __init__(
26 | self,
27 | patch_size,
28 | in_channels,
29 | out_channels,
30 | sample_rate=4,
31 | hidden_size=768,
32 | num_hidden_layers=8,
33 | num_attention_heads=6,
34 | intermediate_size=1024,
35 | hidden_act="gelu",
36 | hidden_dropout_prob=0.1,
37 | attention_probs_dropout_prob=0.1,
38 | max_position_embeddings=512,
39 | initializer_range=0.02,
40 | layer_norm_eps=1e-12,
41 | ):
42 | self.sample_rate = sample_rate
43 | self.patch_size = patch_size
44 | self.in_channels = in_channels
45 | self.out_channels = out_channels
46 | self.hidden_size = hidden_size
47 | self.num_hidden_layers = num_hidden_layers
48 | self.num_attention_heads = num_attention_heads
49 | self.hidden_act = hidden_act
50 | self.intermediate_size = intermediate_size
51 | self.hidden_dropout_prob = hidden_dropout_prob
52 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
53 | self.max_position_embeddings = max_position_embeddings
54 | self.initializer_range = initializer_range
55 | self.layer_norm_eps = layer_norm_eps
56 |
57 |
58 | class TransLayerNorm(nn.Module):
59 | def __init__(self, hidden_size, eps=1e-12):
60 | """Construct a layernorm module in the TF style (epsilon inside the square root).
61 | """
62 | super(TransLayerNorm, self).__init__()
63 |
64 | self.gamma = nn.Parameter(torch.ones(hidden_size))
65 | self.beta = nn.Parameter(torch.zeros(hidden_size))
66 | self.variance_epsilon = eps
67 |
68 |
69 | def forward(self, x):
70 |
71 | u = x.mean(-1, keepdim=True)
72 | s = (x - u).pow(2).mean(-1, keepdim=True)
73 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
74 | return self.gamma * x + self.beta
75 |
76 | class TransEmbeddings(nn.Module):
77 | """Construct the embeddings from word, position and token_type embeddings.
78 | """
79 |
80 | def __init__(self, config):
81 | super().__init__()
82 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
83 |
84 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
85 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
86 |
87 | def forward(self, input_ids):
88 | input_shape = input_ids.size()
89 |
90 | seq_length = input_shape[1]
91 | device = input_ids.device
92 |
93 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
94 | position_ids = position_ids.unsqueeze(0).expand(input_shape[:2])
95 |
96 | position_embeddings = self.position_embeddings(position_ids)
97 |
98 | embeddings = input_ids + position_embeddings
99 | embeddings = self.LayerNorm(embeddings)
100 | embeddings = self.dropout(embeddings)
101 | return embeddings
102 |
103 | class TransSelfAttention(nn.Module):
104 | def __init__(self, config: TransConfig):
105 | super().__init__()
106 | if config.hidden_size % config.num_attention_heads != 0:
107 | raise ValueError(
108 | "The hidden size (%d) is not a multiple of the number of attention "
109 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)
110 | )
111 |
112 | self.num_attention_heads = config.num_attention_heads
113 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
114 | self.all_head_size = self.num_attention_heads * self.attention_head_size
115 |
116 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
117 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
118 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
119 |
120 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
121 |
122 | def transpose_for_scores(self, x):
123 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
124 | x = x.view(*new_x_shape)
125 |
126 | ## 最后xshape (batch_size, num_attention_heads, seq_len, head_size)
127 | return x.permute(0, 2, 1, 3)
128 |
129 | def forward(
130 | self,
131 | hidden_states
132 | ):
133 | mixed_query_layer = self.query(hidden_states)
134 | mixed_key_layer = self.key(hidden_states)
135 | mixed_value_layer = self.value(hidden_states)
136 |
137 | query_layer = self.transpose_for_scores(mixed_query_layer)
138 | key_layer = self.transpose_for_scores(mixed_key_layer)
139 | value_layer = self.transpose_for_scores(mixed_value_layer)
140 |
141 | # Take the dot product between "query" and "key" to get the raw attention scores.
142 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
143 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
144 |
145 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
146 | attention_scores = attention_scores
147 |
148 | # Normalize the attention scores to probabilities.
149 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
150 |
151 | # This is actually dropping out entire tokens to attend to, which might
152 | # seem a bit unusual, but is taken from the original Transformer paper.
153 | attention_probs = self.dropout(attention_probs)
154 |
155 | # 注意力加权
156 | context_layer = torch.matmul(attention_probs, value_layer)
157 | # 把加权后的V reshape, 得到[batch_size, length, embedding_dimension]
158 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
159 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
160 | context_layer = context_layer.view(*new_context_layer_shape)
161 |
162 | return context_layer
163 |
164 |
165 | class TransSelfOutput(nn.Module):
166 | def __init__(self, config):
167 | super().__init__()
168 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
169 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
170 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
171 |
172 | def forward(self, hidden_states, input_tensor):
173 | hidden_states = self.dense(hidden_states)
174 | hidden_states = self.dropout(hidden_states)
175 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
176 | return hidden_states
177 |
178 |
179 | class TransAttention(nn.Module):
180 | def __init__(self, config):
181 | super().__init__()
182 | self.self = TransSelfAttention(config)
183 | self.output = TransSelfOutput(config)
184 |
185 | def forward(
186 | self,
187 | hidden_states,
188 | ):
189 | self_outputs = self.self(hidden_states)
190 | attention_output = self.output(self_outputs, hidden_states)
191 |
192 | return attention_output
193 |
194 |
195 | class TransIntermediate(nn.Module):
196 | def __init__(self, config):
197 | super().__init__()
198 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
199 | self.intermediate_act_fn = ACT2FN[config.hidden_act] ## relu
200 |
201 | def forward(self, hidden_states):
202 | hidden_states = self.dense(hidden_states)
203 | hidden_states = self.intermediate_act_fn(hidden_states)
204 | return hidden_states
205 |
206 | class TransOutput(nn.Module):
207 | def __init__(self, config):
208 | super().__init__()
209 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
210 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
211 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
212 |
213 | def forward(self, hidden_states, input_tensor):
214 | hidden_states = self.dense(hidden_states)
215 | hidden_states = self.dropout(hidden_states)
216 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
217 | return hidden_states
218 |
219 |
220 | class TransLayer(nn.Module):
221 | def __init__(self, config):
222 | super().__init__()
223 | self.attention = TransAttention(config)
224 | self.intermediate = TransIntermediate(config)
225 | self.output = TransOutput(config)
226 |
227 | def forward(
228 | self,
229 | hidden_states
230 | ):
231 | attention_output = self.attention(hidden_states)
232 | intermediate_output = self.intermediate(attention_output)
233 | layer_output = self.output(intermediate_output, attention_output)
234 | return layer_output
235 |
236 |
237 | class TransEncoder(nn.Module):
238 | def __init__(self, config):
239 | super().__init__()
240 | self.layer = nn.ModuleList([TransLayer(config) for _ in range(config.num_hidden_layers)])
241 |
242 | def forward(
243 | self,
244 | hidden_states,
245 | output_all_encoded_layers=True,
246 | ):
247 | all_encoder_layers = []
248 |
249 | for i, layer_module in enumerate(self.layer):
250 |
251 | layer_output = layer_module(hidden_states)
252 | hidden_states = layer_output
253 | if output_all_encoded_layers:
254 | all_encoder_layers.append(hidden_states)
255 | if not output_all_encoded_layers:
256 | all_encoder_layers.append(hidden_states)
257 |
258 | return all_encoder_layers
259 |
260 | class InputDense2d(nn.Module):
261 | def __init__(self, config):
262 | super(InputDense2d, self).__init__()
263 | self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.in_channels, config.hidden_size)
264 | self.transform_act_fn = ACT2FN[config.hidden_act]
265 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
266 |
267 | def forward(self, hidden_states):
268 | hidden_states = self.dense(hidden_states)
269 | hidden_states = self.transform_act_fn(hidden_states)
270 | hidden_states = self.LayerNorm(hidden_states)
271 | return hidden_states
272 |
273 | class InputDense3d(nn.Module):
274 | def __init__(self, config):
275 | super(InputDense3d, self).__init__()
276 | self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.patch_size[2] * config.in_channels, config.hidden_size)
277 | self.transform_act_fn = ACT2FN[config.hidden_act]
278 | self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
279 |
280 | def forward(self, hidden_states):
281 | hidden_states = self.dense(hidden_states)
282 | hidden_states = self.transform_act_fn(hidden_states)
283 | hidden_states = self.LayerNorm(hidden_states)
284 | return hidden_states
285 |
286 | class TransModel2d(nn.Module):
287 |
288 | def __init__(self, config):
289 | super(TransModel2d, self).__init__()
290 | self.config = config
291 | self.dense = InputDense2d(config)
292 | self.embeddings = TransEmbeddings(config)
293 | self.encoder = TransEncoder(config)
294 |
295 | def forward(
296 | self,
297 | input_ids,
298 | output_all_encoded_layers=True,
299 |
300 | ):
301 | dense_out = self.dense(input_ids)
302 | embedding_output = self.embeddings(
303 | input_ids=dense_out
304 | )
305 | encoder_layers = self.encoder(
306 | embedding_output,
307 | output_all_encoded_layers=output_all_encoded_layers,
308 | )
309 | sequence_output = encoder_layers[-1]
310 |
311 | if not output_all_encoded_layers:
312 | # 如果不用输出所有encoder层
313 | encoder_layers = encoder_layers[-1]
314 | return encoder_layers
315 |
316 |
317 | class TransModel3d(nn.Module):
318 |
319 | def __init__(self, config):
320 | super(TransModel3d, self).__init__()
321 | self.config = config
322 | self.dense = InputDense3d(config)
323 | self.embeddings = TransEmbeddings(config)
324 | self.encoder = TransEncoder(config)
325 |
326 | def forward(
327 | self,
328 | input_ids,
329 | output_all_encoded_layers=True,
330 |
331 | ):
332 | dense_out = self.dense(input_ids)
333 | embedding_output = self.embeddings(
334 | input_ids=dense_out
335 | )
336 | encoder_layers = self.encoder(
337 | embedding_output,
338 | output_all_encoded_layers=output_all_encoded_layers,
339 | )
340 | sequence_output = encoder_layers[-1]
341 |
342 | if not output_all_encoded_layers:
343 | # 如果不用输出所有encoder层
344 | encoder_layers = encoder_layers[-1]
345 | return encoder_layers
--------------------------------------------------------------------------------
/SETR/transformer_seg.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import os
4 | import numpy as np
5 |
6 | import torch
7 | from torch import nn
8 | from torch.nn import CrossEntropyLoss, MSELoss
9 | from einops import rearrange
10 | from SETR.transformer_model import TransModel2d, TransConfig
11 | import math
12 |
13 | class Encoder2D(nn.Module):
14 | def __init__(self, config: TransConfig, is_segmentation=True):
15 | super().__init__()
16 | self.config = config
17 | self.out_channels = config.out_channels
18 | self.bert_model = TransModel2d(config)
19 | sample_rate = config.sample_rate
20 | sample_v = int(math.pow(2, sample_rate))
21 | assert config.patch_size[0] * config.patch_size[1] * config.hidden_size % (sample_v**2) == 0, "不能除尽"
22 | self.final_dense = nn.Linear(config.hidden_size, config.patch_size[0] * config.patch_size[1] * config.hidden_size // (sample_v**2))
23 | self.patch_size = config.patch_size
24 | self.hh = self.patch_size[0] // sample_v
25 | self.ww = self.patch_size[1] // sample_v
26 |
27 | self.is_segmentation = is_segmentation
28 | def forward(self, x):
29 | ## x:(b, c, w, h)
30 | b, c, h, w = x.shape
31 | assert self.config.in_channels == c, "in_channels != 输入图像channel"
32 | p1 = self.patch_size[0]
33 | p2 = self.patch_size[1]
34 |
35 | if h % p1 != 0:
36 | print("请重新输入img size 参数 必须整除")
37 | os._exit(0)
38 | if w % p2 != 0:
39 | print("请重新输入img size 参数 必须整除")
40 | os._exit(0)
41 | hh = h // p1
42 | ww = w // p2
43 |
44 | x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p1, p2 = p2)
45 |
46 | encode_x = self.bert_model(x)[-1] # 取出来最后一层
47 | if not self.is_segmentation:
48 | return encode_x
49 |
50 | x = self.final_dense(encode_x)
51 | x = rearrange(x, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", p1 = self.hh, p2 = self.ww, h = hh, w = ww, c = self.config.hidden_size)
52 | return encode_x, x
53 |
54 |
55 | class PreTrainModel(nn.Module):
56 | def __init__(self, patch_size,
57 | in_channels,
58 | out_class,
59 | hidden_size=1024,
60 | num_hidden_layers=8,
61 | num_attention_heads=16,
62 | decode_features=[512, 256, 128, 64]):
63 | super().__init__()
64 | config = TransConfig(patch_size=patch_size,
65 | in_channels=in_channels,
66 | out_channels=0,
67 | hidden_size=hidden_size,
68 | num_hidden_layers=num_hidden_layers,
69 | num_attention_heads=num_attention_heads)
70 | self.encoder_2d = Encoder2D(config, is_segmentation=False)
71 | self.cls = nn.Linear(hidden_size, out_class)
72 |
73 | def forward(self, x):
74 | encode_img = self.encoder_2d(x)
75 | encode_pool = encode_img.mean(dim=1)
76 | out = self.cls(encode_pool)
77 | return out
78 |
79 | class Vit(nn.Module):
80 | def __init__(self, patch_size,
81 | in_channels,
82 | out_class,
83 | hidden_size=1024,
84 | num_hidden_layers=8,
85 | num_attention_heads=16,
86 | sample_rate=4,
87 | ):
88 | super().__init__()
89 | config = TransConfig(patch_size=patch_size,
90 | in_channels=in_channels,
91 | out_channels=0,
92 | sample_rate=sample_rate,
93 | hidden_size=hidden_size,
94 | num_hidden_layers=num_hidden_layers,
95 | num_attention_heads=num_attention_heads)
96 | self.encoder_2d = Encoder2D(config, is_segmentation=False)
97 | self.cls = nn.Linear(hidden_size, out_class)
98 |
99 | def forward(self, x):
100 | encode_img = self.encoder_2d(x)
101 |
102 | encode_pool = encode_img.mean(dim=1)
103 | out = self.cls(encode_pool)
104 | return out
105 |
106 | class Decoder2D(nn.Module):
107 | def __init__(self, in_channels, out_channels, features=[512, 256, 128, 64]):
108 | super().__init__()
109 | self.decoder_1 = nn.Sequential(
110 | nn.Conv2d(in_channels, features[0], 3, padding=1),
111 | nn.BatchNorm2d(features[0]),
112 | nn.ReLU(inplace=True),
113 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
114 | )
115 | self.decoder_2 = nn.Sequential(
116 | nn.Conv2d(features[0], features[1], 3, padding=1),
117 | nn.BatchNorm2d(features[1]),
118 | nn.ReLU(inplace=True),
119 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
120 | )
121 | self.decoder_3 = nn.Sequential(
122 | nn.Conv2d(features[1], features[2], 3, padding=1),
123 | nn.BatchNorm2d(features[2]),
124 | nn.ReLU(inplace=True),
125 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
126 | )
127 | self.decoder_4 = nn.Sequential(
128 | nn.Conv2d(features[2], features[3], 3, padding=1),
129 | nn.BatchNorm2d(features[3]),
130 | nn.ReLU(inplace=True),
131 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
132 | )
133 |
134 | self.final_out = nn.Conv2d(features[-1], out_channels, 3, padding=1)
135 |
136 | def forward(self, x):
137 | x = self.decoder_1(x)
138 | x = self.decoder_2(x)
139 | x = self.decoder_3(x)
140 | x = self.decoder_4(x)
141 | x = self.final_out(x)
142 | return x
143 |
144 | class SETRModel(nn.Module):
145 | def __init__(self, patch_size=(32, 32),
146 | in_channels=3,
147 | out_channels=1,
148 | hidden_size=1024,
149 | num_hidden_layers=8,
150 | num_attention_heads=16,
151 | decode_features=[512, 256, 128, 64],
152 | sample_rate=4,):
153 | super().__init__()
154 | config = TransConfig(patch_size=patch_size,
155 | in_channels=in_channels,
156 | out_channels=out_channels,
157 | sample_rate=sample_rate,
158 | hidden_size=hidden_size,
159 | num_hidden_layers=num_hidden_layers,
160 | num_attention_heads=num_attention_heads)
161 | self.encoder_2d = Encoder2D(config)
162 | self.decoder_2d = Decoder2D(in_channels=config.hidden_size, out_channels=config.out_channels, features=decode_features)
163 |
164 | def forward(self, x):
165 | _, final_x = self.encoder_2d(x)
166 | x = self.decoder_2d(final_x)
167 | return x
168 |
169 |
170 |
171 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from SETR.transformer_seg import SETRModel, Vit
2 | import torch
3 |
4 | if __name__ == "__main__":
5 | net = SETRModel(patch_size=(32, 32),
6 | in_channels=3,
7 | out_channels=1,
8 | hidden_size=1024,
9 | sample_rate=5,
10 | num_hidden_layers=1,
11 | num_attention_heads=16,
12 | decode_features=[512, 256, 128, 64])
13 | t1 = torch.rand(1, 3, 512, 512)
14 | print("input: " + str(t1.shape))
15 |
16 | print("output: " + str(net(t1).shape))
17 |
18 |
19 | model = Vit(patch_size=(32, 32),
20 | in_channels=1,
21 | out_class=10,
22 | sample_rate=4,
23 | hidden_size=1024,
24 | num_hidden_layers=1,
25 | num_attention_heads=16)
26 |
27 | t1 = torch.rand(1, 1, 512, 512)
28 | print("input: " + str(t1.shape))
29 |
30 | print("output: " + str(model(t1).shape))
31 |
32 |
--------------------------------------------------------------------------------
/task_minst.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from SETR.transformer_seg import Vit
3 | import torchvision
4 | import torch
5 | import torch.nn as nn
6 | from torchvision import datasets, transforms
7 | import matplotlib.pyplot as plt
8 | from tqdm import tqdm
9 |
10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11 | print("device is " + str(device))
12 |
13 | def compute_acc(model, test_dataloader):
14 | with torch.no_grad():
15 | right_num = 0
16 | total_num = 0
17 | for in_data, label in tqdm(test_dataloader, total=len(test_dataloader)):
18 | in_data = in_data.to(device)
19 | label = label.to(device)
20 | total_num += len(in_data)
21 | out = model(in_data)
22 | pred = out.argmax(dim=-1)
23 | for i, each_pred in enumerate(pred):
24 | if int(each_pred) == int(label[i]):
25 | right_num += 1
26 |
27 | return (right_num / total_num)
28 |
29 | if __name__ == "__main__":
30 |
31 | model = Vit(patch_size=(7, 7),
32 | in_channels=1,
33 | out_class=10,
34 | hidden_size=1024,
35 | num_hidden_layers=1,
36 | num_attention_heads=16,
37 | )
38 | print(model)
39 | model.to(device)
40 |
41 | transform = transforms.Compose([transforms.ToTensor(),
42 | transforms.Normalize(mean=[0.5],std=[0.5])])
43 | data_train = datasets.MNIST(root = "./data/",
44 | transform=transform,
45 | train = True,
46 | download = True)
47 |
48 | data_test = datasets.MNIST(root="./data/",
49 | transform = transform,
50 | train = False)
51 |
52 | data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
53 | batch_size = 64,
54 | shuffle = True)
55 |
56 | data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
57 | batch_size = 32,
58 | shuffle = True)
59 |
60 | optimizer = torch.optim.Adam(model.parameters())
61 | loss_func = nn.CrossEntropyLoss()
62 | report_loss = 0
63 | step = 0
64 | best_acc = 0.0
65 |
66 | for in_data, label in tqdm(data_loader_train, total=len(data_loader_train)):
67 | batch_size = len(in_data)
68 | in_data = in_data.to(device)
69 | label = label.to(device)
70 | optimizer.zero_grad()
71 | step += 1
72 | out = model(in_data)
73 | loss = loss_func(out, label)
74 | loss.backward()
75 | optimizer.step()
76 | report_loss += loss.item()
77 | if step % 10 == 0:
78 | print("report_loss is : " + str(report_loss))
79 | report_loss = 0
80 | acc = compute_acc(model, data_loader_test)
81 | if acc > best_acc:
82 | best_acc = acc
83 | torch.save(model.state_dict(), "./checkpoints/mnist_model.pkl")
84 |
85 | print("acc is " + str(acc) + ", best acc is " + str(best_acc))
86 |
--------------------------------------------------------------------------------
/tast_car_seg.py:
--------------------------------------------------------------------------------
1 | # data_url : https://www.kaggle.com/c/carvana-image-masking-challenge/data
2 | import torch
3 | import numpy as np
4 | from SETR.transformer_seg import SETRModel
5 | from PIL import Image
6 | import glob
7 | import torch.nn as nn
8 | import matplotlib.pyplot as plt
9 | from torch.utils.data import DataLoader, Dataset
10 | from tqdm import tqdm
11 |
12 | img_url = sorted(glob.glob("./segmentation_car/imgs/*"))
13 | mask_url = sorted(glob.glob("./segmentation_car/masks/*"))
14 | # print(img_url)
15 | train_size = int(len(img_url) * 0.8)
16 | train_img_url = img_url[:train_size]
17 | train_mask_url = mask_url[:train_size]
18 | val_img_url = img_url[train_size:]
19 | val_mask_url = mask_url[train_size:]
20 |
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 | print("device is " + str(device))
23 | epoches = 100
24 | out_channels = 1
25 |
26 | def build_model():
27 | model = SETRModel(patch_size=(16, 16),
28 | in_channels=3,
29 | out_channels=1,
30 | hidden_size=1024,
31 | num_hidden_layers=6,
32 | num_attention_heads=16,
33 | decode_features=[512, 256, 128, 64])
34 | return model
35 |
36 | class CarDataset(Dataset):
37 | def __init__(self, img_url, mask_url):
38 | super(CarDataset, self).__init__()
39 | self.img_url = img_url
40 | self.mask_url = mask_url
41 |
42 | def __getitem__(self, idx):
43 | img = Image.open(self.img_url[idx])
44 | img = img.resize((256, 256))
45 | img_array = np.array(img, dtype=np.float32) / 255
46 | mask = Image.open(self.mask_url[idx])
47 | mask = mask.resize((256, 256))
48 | mask = np.array(mask, dtype=np.float32)
49 | img_array = img_array.transpose(2, 0, 1)
50 |
51 | return torch.tensor(img_array.copy()), torch.tensor(mask.copy())
52 |
53 | def __len__(self):
54 | return len(self.img_url)
55 |
56 | def compute_dice(input, target):
57 | eps = 0.0001
58 | # input 是经过了sigmoid 之后的输出。
59 | input = (input > 0.5).float()
60 | target = (target > 0.5).float()
61 |
62 | # inter = torch.dot(input.view(-1), target.view(-1)) + eps
63 | inter = torch.sum(target.view(-1) * input.view(-1)) + eps
64 |
65 | # print(self.inter)
66 | union = torch.sum(input) + torch.sum(target) + eps
67 |
68 | t = (2 * inter.float()) / union.float()
69 | return t
70 |
71 | def predict():
72 | model = build_model()
73 | model.load_state_dict(torch.load("./checkpoints/SETR_car.pkl", map_location="cpu"))
74 | print(model)
75 |
76 | import matplotlib.pyplot as plt
77 | val_dataset = CarDataset(val_img_url, val_mask_url)
78 | val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
79 | with torch.no_grad():
80 | for img, mask in val_loader:
81 | pred = torch.sigmoid(model(img))
82 | pred = (pred > 0.5).int()
83 | plt.subplot(1, 3, 1)
84 | print(img.shape)
85 | img = img.permute(0, 2, 3, 1)
86 | plt.imshow(img[0])
87 | plt.subplot(1, 3, 2)
88 | plt.imshow(pred[0].squeeze(0), cmap="gray")
89 | plt.subplot(1, 3, 3)
90 | plt.imshow(mask[0], cmap="gray")
91 | plt.show()
92 |
93 | if __name__ == "__main__":
94 |
95 | model = build_model()
96 | model.to(device)
97 |
98 | train_dataset = CarDataset(train_img_url, train_mask_url)
99 | train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True)
100 |
101 | val_dataset = CarDataset(val_img_url, val_mask_url)
102 | val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
103 |
104 | loss_func = nn.BCEWithLogitsLoss()
105 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)
106 |
107 | step = 0
108 | report_loss = 0.0
109 | for epoch in range(epoches):
110 | print("epoch is " + str(epoch))
111 |
112 | for img, mask in tqdm(train_loader, total=len(train_loader)):
113 | optimizer.zero_grad()
114 | step += 1
115 | img = img.to(device)
116 | mask = mask.to(device)
117 |
118 | pred_img = model(img) ## pred_img (batch, len, channel, W, H)
119 | if out_channels == 1:
120 | pred_img = pred_img.squeeze(1) # 去掉通道维度
121 |
122 | loss = loss_func(pred_img, mask)
123 | report_loss += loss.item()
124 | loss.backward()
125 | optimizer.step()
126 |
127 | if step % 1000 == 0:
128 | dice = 0.0
129 | n = 0
130 | model.eval()
131 | with torch.no_grad():
132 | print("report_loss is " + str(report_loss))
133 | report_loss = 0.0
134 | for val_img, val_mask in tqdm(val_loader, total=len(val_loader)):
135 | n += 1
136 | val_img = val_img.to(device)
137 | val_mask = val_mask.to(device)
138 | pred_img = torch.sigmoid(model(val_img))
139 | if out_channels == 1:
140 | pred_img = pred_img.squeeze(1)
141 | cur_dice = compute_dice(pred_img, val_mask)
142 | dice += cur_dice
143 | dice = dice / n
144 | print("mean dice is " + str(dice))
145 | torch.save(model.state_dict(), "./checkpoints/SETR_car.pkl")
146 | model.train()
147 |
--------------------------------------------------------------------------------