├── EDSR
├── common.py
└── edsr.py
├── LICENSE
├── README.md
├── adaptive_gridsampler
├── adaptive_gridsampler_cuda.cpp
├── adaptive_gridsampler_kernel.cu
├── adaptive_gridsampler_kernel.cuh
├── gridsampler.py
├── helper_cuda.h
├── helper_string.h
└── setup.py
├── figs
├── overview.png
└── qualitative.png
├── modules.py
├── run.py
└── utils.py
/EDSR/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
9 | return nn.Conv2d(
10 | in_channels, out_channels, kernel_size,
11 | padding=(kernel_size // 2), bias=bias)
12 |
13 |
14 | class MeanShift(nn.Conv2d):
15 | def __init__(self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
16 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
17 | std = torch.Tensor(rgb_std)
18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
20 | for p in self.parameters():
21 | p.requires_grad = False
22 |
23 |
24 | class BasicBlock(nn.Sequential):
25 | def __init__(
26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
27 | bn=True, act=nn.ReLU(True)):
28 |
29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
30 | if bn:
31 | m.append(nn.BatchNorm2d(out_channels))
32 | if act is not None:
33 | m.append(act)
34 |
35 | super(BasicBlock, self).__init__(*m)
36 |
37 |
38 | class ResBlock(nn.Module):
39 | def __init__(
40 | self, conv, n_feats, kernel_size,
41 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
42 |
43 | super(ResBlock, self).__init__()
44 | m = []
45 | for i in range(2):
46 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
47 | if bn:
48 | m.append(nn.BatchNorm2d(n_feats))
49 | if i == 0:
50 | m.append(act)
51 |
52 | self.body = nn.Sequential(*m)
53 | self.res_scale = res_scale
54 |
55 | def forward(self, x):
56 | res = self.body(x).mul(self.res_scale)
57 | res += x
58 |
59 | return res
60 |
61 |
62 | class Upsampler(nn.Sequential):
63 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
64 |
65 | m = []
66 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
67 | for _ in range(int(math.log(scale, 2))):
68 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
69 | m.append(nn.PixelShuffle(2))
70 | if bn:
71 | m.append(nn.BatchNorm2d(n_feats))
72 | if act == 'relu':
73 | m.append(nn.ReLU(True))
74 | elif act == 'prelu':
75 | m.append(nn.PReLU(n_feats))
76 |
77 | elif scale == 3:
78 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
79 | m.append(nn.PixelShuffle(3))
80 | if bn:
81 | m.append(nn.BatchNorm2d(n_feats))
82 | if act == 'relu':
83 | m.append(nn.ReLU(True))
84 | elif act == 'prelu':
85 | m.append(nn.PReLU(n_feats))
86 | else:
87 | raise NotImplementedError
88 |
89 | super(Upsampler, self).__init__(*m)
90 |
--------------------------------------------------------------------------------
/EDSR/edsr.py:
--------------------------------------------------------------------------------
1 | from EDSR import common
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | url = {
7 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
8 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
9 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
10 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
11 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
12 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
13 | }
14 |
15 |
16 | class EDSR(nn.Module):
17 | def __init__(self, n_resblocks=16, n_feats=64, scale=4, conv=common.default_conv):
18 | super(EDSR, self).__init__()
19 |
20 | # n_resblocks = 16 * 2
21 | # n_feats = 64 * 4
22 | kernel_size = 3
23 | act = nn.ReLU(True)
24 | self.url = url['r{}f{}x{}'.format(n_resblocks, n_feats, scale)]
25 | self.sub_mean = common.MeanShift(1)
26 | self.add_mean = common.MeanShift(1, sign=1)
27 |
28 | # define head module
29 | m_head = [conv(3, n_feats, kernel_size)]
30 |
31 | # define body module
32 | m_body = [
33 | common.ResBlock(
34 | conv, n_feats, kernel_size, act=act, res_scale=0.1
35 | ) for _ in range(n_resblocks)
36 | ]
37 | m_body.append(conv(n_feats, n_feats, kernel_size))
38 |
39 | # define tail module
40 | m_tail = [
41 | common.Upsampler(conv, scale, n_feats, act=False),
42 | conv(n_feats, 3, kernel_size)
43 | ]
44 |
45 | self.head = nn.Sequential(*m_head)
46 | self.body = nn.Sequential(*m_body)
47 | self.tail = nn.Sequential(*m_tail)
48 |
49 | def forward(self, x):
50 | x = self.sub_mean(x)
51 | x = self.head(x)
52 |
53 | res = self.body(x)
54 | res += x
55 |
56 | x = self.tail(res)
57 | x = self.add_mean(x)
58 |
59 | return x
60 |
61 | def load_state_dict(self, state_dict, strict=True):
62 | own_state = self.state_dict()
63 | for name, param in state_dict.items():
64 | if name in own_state:
65 | if isinstance(param, nn.Parameter):
66 | param = param.data
67 | try:
68 | own_state[name].copy_(param)
69 | except Exception:
70 | if name.find('tail') == -1:
71 | raise RuntimeError('While copying the parameter named {}, '
72 | 'whose dimensions in the model are {} and '
73 | 'whose dimensions in the checkpoint are {}.'
74 | .format(name, own_state[name].size(), param.size()))
75 | elif strict:
76 | if name.find('tail') == -1:
77 | raise KeyError('unexpected key "{}" in state_dict'
78 | .format(name))
79 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CAR-pytorch
2 |
3 | Pytorch implementation of paper **"Learned Image Downscaling for Upscaling using Content Adaptive Resampler"**
4 |
5 | 
6 |
7 | ## Installation
8 |
9 | # get CAR-pytorch source
10 | git clone https://github.com/sunwj/CAR.git
11 | cd CAR
12 |
13 | # compile the code of the resampler
14 | cd adaptive_gridsampler
15 | python3 setup.py build_ext --inplace
16 |
17 | ### Python requirements
18 | Currently, the code only supports python3 and machine with NVIDIA GPU (and the CUDA development toolkit) installed
19 |
20 | * numpy
21 | * scipy
22 | * pytorch (== 1.3.1)
23 | * Pillow
24 | * tqdm
25 |
26 | ### Pre-trained models
27 | You can download the pre-trained models for 2x and 4x downscaling and super-resolution from [here](https://mega.nz/#!XzIm3YhT!jbIOOOGBOiKtv3VAOD782Mz7nK1L_kma-BzR-RhboW4).
28 |
29 | ## Inference
30 | python3 run.py --scale 4 --img_dir path_to_images --model_dir path_to_pretrained_models \
31 | --output_dir path_to_output
32 |
33 | ## Sample results
34 | 
35 |
36 | You can download HR images of benchmark datasets, i.e., the Set5, Set14, B100 and Urban100 from [here](https://mega.nz/#!znBRCSJA!_qwJMP5VDe3yleiK8m0QXrpHLee9AS8vzT03lAOorP0).
37 |
38 | If you find our work useful in your research or publication, please cite our work:
39 |
40 | Wanjie Sun, Zhenzhong Chen. **"Learned Image Downscaling for Upscaling using Content Adaptive Resampler"**. arXiv preprint arXiv:1907.12904, 2019.
41 |
42 | ```
43 | @article{sun2020learned,
44 | title={Learned image downscaling for upscaling using content adaptive resampler},
45 | author={Sun, Wanjie and Chen, Zhenzhong},
46 | journal={IEEE Transactions on Image Processing},
47 | volume={29},
48 | pages={4027--4040},
49 | year={2020},
50 | publisher={IEEE}
51 | }
52 | ```
53 |
54 | ## Acknowlegements
55 | EDSR code is provided by [thstkdgus35/EDSR-PyTorch](https://github.com/thstkdgus35/EDSR-PyTorch).
56 |
--------------------------------------------------------------------------------
/adaptive_gridsampler/adaptive_gridsampler_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "adaptive_gridsampler_kernel.cuh"
5 |
6 | int adaptive_gridsampler_cuda_forward(at::Tensor& img, at::Tensor& kernels, at::Tensor& offsets_h, at::Tensor& offsets_v, int offset_unit, int padding, at::Tensor& output)
7 | {
8 | adaptive_gridsampler_kernel_forward(img, kernels, offsets_h, offsets_v, offset_unit, padding, output);
9 | return 1;
10 | }
11 |
12 | int adaptive_gridsampler_cuda_backward(at::Tensor& img, at::Tensor& kernels, at::Tensor& offsets_h, at::Tensor& offsets_v, int offset_unit, at::Tensor& gradOutput, int padding,
13 | at::Tensor& gradInput_kernels, at::Tensor& gradInput_offsets_h, at::Tensor& gradInput_offsets_v)
14 | {
15 | adaptive_gridsampler_kernel_backward(img, kernels, offsets_h, offsets_v, offset_unit, gradOutput, padding, gradInput_kernels, gradInput_offsets_h, gradInput_offsets_v);
16 | return 1;
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
20 | {
21 | m.def("forward", &adaptive_gridsampler_cuda_forward, "adaptive gridsampler forward (CUDA)");
22 | m.def("backward", &adaptive_gridsampler_cuda_backward, "adaptive gridsampler backward (CUDA)");
23 | }
--------------------------------------------------------------------------------
/adaptive_gridsampler/adaptive_gridsampler_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "helper_cuda.h"
5 |
6 | #define BLOCK_SIZE 256
7 |
8 | template
9 | __global__ void kernel_adaptive_gridsampler_update_output(
10 | const torch::PackedTensorAccessor32 img,
11 | const torch::PackedTensorAccessor32 kernels,
12 | const torch::PackedTensorAccessor32 offsets_h,
13 | const torch::PackedTensorAccessor32 offsets_v,
14 | const int offset_unit,
15 | const int padding,
16 | torch::PackedTensorAccessor32 output,
17 | const size_t n)
18 | {
19 | auto global_idx = blockDim.x * blockIdx.x + threadIdx.x;
20 | if(global_idx >= n) return;
21 |
22 | auto dim_b = output.size(0);
23 | auto dim_c = output.size(1);
24 | auto dim_h = output.size(2);
25 | auto dim_w = output.size(3);
26 |
27 | auto idb = (global_idx / (dim_c * dim_h * dim_w)) % dim_b;
28 | auto idc = (global_idx / (dim_h * dim_w)) % dim_c;
29 | auto idy = (global_idx / dim_w) % dim_h;
30 | auto idx = global_idx % dim_w;
31 |
32 | if(idx >= dim_w || idy >= dim_h)
33 | return;
34 |
35 | int k_size = sqrt(float(kernels.size(1)));
36 | float w = float(img.size(3) - 2 * padding);
37 | float h = float(img.size(2) - 2 * padding);
38 |
39 | scalar_t result = 0;
40 | for(int k_y = 0; k_y < k_size; ++k_y)
41 | {
42 | for(int k_x = 0; k_x < k_size; ++k_x)
43 | {
44 | scalar_t offset_h = offsets_h[idb][k_size * k_y + k_x][idy][idx] * offset_unit;
45 | scalar_t offset_v = offsets_v[idb][k_size * k_y + k_x][idy][idx] * offset_unit;
46 |
47 | scalar_t p_x = static_cast(idx + 0.5) / dim_w * w + k_x + offset_h - 0.5;
48 | scalar_t p_y = static_cast(idy + 0.5) / dim_h * h + k_y + offset_v - 0.5;
49 | scalar_t alpha = p_x - floor(p_x);
50 | scalar_t beta = p_y - floor(p_y);
51 |
52 | int xL = max(min(int(floor(p_x)), int(w + 2 * padding - 1)), 0);
53 | int xR = max(min(xL + 1, int(w + 2 * padding - 1)), 0);
54 | int yT = max(min(int(floor(p_y)), int(h + 2 * padding - 1)), 0);
55 | int yB = max(min(yT + 1, int(h + 2 * padding - 1)), 0);
56 |
57 | scalar_t val = 0;
58 | val += (1 - alpha) * (1 - beta) * img[idb][idc][yT][xL];
59 | val += alpha * (1 - beta) * img[idb][idc][yT][xR];
60 | val += (1 - alpha) * beta * img[idb][idc][yB][xL];
61 | val += alpha * beta * img[idb][idc][yB][xR];
62 |
63 | result += val * kernels[idb][k_size * k_y + k_x][idy][idx];
64 | }
65 | }
66 | output[idb][idc][idy][idx] = result;
67 | }
68 |
69 | void adaptive_gridsampler_kernel_forward(const torch::Tensor& img, const torch::Tensor& kernels, const torch::Tensor& offsets_h, const torch::Tensor& offsets_v, const int offset_unit, const int padding, torch::Tensor& output)
70 | {
71 | kernel_adaptive_gridsampler_update_output<<<(output.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
72 | img.packed_accessor32(), kernels.packed_accessor32(),
73 | offsets_h.packed_accessor32(), offsets_v.packed_accessor32(), offset_unit, padding,
74 | output.packed_accessor32(), output.numel());
75 |
76 | checkCudaErrors(cudaGetLastError());
77 | }
78 |
79 | template
80 | __global__ void kernel_adaptive_gridsampler_backward(const torch::PackedTensorAccessor32 img,
81 | const torch::PackedTensorAccessor32 kernels,
82 | const torch::PackedTensorAccessor32 offsets_h,
83 | const torch::PackedTensorAccessor32 offsets_v,
84 | const int offset_unit,
85 | const torch::PackedTensorAccessor32 gradOutput,
86 | const int padding,
87 | torch::PackedTensorAccessor32 gradInput_kernels,
88 | torch::PackedTensorAccessor32 gradInput_offsets_h,
89 | torch::PackedTensorAccessor32 gradInput_offsets_v,
90 | const size_t n)
91 | {
92 | auto global_idx = blockDim.x * blockIdx.x + threadIdx.x;
93 | if(global_idx >= n) return;
94 |
95 | auto dim_b = gradInput_kernels.size(0);
96 | auto dim_c = gradInput_kernels.size(1);
97 | auto dim_h = gradInput_kernels.size(2);
98 | auto dim_w = gradInput_kernels.size(3);
99 |
100 | auto idb = (global_idx / (dim_c * dim_h * dim_w)) % dim_b;
101 | auto idc = (global_idx / (dim_h * dim_w)) % dim_c;
102 | auto idy = (global_idx / dim_w) % dim_h;
103 | auto idx = global_idx % dim_w;
104 |
105 | if(idx >= dim_w || idx >= dim_h)
106 | return;
107 |
108 | int k_size = sqrt(float(dim_c));
109 | int k_y = idc / k_size;
110 | int k_x = idc % k_size;
111 |
112 | scalar_t offset_h = offsets_h[idb][idc][idy][idx] * offset_unit;
113 | scalar_t offset_v = offsets_v[idb][idc][idy][idx] * offset_unit;
114 |
115 | float w = float(img.size(3) - 2 * padding);
116 | float h = float(img.size(2) - 2 * padding);
117 |
118 | scalar_t p_x = static_cast(idx + 0.5) / dim_w * w + k_x + offset_h - 0.5;
119 | scalar_t p_y = static_cast(idy + 0.5) / dim_h * h + k_y + offset_v - 0.5;
120 | scalar_t alpha = p_x - floor(p_x);
121 | scalar_t beta = p_y - floor(p_y);
122 |
123 | int xL = max(min(int(floor(p_x)), int(w + 2 * padding - 1)), 0);
124 | int xR = max(min(xL + 1, int(w + 2 * padding - 1)), 0);
125 | int yT = max(min(int(floor(p_y)), int(h + 2 * padding - 1)), 0);
126 | int yB = max(min(yT + 1, int(h + 2 * padding - 1)), 0);
127 |
128 | scalar_t grad_kernels = 0;
129 | scalar_t grad_offset_h = 0;
130 | scalar_t grad_offset_v = 0;
131 | for(int c = 0; c < img.size(1); ++c)
132 | {
133 | scalar_t c_tl = img[idb][c][yT][xL];
134 | scalar_t c_tr = img[idb][c][yT][xR];
135 | scalar_t c_bl = img[idb][c][yB][xL];
136 | scalar_t c_br = img[idb][c][yB][xR];
137 |
138 | scalar_t grad = 0;
139 | grad += (1 - alpha) * (1 - beta) * c_tl;
140 | grad += alpha * (1 - beta) * c_tr;
141 | grad += (1 - alpha) * beta * c_bl;
142 | grad += alpha * beta * c_br;
143 | grad_kernels += grad * gradOutput[idb][c][idy][idx];
144 |
145 | grad = (beta - 1) * c_tl + (1 - beta) * c_tr - beta * c_bl + beta * c_br;
146 | grad_offset_h += kernels[idb][idc][idy][idx] * grad * gradOutput[idb][c][idy][idx] * offset_unit;
147 |
148 | grad = (alpha - 1) * c_tl - alpha * c_tr + (1 - alpha) * c_bl + alpha * c_br;
149 | grad_offset_v += kernels[idb][idc][idy][idx] * grad * gradOutput[idb][c][idy][idx] * offset_unit;
150 | }
151 |
152 | gradInput_kernels[idb][idc][idy][idx] = grad_kernels;
153 |
154 | gradInput_offsets_h[idb][idc][idy][idx] = grad_offset_h;
155 | gradInput_offsets_v[idb][idc][idy][idx] = grad_offset_v;
156 | }
157 |
158 | void adaptive_gridsampler_kernel_backward(const torch::Tensor& img, const torch::Tensor& kernels, const torch::Tensor& offsets_h, const torch::Tensor& offsets_v, const int offset_unit, const torch::Tensor& gradOutput, const int padding,
159 | torch::Tensor& gradInput_kernels, torch::Tensor& gradInput_offsets_h, torch::Tensor& gradInput_offsets_v)
160 | {
161 | kernel_adaptive_gridsampler_backward<<<(gradInput_kernels.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0>>>(
162 | img.packed_accessor32(), kernels.packed_accessor32(),
163 | offsets_h.packed_accessor32(), offsets_v.packed_accessor32(),
164 | offset_unit,
165 | gradOutput.packed_accessor32(),
166 | padding,
167 | gradInput_kernels.packed_accessor32(),
168 | gradInput_offsets_h.packed_accessor32(), gradInput_offsets_v.packed_accessor32(),
169 | gradInput_kernels.numel());
170 |
171 | checkCudaErrors(cudaGetLastError());
172 | }
--------------------------------------------------------------------------------
/adaptive_gridsampler/adaptive_gridsampler_kernel.cuh:
--------------------------------------------------------------------------------
1 | #ifndef ADAPTIVE_GRIDSAMPLER_KERNEL_CUH
2 | #define ADAPTIVE_GRIDSAMPLER_KERNEL_CUH
3 |
4 | #include
5 |
6 | void adaptive_gridsampler_kernel_forward(const torch::Tensor& img, const torch::Tensor& kernels, const torch::Tensor& offsets_h, const torch::Tensor& offsets_v, const int offset_unit, const int padding, torch::Tensor& output);
7 | void adaptive_gridsampler_kernel_backward(const torch::Tensor& img, const torch::Tensor& kernels, const torch::Tensor& offsets_h, const torch::Tensor& offsets_v, const int offset_unit, const torch::Tensor& gradOutput, const int padding, torch::Tensor& gradInput_kernels, torch::Tensor& gradInput_offsets_h, torch::Tensor& gradInput_offsets_v);
8 |
9 | #endif
--------------------------------------------------------------------------------
/adaptive_gridsampler/gridsampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Function, gradcheck
5 |
6 | from .adaptive_gridsampler_cuda import forward
7 |
8 |
9 | class GridSamplerFunction(Function):
10 | @staticmethod
11 | def forward(ctx, img, kernels, offsets_h, offsets_v, offset_unit, padding, downscale_factor):
12 | assert isinstance(downscale_factor, int)
13 | assert isinstance(padding, int)
14 |
15 | ctx.padding = padding
16 | ctx.offset_unit = offset_unit
17 |
18 | b, c, h, w = img.size()
19 | assert h // downscale_factor == kernels.size(2)
20 | assert w // downscale_factor == kernels.size(3)
21 |
22 | img = nn.ReflectionPad2d(padding)(img)
23 | # ctx.save_for_backward(img, kernels, offsets_h, offsets_v)
24 |
25 | output = img.new(b, c, h // downscale_factor, w // downscale_factor).zero_()
26 | forward(img, kernels, offsets_h, offsets_v, offset_unit, padding, output)
27 |
28 | return output
29 |
30 | @staticmethod
31 | def backward(ctx, grad_output):
32 | raise NotImplementedError
33 |
34 |
35 | class Downsampler(nn.Module):
36 | def __init__(self, ds, k_size):
37 | super(Downsampler, self).__init__()
38 | self.ds = ds
39 | self.k_size = k_size
40 |
41 | def forward(self, img, kernels, offsets_h, offsets_v, offset_unit):
42 | assert self.k_size ** 2 == kernels.size(1)
43 | return GridSamplerFunction.apply(img, kernels, offsets_h, offsets_v, offset_unit, self.k_size // 2, self.ds)
44 |
--------------------------------------------------------------------------------
/adaptive_gridsampler/helper_cuda.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 1993-2012 NVIDIA Corporation. All rights reserved.
3 | *
4 | * Please refer to the NVIDIA end user license agreement (EULA) associated
5 | * with this source code for terms and conditions that govern your use of
6 | * this software. Any use, reproduction, disclosure, or distribution of
7 | * this software and related documentation outside the terms of the EULA
8 | * is strictly prohibited.
9 | *
10 | */
11 |
12 | ////////////////////////////////////////////////////////////////////////////////
13 | // These are CUDA Helper functions for initialization and error checking
14 |
15 | #ifndef HELPER_CUDA_H
16 | #define HELPER_CUDA_H
17 |
18 | #pragma once
19 |
20 | #include
21 | #include
22 | #include
23 |
24 | #include "helper_string.h"
25 |
26 | //#include
27 | //#include
28 | //#include
29 |
30 | // Note, it is required that your SDK sample to include the proper header files, please
31 | // refer the CUDA examples for examples of the needed CUDA headers, which may change depending
32 | // on which CUDA functions are used.
33 |
34 | // CUDA Runtime error messages
35 | #ifdef __DRIVER_TYPES_H__
36 | static const char *_cudaGetErrorEnum(cudaError_t error)
37 | {
38 | switch (error)
39 | {
40 | case cudaSuccess:
41 | return "cudaSuccess";
42 |
43 | case cudaErrorMissingConfiguration:
44 | return "cudaErrorMissingConfiguration";
45 |
46 | case cudaErrorMemoryAllocation:
47 | return "cudaErrorMemoryAllocation";
48 |
49 | case cudaErrorInitializationError:
50 | return "cudaErrorInitializationError";
51 |
52 | case cudaErrorLaunchFailure:
53 | return "cudaErrorLaunchFailure";
54 |
55 | case cudaErrorPriorLaunchFailure:
56 | return "cudaErrorPriorLaunchFailure";
57 |
58 | case cudaErrorLaunchTimeout:
59 | return "cudaErrorLaunchTimeout";
60 |
61 | case cudaErrorLaunchOutOfResources:
62 | return "cudaErrorLaunchOutOfResources";
63 |
64 | case cudaErrorInvalidDeviceFunction:
65 | return "cudaErrorInvalidDeviceFunction";
66 |
67 | case cudaErrorInvalidConfiguration:
68 | return "cudaErrorInvalidConfiguration";
69 |
70 | case cudaErrorInvalidDevice:
71 | return "cudaErrorInvalidDevice";
72 |
73 | case cudaErrorInvalidValue:
74 | return "cudaErrorInvalidValue";
75 |
76 | case cudaErrorInvalidPitchValue:
77 | return "cudaErrorInvalidPitchValue";
78 |
79 | case cudaErrorInvalidSymbol:
80 | return "cudaErrorInvalidSymbol";
81 |
82 | case cudaErrorMapBufferObjectFailed:
83 | return "cudaErrorMapBufferObjectFailed";
84 |
85 | case cudaErrorUnmapBufferObjectFailed:
86 | return "cudaErrorUnmapBufferObjectFailed";
87 |
88 | case cudaErrorInvalidHostPointer:
89 | return "cudaErrorInvalidHostPointer";
90 |
91 | case cudaErrorInvalidDevicePointer:
92 | return "cudaErrorInvalidDevicePointer";
93 |
94 | case cudaErrorInvalidTexture:
95 | return "cudaErrorInvalidTexture";
96 |
97 | case cudaErrorInvalidTextureBinding:
98 | return "cudaErrorInvalidTextureBinding";
99 |
100 | case cudaErrorInvalidChannelDescriptor:
101 | return "cudaErrorInvalidChannelDescriptor";
102 |
103 | case cudaErrorInvalidMemcpyDirection:
104 | return "cudaErrorInvalidMemcpyDirection";
105 |
106 | case cudaErrorAddressOfConstant:
107 | return "cudaErrorAddressOfConstant";
108 |
109 | case cudaErrorTextureFetchFailed:
110 | return "cudaErrorTextureFetchFailed";
111 |
112 | case cudaErrorTextureNotBound:
113 | return "cudaErrorTextureNotBound";
114 |
115 | case cudaErrorSynchronizationError:
116 | return "cudaErrorSynchronizationError";
117 |
118 | case cudaErrorInvalidFilterSetting:
119 | return "cudaErrorInvalidFilterSetting";
120 |
121 | case cudaErrorInvalidNormSetting:
122 | return "cudaErrorInvalidNormSetting";
123 |
124 | case cudaErrorMixedDeviceExecution:
125 | return "cudaErrorMixedDeviceExecution";
126 |
127 | case cudaErrorCudartUnloading:
128 | return "cudaErrorCudartUnloading";
129 |
130 | case cudaErrorUnknown:
131 | return "cudaErrorUnknown";
132 |
133 | case cudaErrorNotYetImplemented:
134 | return "cudaErrorNotYetImplemented";
135 |
136 | case cudaErrorMemoryValueTooLarge:
137 | return "cudaErrorMemoryValueTooLarge";
138 |
139 | case cudaErrorInvalidResourceHandle:
140 | return "cudaErrorInvalidResourceHandle";
141 |
142 | case cudaErrorNotReady:
143 | return "cudaErrorNotReady";
144 |
145 | case cudaErrorInsufficientDriver:
146 | return "cudaErrorInsufficientDriver";
147 |
148 | case cudaErrorSetOnActiveProcess:
149 | return "cudaErrorSetOnActiveProcess";
150 |
151 | case cudaErrorInvalidSurface:
152 | return "cudaErrorInvalidSurface";
153 |
154 | case cudaErrorNoDevice:
155 | return "cudaErrorNoDevice";
156 |
157 | case cudaErrorECCUncorrectable:
158 | return "cudaErrorECCUncorrectable";
159 |
160 | case cudaErrorSharedObjectSymbolNotFound:
161 | return "cudaErrorSharedObjectSymbolNotFound";
162 |
163 | case cudaErrorSharedObjectInitFailed:
164 | return "cudaErrorSharedObjectInitFailed";
165 |
166 | case cudaErrorUnsupportedLimit:
167 | return "cudaErrorUnsupportedLimit";
168 |
169 | case cudaErrorDuplicateVariableName:
170 | return "cudaErrorDuplicateVariableName";
171 |
172 | case cudaErrorDuplicateTextureName:
173 | return "cudaErrorDuplicateTextureName";
174 |
175 | case cudaErrorDuplicateSurfaceName:
176 | return "cudaErrorDuplicateSurfaceName";
177 |
178 | case cudaErrorDevicesUnavailable:
179 | return "cudaErrorDevicesUnavailable";
180 |
181 | case cudaErrorInvalidKernelImage:
182 | return "cudaErrorInvalidKernelImage";
183 |
184 | case cudaErrorNoKernelImageForDevice:
185 | return "cudaErrorNoKernelImageForDevice";
186 |
187 | case cudaErrorIncompatibleDriverContext:
188 | return "cudaErrorIncompatibleDriverContext";
189 |
190 | case cudaErrorPeerAccessAlreadyEnabled:
191 | return "cudaErrorPeerAccessAlreadyEnabled";
192 |
193 | case cudaErrorPeerAccessNotEnabled:
194 | return "cudaErrorPeerAccessNotEnabled";
195 |
196 | case cudaErrorDeviceAlreadyInUse:
197 | return "cudaErrorDeviceAlreadyInUse";
198 |
199 | case cudaErrorProfilerDisabled:
200 | return "cudaErrorProfilerDisabled";
201 |
202 | case cudaErrorProfilerNotInitialized:
203 | return "cudaErrorProfilerNotInitialized";
204 |
205 | case cudaErrorProfilerAlreadyStarted:
206 | return "cudaErrorProfilerAlreadyStarted";
207 |
208 | case cudaErrorProfilerAlreadyStopped:
209 | return "cudaErrorProfilerAlreadyStopped";
210 |
211 | #if __CUDA_API_VERSION >= 0x4000
212 |
213 | case cudaErrorAssert:
214 | return "cudaErrorAssert";
215 |
216 | case cudaErrorTooManyPeers:
217 | return "cudaErrorTooManyPeers";
218 |
219 | case cudaErrorHostMemoryAlreadyRegistered:
220 | return "cudaErrorHostMemoryAlreadyRegistered";
221 |
222 | case cudaErrorHostMemoryNotRegistered:
223 | return "cudaErrorHostMemoryNotRegistered";
224 | #endif
225 |
226 | case cudaErrorStartupFailure:
227 | return "cudaErrorStartupFailure";
228 |
229 | case cudaErrorApiFailureBase:
230 | return "cudaErrorApiFailureBase";
231 | }
232 |
233 | return "";
234 | }
235 | #endif
236 |
237 | #ifdef __cuda_cuda_h__
238 | // CUDA Driver API errors
239 | static const char *_cudaGetErrorEnum(CUresult error)
240 | {
241 | switch (error)
242 | {
243 | case CUDA_SUCCESS:
244 | return "CUDA_SUCCESS";
245 |
246 | case CUDA_ERROR_INVALID_VALUE:
247 | return "CUDA_ERROR_INVALID_VALUE";
248 |
249 | case CUDA_ERROR_OUT_OF_MEMORY:
250 | return "CUDA_ERROR_OUT_OF_MEMORY";
251 |
252 | case CUDA_ERROR_NOT_INITIALIZED:
253 | return "CUDA_ERROR_NOT_INITIALIZED";
254 |
255 | case CUDA_ERROR_DEINITIALIZED:
256 | return "CUDA_ERROR_DEINITIALIZED";
257 |
258 | case CUDA_ERROR_PROFILER_DISABLED:
259 | return "CUDA_ERROR_PROFILER_DISABLED";
260 |
261 | case CUDA_ERROR_PROFILER_NOT_INITIALIZED:
262 | return "CUDA_ERROR_PROFILER_NOT_INITIALIZED";
263 |
264 | case CUDA_ERROR_PROFILER_ALREADY_STARTED:
265 | return "CUDA_ERROR_PROFILER_ALREADY_STARTED";
266 |
267 | case CUDA_ERROR_PROFILER_ALREADY_STOPPED:
268 | return "CUDA_ERROR_PROFILER_ALREADY_STOPPED";
269 |
270 | case CUDA_ERROR_NO_DEVICE:
271 | return "CUDA_ERROR_NO_DEVICE";
272 |
273 | case CUDA_ERROR_INVALID_DEVICE:
274 | return "CUDA_ERROR_INVALID_DEVICE";
275 |
276 | case CUDA_ERROR_INVALID_IMAGE:
277 | return "CUDA_ERROR_INVALID_IMAGE";
278 |
279 | case CUDA_ERROR_INVALID_CONTEXT:
280 | return "CUDA_ERROR_INVALID_CONTEXT";
281 |
282 | case CUDA_ERROR_CONTEXT_ALREADY_CURRENT:
283 | return "CUDA_ERROR_CONTEXT_ALREADY_CURRENT";
284 |
285 | case CUDA_ERROR_MAP_FAILED:
286 | return "CUDA_ERROR_MAP_FAILED";
287 |
288 | case CUDA_ERROR_UNMAP_FAILED:
289 | return "CUDA_ERROR_UNMAP_FAILED";
290 |
291 | case CUDA_ERROR_ARRAY_IS_MAPPED:
292 | return "CUDA_ERROR_ARRAY_IS_MAPPED";
293 |
294 | case CUDA_ERROR_ALREADY_MAPPED:
295 | return "CUDA_ERROR_ALREADY_MAPPED";
296 |
297 | case CUDA_ERROR_NO_BINARY_FOR_GPU:
298 | return "CUDA_ERROR_NO_BINARY_FOR_GPU";
299 |
300 | case CUDA_ERROR_ALREADY_ACQUIRED:
301 | return "CUDA_ERROR_ALREADY_ACQUIRED";
302 |
303 | case CUDA_ERROR_NOT_MAPPED:
304 | return "CUDA_ERROR_NOT_MAPPED";
305 |
306 | case CUDA_ERROR_NOT_MAPPED_AS_ARRAY:
307 | return "CUDA_ERROR_NOT_MAPPED_AS_ARRAY";
308 |
309 | case CUDA_ERROR_NOT_MAPPED_AS_POINTER:
310 | return "CUDA_ERROR_NOT_MAPPED_AS_POINTER";
311 |
312 | case CUDA_ERROR_ECC_UNCORRECTABLE:
313 | return "CUDA_ERROR_ECC_UNCORRECTABLE";
314 |
315 | case CUDA_ERROR_UNSUPPORTED_LIMIT:
316 | return "CUDA_ERROR_UNSUPPORTED_LIMIT";
317 |
318 | case CUDA_ERROR_CONTEXT_ALREADY_IN_USE:
319 | return "CUDA_ERROR_CONTEXT_ALREADY_IN_USE";
320 |
321 | case CUDA_ERROR_INVALID_SOURCE:
322 | return "CUDA_ERROR_INVALID_SOURCE";
323 |
324 | case CUDA_ERROR_FILE_NOT_FOUND:
325 | return "CUDA_ERROR_FILE_NOT_FOUND";
326 |
327 | case CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND:
328 | return "CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND";
329 |
330 | case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED:
331 | return "CUDA_ERROR_SHARED_OBJECT_INIT_FAILED";
332 |
333 | case CUDA_ERROR_OPERATING_SYSTEM:
334 | return "CUDA_ERROR_OPERATING_SYSTEM";
335 |
336 | case CUDA_ERROR_INVALID_HANDLE:
337 | return "CUDA_ERROR_INVALID_HANDLE";
338 |
339 | case CUDA_ERROR_NOT_FOUND:
340 | return "CUDA_ERROR_NOT_FOUND";
341 |
342 | case CUDA_ERROR_NOT_READY:
343 | return "CUDA_ERROR_NOT_READY";
344 |
345 | case CUDA_ERROR_LAUNCH_FAILED:
346 | return "CUDA_ERROR_LAUNCH_FAILED";
347 |
348 | case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES:
349 | return "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES";
350 |
351 | case CUDA_ERROR_LAUNCH_TIMEOUT:
352 | return "CUDA_ERROR_LAUNCH_TIMEOUT";
353 |
354 | case CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING:
355 | return "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING";
356 |
357 | case CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED:
358 | return "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED";
359 |
360 | case CUDA_ERROR_PEER_ACCESS_NOT_ENABLED:
361 | return "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED";
362 |
363 | case CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE:
364 | return "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE";
365 |
366 | case CUDA_ERROR_CONTEXT_IS_DESTROYED:
367 | return "CUDA_ERROR_CONTEXT_IS_DESTROYED";
368 |
369 | case CUDA_ERROR_ASSERT:
370 | return "CUDA_ERROR_ASSERT";
371 |
372 | case CUDA_ERROR_TOO_MANY_PEERS:
373 | return "CUDA_ERROR_TOO_MANY_PEERS";
374 |
375 | case CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED:
376 | return "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED";
377 |
378 | case CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED:
379 | return "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED";
380 |
381 | case CUDA_ERROR_UNKNOWN:
382 | return "CUDA_ERROR_UNKNOWN";
383 | }
384 |
385 | return "";
386 | }
387 | #endif
388 |
389 | #ifdef CUBLAS_API_H_
390 | // cuBLAS API errors
391 | static const char *_cudaGetErrorEnum(cublasStatus_t error)
392 | {
393 | switch (error)
394 | {
395 | case CUBLAS_STATUS_SUCCESS:
396 | return "CUBLAS_STATUS_SUCCESS";
397 |
398 | case CUBLAS_STATUS_NOT_INITIALIZED:
399 | return "CUBLAS_STATUS_NOT_INITIALIZED";
400 |
401 | case CUBLAS_STATUS_ALLOC_FAILED:
402 | return "CUBLAS_STATUS_ALLOC_FAILED";
403 |
404 | case CUBLAS_STATUS_INVALID_VALUE:
405 | return "CUBLAS_STATUS_INVALID_VALUE";
406 |
407 | case CUBLAS_STATUS_ARCH_MISMATCH:
408 | return "CUBLAS_STATUS_ARCH_MISMATCH";
409 |
410 | case CUBLAS_STATUS_MAPPING_ERROR:
411 | return "CUBLAS_STATUS_MAPPING_ERROR";
412 |
413 | case CUBLAS_STATUS_EXECUTION_FAILED:
414 | return "CUBLAS_STATUS_EXECUTION_FAILED";
415 |
416 | case CUBLAS_STATUS_INTERNAL_ERROR:
417 | return "CUBLAS_STATUS_INTERNAL_ERROR";
418 | }
419 |
420 | return "";
421 | }
422 | #endif
423 |
424 | #ifdef _CUFFT_H_
425 | // cuFFT API errors
426 | static const char *_cudaGetErrorEnum(cufftResult error)
427 | {
428 | switch (error)
429 | {
430 | case CUFFT_SUCCESS:
431 | return "CUFFT_SUCCESS";
432 |
433 | case CUFFT_INVALID_PLAN:
434 | return "CUFFT_INVALID_PLAN";
435 |
436 | case CUFFT_ALLOC_FAILED:
437 | return "CUFFT_ALLOC_FAILED";
438 |
439 | case CUFFT_INVALID_TYPE:
440 | return "CUFFT_INVALID_TYPE";
441 |
442 | case CUFFT_INVALID_VALUE:
443 | return "CUFFT_INVALID_VALUE";
444 |
445 | case CUFFT_INTERNAL_ERROR:
446 | return "CUFFT_INTERNAL_ERROR";
447 |
448 | case CUFFT_EXEC_FAILED:
449 | return "CUFFT_EXEC_FAILED";
450 |
451 | case CUFFT_SETUP_FAILED:
452 | return "CUFFT_SETUP_FAILED";
453 |
454 | case CUFFT_INVALID_SIZE:
455 | return "CUFFT_INVALID_SIZE";
456 |
457 | case CUFFT_UNALIGNED_DATA:
458 | return "CUFFT_UNALIGNED_DATA";
459 | }
460 |
461 | return "";
462 | }
463 | #endif
464 |
465 |
466 | #ifdef CUSPARSEAPI
467 | // cuSPARSE API errors
468 | static const char *_cudaGetErrorEnum(cusparseStatus_t error)
469 | {
470 | switch (error)
471 | {
472 | case CUSPARSE_STATUS_SUCCESS:
473 | return "CUSPARSE_STATUS_SUCCESS";
474 |
475 | case CUSPARSE_STATUS_NOT_INITIALIZED:
476 | return "CUSPARSE_STATUS_NOT_INITIALIZED";
477 |
478 | case CUSPARSE_STATUS_ALLOC_FAILED:
479 | return "CUSPARSE_STATUS_ALLOC_FAILED";
480 |
481 | case CUSPARSE_STATUS_INVALID_VALUE:
482 | return "CUSPARSE_STATUS_INVALID_VALUE";
483 |
484 | case CUSPARSE_STATUS_ARCH_MISMATCH:
485 | return "CUSPARSE_STATUS_ARCH_MISMATCH";
486 |
487 | case CUSPARSE_STATUS_MAPPING_ERROR:
488 | return "CUSPARSE_STATUS_MAPPING_ERROR";
489 |
490 | case CUSPARSE_STATUS_EXECUTION_FAILED:
491 | return "CUSPARSE_STATUS_EXECUTION_FAILED";
492 |
493 | case CUSPARSE_STATUS_INTERNAL_ERROR:
494 | return "CUSPARSE_STATUS_INTERNAL_ERROR";
495 |
496 | case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
497 | return "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
498 | }
499 |
500 | return "";
501 | }
502 | #endif
503 |
504 | #ifdef CURAND_H_
505 | // cuRAND API errors
506 | static const char *_cudaGetErrorEnum(curandStatus_t error)
507 | {
508 | switch (error)
509 | {
510 | case CURAND_STATUS_SUCCESS:
511 | return "CURAND_STATUS_SUCCESS";
512 |
513 | case CURAND_STATUS_VERSION_MISMATCH:
514 | return "CURAND_STATUS_VERSION_MISMATCH";
515 |
516 | case CURAND_STATUS_NOT_INITIALIZED:
517 | return "CURAND_STATUS_NOT_INITIALIZED";
518 |
519 | case CURAND_STATUS_ALLOCATION_FAILED:
520 | return "CURAND_STATUS_ALLOCATION_FAILED";
521 |
522 | case CURAND_STATUS_TYPE_ERROR:
523 | return "CURAND_STATUS_TYPE_ERROR";
524 |
525 | case CURAND_STATUS_OUT_OF_RANGE:
526 | return "CURAND_STATUS_OUT_OF_RANGE";
527 |
528 | case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
529 | return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
530 |
531 | case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
532 | return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
533 |
534 | case CURAND_STATUS_LAUNCH_FAILURE:
535 | return "CURAND_STATUS_LAUNCH_FAILURE";
536 |
537 | case CURAND_STATUS_PREEXISTING_FAILURE:
538 | return "CURAND_STATUS_PREEXISTING_FAILURE";
539 |
540 | case CURAND_STATUS_INITIALIZATION_FAILED:
541 | return "CURAND_STATUS_INITIALIZATION_FAILED";
542 |
543 | case CURAND_STATUS_ARCH_MISMATCH:
544 | return "CURAND_STATUS_ARCH_MISMATCH";
545 |
546 | case CURAND_STATUS_INTERNAL_ERROR:
547 | return "CURAND_STATUS_INTERNAL_ERROR";
548 | }
549 |
550 | return "";
551 | }
552 | #endif
553 |
554 | #ifdef NV_NPPIDEFS_H
555 | // NPP API errors
556 | static const char *_cudaGetErrorEnum(NppStatus error)
557 | {
558 | switch (error)
559 | {
560 | case NPP_NOT_SUPPORTED_MODE_ERROR:
561 | return "NPP_NOT_SUPPORTED_MODE_ERROR";
562 |
563 | case NPP_ROUND_MODE_NOT_SUPPORTED_ERROR:
564 | return "NPP_ROUND_MODE_NOT_SUPPORTED_ERROR";
565 |
566 | case NPP_RESIZE_NO_OPERATION_ERROR:
567 | return "NPP_RESIZE_NO_OPERATION_ERROR";
568 |
569 | case NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY:
570 | return "NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY";
571 |
572 | case NPP_BAD_ARG_ERROR:
573 | return "NPP_BAD_ARG_ERROR";
574 |
575 | case NPP_LUT_NUMBER_OF_LEVELS_ERROR:
576 | return "NPP_LUT_NUMBER_OF_LEVELS_ERROR";
577 |
578 | case NPP_TEXTURE_BIND_ERROR:
579 | return "NPP_TEXTURE_BIND_ERROR";
580 |
581 | case NPP_COEFF_ERROR:
582 | return "NPP_COEFF_ERROR";
583 |
584 | case NPP_RECT_ERROR:
585 | return "NPP_RECT_ERROR";
586 |
587 | case NPP_QUAD_ERROR:
588 | return "NPP_QUAD_ERROR";
589 |
590 | case NPP_WRONG_INTERSECTION_ROI_ERROR:
591 | return "NPP_WRONG_INTERSECTION_ROI_ERROR";
592 |
593 | case NPP_NOT_EVEN_STEP_ERROR:
594 | return "NPP_NOT_EVEN_STEP_ERROR";
595 |
596 | case NPP_INTERPOLATION_ERROR:
597 | return "NPP_INTERPOLATION_ERROR";
598 |
599 | case NPP_RESIZE_FACTOR_ERROR:
600 | return "NPP_RESIZE_FACTOR_ERROR";
601 |
602 | case NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR:
603 | return "NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR";
604 |
605 | case NPP_MEMFREE_ERR:
606 | return "NPP_MEMFREE_ERR";
607 |
608 | case NPP_MEMSET_ERR:
609 | return "NPP_MEMSET_ERR";
610 |
611 | case NPP_MEMCPY_ERROR:
612 | return "NPP_MEMCPY_ERROR";
613 |
614 | case NPP_MEM_ALLOC_ERR:
615 | return "NPP_MEM_ALLOC_ERR";
616 |
617 | case NPP_HISTO_NUMBER_OF_LEVELS_ERROR:
618 | return "NPP_HISTO_NUMBER_OF_LEVELS_ERROR";
619 |
620 | case NPP_MIRROR_FLIP_ERR:
621 | return "NPP_MIRROR_FLIP_ERR";
622 |
623 | case NPP_INVALID_INPUT:
624 | return "NPP_INVALID_INPUT";
625 |
626 | case NPP_ALIGNMENT_ERROR:
627 | return "NPP_ALIGNMENT_ERROR";
628 |
629 | case NPP_STEP_ERROR:
630 | return "NPP_STEP_ERROR";
631 |
632 | case NPP_SIZE_ERROR:
633 | return "NPP_SIZE_ERROR";
634 |
635 | case NPP_POINTER_ERROR:
636 | return "NPP_POINTER_ERROR";
637 |
638 | case NPP_NULL_POINTER_ERROR:
639 | return "NPP_NULL_POINTER_ERROR";
640 |
641 | case NPP_CUDA_KERNEL_EXECUTION_ERROR:
642 | return "NPP_CUDA_KERNEL_EXECUTION_ERROR";
643 |
644 | case NPP_NOT_IMPLEMENTED_ERROR:
645 | return "NPP_NOT_IMPLEMENTED_ERROR";
646 |
647 | case NPP_ERROR:
648 | return "NPP_ERROR";
649 |
650 | case NPP_SUCCESS:
651 | return "NPP_SUCCESS";
652 |
653 | case NPP_WARNING:
654 | return "NPP_WARNING";
655 |
656 | case NPP_WRONG_INTERSECTION_QUAD_WARNING:
657 | return "NPP_WRONG_INTERSECTION_QUAD_WARNING";
658 |
659 | case NPP_MISALIGNED_DST_ROI_WARNING:
660 | return "NPP_MISALIGNED_DST_ROI_WARNING";
661 |
662 | case NPP_AFFINE_QUAD_INCORRECT_WARNING:
663 | return "NPP_AFFINE_QUAD_INCORRECT_WARNING";
664 |
665 | case NPP_DOUBLE_SIZE_WARNING:
666 | return "NPP_DOUBLE_SIZE_WARNING";
667 |
668 | case NPP_ODD_ROI_WARNING:
669 | return "NPP_ODD_ROI_WARNING";
670 |
671 | case NPP_WRONG_INTERSECTION_ROI_WARNING:
672 | return "NPP_WRONG_INTERSECTION_ROI_WARNING";
673 | }
674 |
675 | return "";
676 | }
677 | #endif
678 |
679 | template< typename T >
680 | bool check(T result, char const *const func, const char *const file, int const line)
681 | {
682 | if (result)
683 | {
684 | fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n",
685 | file, line, static_cast(result), _cudaGetErrorEnum(result), func);
686 | /*
687 | std::stringstream ss;
688 | std::string msg("CUDA error at ");
689 | msg += file;
690 | msg += ":";
691 | ss << line;
692 | msg += ss.str();
693 | msg += " code=";
694 | ss << static_cast(result);
695 | msg += ss.str();
696 | msg += " (";
697 | msg += _cudaGetErrorEnum(result);
698 | msg += ") \"";
699 | msg += func;
700 | msg += "\"";
701 | //throw msg;
702 | std::cerr << msg <<"\n";
703 | */
704 | return true;
705 | }
706 | else
707 | {
708 | return false;
709 | }
710 | }
711 |
712 | #ifdef __DRIVER_TYPES_H__
713 | // This will output the proper CUDA error strings in the event that a CUDA host call returns an error
714 | #define checkCudaErrors(val) check ( (val), #val, __FILE__, __LINE__ )
715 |
716 | // This will output the proper error string when calling cudaGetLastError
717 | #define getLastCudaError(msg) __getLastCudaError (msg, __FILE__, __LINE__)
718 |
719 | inline void __getLastCudaError(const char *errorMessage, const char *file, const int line)
720 | {
721 | cudaError_t err = cudaGetLastError();
722 |
723 | if (cudaSuccess != err)
724 | {
725 | fprintf(stderr, "%s(%i) : getLastCudaError() CUDA error : %s : (%d) %s.\n",
726 | file, line, errorMessage, (int)err, cudaGetErrorString(err));
727 | exit(EXIT_FAILURE);
728 | }
729 | }
730 | #endif
731 |
732 | #ifndef MAX
733 | #define MAX(a,b) (a > b ? a : b)
734 | #endif
735 |
736 | // Beginning of GPU Architecture definitions
737 | inline int _ConvertSMVer2Cores(int major, int minor)
738 | {
739 | // Defines for GPU Architecture types (using the SM version to determine the # of cores per SM
740 | typedef struct
741 | {
742 | int SM; // 0xMm (hexidecimal notation), M = SM Major version, and m = SM minor version
743 | int Cores;
744 | } sSMtoCores;
745 |
746 | sSMtoCores nGpuArchCoresPerSM[] =
747 | {
748 | { 0x10, 8 }, // Tesla Generation (SM 1.0) G80 class
749 | { 0x11, 8 }, // Tesla Generation (SM 1.1) G8x class
750 | { 0x12, 8 }, // Tesla Generation (SM 1.2) G9x class
751 | { 0x13, 8 }, // Tesla Generation (SM 1.3) GT200 class
752 | { 0x20, 32 }, // Fermi Generation (SM 2.0) GF100 class
753 | { 0x21, 48 }, // Fermi Generation (SM 2.1) GF10x class
754 | { 0x30, 192}, // Kepler Generation (SM 3.0) GK10x class
755 | { 0x35, 192}, // Kepler Generation (SM 3.5) GK11x class
756 | { -1, -1 }
757 | };
758 |
759 | int index = 0;
760 |
761 | while (nGpuArchCoresPerSM[index].SM != -1)
762 | {
763 | if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor))
764 | {
765 | return nGpuArchCoresPerSM[index].Cores;
766 | }
767 |
768 | index++;
769 | }
770 |
771 | // If we don't find the values, we default use the previous one to run properly
772 | printf("MapSMtoCores for SM %d.%d is undefined. Default to use %d Cores/SM\n", major, minor, nGpuArchCoresPerSM[7].Cores);
773 | return nGpuArchCoresPerSM[7].Cores;
774 | }
775 | // end of GPU Architecture definitions
776 |
777 | #ifdef __CUDA_RUNTIME_H__
778 | // General GPU Device CUDA Initialization
779 | inline int gpuDeviceInit(int devID)
780 | {
781 | int deviceCount;
782 | checkCudaErrors(cudaGetDeviceCount(&deviceCount));
783 |
784 | if (deviceCount == 0)
785 | {
786 | fprintf(stderr, "gpuDeviceInit() CUDA error: no devices supporting CUDA.\n");
787 | exit(EXIT_FAILURE);
788 | }
789 |
790 | if (devID < 0)
791 | {
792 | devID = 0;
793 | }
794 |
795 | if (devID > deviceCount-1)
796 | {
797 | fprintf(stderr, "\n");
798 | fprintf(stderr, ">> %d CUDA capable GPU device(s) detected. <<\n", deviceCount);
799 | fprintf(stderr, ">> gpuDeviceInit (-device=%d) is not a valid GPU device. <<\n", devID);
800 | fprintf(stderr, "\n");
801 | return -devID;
802 | }
803 |
804 | cudaDeviceProp deviceProp;
805 | checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID));
806 |
807 | if (deviceProp.computeMode == cudaComputeModeProhibited)
808 | {
809 | fprintf(stderr, "Error: device is running in , no threads can use ::cudaSetDevice().\n");
810 | return -1;
811 | }
812 |
813 | if (deviceProp.major < 1)
814 | {
815 | fprintf(stderr, "gpuDeviceInit(): GPU device does not support CUDA.\n");
816 | exit(EXIT_FAILURE);
817 | }
818 |
819 | checkCudaErrors(cudaSetDevice(devID));
820 | printf("gpuDeviceInit() CUDA Device [%d]: \"%s\n", devID, deviceProp.name);
821 |
822 | return devID;
823 | }
824 |
825 | // This function returns the best GPU (with maximum GFLOPS)
826 | inline int gpuGetMaxGflopsDeviceId()
827 | {
828 | int current_device = 0, sm_per_multiproc = 0;
829 | int max_compute_perf = 0, max_perf_device = 0;
830 | int device_count = 0, best_SM_arch = 0;
831 | cudaDeviceProp deviceProp;
832 | cudaGetDeviceCount(&device_count);
833 |
834 | // Find the best major SM Architecture GPU device
835 | while (current_device < device_count)
836 | {
837 | cudaGetDeviceProperties(&deviceProp, current_device);
838 |
839 | // If this GPU is not running on Compute Mode prohibited, then we can add it to the list
840 | if (deviceProp.computeMode != cudaComputeModeProhibited)
841 | {
842 | if (deviceProp.major > 0 && deviceProp.major < 9999)
843 | {
844 | best_SM_arch = MAX(best_SM_arch, deviceProp.major);
845 | }
846 | }
847 |
848 | current_device++;
849 | }
850 |
851 | // Find the best CUDA capable GPU device
852 | current_device = 0;
853 |
854 | while (current_device < device_count)
855 | {
856 | cudaGetDeviceProperties(&deviceProp, current_device);
857 |
858 | // If this GPU is not running on Compute Mode prohibited, then we can add it to the list
859 | if (deviceProp.computeMode != cudaComputeModeProhibited)
860 | {
861 | if (deviceProp.major == 9999 && deviceProp.minor == 9999)
862 | {
863 | sm_per_multiproc = 1;
864 | }
865 | else
866 | {
867 | sm_per_multiproc = _ConvertSMVer2Cores(deviceProp.major, deviceProp.minor);
868 | }
869 |
870 | int compute_perf = deviceProp.multiProcessorCount * sm_per_multiproc * deviceProp.clockRate;
871 |
872 | if (compute_perf > max_compute_perf)
873 | {
874 | // If we find GPU with SM major > 2, search only these
875 | if (best_SM_arch > 2)
876 | {
877 | // If our device==dest_SM_arch, choose this, or else pass
878 | if (deviceProp.major == best_SM_arch)
879 | {
880 | max_compute_perf = compute_perf;
881 | max_perf_device = current_device;
882 | }
883 | }
884 | else
885 | {
886 | max_compute_perf = compute_perf;
887 | max_perf_device = current_device;
888 | }
889 | }
890 | }
891 |
892 | ++current_device;
893 | }
894 |
895 | return max_perf_device;
896 | }
897 |
898 |
899 | // Initialization code to find the best CUDA Device
900 | inline int findCudaDevice(int argc, const char **argv)
901 | {
902 | cudaDeviceProp deviceProp;
903 | int devID = 0;
904 |
905 | // If the command-line has a device number specified, use it
906 | if (checkCmdLineFlag(argc, argv, "device"))
907 | {
908 | devID = getCmdLineArgumentInt(argc, argv, "device=");
909 |
910 | if (devID < 0)
911 | {
912 | printf("Invalid command line parameter\n ");
913 | exit(EXIT_FAILURE);
914 | }
915 | else
916 | {
917 | devID = gpuDeviceInit(devID);
918 |
919 | if (devID < 0)
920 | {
921 | printf("exiting...\n");
922 | exit(EXIT_FAILURE);
923 | }
924 | }
925 | }
926 | else
927 | {
928 | // Otherwise pick the device with highest Gflops/s
929 | devID = gpuGetMaxGflopsDeviceId();
930 | checkCudaErrors(cudaSetDevice(devID));
931 | checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID));
932 | printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", devID, deviceProp.name, deviceProp.major, deviceProp.minor);
933 | }
934 |
935 | return devID;
936 | }
937 |
938 | // General check for CUDA GPU SM Capabilities
939 | inline bool checkCudaCapabilities(int major_version, int minor_version)
940 | {
941 | cudaDeviceProp deviceProp;
942 | deviceProp.major = 0;
943 | deviceProp.minor = 0;
944 | int dev;
945 |
946 | checkCudaErrors(cudaGetDevice(&dev));
947 | checkCudaErrors(cudaGetDeviceProperties(&deviceProp, dev));
948 |
949 | if ((deviceProp.major > major_version) ||
950 | (deviceProp.major == major_version && deviceProp.minor >= minor_version))
951 | {
952 | printf("> Device %d: <%16s >, Compute SM %d.%d detected\n", dev, deviceProp.name, deviceProp.major, deviceProp.minor);
953 | return true;
954 | }
955 | else
956 | {
957 | printf("No GPU device was found that can support CUDA compute capability %d.%d.\n", major_version, minor_version);
958 | return false;
959 | }
960 | }
961 | #endif
962 |
963 | // end of CUDA Helper Functions
964 |
965 |
966 | #endif
--------------------------------------------------------------------------------
/adaptive_gridsampler/helper_string.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 1993-2012 NVIDIA Corporation. All rights reserved.
3 | *
4 | * Please refer to the NVIDIA end user license agreement (EULA) associated
5 | * with this source code for terms and conditions that govern your use of
6 | * this software. Any use, reproduction, disclosure, or distribution of
7 | * this software and related documentation outside the terms of the EULA
8 | * is strictly prohibited.
9 | *
10 | */
11 |
12 | // These are helper functions for the SDK samples (string parsing, timers, etc)
13 | #ifndef STRING_HELPER_H
14 | #define STRING_HELPER_H
15 |
16 | #include
17 | #include
18 | #include
19 | #include
20 |
21 | #ifdef _WIN32
22 | #ifndef STRCASECMP
23 | #define STRCASECMP _stricmp
24 | #endif
25 | #ifndef STRNCASECMP
26 | #define STRNCASECMP _strnicmp
27 | #endif
28 | #ifndef STRCPY
29 | #define STRCPY(sFilePath, nLength, sPath) strcpy_s(sFilePath, nLength, sPath)
30 | #endif
31 |
32 | #ifndef FOPEN
33 | #define FOPEN(fHandle,filename,mode) fopen_s(&fHandle, filename, mode)
34 | #endif
35 | #ifndef FOPEN_FAIL
36 | #define FOPEN_FAIL(result) (result != 0)
37 | #endif
38 | #ifndef SSCANF
39 | #define SSCANF sscanf_s
40 | #endif
41 |
42 | #else
43 | #include
44 | #include
45 |
46 | #ifndef STRCASECMP
47 | #define STRCASECMP strcasecmp
48 | #endif
49 | #ifndef STRNCASECMP
50 | #define STRNCASECMP strncasecmp
51 | #endif
52 | #ifndef STRCPY
53 | #define STRCPY(sFilePath, nLength, sPath) strcpy(sFilePath, sPath)
54 | #endif
55 |
56 | #ifndef FOPEN
57 | #define FOPEN(fHandle,filename,mode) (fHandle = fopen(filename, mode))
58 | #endif
59 | #ifndef FOPEN_FAIL
60 | #define FOPEN_FAIL(result) (result == NULL)
61 | #endif
62 | #ifndef SSCANF
63 | #define SSCANF sscanf
64 | #endif
65 | #endif
66 |
67 | // CUDA Utility Helper Functions
68 | inline int stringRemoveDelimiter(char delimiter, const char *string)
69 | {
70 | int string_start = 0;
71 |
72 | while (string[string_start] == delimiter)
73 | {
74 | string_start++;
75 | }
76 |
77 | if (string_start >= (int)strlen(string)-1)
78 | {
79 | return 0;
80 | }
81 |
82 | return string_start;
83 | }
84 |
85 | inline int getFileExtension(char *filename, char **extension)
86 | {
87 | int string_length = (int)strlen(filename);
88 |
89 | while (filename[string_length--] != '.') {
90 | if (string_length == 0)
91 | break;
92 | }
93 | if (string_length > 0) string_length += 2;
94 |
95 | if (string_length == 0)
96 | *extension = NULL;
97 | else
98 | *extension = &filename[string_length];
99 |
100 | return string_length;
101 | }
102 |
103 |
104 | inline int checkCmdLineFlag(const int argc, const char **argv, const char *string_ref)
105 | {
106 | bool bFound = false;
107 |
108 | if (argc >= 1)
109 | {
110 | for (int i=1; i < argc; i++)
111 | {
112 | int string_start = stringRemoveDelimiter('-', argv[i]);
113 | const char *string_argv = &argv[i][string_start];
114 |
115 | const char *equal_pos = strchr(string_argv, '=');
116 | int argv_length = (int)(equal_pos == 0 ? strlen(string_argv) : equal_pos - string_argv);
117 |
118 | int length = (int)strlen(string_ref);
119 |
120 | if (length == argv_length && !STRNCASECMP(string_argv, string_ref, length))
121 | {
122 |
123 | bFound = true;
124 | continue;
125 | }
126 | }
127 | }
128 |
129 | return (int)bFound;
130 | }
131 |
132 | inline int getCmdLineArgumentInt(const int argc, const char **argv, const char *string_ref)
133 | {
134 | bool bFound = false;
135 | int value = -1;
136 |
137 | if (argc >= 1)
138 | {
139 | for (int i=1; i < argc; i++)
140 | {
141 | int string_start = stringRemoveDelimiter('-', argv[i]);
142 | const char *string_argv = &argv[i][string_start];
143 | int length = (int)strlen(string_ref);
144 |
145 | if (!STRNCASECMP(string_argv, string_ref, length))
146 | {
147 | if (length+1 <= (int)strlen(string_argv))
148 | {
149 | int auto_inc = (string_argv[length] == '=') ? 1 : 0;
150 | value = atoi(&string_argv[length + auto_inc]);
151 | }
152 | else
153 | {
154 | value = 0;
155 | }
156 |
157 | bFound = true;
158 | continue;
159 | }
160 | }
161 | }
162 |
163 | if (bFound)
164 | {
165 | return value;
166 | }
167 | else
168 | {
169 | return 0;
170 | }
171 | }
172 |
173 | inline float getCmdLineArgumentFloat(const int argc, const char **argv, const char *string_ref)
174 | {
175 | bool bFound = false;
176 | float value = -1;
177 |
178 | if (argc >= 1)
179 | {
180 | for (int i=1; i < argc; i++)
181 | {
182 | int string_start = stringRemoveDelimiter('-', argv[i]);
183 | const char *string_argv = &argv[i][string_start];
184 | int length = (int)strlen(string_ref);
185 |
186 | if (!STRNCASECMP(string_argv, string_ref, length))
187 | {
188 | if (length+1 <= (int)strlen(string_argv))
189 | {
190 | int auto_inc = (string_argv[length] == '=') ? 1 : 0;
191 | value = (float)atof(&string_argv[length + auto_inc]);
192 | }
193 | else
194 | {
195 | value = 0.f;
196 | }
197 |
198 | bFound = true;
199 | continue;
200 | }
201 | }
202 | }
203 |
204 | if (bFound)
205 | {
206 | return value;
207 | }
208 | else
209 | {
210 | return 0;
211 | }
212 | }
213 |
214 | inline bool getCmdLineArgumentString(const int argc, const char **argv,
215 | const char *string_ref, char **string_retval)
216 | {
217 | bool bFound = false;
218 |
219 | if (argc >= 1)
220 | {
221 | for (int i=1; i < argc; i++)
222 | {
223 | int string_start = stringRemoveDelimiter('-', argv[i]);
224 | char *string_argv = (char *)&argv[i][string_start];
225 | int length = (int)strlen(string_ref);
226 |
227 | if (!STRNCASECMP(string_argv, string_ref, length))
228 | {
229 | *string_retval = &string_argv[length+1];
230 | bFound = true;
231 | continue;
232 | }
233 | }
234 | }
235 |
236 | if (!bFound)
237 | {
238 | *string_retval = NULL;
239 | }
240 |
241 | return bFound;
242 | }
243 |
244 | //////////////////////////////////////////////////////////////////////////////
245 | //! Find the path for a file assuming that
246 | //! files are found in the searchPath.
247 | //!
248 | //! @return the path if succeeded, otherwise 0
249 | //! @param filename name of the file
250 | //! @param executable_path optional absolute path of the executable
251 | //////////////////////////////////////////////////////////////////////////////
252 | inline char *sdkFindFilePath(const char *filename, const char *executable_path)
253 | {
254 | // defines a variable that is replaced with the name of the executable
255 |
256 | // Typical relative search paths to locate needed companion files (e.g. sample input data, or JIT source files)
257 | // The origin for the relative search may be the .exe file, a .bat file launching an .exe, a browser .exe launching the .exe or .bat, etc
258 | const char *searchPath[] =
259 | {
260 | "./", // same dir
261 | "./common/", // "/common/" subdir
262 | "./common/data/", // "/common/data/" subdir
263 | "./data/", // "/data/" subdir
264 | "./src/", // "/src/" subdir
265 | "./src//data/", // "/src//data/" subdir
266 | "./inc/", // "/inc/" subdir
267 | "./0_Simple/", // "/0_Simple/" subdir
268 | "./1_Utilities/", // "/1_Utilities/" subdir
269 | "./2_Graphics/", // "/2_Graphics/" subdir
270 | "./3_Imaging/", // "/3_Imaging/" subdir
271 | "./4_Financial/", // "/4_Financial/" subdir
272 | "./5_Simulations/", // "/5_Simulations/" subdir
273 | "./6_Advanced/", // "/6_Advanced/" subdir
274 | "./7_CUDALibraries/", // "/7_CUDALibraries/" subdir
275 |
276 | "../", // up 1 in tree
277 | "../common/", // up 1 in tree, "/common/" subdir
278 | "../common/data/", // up 1 in tree, "/common/data/" subdir
279 | "../data/", // up 1 in tree, "/data/" subdir
280 | "../src/", // up 1 in tree, "/src/" subdir
281 | "../inc/", // up 1 in tree, "/inc/" subdir
282 | "../C/src//", // up 1 in tree, "/C/src//" subdir
283 | "../C/src//data/", // up 1 in tree, "/C/src//data/" subdir
284 | "../C/src//src/", // up 1 in tree, "/C/src//src/" subdir
285 | "../C/src//inc/", // up 1 in tree, "/C/src//inc/" subdir
286 | "../C/", // up 1 in tree
287 | "../C/common/", // up 1 in tree, "/common/" subdir
288 | "../C/common/data/", // up 1 in tree, "/common/data/" subdir
289 | "../C/data/", // up 1 in tree, "/data/" subdir
290 | "../C/src/", // up 1 in tree, "/src/" subdir
291 | "../C/inc/", // up 1 in tree, "/inc/" subdir
292 | "../C/0_Simple//data/", // up 1 in tree, "/0_Simple//" subdir
293 | "../C/1_Utilities//data/", // up 1 in tree, "/1_Utilities//" subdir
294 | "../C/2_Graphics//data/", // up 1 in tree, "/2_Graphics//" subdir
295 | "../C/3_Imaging//data/", // up 1 in tree, "/3_Imaging//" subdir
296 | "../C/4_Financial//data/", // up 1 in tree, "/4_Financial//" subdir
297 | "../C/5_Simulations//data/", // up 1 in tree, "/5_Simulations//" subdir
298 | "../C/6_Advanced//data/", // up 1 in tree, "/6_Advanced//" subdir
299 | "../C/7_CUDALibraries//data/", // up 1 in tree, "/7_CUDALibraries//" subdir
300 |
301 | "../0_Simple//data/", // up 1 in tree, "/0_Simple//" subdir
302 | "../1_Utilities//data/", // up 1 in tree, "/1_Utilities//" subdir
303 | "../2_Graphics//data/", // up 1 in tree, "/2_Graphics//" subdir
304 | "../3_Imaging//data/", // up 1 in tree, "/3_Imaging//" subdir
305 | "../4_Financial//data/", // up 1 in tree, "/4_Financial//" subdir
306 | "../5_Simulations//data/", // up 1 in tree, "/5_Simulations//" subdir
307 | "../6_Advanced//data/", // up 1 in tree, "/6_Advanced//" subdir
308 | "../7_CUDALibraries//data/", // up 1 in tree, "/7_CUDALibraries//" subdir
309 | "../../", // up 2 in tree
310 | "../../common/", // up 2 in tree, "/common/" subdir
311 | "../../common/data/", // up 2 in tree, "/common/data/" subdir
312 | "../../data/", // up 2 in tree, "/data/" subdir
313 | "../../src/", // up 2 in tree, "/src/" subdir
314 | "../../inc/", // up 2 in tree, "/inc/" subdir
315 | "../../sandbox//data/", // up 2 in tree, "/sandbox//" subdir
316 | "../../0_Simple//data/", // up 2 in tree, "/0_Simple//" subdir
317 | "../../1_Utilities//data/", // up 2 in tree, "/1_Utilities//" subdir
318 | "../../2_Graphics//data/", // up 2 in tree, "/2_Graphics//" subdir
319 | "../../3_Imaging//data/", // up 2 in tree, "/3_Imaging//" subdir
320 | "../../4_Financial//data/", // up 2 in tree, "/4_Financial//" subdir
321 | "../../5_Simulations//data/", // up 2 in tree, "/5_Simulations//" subdir
322 | "../../6_Advanced//data/", // up 2 in tree, "/6_Advanced//" subdir
323 | "../../7_CUDALibraries//data/", // up 2 in tree, "/7_CUDALibraries//" subdir
324 | "../../../", // up 3 in tree
325 | "../../../src//", // up 3 in tree, "/src//" subdir
326 | "../../../src//data/", // up 3 in tree, "/src//data/" subdir
327 | "../../../src//src/", // up 3 in tree, "/src//src/" subdir
328 | "../../../src//inc/", // up 3 in tree, "/src//inc/" subdir
329 | "../../../sandbox//", // up 3 in tree, "/sandbox//" subdir
330 | "../../../sandbox//data/", // up 3 in tree, "/sandbox//data/" subdir
331 | "../../../sandbox//src/", // up 3 in tree, "/sandbox//src/" subdir
332 | "../../../sandbox//inc/", // up 3 in tree, "/sandbox//inc/" subdir
333 | "../../../0_Simple//data/", // up 3 in tree, "/0_Simple//" subdir
334 | "../../../1_Utilities//data/", // up 3 in tree, "/1_Utilities//" subdir
335 | "../../../2_Graphics//data/", // up 3 in tree, "/2_Graphics//" subdir
336 | "../../../3_Imaging//data/", // up 3 in tree, "/3_Imaging//" subdir
337 | "../../../4_Financial//data/", // up 3 in tree, "/4_Financial//" subdir
338 | "../../../5_Simulations//data/",// up 3 in tree, "/5_Simulations//" subdir
339 | "../../../6_Advanced//data/", // up 3 in tree, "/6_Advanced//" subdir
340 | "../../../7_CUDALibraries//data/", // up 3 in tree, "/7_CUDALibraries//" subdir
341 | "../../../common/", // up 3 in tree, "../../../common/" subdir
342 | "../../../common/data/", // up 3 in tree, "../../../common/data/" subdir
343 | "../../../data/", // up 3 in tree, "../../../data/" subdir
344 | };
345 |
346 | // Extract the executable name
347 | std::string executable_name;
348 |
349 | if (executable_path != 0)
350 | {
351 | executable_name = std::string(executable_path);
352 |
353 | #ifdef _WIN32
354 | // Windows path delimiter
355 | size_t delimiter_pos = executable_name.find_last_of('\\');
356 | executable_name.erase(0, delimiter_pos + 1);
357 |
358 | if (executable_name.rfind(".exe") != std::string::npos)
359 | {
360 | // we strip .exe, only if the .exe is found
361 | executable_name.resize(executable_name.size() - 4);
362 | }
363 |
364 | #else
365 | // Linux & OSX path delimiter
366 | size_t delimiter_pos = executable_name.find_last_of('/');
367 | executable_name.erase(0,delimiter_pos+1);
368 | #endif
369 | }
370 |
371 | // Loop over all search paths and return the first hit
372 | for (unsigned int i = 0; i < sizeof(searchPath)/sizeof(char *); ++i)
373 | {
374 | std::string path(searchPath[i]);
375 | size_t executable_name_pos = path.find("");
376 |
377 | // If there is executable_name variable in the searchPath
378 | // replace it with the value
379 | if (executable_name_pos != std::string::npos)
380 | {
381 | if (executable_path != 0)
382 | {
383 | path.replace(executable_name_pos, strlen(""), executable_name);
384 | }
385 | else
386 | {
387 | // Skip this path entry if no executable argument is given
388 | continue;
389 | }
390 | }
391 |
392 | #ifdef _DEBUG
393 | printf("sdkFindFilePath <%s> in %s\n", filename, path.c_str());
394 | #endif
395 |
396 | // Test if the file exists
397 | path.append(filename);
398 | FILE *fp;
399 | FOPEN(fp, path.c_str(), "rb");
400 |
401 | if (fp != NULL)
402 | {
403 | fclose(fp);
404 | // File found
405 | // returning an allocated array here for backwards compatibility reasons
406 | char *file_path = (char *) malloc(path.length() + 1);
407 | STRCPY(file_path, path.length() + 1, path.c_str());
408 | return file_path;
409 | }
410 |
411 | if (fp)
412 | {
413 | fclose(fp);
414 | }
415 | }
416 |
417 | // File not found
418 | return 0;
419 | }
420 |
421 | #endif
--------------------------------------------------------------------------------
/adaptive_gridsampler/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from setuptools import setup
5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
6 |
7 | cxx_args = ['-std=c++11']
8 |
9 | nvcc_args = [
10 | '-gencode', 'arch=compute_60,code=sm_60',
11 | '-gencode', 'arch=compute_61,code=sm_61',
12 | '-gencode', 'arch=compute_70,code=sm_70',
13 | '-gencode', 'arch=compute_70,code=compute_70'
14 | ]
15 |
16 | setup(
17 | name='adaptive_gridsampler_cuda',
18 | ext_modules=[
19 | CUDAExtension('adaptive_gridsampler_cuda', ['adaptive_gridsampler_cuda.cpp', 'adaptive_gridsampler_kernel.cu'], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
20 | ],
21 | cmdclass={'build_ext': BuildExtension}
22 | )
23 |
--------------------------------------------------------------------------------
/figs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunwj/CAR/05b22776b9f690dac94ced8baeb455bb722c0997/figs/overview.png
--------------------------------------------------------------------------------
/figs/qualitative.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunwj/CAR/05b22776b9f690dac94ced8baeb455bb722c0997/figs/qualitative.png
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | LEAKY_FACTOR = 0.2
9 | MULT_FACTOR = 1
10 |
11 |
12 | # TEST PASSED
13 | class PixelUnShuffle(nn.Module):
14 | """
15 | Inverse process of pytorch pixel shuffle module
16 | """
17 | def __init__(self, down_scale):
18 | """
19 | :param down_scale: int, down scale factor
20 | """
21 | super(PixelUnShuffle, self).__init__()
22 |
23 | if not isinstance(down_scale, int):
24 | raise ValueError('Down scale factor must be a integer number')
25 | self.down_scale = down_scale
26 |
27 | def forward(self, input):
28 | """
29 | :param input: tensor of shape (batch size, channels, height, width)
30 | :return: tensor of shape(batch size, channels * down_scale * down_scale, height / down_scale, width / down_scale)
31 | """
32 | b, c, h, w = input.size()
33 | assert h % self.down_scale == 0
34 | assert w % self.down_scale == 0
35 |
36 | oc = c * self.down_scale ** 2
37 | oh = int(h / self.down_scale)
38 | ow = int(w / self.down_scale)
39 |
40 | output_reshaped = input.reshape(b, c, oh, self.down_scale, ow, self.down_scale)
41 | output = output_reshaped.permute(0, 1, 3, 5, 2, 4).reshape(b, oc, oh, ow)
42 |
43 | return output
44 |
45 |
46 | class DownsampleBlock(nn.Module):
47 | def __init__(self, scale, input_channels, output_channels, ksize=1):
48 | super(DownsampleBlock, self).__init__()
49 | self.downsample = nn.Sequential(
50 | PixelUnShuffle(scale),
51 | nn.Conv2d(input_channels * (scale ** 2), output_channels, kernel_size=ksize, stride=1, padding=ksize//2)
52 | )
53 |
54 | def forward(self, input):
55 | return self.downsample(input)
56 |
57 |
58 | class UpsampleBlock(nn.Module):
59 | def __init__(self, scale, input_channels, output_channels, ksize=1):
60 | super(UpsampleBlock, self).__init__()
61 | self.upsample = nn.Sequential(
62 | nn.Conv2d(input_channels, output_channels * (scale ** 2), kernel_size=1, stride=1, padding=ksize//2),
63 | nn.PixelShuffle(scale)
64 | )
65 |
66 | def forward(self, input):
67 | return self.upsample(input)
68 |
69 |
70 | class ResidualBlock(nn.Module):
71 | def __init__(self, input_channels, channels, ksize=3,
72 | use_instance_norm=False, affine=False):
73 | super(ResidualBlock, self).__init__()
74 | self.channels = channels
75 | self.ksize = ksize
76 | padding = self.ksize // 2
77 | if use_instance_norm:
78 | self.transform = nn.Sequential(
79 | nn.ReflectionPad2d(padding),
80 | nn.Conv2d(input_channels, channels, kernel_size=self.ksize, stride=1),
81 | nn.InstanceNorm2d(channels, affine=affine),
82 | nn.LeakyReLU(0.2),
83 | nn.ReflectionPad2d(padding),
84 | nn.Conv2d(channels, channels, kernel_size=self.ksize, stride=1),
85 | nn.InstanceNorm2d(channels)
86 | )
87 | else:
88 | self.transform = nn.Sequential(
89 | nn.ReflectionPad2d(padding),
90 | nn.Conv2d(input_channels, channels, kernel_size=self.ksize, stride=1),
91 | nn.LeakyReLU(0.2),
92 | nn.ReflectionPad2d(padding),
93 | nn.Conv2d(channels, channels, kernel_size=self.ksize, stride=1),
94 | )
95 |
96 | def forward(self, input):
97 | return input + self.transform(input) * MULT_FACTOR
98 |
99 |
100 | class NormalizeBySum(nn.Module):
101 | def forward(self, x):
102 | return x / torch.sum(x, dim=1, keepdim=True).clamp(min=1e-7)
103 |
104 |
105 | class MeanShift(nn.Conv2d):
106 | def __init__(self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
107 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
108 | std = torch.Tensor(rgb_std)
109 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
110 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
111 | for p in self.parameters():
112 | p.requires_grad = False
113 |
114 |
115 | class DSN(nn.Module):
116 | def __init__(self, k_size, input_channels=3, scale=4):
117 | super(DSN, self).__init__()
118 |
119 | self.k_size = k_size
120 |
121 | self.sub_mean = MeanShift(1)
122 |
123 | self.ds_1 = nn.Sequential(
124 | nn.ReflectionPad2d(2),
125 | nn.Conv2d(input_channels, 64, 5),
126 | nn.LeakyReLU(LEAKY_FACTOR)
127 | )
128 |
129 | self.ds_2 = DownsampleBlock(2, 64, 128, ksize=1)
130 | self.ds_4 = DownsampleBlock(2, 128, 128, ksize=1)
131 |
132 | res_4 = list()
133 | for idx in range(5):
134 | res_4 += [ResidualBlock(128, 128)]
135 | self.res_4 = nn.Sequential(*res_4)
136 |
137 | self.ds_8 = DownsampleBlock(2, 128, 256)
138 |
139 | self.kernels_trunk = nn.Sequential(
140 | nn.ReflectionPad2d(1),
141 | nn.Conv2d(256, 256, 3),
142 | nn.ReLU(),
143 | nn.ReflectionPad2d(1),
144 | nn.Conv2d(256, 256, 3),
145 | nn.ReLU(),
146 | nn.ReflectionPad2d(1),
147 | nn.Conv2d(256, 256, 3),
148 | nn.ReLU(),
149 | UpsampleBlock(8 // scale, 256, 256, ksize=1),
150 | nn.ReflectionPad2d(1),
151 | nn.Conv2d(256, 256, 3),
152 | nn.ReLU()
153 | )
154 |
155 | self.kernels_weight = nn.Sequential(
156 | nn.ReflectionPad2d(1),
157 | nn.Conv2d(256, 256, 3),
158 | nn.ReLU(),
159 | nn.ReflectionPad2d(1),
160 | nn.Conv2d(256, k_size ** 2, 3)
161 | )
162 |
163 | self.offsets_trunk = nn.Sequential(
164 | nn.ReflectionPad2d(1),
165 | nn.Conv2d(256, 256, 3),
166 | nn.ReLU(),
167 | nn.ReflectionPad2d(1),
168 | nn.Conv2d(256, 256, 3),
169 | nn.ReLU(),
170 | nn.ReflectionPad2d(1),
171 | nn.Conv2d(256, 256, 3),
172 | nn.ReLU(),
173 | UpsampleBlock(8 // scale, 256, 256, ksize=1),
174 | nn.ReflectionPad2d(1),
175 | nn.Conv2d(256, 256, 3),
176 | nn.ReLU()
177 | )
178 |
179 | self.offsets_h_generation = nn.Sequential(
180 | nn.ReflectionPad2d(1),
181 | nn.Conv2d(256, 256, 3),
182 | nn.ReLU(),
183 | nn.ReflectionPad2d(1),
184 | nn.Conv2d(256, k_size ** 2, 3),
185 | nn.Tanh()
186 | )
187 |
188 | self.offsets_v_generation = nn.Sequential(
189 | nn.ReflectionPad2d(1),
190 | nn.Conv2d(256, 256, 3),
191 | nn.ReLU(),
192 | nn.ReflectionPad2d(1),
193 | nn.Conv2d(256, k_size ** 2, 3),
194 | nn.Tanh()
195 | )
196 |
197 | def forward(self, x):
198 | x = self.sub_mean(x)
199 |
200 | x = self.ds_1(x)
201 | x = self.ds_2(x)
202 | x = self.ds_4(x)
203 | x = x + self.res_4(x)
204 | x = self.ds_8(x)
205 |
206 | kt = self.kernels_trunk(x)
207 | k_weight = torch.clamp(self.kernels_weight(kt), min=1e-6, max=1)
208 | kernels = k_weight / torch.sum(k_weight, dim=1, keepdim=True).clamp(min=1e-6)
209 |
210 | ot = self.offsets_trunk(x)
211 | offsets_h = self.offsets_h_generation(ot)
212 | offsets_v = self.offsets_v_generation(ot)
213 |
214 | return kernels, offsets_h, offsets_v
215 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import numpy as np
3 | from tqdm import tqdm
4 | from glob import glob
5 | from PIL import Image
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import utils
11 | from EDSR.edsr import EDSR
12 | from modules import DSN
13 | from adaptive_gridsampler.gridsampler import Downsampler
14 | from skimage.color import rgb2ycbcr
15 |
16 |
17 | parser = argparse.ArgumentParser(description='Content Adaptive Resampler for Image downscaling')
18 | parser.add_argument('--model_dir', type=str, default='./models', help='path to the pre-trained model')
19 | parser.add_argument('--img_dir', type=str, help='path to the HR images to be downscaled')
20 | parser.add_argument('--scale', type=int, help='downscale factor')
21 | parser.add_argument('--output_dir', type=str, help='path to store results')
22 | parser.add_argument('--benchmark', type=bool, default=True, help='report benchmark results')
23 | args = parser.parse_args()
24 |
25 |
26 | SCALE = args.scale
27 | KSIZE = 3 * SCALE + 1
28 | OFFSET_UNIT = SCALE
29 | BENCHMARK = args.benchmark
30 |
31 | kernel_generation_net = DSN(k_size=KSIZE, scale=SCALE).cuda()
32 | downsampler_net = Downsampler(SCALE, KSIZE).cuda()
33 | upscale_net = EDSR(32, 256, scale=SCALE).cuda()
34 |
35 | kernel_generation_net = nn.DataParallel(kernel_generation_net, [0])
36 | downsampler_net = nn.DataParallel(downsampler_net, [0])
37 | upscale_net = nn.DataParallel(upscale_net, [0])
38 |
39 | kernel_generation_net.load_state_dict(torch.load(os.path.join(args.model_dir, '{0}x'.format(SCALE), 'kgn.pth')))
40 | upscale_net.load_state_dict(torch.load(os.path.join(args.model_dir, '{0}x'.format(SCALE), 'usn.pth')))
41 | torch.set_grad_enabled(False)
42 |
43 |
44 | def validation(img, name, save_imgs=False, save_dir=None):
45 | kernel_generation_net.eval()
46 | downsampler_net.eval()
47 | upscale_net.eval()
48 |
49 | kernels, offsets_h, offsets_v = kernel_generation_net(img)
50 | downscaled_img = downsampler_net(img, kernels, offsets_h, offsets_v, OFFSET_UNIT)
51 | downscaled_img = torch.clamp(downscaled_img, 0, 1)
52 | downscaled_img = torch.round(downscaled_img * 255)
53 |
54 | reconstructed_img = upscale_net(downscaled_img / 255.0)
55 |
56 | img = img * 255
57 | img = img.data.cpu().numpy().transpose(0, 2, 3, 1)
58 | img = np.uint8(img)
59 |
60 | reconstructed_img = torch.clamp(reconstructed_img, 0, 1) * 255
61 | reconstructed_img = reconstructed_img.data.cpu().numpy().transpose(0, 2, 3, 1)
62 | reconstructed_img = np.uint8(reconstructed_img)
63 |
64 | downscaled_img = downscaled_img.data.cpu().numpy().transpose(0, 2, 3, 1)
65 | downscaled_img = np.uint8(downscaled_img)
66 |
67 | orig_img = img[0, ...].squeeze()
68 | downscaled_img = downscaled_img[0, ...].squeeze()
69 | recon_img = reconstructed_img[0, ...].squeeze()
70 |
71 | if save_imgs and save_dir:
72 | img = Image.fromarray(orig_img)
73 | img.save(os.path.join(save_dir, name + '_orig.png'))
74 |
75 | img = Image.fromarray(downscaled_img)
76 | img.save(os.path.join(save_dir, name + '_down.png'))
77 |
78 | img = Image.fromarray(recon_img)
79 | img.save(os.path.join(save_dir, name + '_recon.png'))
80 |
81 | psnr = utils.cal_psnr(orig_img[SCALE:-SCALE, SCALE:-SCALE, ...], recon_img[SCALE:-SCALE, SCALE:-SCALE, ...], benchmark=BENCHMARK)
82 |
83 | orig_img_y = rgb2ycbcr(orig_img)[:, :, 0]
84 | recon_img_y = rgb2ycbcr(recon_img)[:, :, 0]
85 | orig_img_y = orig_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]
86 | recon_img_y = recon_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]
87 |
88 | ssim = utils.calc_ssim(recon_img_y, orig_img_y)
89 |
90 | return psnr, ssim
91 |
92 |
93 | if __name__ == '__main__':
94 | img_list = glob(os.path.join(args.img_dir, '**', '*.png'), recursive=True)
95 | assert len(img_list) > 0
96 |
97 | if not os.path.exists(args.output_dir):
98 | os.makedirs(args.output_dir)
99 |
100 | psnr_list = list()
101 | ssim_list = list()
102 | for img_file in tqdm(img_list):
103 | name = os.path.basename(img_file)
104 | name = os.path.splitext(name)[0]
105 |
106 | img = utils.load_img(img_file)
107 |
108 | psnr, ssim = validation(img, name, save_imgs=True, save_dir=args.output_dir)
109 | psnr_list.append(psnr)
110 | ssim_list.append(ssim)
111 |
112 | print('Mean PSNR: {0:.2f}'.format(np.mean(psnr_list)))
113 | print('Mean SSIM: {0:.4f}'.format(np.mean(ssim_list)))
114 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy import signal
4 | from PIL import Image
5 |
6 |
7 | def matlab_style_gauss2D(shape=(3, 3), sigma=0.5):
8 | """
9 | 2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma])
10 | Acknowledgement : https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python (Author@ali_m)
11 | """
12 | m, n = [(ss - 1.) / 2. for ss in shape]
13 | y, x = np.ogrid[-m:m + 1, -n:n + 1]
14 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))
15 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
16 | sumh = h.sum()
17 | if sumh != 0:
18 | h /= sumh
19 | return h
20 |
21 |
22 | def calc_ssim(X, Y, sigma=1.5, K1=0.01, K2=0.03, R=255):
23 | '''
24 | X : y channel (i.e., luminance) of transformed YCbCr space of X
25 | Y : y channel (i.e., luminance) of transformed YCbCr space of Y
26 | Please follow the setting of psnr_ssim.m in EDSR (Enhanced Deep Residual Networks for Single Image Super-Resolution CVPRW2017).
27 | Official Link : https://github.com/LimBee/NTIRE2017/tree/db34606c2844e89317aac8728a2de562ef1f8aba
28 | The authors of EDSR use MATLAB's ssim as the evaluation tool,
29 | thus this function is the same as ssim.m in MATLAB with C(3) == C(2)/2.
30 | '''
31 | gaussian_filter = matlab_style_gauss2D((11, 11), sigma)
32 |
33 | X = X.astype(np.float64)
34 | Y = Y.astype(np.float64)
35 |
36 | window = gaussian_filter
37 |
38 | ux = signal.convolve2d(X, window, mode='same', boundary='symm')
39 | uy = signal.convolve2d(Y, window, mode='same', boundary='symm')
40 |
41 | uxx = signal.convolve2d(X * X, window, mode='same', boundary='symm')
42 | uyy = signal.convolve2d(Y * Y, window, mode='same', boundary='symm')
43 | uxy = signal.convolve2d(X * Y, window, mode='same', boundary='symm')
44 |
45 | vx = uxx - ux * ux
46 | vy = uyy - uy * uy
47 | vxy = uxy - ux * uy
48 |
49 | C1 = (K1 * R) ** 2
50 | C2 = (K2 * R) ** 2
51 |
52 | A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2))
53 | D = B1 * B2
54 | S = (A1 * A2) / D
55 | mssim = S.mean()
56 |
57 | return mssim
58 |
59 |
60 | def cal_psnr(img_1, img_2, benchmark=False):
61 | assert img_1.shape[0] == img_2.shape[0] and img_1.shape[1] == img_2.shape[1]
62 | img_1 = np.float64(img_1)
63 | img_2 = np.float64(img_2)
64 |
65 | diff = (img_1 - img_2) / 255.0
66 | if benchmark:
67 | gray_coeff = np.array([65.738, 129.057, 25.064]).reshape(1, 1, 3) / 255.0
68 | diff = diff * gray_coeff
69 | diff = diff[:, :, 0] + diff[:, :, 1] + diff[:, :, 2]
70 |
71 | mse = np.mean(diff ** 2)
72 | psnr = -10.0 * np.log10(mse)
73 |
74 | return psnr
75 |
76 |
77 | def load_img(img_file):
78 | img = Image.open(img_file).convert('RGB')
79 | img = np.array(img)
80 | h, w, _ = img.shape
81 | img = img[:h // 8 * 8, :w // 8 * 8, :]
82 | img = np.array(img) / 255.
83 | img = img.transpose((2, 0, 1))
84 | img = torch.from_numpy(img).float().unsqueeze(0).cuda()
85 |
86 | return img
87 |
--------------------------------------------------------------------------------