├── .gitignore ├── LICENSE ├── README.md ├── create_dataset ├── args.go ├── brightness.go ├── images.go ├── main.go ├── material.go ├── modelnet.go ├── objects.go ├── repair.go └── scene.go ├── example ├── 50_rpp.png ├── 512_rpp.png ├── albedo.png ├── denoised_deep.png ├── denoised_deep_aux.png ├── denoised_shallow_aux.png ├── half_and_half.png └── incidence.png ├── go.mod ├── go.sum ├── main.go ├── polish ├── features.go ├── model_data.go ├── model_data_deep.go ├── model_data_deep_aux.go ├── model_data_shallow.go ├── model_data_shallow_aux.go ├── model_test.go ├── models.go ├── nn │ ├── affine.go │ ├── bilateral.go │ ├── conv.go │ ├── conv_test.go │ ├── deconv.go │ ├── deconv_test.go │ ├── group_norm.go │ ├── group_norm_test.go │ ├── nn.go │ ├── simple_ops.go │ ├── tensor.go │ └── tensor_test.go ├── polish.go └── polish_test.go └── training ├── compute_rf.py ├── dump_params.ipynb ├── polish ├── __init__.py ├── baseline.py ├── dataset.py └── models.py ├── run_image.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | *.pt 3 | __pycache__/ 4 | samples.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Alexander Nichol. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # polish 2 | 3 | This is a simple deep learning system for denoising ray traced images. 4 | 5 | ![Side-by-side of noisy and denoised images](example/half_and_half.png) 6 | 7 | The above image was rendered with 50 rays-per-pixel, and then denoised in RGB space with a deep neural network. For more, see [the example below](#example). 8 | 9 | This repository includes: 10 | 11 | * A command-line utility for denoising images 12 | * A Go inference library with pre-trained models 13 | * A program to create a denoising dataset from scratch 14 | * A training pipeline in [PyTorch](https://pytorch.org/) 15 | 16 | This package supports plain RGB images, as well as images with auxiliary feature channels (e.g. albedo maps). 17 | 18 | # Usage 19 | 20 | **Note:** this code expects a version of Go that supports modules. Ideally, version 1.14 or later. See the [Go downloads page](https://golang.org/dl/). 21 | 22 | ## Command-line interface 23 | 24 | To build the command-line tool, simply clone this repository (outside of your `GOPATH`) and run: 25 | 26 | ``` 27 | $ go build -o polish_cli 28 | ``` 29 | 30 | Now you can run the `polish_cli` binary to denoise an image: 31 | 32 | ``` 33 | ./polish_cli input.png output.png 34 | ``` 35 | 36 | ## Go API 37 | 38 | There is also a Go API for `polish`, implemented in the [polish](polish) sub-directory. The main API is `PolishImage`: 39 | 40 | ```go 41 | func PolishImage(t ModelType, img image.Image) image.Image 42 | ``` 43 | 44 | For example, you could use the built-in deep CNN model as follows: 45 | 46 | ```go 47 | output := polish.PolishImage(polish.ModelTypeDeep, input) 48 | ``` 49 | 50 | # Training your own models 51 | 52 | The built-in pre-trained models should be sufficient for most use cases. However, if you do need to train your own model, this repository includes everything needed to create a dataset and train a model on it. 53 | 54 | ## Getting data 55 | 56 | You will likely want to get started by downloading the ~2GB [data_1187.tar](https://polish.aqnichol.com/data_1187.tar) dataset, which includes 1187 rendered scenes. 57 | 58 | The dataset was created with the [create_dataset](create_dataset) program, which creates random scenes and renders them at various rays-per-pixel. It expects to use models from [ModelNet40](https://modelnet.cs.princeton.edu/), and textures from ImageNet (or any directory of images, really). It generates scenes by selecting a layout type (either a boxed room or a large dome), randomizing lighting, loading and positioning various 3D models, and selecting random textures and materials for all models and walls. 59 | 60 | ## Training with PyTorch 61 | 62 | The [training](training) directory contains a Python program to train a denoising neural network. It processes data produced by `create_dataset`, and automatically performs data augmentation and other tricks using that data. It includes a Jupyter notebook for converting the finished PyTorch models into Go source files that can be integrated into the Go package. 63 | 64 | # Example 65 | 66 | Here is a noisy rendering, produced from the [model3d](https://github.com/unixpickle/model3d) showcase with 50 rays-per-pixel: 67 | 68 | ![50 rays-per-pixel rendering](example/50_rpp.png) 69 | 70 | This picture is pretty noisy, We can make it less noisy by using more rays. Here's a rendering with 10 times as many rays, which makes rendering take 10x as long: 71 | 72 | ![512 rays-per-pixel rendering](example/512_rpp.png) 73 | 74 | Obviously, it'd be nice if we didn't need so much more compute to produce a clean image. Enter `polish`. We can simply denoise the noisy rendering like so: 75 | 76 | ``` 77 | $ polish example/50_rpp.png example/denoised_deep.png 78 | ``` 79 | 80 | ![Denoised 50 rpp](example/denoised_deep.png) 81 | 82 | This denoising took place using only RGB values from the original image. We could also use albedo maps and incidence angles, which are auxiliary channels looking like this: 83 | 84 | ![Albedo](example/albedo.png) 85 | 86 | ![Incidence angles](example/incidence.png) 87 | 88 | The `polish` API can generate these images for a scene, and can denoise using these features. Here's how you can use the command-line tool to run a deep model with auxiliary input channels: 89 | 90 | ``` 91 | polish -model deep-aux -incidence example/incidence.png -albedo example/albedo.png example/50_rpp.png example/denoised_deep_aux.png 92 | ``` 93 | 94 | ![Deep denoised with aux](example/denoised_deep_aux.png) 95 | -------------------------------------------------------------------------------- /create_dataset/args.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/unixpickle/essentials" 9 | ) 10 | 11 | type Args struct { 12 | ModelNetPath string 13 | ImagesPath string 14 | 15 | OutputDir string 16 | } 17 | 18 | func (a *Args) Parse() { 19 | flag.StringVar(&a.ModelNetPath, "modelnet", "", "path to ModelNet-40 dataset") 20 | flag.StringVar(&a.ImagesPath, "images", "", "path to (recursive) texture library") 21 | flag.StringVar(&a.OutputDir, "outdir", "../data", "dataset output directory") 22 | flag.Parse() 23 | 24 | var missingArgs []string 25 | if a.ModelNetPath == "" { 26 | missingArgs = append(missingArgs, "-modelnet") 27 | } 28 | if a.ImagesPath == "" { 29 | missingArgs = append(missingArgs, "-images") 30 | } 31 | if len(missingArgs) > 0 { 32 | essentials.Die(fmt.Sprintf("missing required arguments: %s", 33 | strings.Join(missingArgs, ", "))) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /create_dataset/brightness.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "sort" 7 | 8 | "github.com/unixpickle/model3d/render3d" 9 | ) 10 | 11 | func BrightnessScale(img *render3d.Image) float64 { 12 | target := math.Min(0.9, math.Max(0.1, rand.NormFloat64()*0.1+0.3)) 13 | median := math.Max(1e-5, quantileBrightness(img)) 14 | return math.Max(1.0, target/median) 15 | } 16 | 17 | func quantileBrightness(img *render3d.Image) float64 { 18 | bs := make([]float64, len(img.Data)) 19 | for i, c := range img.Data { 20 | bs[i] = c.Sum() / 3.0 21 | } 22 | sort.Float64s(bs) 23 | return bs[int(float64(len(bs))*0.8)] 24 | } 25 | -------------------------------------------------------------------------------- /create_dataset/images.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io/ioutil" 5 | "path/filepath" 6 | "strings" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // ScanImages finds all of the image paths in a directory, 12 | // recursively. 13 | func ScanImages(imageDir string) ([]string, error) { 14 | listing, err := ioutil.ReadDir(imageDir) 15 | if err != nil { 16 | return nil, errors.Wrap(err, "scan images") 17 | } 18 | 19 | var results []string 20 | 21 | for _, d := range listing { 22 | dPath := filepath.Join(imageDir, d.Name()) 23 | if d.IsDir() { 24 | subResults, err := ScanImages(dPath) 25 | if err != nil { 26 | return nil, err 27 | } 28 | results = append(results, subResults...) 29 | } else { 30 | ext := strings.ToLower(filepath.Ext(d.Name())) 31 | if ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".gif" { 32 | results = append(results, dPath) 33 | } 34 | } 35 | } 36 | 37 | return results, nil 38 | } 39 | -------------------------------------------------------------------------------- /create_dataset/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math" 7 | "math/rand" 8 | "os" 9 | "path/filepath" 10 | "time" 11 | 12 | "github.com/unixpickle/essentials" 13 | "github.com/unixpickle/model3d/render3d" 14 | "github.com/unixpickle/polish/polish" 15 | ) 16 | 17 | const ( 18 | ImageSize = 256 19 | AlbedoSamples = 400 20 | ) 21 | 22 | func main() { 23 | rand.Seed(time.Now().UnixNano()) 24 | 25 | var args Args 26 | args.Parse() 27 | 28 | models, err := ScanModelNet(args.ModelNetPath) 29 | essentials.Must(err) 30 | 31 | images, err := ScanImages(args.ImagesPath) 32 | essentials.Must(err) 33 | 34 | CreateOutput(args.OutputDir) 35 | 36 | for i := 0; true; i++ { 37 | obj, rend, bidir := RandomScene(models, images) 38 | SaveScene(args.OutputDir, obj, rend, bidir) 39 | } 40 | } 41 | 42 | func CreateOutput(outDir string) { 43 | if _, err := os.Stat(outDir); os.IsNotExist(err) { 44 | essentials.Must(os.Mkdir(outDir, 0755)) 45 | } else { 46 | essentials.Must(err) 47 | } 48 | } 49 | 50 | func CreateSceneDir(outDir string) string { 51 | for i := 0; i < 10; i++ { 52 | outName := fmt.Sprintf("%06x", rand.Intn(0x1000000)) 53 | newPath := filepath.Join(outDir, outName) 54 | if _, err := os.Stat(newPath); os.IsNotExist(err) { 55 | essentials.Must(os.Mkdir(newPath, 0755)) 56 | return newPath 57 | } 58 | } 59 | essentials.Die("could not allocate a new output file") 60 | return "" 61 | } 62 | 63 | func SaveScene(outDir string, obj render3d.Object, rend *render3d.RecursiveRayTracer, 64 | bidir *render3d.BidirPathTracer) { 65 | rend.Antialias = 1.0 66 | rend.MaxDepth = 10 67 | rend.Cutoff = 1e-4 68 | bidir.Antialias = 1.0 69 | bidir.MinDepth = 3 70 | bidir.MaxDepth = 15 71 | bidir.Cutoff = 1e-5 72 | bidir.RouletteDelta = 0.05 73 | bidir.PowerHeuristic = 2 74 | 75 | variance := rend.RayVariance(obj, 200, 200, 10) 76 | bidirVariance := bidir.RayVariance(obj, 200, 200, 10) 77 | log.Printf("Creating scene (var=%f bidir_var=%f) ...", variance, bidirVariance) 78 | 79 | incidence := polish.CreateIncidenceMap(rend.Camera, obj, ImageSize, ImageSize) 80 | albedo := polish.CreateAlbedoMap(rend.Camera, obj, ImageSize, ImageSize, AlbedoSamples) 81 | 82 | log.Println("Creating low-res renderings ...") 83 | 84 | renderAtRes := func(samples int) *render3d.Image { 85 | rend.NumSamples = samples 86 | img1 := render3d.NewImage(ImageSize, ImageSize) 87 | img2 := render3d.NewImage(ImageSize, ImageSize) 88 | rend.Render(img1, obj) 89 | rend.Render(img2, obj) 90 | img := render3d.NewImage(ImageSize*2, ImageSize) 91 | img.CopyFrom(img1, 0, 0) 92 | img.CopyFrom(img2, ImageSize, 0) 93 | return img 94 | } 95 | 96 | images := map[string]*render3d.Image{} 97 | for _, samples := range []int{1, 16, 64, 128, 512} { 98 | images[fmt.Sprintf("input_%d.png", samples)] = renderAtRes(samples) 99 | } 100 | 101 | scale := BrightnessScale(images["input_512.png"]) 102 | log.Printf("Creating HD rendering (scale=%f) ...", scale) 103 | 104 | bidir.NumSamples = 16384 105 | bidir.MinSamples = 1024 106 | bidir.Convergence = func(mean, stddev render3d.Color) bool { 107 | meanArr := mean.Array() 108 | for i, std := range stddev.Array() { 109 | m := meanArr[i] * scale 110 | std = std * scale 111 | if m-3*std > 1 { 112 | // Oversaturated cutoff. 113 | continue 114 | } 115 | // Gamma-aware error margin. 116 | delta := math.Pow(m+std, 1/2.2) - math.Pow(m, 1/2.2) 117 | if delta > 0.01 { 118 | return false 119 | } 120 | } 121 | return true 122 | } 123 | bidir.Cutoff = 1e-5 / scale 124 | bidir.RouletteDelta = 0.05 / scale 125 | 126 | var lastFrac float64 127 | bidir.LogFunc = func(frac, samples float64) { 128 | if frac-lastFrac > 0.1 { 129 | lastFrac = frac 130 | log.Printf(" * progress %.1f (samples %d)", frac, int(samples)) 131 | } 132 | } 133 | 134 | target := render3d.NewImage(ImageSize, ImageSize) 135 | bidir.Render(target, obj) 136 | images["target.png"] = target 137 | 138 | // Save all the outputs once we have created them 139 | // to avoid creating empty folders in the dataset 140 | // for a long period of time. 141 | sampleDir := CreateSceneDir(outDir) 142 | for name, img := range images { 143 | img.Scale(scale) 144 | img.Save(filepath.Join(sampleDir, name)) 145 | } 146 | essentials.Must(polish.SaveFeatureMap(filepath.Join(sampleDir, "incidence.png"), incidence)) 147 | essentials.Must(polish.SaveFeatureMap(filepath.Join(sampleDir, "albedo.png"), albedo)) 148 | } 149 | -------------------------------------------------------------------------------- /create_dataset/material.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "image" 5 | "math" 6 | "math/rand" 7 | "os" 8 | 9 | "github.com/unixpickle/essentials" 10 | "github.com/unixpickle/model3d/model3d" 11 | "github.com/unixpickle/model3d/render3d" 12 | ) 13 | 14 | // ModelMaterial defines the material of an object. 15 | type ModelMaterial interface { 16 | FixNormal(ray *model3d.Ray, normal model3d.Coord3D) model3d.Coord3D 17 | Material(coord model3d.Coord3D) render3d.Material 18 | } 19 | 20 | // RandomizeMaterial generates a random material for the 21 | // mesh and returns a new mesh to use in the old mesh's 22 | // place (which may be necessary for refracted objects). 23 | // 24 | // The material may be based on a random image from a list 25 | // of images, or it may be some other kind of material 26 | // chosen from a distribution. 27 | func RandomizeMaterial(m *model3d.Mesh, images []string) (*model3d.Mesh, ModelMaterial) { 28 | n := rand.Intn(10) 29 | if n == 0 { 30 | m = RepairMesh(m) 31 | } else { 32 | m = RepairOrKeep(m) 33 | } 34 | switch n { 35 | case 0: 36 | return m, createTransparent() 37 | case 1: 38 | return m, createMirror() 39 | case 2, 3, 4, 5: 40 | return m, createColored() 41 | default: 42 | return m, createTextured(m, images) 43 | } 44 | } 45 | 46 | // RandomizeWallMaterial is like RandomizeMaterial, but 47 | // with a restricted class of materials for boundaries of 48 | // the scene. 49 | func RandomizeWallMaterial(c model3d.Collider, images []string) ModelMaterial { 50 | switch rand.Intn(10) { 51 | case 0: 52 | return createMirror() 53 | case 1, 2, 3, 4, 5: 54 | return createColored() 55 | default: 56 | return createTextured(c, images) 57 | } 58 | } 59 | 60 | func createTransparent() ModelMaterial { 61 | reflectFraction := math.Pow(rand.Float64(), 5) 62 | 63 | var refractColor render3d.Color 64 | if rand.Intn(2) == 0 { 65 | refractColor = render3d.NewColor(1 - reflectFraction) 66 | } else { 67 | refractColor = render3d.NewColorRGB( 68 | rand.Float64(), rand.Float64(), rand.Float64(), 69 | ).Scale(1 - reflectFraction) 70 | } 71 | 72 | refractIndex := rand.Float64() + 1 73 | 74 | return StaticModelMaterial{ 75 | Mat: &render3d.RefractMaterial{ 76 | IndexOfRefraction: refractIndex, 77 | RefractColor: refractColor, 78 | SpecularColor: render3d.NewColor(1), 79 | }, 80 | } 81 | } 82 | 83 | func createMirror() ModelMaterial { 84 | return StaticModelMaterial{ 85 | ShouldFixNormal: true, 86 | Mat: &render3d.PhongMaterial{ 87 | Alpha: 200.0, 88 | SpecularColor: render3d.NewColor(0.95 + rand.Float64()*0.05), 89 | }, 90 | } 91 | } 92 | 93 | func createColored() ModelMaterial { 94 | color := render3d.NewColorRGB(rand.Float64(), rand.Float64(), rand.Float64()) 95 | diffuse := rand.Float64() 96 | 97 | var mat render3d.Material 98 | if rand.Intn(2) == 0 { 99 | mat = &render3d.LambertMaterial{DiffuseColor: color.Scale(diffuse)} 100 | } else { 101 | specular := rand.Float64() * (1 - diffuse) 102 | alpha := math.Exp(rand.Float64()*5 + 1) 103 | mat = &render3d.PhongMaterial{ 104 | Alpha: alpha, 105 | DiffuseColor: color.Scale(diffuse), 106 | SpecularColor: render3d.NewColor(specular), 107 | } 108 | } 109 | 110 | return StaticModelMaterial{ 111 | ShouldFixNormal: true, 112 | Mat: mat, 113 | } 114 | } 115 | 116 | func createTextured(obj model3d.Bounder, images []string) ModelMaterial { 117 | path := images[rand.Intn(len(images))] 118 | r, err := os.Open(path) 119 | essentials.Must(err) 120 | defer r.Close() 121 | img, _, err := image.Decode(r) 122 | essentials.Must(err) 123 | return NewTexturedModelMaterial(obj, img) 124 | } 125 | 126 | // StaticModelMaterial is a ModelMaterial with a constant 127 | // value. 128 | type StaticModelMaterial struct { 129 | ShouldFixNormal bool 130 | 131 | Mat render3d.Material 132 | } 133 | 134 | func (s StaticModelMaterial) FixNormal(r *model3d.Ray, normal model3d.Coord3D) model3d.Coord3D { 135 | if s.ShouldFixNormal && r.Direction.Dot(normal) > 0 { 136 | return normal.Scale(-1) 137 | } 138 | return normal 139 | } 140 | 141 | func (s StaticModelMaterial) Material(coord model3d.Coord3D) render3d.Material { 142 | return s.Mat 143 | } 144 | 145 | // A TexturedModelMaterial is a ModelMaterial that applies 146 | // the orthographic projection of an image to the model. 147 | type TexturedModelMaterial struct { 148 | Alpha float64 149 | Specular float64 150 | Diffuse float64 151 | Texture image.Image 152 | XBasis model3d.Coord3D 153 | YBasis model3d.Coord3D 154 | } 155 | 156 | // NewTexturedModelMaterial creates an object with a 157 | // texture randomly slapped on along some axis. 158 | func NewTexturedModelMaterial(obj model3d.Bounder, texture image.Image) *TexturedModelMaterial { 159 | size := obj.Max().Sub(obj.Min()) 160 | maxDim := math.Max(math.Max(size.X, size.Y), size.Z) 161 | 162 | xBasis := model3d.NewCoord3DRandUnit() 163 | yBasis := model3d.NewCoord3DRandUnit().ProjectOut(xBasis).Normalize() 164 | 165 | bounds := texture.Bounds() 166 | scale := math.Exp(rand.Float64()*5) * 0.5 167 | xBasis = xBasis.Scale(scale * float64(bounds.Dx()) / maxDim) 168 | yBasis = yBasis.Scale(scale * float64(bounds.Dy()) / maxDim) 169 | 170 | diffuse := rand.Float64() 171 | specular := rand.Float64() * (1 - diffuse) 172 | 173 | return &TexturedModelMaterial{ 174 | Alpha: math.Exp(rand.Float64()*5 - 1), 175 | Specular: specular, 176 | Diffuse: diffuse, 177 | Texture: texture, 178 | XBasis: xBasis, 179 | YBasis: yBasis, 180 | } 181 | } 182 | 183 | func (t *TexturedModelMaterial) FixNormal(r *model3d.Ray, normal model3d.Coord3D) model3d.Coord3D { 184 | if r.Direction.Dot(normal) > 0 { 185 | return normal.Scale(-1) 186 | } 187 | return normal 188 | } 189 | 190 | func (t *TexturedModelMaterial) Material(p model3d.Coord3D) render3d.Material { 191 | x := int(t.XBasis.Dot(p)) 192 | y := int(t.YBasis.Dot(p)) 193 | 194 | // Add a large offset to prevent the modulus from not 195 | // working. 196 | x += 1000000 197 | y += 1000000 198 | 199 | bounds := t.Texture.Bounds() 200 | if (x/bounds.Dx())%2 == 0 { 201 | x = bounds.Dx() - (x % bounds.Dx()) - 1 202 | } else { 203 | x = x % bounds.Dx() 204 | } 205 | if (y/bounds.Dy())%2 == 0 { 206 | y = bounds.Dy() - (y % bounds.Dy()) - 1 207 | } else { 208 | y = y % bounds.Dy() 209 | } 210 | 211 | r, g, b, _ := t.Texture.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA() 212 | color := render3d.NewColorRGB(float64(r)/0xffff, float64(g)/0xffff, 213 | float64(b)/0xffff) 214 | 215 | return &render3d.PhongMaterial{ 216 | Alpha: t.Alpha, 217 | SpecularColor: render3d.NewColor(t.Specular), 218 | DiffuseColor: color.Scale(t.Diffuse), 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /create_dataset/modelnet.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io/ioutil" 5 | "path/filepath" 6 | 7 | "github.com/pkg/errors" 8 | ) 9 | 10 | // ScanModelNet finds all of the .off model file paths. 11 | func ScanModelNet(dir string) (paths []string, err error) { 12 | paths, err = scanModelNet(dir) 13 | if err != nil { 14 | err = errors.Wrap(err, "scan ModelNet") 15 | } 16 | return 17 | } 18 | 19 | func scanModelNet(dir string) (paths []string, err error) { 20 | dirs, err := ioutil.ReadDir(dir) 21 | if err != nil { 22 | return nil, err 23 | } 24 | for _, d := range dirs { 25 | if !d.IsDir() { 26 | if filepath.Ext(d.Name()) == ".off" { 27 | paths = append(paths, filepath.Join(dir, d.Name())) 28 | } 29 | continue 30 | } 31 | dPath := filepath.Join(dir, d.Name()) 32 | subPaths, err := scanModelNet(dPath) 33 | if err != nil { 34 | return nil, err 35 | } 36 | paths = append(paths, subPaths...) 37 | } 38 | return 39 | } 40 | -------------------------------------------------------------------------------- /create_dataset/objects.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/unixpickle/model3d/model3d" 5 | "github.com/unixpickle/model3d/render3d" 6 | ) 7 | 8 | type ColliderObject struct { 9 | render3d.Object 10 | Mat ModelMaterial 11 | } 12 | 13 | func NewColliderObject(c model3d.Collider, mat ModelMaterial) *ColliderObject { 14 | return &ColliderObject{ 15 | Object: &render3d.ColliderObject{Collider: c}, 16 | Mat: mat, 17 | } 18 | } 19 | 20 | func (c *ColliderObject) Cast(r *model3d.Ray) (model3d.RayCollision, render3d.Material, bool) { 21 | rc, _, ok := c.Object.Cast(r) 22 | if !ok { 23 | return rc, nil, ok 24 | } 25 | mat := c.Mat.Material(r.Origin.Add(r.Direction.Scale(rc.Scale))) 26 | rc.Normal = c.Mat.FixNormal(r, rc.Normal) 27 | return rc, mat, ok 28 | } 29 | 30 | type MeshesObject struct { 31 | render3d.Object 32 | Mats map[*model3d.Triangle]ModelMaterial 33 | } 34 | 35 | func NewMeshesObject(meshes []*model3d.Mesh, mats []ModelMaterial) *MeshesObject { 36 | res := map[*model3d.Triangle]ModelMaterial{} 37 | fullMesh := model3d.NewMesh() 38 | for i, m := range meshes { 39 | m.Iterate(func(t *model3d.Triangle) { 40 | res[t] = mats[i] 41 | }) 42 | fullMesh.AddMesh(m) 43 | } 44 | return &MeshesObject{ 45 | Object: &render3d.ColliderObject{ 46 | Collider: model3d.MeshToCollider(fullMesh), 47 | }, 48 | Mats: res, 49 | } 50 | } 51 | 52 | func (m *MeshesObject) Cast(r *model3d.Ray) (model3d.RayCollision, render3d.Material, bool) { 53 | rc, _, ok := m.Object.Cast(r) 54 | if !ok { 55 | return rc, nil, ok 56 | } 57 | tri := rc.Extra.(*model3d.TriangleCollision).Triangle 58 | mmat := m.Mats[tri] 59 | mat := mmat.Material(r.Origin.Add(r.Direction.Scale(rc.Scale))) 60 | rc.Normal = mmat.FixNormal(r, rc.Normal) 61 | return rc, mat, ok 62 | } 63 | -------------------------------------------------------------------------------- /create_dataset/repair.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "math/rand" 5 | 6 | "github.com/unixpickle/model3d/model3d" 7 | ) 8 | 9 | // RepairMesh creates a manifold, oriented mesh from a 10 | // noisy mesh with incorrect normals, duplicate triangles, 11 | // singularities, self-intersections, etc. 12 | // 13 | // If possible, this simply operates on the mesh. If a 14 | // mesh-level repair does not work, a new mesh is created 15 | // from scratch by thickening the surface of the original 16 | // mesh. 17 | func RepairMesh(m *model3d.Mesh) *model3d.Mesh { 18 | if m1 := RepairDirectly(m); m1 != nil { 19 | return m1 20 | } 21 | return createThicknessMesh(m) 22 | } 23 | 24 | // RepairOrKeep repairs the mesh on a triangle level as 25 | // much as possible, and returns the new mesh. 26 | func RepairOrKeep(m *model3d.Mesh) *model3d.Mesh { 27 | if m1 := RepairDirectly(m); m1 != nil { 28 | return m1 29 | } 30 | // At least eliminate some extra triangle overhead in 31 | // some 3D models. 32 | eliminateDuplicates(m) 33 | return m 34 | } 35 | 36 | // RepairDirectly attempts to repair the mesh by modifying 37 | // its triangles. 38 | // Returns nil if the mesh cannot be directly repaired. 39 | func RepairDirectly(m *model3d.Mesh) *model3d.Mesh { 40 | span := m.Max().Sub(m.Min()).Norm() 41 | 42 | // Fix small holes and duplicate triangles. 43 | if m.NeedsRepair() { 44 | m = m.Repair(span * 1e-5) 45 | eliminateDuplicates(m) 46 | if m.NeedsRepair() { 47 | return nil 48 | } 49 | } 50 | 51 | if !checkRayConsistency(m) { 52 | return nil 53 | } 54 | 55 | if len(m.SingularVertices()) > 0 || m.SelfIntersections() != 0 { 56 | return nil 57 | } 58 | 59 | m, _ = m.RepairNormals(span * 1e-5) 60 | 61 | // Try to make the mesh smaller to speed things up. 62 | m = m.EliminateCoplanar(1e-5) 63 | 64 | return m 65 | } 66 | 67 | func eliminateDuplicates(m *model3d.Mesh) { 68 | m.Iterate(func(t *model3d.Triangle) { 69 | if len(m.Find(t[0], t[1], t[2])) > 1 { 70 | m.Remove(t) 71 | } 72 | }) 73 | } 74 | 75 | // checkRayConsistency makes sure the even-odd test is 76 | // reliable for the mesh. 77 | func checkRayConsistency(m *model3d.Mesh) bool { 78 | collider := model3d.MeshToCollider(m) 79 | 80 | min, max := m.Min(), m.Max() 81 | 82 | evenOddAt := func(o model3d.Coord3D) bool { 83 | ray := &model3d.Ray{ 84 | Origin: o, 85 | Direction: model3d.NewCoord3DRandUnit(), 86 | } 87 | return collider.RayCollisions(ray, nil)%2 == 1 88 | } 89 | 90 | for i := 0; i < 1000; i++ { 91 | o := min.Add(model3d.Coord3D{ 92 | X: rand.Float64(), 93 | Y: rand.Float64(), 94 | Z: rand.Float64(), 95 | }.Mul(max.Sub(min))) 96 | c1 := evenOddAt(o) 97 | c2 := evenOddAt(o) 98 | if c1 != c2 { 99 | return false 100 | } 101 | } 102 | 103 | // Check that the truly faces split the space. 104 | // This will catch infinitely thin meshes. 105 | for _, t := range m.TriangleSlice() { 106 | center := t[0].Add(t[1]).Add(t[2]).Scale(1.0 / 3) 107 | delta := t.Normal().Scale(1e-8) 108 | c1 := evenOddAt(center.Add(delta)) 109 | c2 := evenOddAt(center.Sub(delta)) 110 | if c1 == c2 { 111 | return false 112 | } 113 | } 114 | 115 | return true 116 | } 117 | 118 | // createThicknessMesh derives a new mesh from m based on 119 | // the surface of m. 120 | func createThicknessMesh(m *model3d.Mesh) *model3d.Mesh { 121 | delta := m.Max().Sub(m.Min()).Norm() / 100.0 122 | collider := model3d.MeshToCollider(m) 123 | solid := model3d.NewColliderSolidHollow(collider, delta*4) 124 | m = model3d.MarchingCubesSearch(solid, delta, 8) 125 | m = m.EliminateCoplanar(1e-5) 126 | return m 127 | } 128 | -------------------------------------------------------------------------------- /create_dataset/scene.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "os" 7 | 8 | "github.com/unixpickle/essentials" 9 | "github.com/unixpickle/model3d/model3d" 10 | "github.com/unixpickle/model3d/render3d" 11 | ) 12 | 13 | // RandomScene creates a random collection of objects and 14 | // fills out a renderer to render them. 15 | func RandomScene(models, images []string) (render3d.Object, *render3d.RecursiveRayTracer, 16 | *render3d.BidirPathTracer) { 17 | layout := RandomSceneLayout() 18 | numObjects := rand.Intn(10) + 1 19 | numLights := rand.Intn(10) + 1 20 | 21 | var objects render3d.JoinedObject 22 | var lights []render3d.AreaLight 23 | var focusPoints []render3d.FocusPoint 24 | var focusProbs []float64 25 | 26 | for _, wall := range layout.CreateBackdrop() { 27 | mat := RandomizeWallMaterial(wall, images) 28 | objects = append(objects, NewColliderObject(wall, mat)) 29 | } 30 | 31 | var modelMeshes []*model3d.Mesh 32 | var modelMats []ModelMaterial 33 | for i := 0; i < numObjects; i++ { 34 | path := models[rand.Intn(len(models))] 35 | r, err := os.Open(path) 36 | essentials.Must(err) 37 | defer r.Close() 38 | tris, err := model3d.ReadOFF(r) 39 | essentials.Must(err) 40 | mesh := model3d.NewMeshTriangles(tris) 41 | mesh = randomRotation(mesh) 42 | mesh = layout.PlaceMesh(mesh) 43 | mesh, mat := RandomizeMaterial(mesh, images) 44 | modelMeshes = append(modelMeshes, mesh) 45 | modelMats = append(modelMats, mat) 46 | } 47 | objects = append(objects, NewMeshesObject(modelMeshes, modelMats)) 48 | 49 | for i := 0; i < numLights; i++ { 50 | light, focusPoint := layout.CreateLight() 51 | objects = append(objects, light) 52 | lights = append(lights, light) 53 | focusPoints = append(focusPoints, focusPoint) 54 | focusProbs = append(focusProbs, 0.3/float64(numLights)) 55 | } 56 | 57 | origin, target := layout.CameraInfo() 58 | fov := (rand.Float64()*0.5 + 0.5) * math.Pi / 3.0 59 | camera := render3d.NewCameraAt(origin, target, fov) 60 | return objects, &render3d.RecursiveRayTracer{ 61 | Camera: camera, 62 | FocusPoints: focusPoints, 63 | FocusPointProbs: focusProbs, 64 | }, &render3d.BidirPathTracer{ 65 | Camera: camera, 66 | Light: render3d.JoinAreaLights(lights...), 67 | } 68 | } 69 | 70 | func randomRotation(m *model3d.Mesh) *model3d.Mesh { 71 | var rotation *model3d.Matrix3 72 | if rand.Intn(3) == 0 { 73 | // Completely random rotation. 74 | rotation = model3d.NewMatrix3Rotation(model3d.NewCoord3DRandUnit(), 75 | rand.Float64()*math.Pi*2) 76 | } else { 77 | // Axis swap rotation 78 | a1 := rand.Intn(3) 79 | a2 := rand.Intn(2) 80 | if a2 >= a1 { 81 | a2++ 82 | } 83 | rotation = &model3d.Matrix3{} 84 | for i := 0; i < 3; i++ { 85 | if i == a1 { 86 | rotation[i*3+a2] = 1 87 | } else if i == a2 { 88 | rotation[i*3+a1] = 1 89 | } else { 90 | rotation[i*3+i] = 1 91 | } 92 | } 93 | } 94 | return m.MapCoords(rotation.MulColumn) 95 | } 96 | 97 | // RandomSceneLayout samples a SceneLayout from some 98 | // distribution. 99 | func RandomSceneLayout() SceneLayout { 100 | if rand.Intn(2) == 0 { 101 | return RoomLayout{ 102 | Width: rand.Float64()*2.0 + 0.5, 103 | Depth: rand.Float64()*3.0 + 2.0, 104 | } 105 | } else { 106 | return WorldLayout{} 107 | } 108 | } 109 | 110 | type SceneLayout interface { 111 | // CameraInfo determines where the scene would like to 112 | // setup the camera for rendering. 113 | CameraInfo() (position, target model3d.Coord3D) 114 | 115 | // CreateLight creates a randomized light object that 116 | // makes sense in this kind of scene. 117 | CreateLight() (render3d.AreaLight, render3d.FocusPoint) 118 | 119 | // CreateBackdrop creates models which act as walls of 120 | // the scene. 121 | CreateBackdrop() []model3d.Collider 122 | 123 | // PlaceMesh translates and scales the mesh so that it 124 | // fits within the scene. 125 | PlaceMesh(m *model3d.Mesh) *model3d.Mesh 126 | } 127 | 128 | // RoomLayout is a simple scene in a room with lights on 129 | // the walls and ceiling. 130 | type RoomLayout struct { 131 | Width float64 132 | Depth float64 133 | } 134 | 135 | func (r RoomLayout) CameraInfo() (position, target model3d.Coord3D) { 136 | return model3d.Coord3D{Z: 0.5, Y: -r.Depth/2 + 1e-5}, model3d.Coord3D{Z: 0.5, Y: r.Depth / 2} 137 | } 138 | 139 | func (r RoomLayout) CreateLight() (render3d.AreaLight, render3d.FocusPoint) { 140 | var center model3d.Coord3D 141 | var axis model3d.Coord3D 142 | if rand.Intn(2) == 0 { 143 | // Place light on ceiling. 144 | center = model3d.Coord3D{ 145 | X: (rand.Float64() - 0.5) * r.Width, 146 | Y: (rand.Float64() - 0.5) * r.Depth, 147 | Z: 1.0, 148 | } 149 | axis = model3d.Coord3D{Z: 1} 150 | } else { 151 | // Place light on side wall. 152 | x := r.Width / 2 153 | if rand.Intn(2) == 0 { 154 | x = -x 155 | } 156 | center = model3d.Coord3D{ 157 | X: x, 158 | Y: (rand.Float64() - 0.5) * r.Depth, 159 | Z: rand.Float64() * 0.9, 160 | } 161 | axis = model3d.Coord3D{X: 1 / x} 162 | } 163 | 164 | var light render3d.AreaLight 165 | var focusRadius float64 166 | color := render3d.NewColor((rand.Float64() + 0.1) * 20) 167 | if rand.Intn(2) == 0 { 168 | focusRadius = rand.Float64()*0.2 + 0.05 169 | light = render3d.NewSphereAreaLight( 170 | &model3d.Sphere{Center: center, Radius: focusRadius}, 171 | color, 172 | ) 173 | } else { 174 | size := uniformRandom().Scale(0.1).Add(model3d.Coord3D{X: 0.05, Y: 0.05, Z: 0.05}) 175 | light = render3d.NewMeshAreaLight( 176 | model3d.NewMeshRect( 177 | center.Sub(size), 178 | center.Add(size), 179 | ), 180 | color, 181 | ) 182 | focusRadius = size.Norm() 183 | } 184 | 185 | light = &HalfLight{ 186 | AreaLight: light, 187 | Axis: axis, 188 | MaxDot: 1, 189 | } 190 | 191 | return light, &render3d.SphereFocusPoint{ 192 | Center: light.Min().Mid(light.Max()), 193 | Radius: focusRadius, 194 | } 195 | } 196 | 197 | func (r RoomLayout) CreateBackdrop() []model3d.Collider { 198 | min := model3d.Coord3D{X: -r.Width / 2, Y: -r.Depth / 2} 199 | max := model3d.Coord3D{X: r.Width / 2, Y: r.Depth / 2, Z: 1} 200 | mesh := model3d.NewMeshRect(min, max) 201 | 202 | var walls []model3d.Collider 203 | mesh.Iterate(func(t *model3d.Triangle) { 204 | var neighbor *model3d.Triangle 205 | for _, n := range mesh.Neighbors(t) { 206 | if n.Normal().Dot(t.Normal()) > 0.99 { 207 | neighbor = n 208 | break 209 | } 210 | } 211 | mesh.Remove(neighbor) 212 | mesh.Remove(t) 213 | walls = append(walls, model3d.NewJoinedCollider([]model3d.Collider{t, neighbor})) 214 | }) 215 | 216 | return walls 217 | } 218 | 219 | func (r RoomLayout) PlaceMesh(m *model3d.Mesh) *model3d.Mesh { 220 | placeMin := model3d.Coord3D{X: -r.Width / 2, Y: -r.Depth / 4} 221 | placeMax := model3d.Coord3D{X: r.Width / 2, Y: r.Depth / 2, Z: 1} 222 | return placeInBounds(placeMin, placeMax, m) 223 | } 224 | 225 | func placeInBounds(placeMin, placeMax model3d.Coord3D, m *model3d.Mesh) *model3d.Mesh { 226 | min, max := m.Min(), m.Max() 227 | diff := max.Sub(min) 228 | pDiff := placeMax.Sub(placeMin) 229 | maxScale := math.Min(pDiff.X/diff.X, math.Min(pDiff.Y/diff.Y, pDiff.Z/diff.Z)) 230 | scale := (rand.Float64()*0.9 + 0.1) * maxScale 231 | m = m.Scale(scale) 232 | 233 | min, max = m.Min(), m.Max() 234 | translateMin := placeMin.Sub(min) 235 | translateMax := placeMax.Sub(max) 236 | translate := uniformRandom().Mul(translateMax.Sub(translateMin)).Add(translateMin) 237 | 238 | // Drop Z to minimum. 239 | translate.Z = translateMin.Z 240 | 241 | return m.MapCoords(translate.Add) 242 | } 243 | 244 | func uniformRandom() model3d.Coord3D { 245 | return model3d.Coord3D{X: rand.Float64(), Y: rand.Float64(), Z: rand.Float64()} 246 | } 247 | 248 | // WorldLayout is a layout that places objects in a large 249 | // hemisphere. 250 | type WorldLayout struct{} 251 | 252 | func (w WorldLayout) CameraInfo() (position, target model3d.Coord3D) { 253 | return model3d.Coord3D{Y: -20, Z: 5}, model3d.Coord3D{Y: 0, Z: 5} 254 | } 255 | 256 | func (w WorldLayout) CreateLight() (render3d.AreaLight, render3d.FocusPoint) { 257 | center := model3d.NewCoord3DRandUnit().Scale(70) 258 | if center.Z < 0 { 259 | center.Z = -center.Z 260 | } 261 | if center.Y > 0 { 262 | // Usually, we want the lights behind the camera. 263 | if rand.Intn(5) != 0 { 264 | center.Y *= -1 265 | } 266 | } 267 | shape := &model3d.Sphere{Center: center, Radius: rand.Float64()*5.0 + 2.0} 268 | r2 := shape.Radius * shape.Radius 269 | emission := render3d.NewColor((rand.Float64() + 0.5) * 200 / r2) 270 | return render3d.NewSphereAreaLight(shape, emission), 271 | &render3d.SphereFocusPoint{ 272 | Center: shape.Center, 273 | Radius: shape.Radius, 274 | } 275 | } 276 | 277 | func (w WorldLayout) CreateBackdrop() []model3d.Collider { 278 | r := 100.0 279 | p1 := model3d.Coord3D{X: -r, Y: -r} 280 | p2 := model3d.Coord3D{X: -r, Y: r} 281 | p3 := model3d.Coord3D{X: r, Y: r} 282 | p4 := model3d.Coord3D{X: r, Y: -r} 283 | 284 | floor := model3d.NewMesh() 285 | floor.Add(&model3d.Triangle{p1, p2, p3}) 286 | floor.Add(&model3d.Triangle{p1, p3, p4}) 287 | 288 | dome := &model3d.Sphere{Radius: r} 289 | 290 | return []model3d.Collider{model3d.MeshToCollider(floor), dome} 291 | } 292 | 293 | func (w WorldLayout) PlaceMesh(m *model3d.Mesh) *model3d.Mesh { 294 | min := model3d.Coord3D{X: -7, Y: -7} 295 | max := model3d.Coord3D{X: 7, Y: 7, Z: 7} 296 | return placeInBounds(min, max, m) 297 | } 298 | 299 | type HalfLight struct { 300 | render3d.AreaLight 301 | 302 | Axis model3d.Coord3D 303 | MaxDot float64 304 | } 305 | 306 | func (h *HalfLight) SampleLight(gen *rand.Rand) (point, normal model3d.Coord3D, c render3d.Color) { 307 | for { 308 | point, normal, c = h.AreaLight.SampleLight(gen) 309 | if h.Axis.Dot(point) < h.MaxDot { 310 | return 311 | } 312 | } 313 | } 314 | 315 | func (h *HalfLight) TotalEmission() float64 { 316 | return h.AreaLight.TotalEmission() / 2 317 | } 318 | -------------------------------------------------------------------------------- /example/50_rpp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/polish/4170538ad0dd695389fb90e50ebd9a6d555bf3c5/example/50_rpp.png -------------------------------------------------------------------------------- /example/512_rpp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/polish/4170538ad0dd695389fb90e50ebd9a6d555bf3c5/example/512_rpp.png -------------------------------------------------------------------------------- /example/albedo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/polish/4170538ad0dd695389fb90e50ebd9a6d555bf3c5/example/albedo.png -------------------------------------------------------------------------------- /example/denoised_deep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/polish/4170538ad0dd695389fb90e50ebd9a6d555bf3c5/example/denoised_deep.png -------------------------------------------------------------------------------- /example/denoised_deep_aux.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/polish/4170538ad0dd695389fb90e50ebd9a6d555bf3c5/example/denoised_deep_aux.png -------------------------------------------------------------------------------- /example/denoised_shallow_aux.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/polish/4170538ad0dd695389fb90e50ebd9a6d555bf3c5/example/denoised_shallow_aux.png -------------------------------------------------------------------------------- /example/half_and_half.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/polish/4170538ad0dd695389fb90e50ebd9a6d555bf3c5/example/half_and_half.png -------------------------------------------------------------------------------- /example/incidence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/polish/4170538ad0dd695389fb90e50ebd9a6d555bf3c5/example/incidence.png -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/unixpickle/polish 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/pkg/errors v0.9.1 7 | github.com/unixpickle/essentials v1.1.0 8 | github.com/unixpickle/model3d v0.2.1 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 2 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 3 | github.com/unixpickle/essentials v1.0.1/go.mod h1:dQ1idvqrgrDgub3mfckQm7osVPzT3u9rB6NK/LEhmtQ= 4 | github.com/unixpickle/essentials v1.1.0 h1:kJ/mU3MfmmSfuU8zyplwkup60lKV9+ucqZC+hR1GgVU= 5 | github.com/unixpickle/essentials v1.1.0/go.mod h1:dQ1idvqrgrDgub3mfckQm7osVPzT3u9rB6NK/LEhmtQ= 6 | github.com/unixpickle/model3d v0.1.1 h1:UK3GlZPOfdFyUctwYsfrdAnqFI94It7deyqP/gfD0GQ= 7 | github.com/unixpickle/model3d v0.1.1/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 8 | github.com/unixpickle/model3d v0.1.2 h1:q2AcDy3iKVPI+Vu7aJd2PfeZ5SGpWT9YMi2DBCugVYo= 9 | github.com/unixpickle/model3d v0.1.2/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 10 | github.com/unixpickle/model3d v0.1.4 h1:bHvTzYXPuZdqgP15f9kc+/jHjdtLxkVL53k+QSMZHsc= 11 | github.com/unixpickle/model3d v0.1.4/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 12 | github.com/unixpickle/model3d v0.1.5 h1:WLMkF7MzZfHMU5zg9Uj5ZAaB2sgpUxDUI7LB7emY2mM= 13 | github.com/unixpickle/model3d v0.1.5/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 14 | github.com/unixpickle/model3d v0.1.9 h1:mrJrynk1Mn/FEiSaEByg5CXq+1zu+BhnUMiD74vc7VY= 15 | github.com/unixpickle/model3d v0.1.9/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 16 | github.com/unixpickle/model3d v0.1.12 h1:4pboP6I9KVFB/3lD7J82Ur53GfumyZ8kSB/vLRvym5s= 17 | github.com/unixpickle/model3d v0.1.12/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 18 | github.com/unixpickle/model3d v0.1.13 h1:Onepw2lib5T3hkLc82uyCOTQKW5Qo2wqCwwgGvvx/Bk= 19 | github.com/unixpickle/model3d v0.1.13/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 20 | github.com/unixpickle/model3d v0.1.14 h1:UHGM9hTk4bgz3s8TIYQBI6v0EwIBROWMdwEacu44UhY= 21 | github.com/unixpickle/model3d v0.1.14/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 22 | github.com/unixpickle/model3d v0.1.15 h1:h4cxQRhTAwo4tJ7bYjVNqIpHb1aYqnT/ChMDP9GbPPY= 23 | github.com/unixpickle/model3d v0.1.15/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 24 | github.com/unixpickle/model3d v0.2.0 h1:IMD+FvPU6ho1hb/WmluVP4MG9QNaE9eoaVTKIjjac3Y= 25 | github.com/unixpickle/model3d v0.2.0/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 26 | github.com/unixpickle/model3d v0.2.1 h1:zivAhQEXGRXiXjgzoSwR5XabUBoh+2icTi9aKNR0i60= 27 | github.com/unixpickle/model3d v0.2.1/go.mod h1:/jD5uEOZVtoeyF0K1t9rCFsturMfh/gdj3fX/DVV360= 28 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // Command polish denoises images that were produced by a 2 | // Monte Carlo rendering technique (e.g. path tracing). 3 | package main 4 | 5 | import ( 6 | "flag" 7 | "fmt" 8 | "image" 9 | "image/png" 10 | "os" 11 | 12 | "github.com/unixpickle/polish/polish" 13 | 14 | "github.com/unixpickle/essentials" 15 | ) 16 | 17 | func main() { 18 | var model string 19 | var patchSize int 20 | var patchBorder int 21 | var albedoPath string 22 | var incidencePath string 23 | flag.StringVar(&model, "model", "deep", "type of model to use "+ 24 | "('shallow', 'deep', 'shallow-aux', 'deep-aux', 'bilateral')") 25 | flag.IntVar(&patchSize, "patch", 0, "image patch size to process at once (0 to disable)") 26 | flag.IntVar(&patchBorder, "patch-border", -1, "border for image patches (-1 uses default)") 27 | flag.StringVar(&albedoPath, "albedo", "", "path to albedo map image (for aux models)") 28 | flag.StringVar(&incidencePath, "incidence", "", "path to incidence map image (for aux models)") 29 | 30 | flag.Usage = func() { 31 | fmt.Fprintln(os.Stderr, "Usage: "+os.Args[0]+" [flags] ") 32 | fmt.Fprintln(os.Stderr) 33 | flag.PrintDefaults() 34 | fmt.Fprintln(os.Stderr) 35 | os.Exit(1) 36 | } 37 | 38 | flag.Parse() 39 | if len(flag.Args()) != 2 { 40 | flag.Usage() 41 | } 42 | 43 | var modelType polish.ModelType 44 | if model == "shallow" { 45 | modelType = polish.ModelTypeShallow 46 | } else if model == "deep" { 47 | modelType = polish.ModelTypeDeep 48 | } else if model == "bilateral" { 49 | modelType = polish.ModelTypeBilateral 50 | } else if model == "shallow-aux" { 51 | modelType = polish.ModelTypeShallowAux 52 | } else if model == "deep-aux" { 53 | modelType = polish.ModelTypeDeepAux 54 | } else { 55 | flag.Usage() 56 | } 57 | 58 | if modelType.Aux() { 59 | if incidencePath == "" { 60 | fmt.Fprintln(os.Stderr, "auxiliary model requires -incidence flag") 61 | } 62 | if albedoPath == "" { 63 | fmt.Fprintln(os.Stderr, "auxiliary model requires -albedo flag") 64 | } 65 | if albedoPath == "" || incidencePath == "" { 66 | os.Exit(1) 67 | } 68 | } 69 | 70 | inPath := flag.Args()[0] 71 | outPath := flag.Args()[1] 72 | 73 | inImage := readPNG(inPath) 74 | 75 | var outImage image.Image 76 | if !modelType.Aux() { 77 | if patchSize != 0 { 78 | outImage = polish.PolishImagePatches(modelType, inImage, patchSize, patchBorder) 79 | } else { 80 | outImage = polish.PolishImage(modelType, inImage) 81 | } 82 | } else { 83 | albedo := readPNG(albedoPath) 84 | incidence := readPNG(incidencePath) 85 | inTensor := polish.CreateAuxTensorImages(inImage, albedo, incidence) 86 | if patchSize != 0 { 87 | outImage = polish.PolishAuxPatches(modelType, inTensor, patchSize, patchBorder) 88 | } else { 89 | outImage = polish.PolishAux(modelType, inTensor) 90 | } 91 | } 92 | 93 | w, err := os.Create(outPath) 94 | essentials.Must(err) 95 | essentials.Must(png.Encode(w, outImage)) 96 | } 97 | 98 | func readPNG(path string) image.Image { 99 | r, err := os.Open(path) 100 | essentials.Must(err) 101 | defer r.Close() 102 | inImage, err := png.Decode(r) 103 | essentials.Must(err) 104 | return inImage 105 | } 106 | -------------------------------------------------------------------------------- /polish/features.go: -------------------------------------------------------------------------------- 1 | package polish 2 | 3 | import ( 4 | "image" 5 | "image/color" 6 | "image/png" 7 | "math" 8 | "math/rand" 9 | "os" 10 | 11 | "github.com/unixpickle/model3d/model3d" 12 | "github.com/unixpickle/model3d/render3d" 13 | "github.com/unixpickle/polish/polish/nn" 14 | ) 15 | 16 | // albedoMapSamples is the number of BSDF samples used to 17 | // estimate a surface's albedo. 18 | // This should roughly match the sample count used to 19 | // train the model. 20 | const albedoMapSamples = 400 21 | 22 | // CreateAuxTensor creates a Tensor for a rendering with 23 | // auxiliary feature channels. 24 | // 25 | // This Tensor can then be passed to PolishAux. 26 | // 27 | // The channels are ordered as follows: 28 | // 29 | // 1. Red 30 | // 2. Green 31 | // 3. Blue 32 | // 4. Albedo red 33 | // 5. Albedo green 34 | // 6. Albedo blue 35 | // 7. Ray-surface cosine map 36 | // 37 | func CreateAuxTensor(c *render3d.Camera, obj render3d.Object, img image.Image) *nn.Tensor { 38 | b := img.Bounds() 39 | w, h := b.Dx(), b.Dy() 40 | albedo := CreateAlbedoMap(c, obj, w, h, 400) 41 | incidence := CreateIncidenceMap(c, obj, w, h) 42 | return CreateAuxTensorImages(img, albedo, incidence) 43 | } 44 | 45 | // CreateAuxTensorImages creates an auxiliary Tensor using 46 | // pre-constructed auxiliary images. 47 | // 48 | // See CreateAuxTensor for details on the channel order. 49 | func CreateAuxTensorImages(img, albedo, incidence image.Image) *nn.Tensor { 50 | b := img.Bounds() 51 | w, h := b.Dx(), b.Dy() 52 | inTensor := nn.NewTensor(h, w, 7) 53 | for y := 0; y < h; y++ { 54 | for x := 0; x < w; x++ { 55 | red, green, blue, _ := img.At(x+b.Min.X, y+b.Min.Y).RGBA() 56 | for i, c := range []uint32{red, green, blue} { 57 | *inTensor.At(y, x, i) = float32(c) / 0xffff 58 | } 59 | red, green, blue, _ = albedo.At(x, y).RGBA() 60 | for i, c := range []uint32{red, green, blue} { 61 | *inTensor.At(y, x, i+3) = float32(c) / 0xffff 62 | } 63 | gray, _, _, _ := incidence.At(x, y).RGBA() 64 | *inTensor.At(y, x, 6) = float32(gray) / 0xffff 65 | } 66 | } 67 | return inTensor 68 | } 69 | 70 | // CreateIncidenceMap creates a feature image where each 71 | // pixel indicates the dot product of the camera ray with 72 | // the normal of the first ray collision. 73 | func CreateIncidenceMap(c *render3d.Camera, obj render3d.Object, 74 | width, height int) *image.Gray { 75 | caster := c.Caster(float64(width)-1, float64(height)-1) 76 | img := image.NewGray(image.Rect(0, 0, width, height)) 77 | for y := 0; y < height; y++ { 78 | for x := 0; x < width; x++ { 79 | ray := &model3d.Ray{ 80 | Origin: c.Origin, 81 | Direction: caster(float64(x), float64(y)), 82 | } 83 | coll, _, ok := obj.Cast(ray) 84 | if ok { 85 | incidence := uint8(math.Abs(coll.Normal.Dot(ray.Direction.Normalize())) * 255.999) 86 | img.SetGray(x, y, color.Gray{Y: incidence}) 87 | } 88 | } 89 | } 90 | return img 91 | } 92 | 93 | // CreateAlbedoMap creates a feature image where each 94 | // pixel indicates the albedo of the surface intersected 95 | // by the camera ray. 96 | // 97 | // The bsdfSamples argument specifies how many times each 98 | // BSDF is sampled to approximate the albedo. 99 | // A higher value gives more accurate results for complex 100 | // materials. 101 | func CreateAlbedoMap(c *render3d.Camera, obj render3d.Object, 102 | width, height, bsdfSamples int) *image.RGBA { 103 | caster := c.Caster(float64(width)-1, float64(height)-1) 104 | res := render3d.NewImage(width, height) 105 | gen := rand.New(rand.NewSource(rand.Int63())) 106 | var idx int 107 | for y := 0; y < height; y++ { 108 | for x := 0; x < width; x++ { 109 | ray := &model3d.Ray{ 110 | Origin: c.Origin, 111 | Direction: caster(float64(x), float64(y)), 112 | } 113 | coll, mat, ok := obj.Cast(ray) 114 | if ok { 115 | dest := ray.Direction.Scale(-1).Normalize() 116 | res.Data[idx] = estimateAlbedo(gen, mat, coll.Normal, dest, bsdfSamples) 117 | idx++ 118 | } 119 | } 120 | } 121 | return res.RGBA() 122 | } 123 | 124 | func estimateAlbedo(gen *rand.Rand, mat render3d.Material, normal, dest model3d.Coord3D, 125 | bsdfSamples int) render3d.Color { 126 | switch mat := mat.(type) { 127 | case *render3d.LambertMaterial: 128 | if normal.Dot(dest) < 0 { 129 | return render3d.Color{} 130 | } 131 | return mat.DiffuseColor 132 | default: 133 | var colorSum render3d.Color 134 | for i := 0; i < bsdfSamples; i++ { 135 | source := mat.SampleSource(gen, normal, dest) 136 | density := mat.SourceDensity(normal, source, dest) 137 | bsdf := mat.BSDF(normal, source, dest) 138 | sourceDot := math.Abs(source.Dot(normal)) 139 | colorSum = colorSum.Add(bsdf.Scale(sourceDot / density)) 140 | } 141 | return colorSum.Scale(1 / float64(bsdfSamples)) 142 | } 143 | } 144 | 145 | // SaveFeatureMap encodes a feature map image to a PNG 146 | // file. 147 | func SaveFeatureMap(path string, img image.Image) error { 148 | w, err := os.Create(path) 149 | if err != nil { 150 | return err 151 | } 152 | defer w.Close() 153 | return png.Encode(w, img) 154 | } 155 | -------------------------------------------------------------------------------- /polish/model_data.go: -------------------------------------------------------------------------------- 1 | package polish 2 | 3 | import ( 4 | "archive/zip" 5 | "bytes" 6 | "encoding/binary" 7 | "fmt" 8 | "io/ioutil" 9 | "math" 10 | 11 | "github.com/unixpickle/essentials" 12 | "github.com/unixpickle/polish/polish/nn" 13 | ) 14 | 15 | func createShallow() nn.Layer { 16 | params := readParameterZip(shallowModelZipData) 17 | return nn.NN{ 18 | loadConv(params, "conv1", 5, 1, 3, 32), 19 | nn.ReLU{}, 20 | loadConv(params, "conv2", 5, 1, 32, 3), 21 | } 22 | } 23 | 24 | func createShallowAux() nn.Layer { 25 | params := readParameterZip(shallowAuxModelZipData) 26 | return nn.NN{ 27 | loadConv(params, "conv1", 5, 1, 7, 32), 28 | nn.ReLU{}, 29 | loadConv(params, "conv2", 5, 1, 32, 3), 30 | } 31 | } 32 | 33 | func createDeep(aux bool) nn.Layer { 34 | paramData := deepModelZipData 35 | inChannels := 3 36 | if aux { 37 | paramData = deepAuxModelZipData 38 | inChannels = 7 39 | } 40 | params := readParameterZip(paramData) 41 | 42 | result := nn.NN{ 43 | loadConv(params, "conv1", 5, 2, inChannels, 64), 44 | nn.ReLU{}, 45 | loadDepthSepConv(params, "conv2", 5, 2, 64, 128), 46 | } 47 | 48 | for i := 0; i < 4; i++ { 49 | layer := fmt.Sprintf("residuals.%d", i) 50 | result = append(result, nn.Residual{ 51 | loadBatchNorm(params, layer+".0"), 52 | nn.ReLU{}, 53 | loadDepthSepConv(params, layer+".2", 3, 1, 128, 256), 54 | nn.ReLU{}, 55 | loadDepthSepConv(params, layer+".4", 3, 1, 256, 128), 56 | }) 57 | } 58 | 59 | result = append(result, 60 | loadDeconv(params, "deconv1", 4, 2, 128, 64), 61 | nn.ReLU{}, 62 | loadDeconv(params, "deconv2", 4, 2, 64, 32), 63 | nn.ReLU{}, 64 | loadConv(params, "conv3", 3, 1, 32, 3), 65 | ) 66 | 67 | return result 68 | } 69 | 70 | func loadConv(p map[string][]float32, key string, kernel, stride, inDepth, outDepth int) nn.Layer { 71 | return nn.NN{ 72 | nn.NewPad(kernel/2, kernel/2, kernel/2, kernel/2), 73 | &nn.Conv{ 74 | InDepth: inDepth, 75 | OutDepth: outDepth, 76 | KernelSize: kernel, 77 | Stride: stride, 78 | Weights: p[key+".weight"], 79 | }, 80 | &nn.Bias{Data: p[key+".bias"]}, 81 | } 82 | } 83 | 84 | func loadDeconv(p map[string][]float32, key string, kernel, stride, 85 | inDepth, outDepth int) nn.Layer { 86 | s := (kernel - 1) / 2 87 | return nn.NN{ 88 | &nn.Deconv{ 89 | InDepth: inDepth, 90 | OutDepth: outDepth, 91 | KernelSize: kernel, 92 | Stride: stride, 93 | Weights: p[key+".weight"], 94 | }, 95 | nn.NewUnpad(s, s, s, s), 96 | &nn.Bias{Data: p[key+".bias"]}, 97 | } 98 | } 99 | 100 | func loadDepthSepConv(p map[string][]float32, key string, 101 | kernel, stride, inDepth, outDepth int) nn.Layer { 102 | return nn.NN{ 103 | nn.NewPad(kernel/2, kernel/2, kernel/2, kernel/2), 104 | &nn.SpatialConv{ 105 | Depth: inDepth, 106 | KernelSize: kernel, 107 | Stride: stride, 108 | Weights: p[key+".spatial.weight"], 109 | }, 110 | &nn.Bias{Data: p[key+".spatial.bias"]}, 111 | nn.ReLU{}, 112 | &nn.Conv{ 113 | InDepth: inDepth, 114 | OutDepth: outDepth, 115 | KernelSize: 1, 116 | Stride: 1, 117 | Weights: p[key+".depthwise.weight"], 118 | }, 119 | &nn.Bias{Data: p[key+".depthwise.bias"]}, 120 | } 121 | } 122 | 123 | func loadBatchNorm(p map[string][]float32, key string) nn.Layer { 124 | negMean := append([]float32{}, p[key+".running_mean"]...) 125 | for i, x := range negMean { 126 | negMean[i] = -x 127 | } 128 | variance := p[key+".running_var"] 129 | weight := p[key+".weight"] 130 | bias := p[key+".bias"] 131 | 132 | scale := make([]float32, len(variance)) 133 | offset := make([]float32, len(variance)) 134 | for i, x := range variance { 135 | invStd := float32(1.0 / math.Sqrt(float64(x+1e-5))) 136 | scale[i] = weight[i] * invStd 137 | offset[i] = bias[i] 138 | } 139 | return nn.NN{ 140 | &nn.Bias{Data: negMean}, 141 | &nn.Mul{Data: scale}, 142 | &nn.Bias{Data: offset}, 143 | } 144 | } 145 | 146 | func readParameterZip(rawZip string) map[string][]float32 { 147 | zipData := []byte(rawZip) 148 | byteReader := bytes.NewReader(zipData) 149 | zipReader, err := zip.NewReader(byteReader, int64(len(zipData))) 150 | essentials.Must(err) 151 | 152 | params := map[string][]float32{} 153 | for _, file := range zipReader.File { 154 | r, err := file.Open() 155 | essentials.Must(err) 156 | data, err := ioutil.ReadAll(r) 157 | essentials.Must(err) 158 | values := make([]float32, len(data)/4) 159 | binary.Read(bytes.NewReader(data), binary.LittleEndian, values) 160 | params[file.Name] = values 161 | } 162 | 163 | return params 164 | } 165 | -------------------------------------------------------------------------------- /polish/model_data_shallow.go: -------------------------------------------------------------------------------- 1 | package polish 2 | 3 | const shallowModelZipData = "PK\x03\x04\x14\x00\x00\x00\x00\x00\x04\xa7\xa5P\xea\x02A\xfb\x80%\x00\x00\x80%\x00\x00\x0c\x00\x00\x00conv1.weight<2e\xba \x0b<\xbd\xbc{Q\xbc8\xa0S\xbd\x12\x85\xfa\xbc\xd4p\x82=\x93\xb3\x98=\xac\x9a\xc4\xbb<\xf5U\xbd\x1e\xe2\x8b\xbc\xff1\xaf=l\xca+=\x9d\xfb.\xbd\x98\x9b\x86\xbd\xfd\xa9;='.\x90=\x98S}\xbd\x9e\xe4\x9c\xbcA\x17-\xbe\x0d\x9f\xcd]\xd2\x17>\xf51\xeb=|\xa7\x94=\xbay\xf6\xba\xe2\x18{\xbd\x1ef^<7\xc4\xda\xbd[\x87}\xbe\xa9\xb7\xa6\xbd{K\x92=\xd8\xc0/\xbe\x0f\xe2,\xbd,A$\xbe_\xfe\xfa\xbd\xa5\xb5D=\xbd\xea\x14\xbb\x9e\x0dM\xbd\xc2\x8a~=K\x06\x9a<[,\x12=![\x10\xbeX\xa0\xa6\xbcc\x8c\xb4\xbd\x0b\xa24\xbd\x87\xa0x\xa6\xa9(=:\xd6\xc3:V[\x89:\x15R\x0f\xbdK\x13\xe7\xbb\x00\x06`\xbdt7\xe4=-T\xf3IX\xf1<\x81\x1d\xa5\xbd\x9c\x0f\x86<\x9e\x90\x87=\xd0\xa7\xdb\xbdMX\xec=\x9b\x13\x86\xbc\x84\xc9\xac\xbd\x10\x11\xa2\xbd%=\xd4=jF==\x02\xf8\x14=*\xda\x9d=\xd4\x8d\xd9\xbdxQ\xfc=\xc2\xdd\xbf\xbb_\x97\xfd\xbd7f\xa0\xbckDR\xbc\x05\x80%\xbe\xdbQ\x05\xbe\xa0\x17\xbc=W\xed\xd4<\x085\x92=\x07\xaa$>\xd8\xfc\x8e\xbd\xa4\xbe\xd1=\xed)\x92\xbcW\xedn\xbdE\xc1\x1f\xbe-x\xa6=i\xbcc\xbd~\x19\x91=\xb8\xaaX\xbd\x06\xd4T=6~\xc9=]e\xd8\xbdl\xb6\x9a\xbd\x1e\x06\xe2\xbc)\xfc\xf7<\x93l\xa5;\x89\xc0\xed<\x88\x94t<\xa2\x04\x12=)E#\xbe\xce\x80\xef\xbd\x92\xa8\xbc=\x11`\x8b=r\x06\xe6\xbd\xf8p.>\x22\xd4q\xbd\xd9\xc5\x16\xbd\xa7\xe4/=\xd8*@<7L\x8a\xbd\xa0Rf<\xbe\xba_\xbc\x85\xc4\x19\xbb\xc5|\xd3\xbcnua<\x01\xf6\x83<\x8cw\x14\xbdY8\xbf<.k\x17<\xb6H\x8d;*%\x0c\xbdf\x19y\xbdN\xa5y\xbcn@\x04\xbc{\xb0\x0b=\xbf\xc3\xba\xbd\x87z\x01\xbe \x9d.\xbdo\x05\xba\xbc\xdd\xa56\xbdR\xaay=>\x9d\x8b=#v\x5c=:\xdeX=\x18\x15\x05=\xaeI\x1c='\xa36=%\x00Y=,\xf1\xd9;\xb3\xb8\x82;x.\x10\xbc\xa4\xa8R=\xdb7n\xbd/\x99\xbc\xbc\xa1\xc5\xeb\xbc\xa9\xea\x99\xa0\xf5\xb3=\xd1\xd6\xe5<}]i=i\xbd\xa5\xbd\xb2d\x12\xbe'\x1d\xcf\xbd]Il\xbd\xaa\x9c\x82\xbd\xb9/i<\x8b\x04\xbe\xbc\x1f\xad\xb4\xbd\x19\x82\xe9\xbci\xf2\x84\xba}\x92\xfd\xbaM6w\xbd\x86\xbe\xde=\xc8\x0b\xf6<\xc5\xf0\xae\xbbu\x94\xa3<\xea1\xd8\xbc\xd6\xd2\x1a>\x9b\x99\x9e=@}\x82<\xd1\xb20<\xfd\xa1\xc0;\xbbP_>\x1a:\x10>7\x08\x8a=\x8b\x04^\xben\xcb\xd0\xbeW\xd2I=T8h\xbc\xaa\x19\x02\xbd\xd8\xee\xdc\xbd\x01\x80`\xbeI\x1d\x84=\xc2+\xd1<3I\x94\xbcN\xce\xae\x0cJ\x93\xbd\x1a\xd6-\xbc\xf3\xd2\xcc=22\x1e\xbdW\x03\x1e\xbd]\xc7\xa8\xf1.\xd3=\xaf}\x97\xbd\xdcw`=MT\x02\xbdI\xc7\xcf=\xa2\xd6<<\x19\x7f\x85=X\xcc\xeb;Y!\x94=(\xc1\x9b\xbc\xaa\x14\xd3=\xdc\xc5\xb5=C\x15T\xba\xf8q\x07=;\x19\x91\xbd\xa7,\xd0=\x82\x15\xc3\xbb;\xf8\xbe\xbd\xd8\xf9\x0b<\xa9%\xe6\xbd\xed8\x97}\x85\xc4\xbd\xc1\xbf\xac\xbc\x0f#\x86=\xfc\xa3\xdd=\xd9e\xb1\xbd\xee\xc1\x80\xbd\xa5\xad\x84<\x1e\x89\x0c\xbe\x9d\x9c\xe8\xbc(\xf0\xee\xbd\xcd25\xbd\xa5k\xab\xbcr\xecV\xbck\xfa\x04\xbe\x98\x0b\xb5\xbd\xbb<\x07>}*\x1c\xbe\xd1\xb5\xb9\xbd\xd0\xda =@\x07\x86<\x1b\x1b\x95\xbdY\x19L=\x8aH0\xbd\xd7\x18I=\x0a\xdc\x02=\xdf\xde\xaf=I\x8e\x02;>%k=\xeb\x1d\x7f<\xe7K\x06\xbd\x95s\xf0\xbc+\x93\xf0=\x7f^\x13\xbe\x8a\x9c\x0f>\xf7\xf7\xda<\xfa\x1c(\xbbh\xca\x9b<\xb8\x1c!=_\xb9\xb1\xbc]\xa9:\xbd\xba\xa7U=\xa4/\x01\xbeI\x1d\x0b<\x1b\xf7\xaf<%f\xf0\xbde\xb1\xb0\xbd\xa0T->\x8b\xb2\xad\xbd\xac\xf6\x13\xbdQ\xd6\x1d=\xefm\xa2=X\xd32\xbd\xe8\xa0\x14>\x16\x1f\x0c\xbd`\xa8\x87:\x93{]=\xab\xe2L=f\xc73\xbe7\xed\xe8=\x96\xeb\xac\xbc\xd8S\x22\xbd\x0a\x82~=\x9c\xd1z\xbd\x87\xb8\x0e\xbe\x8a\xf0#=\x99\x92%\xbd\x81O\x19\xbdz\x84:\xbdg\x9er\xbd\xa6e\x8c<\x12}y\xbd\xa6%\x19>%E\x09=9\xd3Z=\xa9*\x96=\xa9\xde/~\xdb\xcd=\xcaO\xcc=|\x96\x01>\xcb\x0c\x12< f\x96=C\xae\x87=\xfe\xf7O=m%\x8d=\x1a \x999\xd1^F\xbd\x1f\xec*:3\xa1\xe4=\xc9\xfd\x18=\x1f\x9e\xe7;^\x93\x16\xbd\xaf\xc4\xb8<\x86\x9d\x00\xbc\x0c\xbfW=_\xf1\xee=\xc7\xb6h\xbd\x96\x97L\xbd\x15\xed\xa5\xbd\xe4\xa5\xc1=\x9d%\xe3=\xe6\x90\xb9\xbd\xdfZ\x18\xbd\xbc5\xdb\xbb\x5c\xfcj=\x1c\xbb\x99\xbd^\xc9Z=j\xc5\xe4\xbaY\x04\xd2;\xf4i@\xbb\xba\xc0~=^E\x97\xbd\xdax>=\xc4=k=\xe39\xc5\xbd\xddX\x89<4w\xb3\xbbjo\xad\xbd\xfe\xaa==\xfeT`\xbd!\xb6O=y\xe0\x0a:Q\xb7=\x84\xcb\xd4\xbc\xf7\x99\x19=zJ\x06=S\x10\xb1\xbd\xd2\xb9t=}\xd8\x12=\xc5Pg\xbd\x09]Q\xbd1\xeb)\xbdf\x82\x89\xbc\xa7c\x9e\xbd|\x06\x9e\xbd\xb1p\x8a\xbd\xcf/\x8d\xbd\xae\xb3\x0a\xbd}\x9c\xba\xbbRl\xbb=\xae\xa7\xdb\xbd67:\xbd\xfb\xf7\x94<\xf4\x13\xf0;\xa0\xe2\xdf=\xa1k7>\xf1Q\x15>\x8d\xa0\x97\xbc\x0f[\xa3\xbbB\xab\x83\xbdm\x9d\x02\xbd\x7f8\x96<\x05Z\x07\xbdBN\xb8\xbdx\xf3\xdc\xbdO\xda\x11\xbe\x85\xea\x09\xbe7{\xac\xbdJ\xe9s\xbc$\xd5n\xbc\xd08e\xbc\x7f-\x86=\xcev\x1d;\xd0\xb7P=\xc8\xf6\xe3=&}!\xbe\xc4\xb7\xd3\xbb\xce\x93n<\x9c\xb0\x19=06A>\x02>\x9d<\xbd\xff\xba=\xb8\xd7\xf1=F\xf4\xe0;\xd2\xe0\x1a\xbcOi\xb9<\x88\xcc\xef<\x90\xcb*=\xb3\x06N=\xf9\x84\xc8\xbd\x8c\xc3E\xbd\x15k\x13\xbe#\xc8Q>n\xb4\x9a\xbc\x7f\xa5\xb7\xbd}\x1c\x8d=\x16\xa0\x1a>\xbb\x0fv\xbd\x01\x22\x95\xbd\xbe\xa8\x95\xbdr2\xb2\xbc\x9a6\xd2=3\xda\x13\xbd\xb5\xa8\x1a\xbca\xbdA\xbdq\xc3\x8e\xbc,9\x93\xbd\xef\xf0\xa2=?`V\xbdY\xe1\xe7\xbd\xfc4(=)#\xdd<\x86,8=:9\xe9:\x9f\x8c_\xbc\xad\x0f\xb7:%\xd1\x1b\xbe\x0b\xa1,>`y\xe0\xbc\xeb \xfc\xbd\x01\x8e\xac<_l\x10>\xe3%\xc6\xbd\xb8\xde\x0f=_K\xa0\xbc\x8c\x92\xaf\xbdfh\x0c=\xf4\xe8\x96\xbc\x16\x8d\x86<\xd3\x95\x1b\xbd\xcb\x00\xc0<=-\xe9<\x9b\xd0\xa1=h\xb8\xd9<\xcee\xd9\xbdd9\x7f;\xe7\xb6d+C\xe8<\x87\xf3\x06\xbe\xbd\x90\x8b=\x0e\xb2<>\x7fd\x07\xbcY-\x1b<\xf6\xda\x90\xbd\x8f\xec3\xbdZ|\xae=\x85\x14)\xbd\xfe\x89A<\xea\x99\x80\xbd\x9d\xf8\x99\xbcX\x92\xfa\xbc\xca\xc5\x86=I.\x8c\xbc\xc4\xc2N\xbd*\xd5\x92=\xf8B\xc2=\x15\xca\xe1<\x1b\xbb\xed\xbd\xdd\x07\x8a<\x00\xcb\xfb=\x9cN\x7f=\x11\x02J=rUh\xbe^3\xb8\xbd\x0ee\xac<\x1f)~=h\x0aH\xbc\x87\x05\x1a\xbe\xf94\x01\xbe|\x1d\x0e>\xaf*\x0c=\xf1X\x0e=\xc2\xb9\xe3\xbd\x5c\x0b\x13\xbd\xcb\xdd\xb3=_\xec\xbd<\x1c\x88\xae<\xa6\x84\xb5\xbdx\x89\xbc\xbcB\x88\x0a\xbd\xafm\x0c=[\xa1N\xbcD\xb3\xe5\xbd\x1c\xa6\xa9\xbdo\xb5=<\xd3\xd8\x8d=V\x97\x83\xbd\xb7\xdb\x02\xbeE\x155\xbbjJ\x18>\xd6\xed\x17>\x87y\xa1=\x03=\xe6\xbd1A~=\x8a\xbb\x90;\x80^\x85=\xb7\x1a\xa5\xbbTL\x9e\xbd\xf4\x99\x8c\xbb\xb6E\xbd'<\xdb\x01\x90<$\xca\x88\xbde\xe2\xae<\xbf\xfff=g\xb7\xa3;\xeb(\x81\xba\x8b\x9e\x85\xbd\x1cu\x18\xbd\x940{\xbc\x11\x16y=#W\x98=Ql\xd2\xbd\xac\xde\xbb\xbc*\xbbL=ZN\x96<\xa5\xe7\xa5\xd7\xea\x08=\xc1N%\xbd\xcd\x5c\x8a\xbe\xbb\xec\x8d=\xb5\xd4\x0c>m\x94\x8d=|!1<\x86k\xca\xbdS\x05\xd1\xbf\xf5=\xe5\xf2'>j\x01\xbf\xbd4\xa8\xab\xbd\x85\xdaN\xbdn\xf8\x86=i\x13\xa1=(:d\xbd\xd8 \xb6\x9e==\x82\xa0\xbd\xcb8\x03\xbd\xe6\x95F\xbc\xf3G4\xbd\x7f\xb4\xe1=\x80_\xf6=4\xad\xbe;\xd3:\xaa\xbdW\xc9\x8a=\x5c\xe1==\xc0\xfe\xed\xbc\xb4\x94\xd0<'X\xb9<\xd0\x8a\x8e\xbd\x18\x9e\x8b\xbb\xce\xe1n=Ze\xfc\xbc\xaf?&\xbc\xb5\xa2\xde=\xe8\x1bw<\xcf\xdf\xf1<\xb0\xad\xf0\xbct\x9aA\xbd*\xb4\x01\xbe\x14\xf1\xb8\xbd9\x03U=\xd9\xfdo\xbd\xf05\xa6\xbch_\xbc\xbd\xb2b\x0a\xbdx\x09\xfe\xbd\xe5\xfcW<\x02\x8fT\xbdE\xc8\xf6\xbd\xcd\xc6\xff\xbd\xd7\xee\xf2\xbd\x95\x5ci<1\xff{\xbc\x19\xcb\x99\xbb!\x88\x01=\xce\x03\x9f\xbd\xbe\x85.\xbdV=>=+\xc6\xc9\xbd\x22\xc3\xca\xbc\xce\xaa\xc6<\xe9\xe9\x9f;I\xb9.<\x97*\xa9=\xe3\xeb\xc6=\xfe-B\xbdq\xb7m=\x18\x9cM=\x01\x0b\x02=|\x81\xda<\x1d=\xe0\xbc\xdd)g=\xae_\xa7=_k\xe3=Qo\x18>1\xbc\x92\xbd\xe1\xb9\xb7=\x04\xce\xeb\xbd\xfe%\xa8=\xec\xa9!=\xaa\xe6\x1c\xbc\x11\xd12<:\xdb\x89\xbd4le\xbd\xc3\xc6\xbb=\x98\x82f<\xd8\x01\x9b<\x8aI\xdc\xbcN\x0a9\xbe\xf5i\xd5=\xd3\xee\x91=\xf59\x17=\xf0n\x94\xbd\x1e\xe7\x0e\xbe\x92Y\x9e=kf\x93F|\xfd=i\xb9==[\xcc\xdd\xbdz9\xce\xbd\x0fok\xbc\xf4\x98\xa8;u\xc8\x8a\xbc\xe2\xa39\xbdl{>\xbc\xaa\xea =C\xa8\xf6\xba\xadR\x0a\xbd\x8c\x97\xa0=\xdd\xf4*>\xcb\x8f\x82=r.\xb7=\xf8\xa0==\xb7k\x1e=o\xad\x18>\xb2.\x89=\xd6\x171=\xde\xbf\x85=Bo\xd7\xbd\xae\xba:\xbe\xb5L\x97\xbe\x85K\x1f\xbe\xb3\x18\x87\xbdH<\xef;3\x19\x1c=2\x86\x05\xbd\xc6\xba\x14\xbdH\xac\x09\xbdi\xa6\x82\xbal:\x18=\x5c\xe7S\xbd \xff\xce\xbc\xa9\xa2;<\xe8\xbf\xf2\xbc\xa6C\xe8\xbb8\x9a\x7f=+cI=G\xa8\x8aq\xc8\xc9=L\x8a\xd0<\xcf\x011=dW\x0c>\x8c\xc8\xd7=\x9b\xd7\x09>\x8b\x01\xf1=\x80\xd9\xc8=\xddk\x89\xbd\x8e\x12\xee\xbd\xa8\xca\xeb\xbd\x97C@\xbetO\x9f\xbd\xc9\x11\xd3\xbc\x9f\xcb\x9b\xbd\x06\xea5\xbe'\xc7\xee\xbdj{\xca\xbd\x99\xc0\xf3\xbb9b2\xbc\xa4\x8a\xcf\xbc>_)=\x99\xc0&=\x82\x18W=cAF=\xcd,\xf4=\xd4\x83\xc9=f\x8b\xe4\xbb\x1c\xb9\x95=X7\xd6<\xca}\xec<\xe7\x84\xf0=\x05\xd9z=6?\x12\xbd\x1b\xbdW\xbc/\x95y\xbd\x0f\xc8\x9d\xbd\x0f\x1at\xbc!\x89c\xbd'\x08\x01\xbe\x15\x06\x16\xbeY\x09\x0c\xbeX\xe51\xbd\xfd\xecs\xbcP\xfa\x0d=HHQ=\xc5\xb0\xe6\xbc>BR\xbd,w\xa6=\xd4\xe9\xc0<|\xf8\x15=\xf0\x0b\xb4=m\x08\xc5=\x1b`T=*\x80\x13\xbc\xed\xe4\x06=\xb5\x93!\xbd\xbe\xca\xf5:\xad\xb7\x93<\xa9F\x22\xbd\xd2\xc5\xce\xbd\xad\x04~\xba\x85\xdb\x87\xbd\xe0\xe9\xd8;\x89\x09\x9f\xbd\x1aZ\xe1\xbc\x97\xd9\x82\xbdC\xbaT\xbd\xfb\xc8\x1c>#\xd3\x1a>MH\xab=f\xfd\xaa=\x92b\x9d=\xe1\x90\xf0\xbc\x04\x12\xad9\xe0\x10\x80\xbd\x0e\xae\xe1\xbd4\x9d\xe8\xbda\xe2\x0a\xbd\x8c\x85\x85\xbd\xdf\xb4\x0b\xbe\xc3\x10\xc1\xbdP\x0aF\xdd\x0a\x8e\xbe\xa9\x82\x14\xbd\xee\x01)= C\xbe=~\xb1\x8e=\xb6\x0b$\xbe\xfa\xe9\x93\xbd\x825\xea<\x0a\xe8\xe2<\x1a3\x05>\xd2\xd1N\xbe \xb9\xe0;=\xb3r<\x92\xaf6<\xc9\x85\xf9=(+\xfe\xbd\x0f\xee\x8c\xbc)\x9b\x9b=\x98\xaex=\xe5\xfd\x16\xbc\xe9\x87\x90\xbcO&\xe7<\xaf`\xf6\xb5\xb5&>\xd0\xab\x06=\x06\xb9h\x13\xa6\xcc=\xfd\xc5;=O%2\xbd\xdf|T<\x06\x0c\xdc\xfc(i\xbd\xaaH\xec\xbd\xd6\xa5\x11\xbc\x88]d=\x98t3>\x1e\x92\xa2\xbdB4I\xbe\xbf\x15\x82\xbb\xc2\xf4$>n\x90->\xa9\xae\xf4\xbd\x08J\x85\xbdZf\xab\xbc/\xd2\x0b=\xde!\x08>9\x14\xd0\xbdQ\x0a\x1e\xbeR\xdb\x97:\xb82\xa6<`\xd7\xdd\xbc\x10\x079=nu\xe2<\x12\xa9\x97<\x1b9\xea\xbc\xa8\xbb(\xbd3\xff\xa9=\xa1s`\xbdU\x10\xe4:\x0b\xfe\x07\xbd\xb0\xb5\x91;A\x90\x00\xbd\xaa\x08i:\x7f\x97\xddc\x87\xc4=^\x1ci\xbe\xc5\x9a\x8b\xbd\xb5\xab\x1e=/\x03\xe2=/\xe0\x16=\x0dVA\xbeE\x02\x99\xbd\x02\x0d\x11\xbc\xa6L\x0a=V\xae\x9a=F\xa1\xaf\xbd\xef\xba\x95\xbd\xbd\x04Z\xbcP\xa5J<\xd0p\xdd=\x83\x88g\xbbT\x17\x88\xbdh\x84\xf4;\xb2\x06\xb9=*\x88\x88=\xfa\xbd\xaa\xbd\xe9D\x18\xbeK\xd5\x9c=\xdc6\x09>\xbe'.=y\xcfi\xbe\x05w\x1d\xbe\xa2\xfb=\xbc\xf7\x9ee=\xf9U\xd3=\x1f\x02h\xbd~r8\xbel\xdc\xfd\xba\xe7\xa5\x83=!\xfc\xbb=rEK=\xdbVx\xbd\xd3~\xa6\xbc\xbbR\xd3\xbcD\x98!<\x09\x9af<\xf9\xbbU\xbd\x8f~/\xbd|\x9c\x1b\xbd\xa7\xcd!\xbd,\x1f\xb0<\x8aL\x1f=\xf8\xc94\xbc\xce\xa2*\xbd\x16\xaa\x03\xbc[{\x00=\x7f\xb6$\xbd^\xde\x90=T\xd3\x88\xbcL\xf1\x84=F\xaf\xfc<\xe3^p=\xda\xeb\xb1:1BI\xbd\xd0NG=T<\x10=R\xe6\xd1\xbc\x9f2\x07\xbc\x9e7\x18\xbb\xcb:g\xbd\x8bC\xbd\xbde\x16\xe7< ]\xd0<\x8d\xfc;\xbd\x94\xfe\xa0\xbd;\xf7\x02\xbd d\xad;7Do\xbc\x0d\xac\xa7\xbc\xd6:8\xbc\x0d\xf8N\xbd\xe5t\xe0:\xeb\xfb`=H\x84\x88=\x1c5\x10>\x88\xad =D\xc3\x05\xba\xe6C\x13\xbcC\xa2W=\xaa\xa8\xc4\xbc\xc2\x96\x04\xbd\x84\xec\xe0<\xe8%\x97\xbd\xbf_\x0b\xbeN\xe76\xbe\xdd\xd0\x10\xbe\xaa\xbc\xa7\xbd\xb0^\xe7\xbd\xd6kv\xbe\xadX\x84\xbeB\xbc\xe1\xbd\xfdy\xeb\xbd\x8bz\x0f>JlP>\xd7T2>\xda\x02\x16>\x82\x0d\x01>0\xf3N<\x07\xd3(>\xd4>\x97=rz\x10>\x82\x94{=%Sk\xbc\xb6(b<\xbe\x0a\x0a=\xba\x05\xbcw\x14\xce;\x94\xdcH\xbd\xe2\xdc\x80\xbc)\xbb\xfc<\x1bAR<'I\xb1\xbd\xe2\xf4\x9f=s2\xb1=l\x8e\xbb\xbd\x91\xd3\x84\xbd\xf8\x11\x0f=d\x08\xe7<7N\xc1\xbc\xa3$\xf1=~\x03\xd2\xe8y\xe8\xbd\xb3I\xa1:\xe4\xf1\xf1\xbc.\x1b\xd2\xbc\xd2/\xc7\xbd\xdc\xbcX<4!\x1a=\xcbDj\xbd\xf7W\xd7=\x8e\xe5>\xbdLS\x97\xc8=\xd5\xbc\x07>\x5c\xbdR\xd0\xd6=\xaa\x87\x0b\xbd\xedE\x08>\xe3s\xc1\xe9\x978\xbds\xe8\x82\x9a\xd5=k\xe6\xb7\xbd\x96\x8f4=\xa0m\xdf\x9a0\x83\xbc\xc0u\x0e>\x0b\xa4N=\xfa\xa3\xc3\xbd\xd0\xd5\x9d=\x1a\x8da\xbc\xd4}\x98\xbd8\xdf'\xbd\x0b\xda\x1c\xbd[\xc6x\xbb\x07\x0c\xf3\xbc\x95\xeb\x0f\xbdq\x0f\xe1\xbb\x9a\xc5\xa4\xbd\xd2\xd43\xbd5\x1b\xb4=\x0e\xdb\xaa=l}\xe7\xbbZ\x8c\x86\xbd6\x18\xfb\xbcf-\xcb:\xe2\xea6\xbe\xc6\xedB=\x5c\x9d@=\xf2-\xea=\x9d|\xf4\xbd\xb0\xb8\x16>@\x84Y=\x88\xcc\xd8\x1e\xce@\xbc\xd9\xba\xbc=\xae\xce\x8a\xbc\xb0\xd5]=\xe5\xf1S\xbb\xe6\xf2\x93\xbd\xc9\xe5\xbe=j\x0a\xc6\xbd\xc3\x08\x86\xbdR\xd8\xb3\xbd@hT\x05\x94\xa0=\xb6\x92\x8e<}1\xba=\xb0%8>\xf0\xa9J\xbeQ,C\xbd\x0e\x7f\x05>~.\x8c\xbdS\xd9\xc3\xbcMZ+>u\xcaX=\xa4p\x10=\xf30\xe5;\xa6\xa3\x90\xbd\xcd[\x5c\xbe\xc2\x01\x95\xbd!N\xe9\xbd|U\xd5\xbc\x89\xdc\x80;U\x223=\xa8[\xb7\xbdc\xa2V\xbc\x18\xf1W\xbc\x9e\xb1\x09\xbd\x0c\xdc\xc5=\xa7|@\xbb\xe5\x9c\xbd!k#=40L=\xa6\x8d\xac=\xdd\x22\x9d\xbd(Mu=\xe27\xd8\xbd\xd7\x07\xfc\xbd\xc6s\xda<\xeb\xe2>=\xc41\x1e\xbeEE\xb2\xbd.\x0c\xa3\xbd\xaa(x\xbdJ\xdb\xe9\xbb\x172:=\x91>\xb8\xbd\x831\xab=\x06\x16\x1b=\xd6?\xb9<\xd6|\xde=\xea\xc2D=o\x16,=\x1f\x17\xaa\xbcYD\x5c>\xa5\xe5d\xd6=U\x22\xa0=\x0af\x0d>\x8dc\x98=S\xc0\xec=\xb3Pv=\xde\xe4*=\xfd'\x0a>r\x09\xfb9\x05:&<\x8e\xf8\x8e<\xeb\xe6\xda<\xe6/\x99=\xe5_\xec\xbc4\x11c\xbd\xc1W\x88\xbd\x05\xc8I\xbcDR\x05\xbes.\xf2\xd2\xbc\xbbeL\xbd\xdb\xaa\x9a\xbc\xd0~o<\xb9\x92T=\xc5\x91\xfb\xbd\xb1\xf8\x9a;\xd6'\xcd\xbd\xc0\xfeZ\xbd\xe3\xf8{\xbcL\xac\x04\xbc\xf2J|< [\x0a\xbc\xc2\x05\xf5=\x90\x96\xd9\xb2\xd5<\xe8M\xe3\xbc\xbe\xa0\x0e>\x92\xbb\x8c=\xd6\xe6\xb6=\x13;\xc6<4\xf2Q\xbc\xf0#\xc8=gnG\xbd\x8d.\xa1\xbd\x11fk=\x09Q\x88\xbd\xd3\x95\xc7=\xd7\xf0n=\x00pI=\xa8\xf0w\xbd?\xc3\xfd<55\x05=z\xde*=h\xf3a<\x82\x16\xf2<}\xef\xfb\xbcZ&\xc8\xbb\xc4\xbao\xbdYY\x1e:\x02\xdc\xcc<\x1a\x9d\xc4\xbd\xf7<\xfa\xbdmQ\x8e\xbd/0\xb8\xb9\xa5vG<\x8f \xbd\x0d\x7f\x08\xbe\xaaR\xcd+\xd4\xb8\xbd\xe2\xa2\xba=\x8bO6=\x9b\x05\x03=Z\xe0\x01<\x87j\xb1\xbdx\xac\x0d;\x0d}\xa5\xbd\x04X\x95=\xd1zA=Q\xd7\xa6\xbb\xf5\xb3\x9f\xbdK,\xb8\xbc\x0bg\xc4<7\xf3\x93\xbd\x0a\xc7\xc5\xbc\xe1\x9c\xb6\xbd7\xfet\xbd\xf1\x90\x87=&y\xcc=\xc2^\x9e=\xa2u\xaa=\xdf\xe0\xc6<.\x00Z=:k\xae:\xe7,0=\xde\x8e\x07\xba\xcf\xc9K=x\xf3\xaa\xbd\xda\x80\xa4\xbd8\x91c\xbc\xf8\x03\x17\xbd8\x06\xbf\xbd+\xd3\xad\xbdk\xaan\xbc;Rp=\xa3F\xc6=e\x02N<~N\xb1\xbc\xd0\xbe\xc8;\xf1|\x90=\x1f\x81\x0b=\x89\xc8\xd8=\xeem\xa0=hWN\xbd)ok\xbc\x9f\x13b\xbd7c\x9a=Y\x97\xc4=-=\x9c;X\xb9u=\x95\xf8\x9d=0\x03\x09=\x14\xbf\xc1=V\xfb\xaf=\x82\xcfC\xbd\xce\x8b\xbd\xbc\x10\xa1\x93=`\x82J=\x0bV\xce\xce=?\xe2\xdf\xbc\xcb\xc9\xc4=\xc5\x16\xb0=A\x18\x94\xbdQ\xa4\xbc\xba}\x98\xbf\xbd\xf9\xec\xe5=y\xb7\xbc=\x91[\xf0=\x01\xaa\x87=R=\x84\xbd\xdd>\xf9\xbc$r\x99\xbd\xbbY$=]wk=M\xa2)<:N\xbb=\x90\x06`=\x1b\xc9\xc3=%\x80\x10\xbd\xa8\xec`\xc9\x08\xe4<\x1b^\xb8=z\x14\xdc<\xa2\xec5>\xa9[\x0b>\xad\xc6\x0e\xbe\xec\xe2\xfb\xbd^8\x08>\x16C\x89=\xd9\xad\xfe\xbcF\xe5W\xbe\x9eh\xfb\xbd \x17u\xbd\xae`\x04\xbd\x10\xd0<\xbc-\x06\x06=pLd\xbc\x9a{t\xbd\xe2\x1c\xbf=\xa3e\xbb=\x8d\x5c\x8d=\xc77\x99\xbd\xc0\xaa\xd7<\xa3j\xbf\xbc\xc8)r<\x09k\xb2=^\x95\x8c=Q\xd7\xd4=\x1d\xc4\x85<\x98\x95P=v\x19\x5c=\xda\xf1?;MsM=>\xd6\x15=\x9bef=\xc0\x09\xe1<\xaeX\x04<\xfa\x10<\xbbT\x0e\xc7;B\x85R\xa5\x06\xbdh*\xa1r='J\xfd<\x7fs\xa7\x01G\xcc<\xff\x9f\x90< \xf7\xa0\xbdV]\x87\xbd\xbc\xd3\xa5=\xfd\xa5g\xbc\x19rc;\xa7\xaf?<@1\x9e\xbc\x0d\xf7\xc1=\xc6\xb8,=D\xcf\x97\xdf-=\xecKU<2\xb3\xca<\x13_e\xbcw\xa1V<\x5c\xb5\xda;pB\xac;\x7f42<\x11\xa5\x0c<\xdf\x14\x98\xba\x09bc<\xca\xedC\xbc\x8fE\x13=bp\xa9\xbb\x8b\xf0\xd09\xc4\xc3\xb8;\x01L\xc9<9\xb5\xdc;\x7f\xa4\xb1\xbc\xc9\x14\xe8\xba\x00^\xdb;<\xcd*<\xc9@\x95<\x0d\xc8Z\x14`g=zN\x03=\xaa~\x80\xbc\xbe\xac\xda;/W\xad=&I'\xbd\xd2\xf94\xbc'\xa7\x837\xe3\xb8z\xbbu\xf2\x7f=n5\xed\xbc\xd8\xd9\xbb\xbb\xd1q\x03=\xb2\xcf\x96=\x9cC'=\xdeo\xf4<\xd2?p<\xa0<\x82=m8\xb9=\xdb\xca\xbc<\x06\xd6T=\x95\x05\xa5=\xb1\x5c\x8e\xbd\xa2\xb0\xbf<\x8cF\xaf\xbc\xaaY\xa3<\x9e\xaa\xd2=\x91S\xa5=\x07G\xe3=\x86\xa4\x18=&\xd6{=\x0e&\x94=\xf22A8\x90\x09d=\xdb/P=\x0c\xfcz=\xb1\x96\x15=\x91\xff\xd4\xbb\xd6\xf5\x9f\xbc6\xad\x1d\xbd\xd7\x8c6<\xeb]\xae<3x\xbc\xbc\xe3\x0fs\xbd\xc9\xe6/\xbd\xcb\xbf\x1b\xbe\xb8\xb1l\xbd7\x09\xaa<\x12\xd4v<\xbc\x91\xa5=\x82\xf2\xc7=H\x0du=L^\x08;\xb1\xb7\xfd;)\xfa\x07<\xb6\xd6#=\xbe\xf4\xc3<\xfd\xbd\xe6\xbbnP\xde\xba&=\x00\xffH=\x92\xdd1=b\x82\xbd<>\xe9L<\xf2\x0e5=\xa32#=>\x95\xa9<\x80\xa7H\xbb\xbcx\xa5;\x12\xa7\x8d\xb9?\xfeO<\xdd\xd28\xcd<\xb6\x82\xfa\xbcq\xa7\x84\xbc\x0aX\xee<\xd2C\x85=\xc7\x01\x18=\x0a\xcf\x80\xbdn1~\xbd\xa5\xbd\x8c<\xcd\x13c=\x98\x16\x0b=W\xc5.\xbdmFs\xbdJ\xcf\x19=\xafgF=9\xab\xec<\xf1\x8e\xa0\xbchNX\xbc\x9b<+=w\x1e-=V\x81\xab<\x9e\x9a\x90;\x88%\xab-\x81\xa3=\x12Gq=\xfc7&\xbd\xd3h\x8a\xbdWP\xa6\xbcU\x021\xbe\xa3\x0c\xce\xbds\x22+=\x9e\xcb\x9c<\xeb\x1b\xb9<\x0c7e=\xdbF\x96;\x83l\xd8\xbc(|\x11\xbe.\x0cL<\x10\xb3M\xbd\xc9B\xec\xbc\x9c\xa6:<\x01\xd5\xfc:\xc2&_=\xd7\xf8\x92\xbc[VJ<\xabC\xad\xbc\xd8\x01!\xbd\xd8\xde\x97\xbc\x04\xd2\xda\xbc\xeb\xa1\x8b\xbb\xc8\xc7c<\x02\xb4\xd1\xbb-\x13\x9c;^\x9c\xbb2\xbb\xb6\x0a\x90;\x09\x1d;<<\x00\x03\xbc\xc0\xf4<\xbc/\x01\xc9\xbc&N\x91\xbbW\xbb\xf2:O\xe7*\xbd*\x85n\xbc4x\xa8\xbc<\xb8\xefr\xbc1\x9ef\xbc\xd2\xe4\x93\xbc\x936\x07\xbd\xa6\xb2\x9e\xbd\x176x\xbdg\x83\x04\xbc\xb1\x8a\x8c\xbc,\xd5\xc1\xbc\xc7\xa9\xd2\xbc\xb8B\xff\xbc%\x91#\xbc\x91\xfd\xf6<\xc6\x0f\xb1<\x95\x0aq\xbb\xc2c);\x83k\x12\xbc\xd7\xacc=\x06nA=\xa3\x87\xe9;\x07\xa4'<\xf1Gj\xbbNY\x5c=\x92\xa9@=\x97\xa7\xc6<\xe8\xfc\xe1<6\x1c`\xbd\xef\x12a\xbd\xcf\xd7@\xbd\xac\x94\xae\xbc\x85\xcc\xa9\xbc\xdb%\x18\xbd)\xb4D\xbcaY\x82\xbd3\x94]\xbd\xae\xe0,\xbc\xb5\x7f$\x0a\xfc\xbc\xf3\xf3\xde\xbcm\x06\xa5\xbcWy\xc9\xbc\xc0\xe7\x0a\xbd\xe7S\xf8\xbc`p\xab\xbc\x00\xc9\xb6\xbcv\x01g=\x1b\xb0\x8a<\xba\x89\xa5\xbd!\xa3g\xbc\xf0Xg<\x83\xb3i=\x0a\xb0\x98<\xd1\xbe\xd0\xbc\x86\x5c+\xbb\xaf\xa0\xf6\xbc\xb2\x14{;\xdf&\xf0\xbc\x03\xf5n\xbd\xc4\xd6?\xbd\x1d \x85\xbcB\xe59=Q(m<\xbf\x94X\xbd\xff\x16A\xbb\x17\xaa\x999\xb2\xe8\xd7\x9d\x80_<\xf9%&=\x1c\x0f\x9b7]E\xc3<\x0cT\x98=\xdd\x01{\xbdy\xbd:\xbc\x08\xe9\x06\xbb\xc6\xac\x91;-8\x02==A\x12\xbd\xdbz\x7f\xbbtS6\xbc/\x9eK\xbc\xa9\x84\x12<\x88\xe8\xb5;TL\xbb<\xda\xae\xe5\xbb\xd1\xf1\x87\xbc\x22U\xe5\xbb\x0a \xcb:\xf6\x08\xa1\xbav\xbd^t\xa2\xbb\x1c70\xbc\xd3\xc3\x0f\xbcX\xe4\x8d\xbc\xa4Al\xbc\xe5\x12\xcc\xbbzO\xd1\xbb\x94\xc8\xc5\xbcMvF\xbc\xaf\x9f\x05\xbc\xd8;\xdb\xbbEWP<\xbd4\x8e\xbc\xaa\xdd\x85\xbc\xc8}\xae\xbc\x8c\x0c\x8d\xbbC\x85\xbf<\x039\x01=\xf3\xa0\x1e;\x88&A<1\xc8_S\xd4\xbbh\xd8F;E\xba\x14=\x02~\x90;o\x07\xe0<\x93\x06\xad\xbd\xb1\x7f\xa3<<\x1f)\xbd\x98\x18\xaa\xbcMQ\x15=\xc6\xdb\x865\x88<\x89;O<\x09\x90\x80<\xfb\x17\x0a\xbc\xdd\x86q<\x88Ts\x1c\x9b<\x02\x0b\x07\xbb\xc0\xe4\xde;\xb3\xc0\xa9\xbbx\xb5\xa8;u\x97L\xbaF\x87A\xbc\xad>-\xbb\xf5r\x8f+\xbcD{V\xbc\xa7\x16f;\x90\x0f\xa0<\xef\xfc>\xbbE\x9e\xd8\xba\xb8\x09\x11=\xadP\xdd<\x0f\xd5\xc5\xbb\xd4\xf3\x18\xbc\x06;I\xbc\xee\x0bx\xbcW\xd7\xe5;\xd5\xaf\x10\xbc0\xe9s\xbcs$\x1c\xbc\xcf\x5cT\xbcw\x0bL<\x8d\xc4\xce<,/\x1b=\xf8\x8e\x0c<\x1c\xbbh\xbc\xa8D\x04=\xab[n=\x5c\xd2\x9f=\x1e\x13\x05=,H\xe7;\x91\xc9\xae=\xbc\xb1\x07\xc2\xbbw \x14\xbcN$\x84\xbc1\xafg 1e-3 { 61 | t.Errorf("case %d: element %d: expected %f but got %f", i, j, x, a) 62 | break 63 | } 64 | } 65 | } 66 | } 67 | 68 | func TestConvSmall(t *testing.T) { 69 | c := &Conv{ 70 | OutDepth: 1, 71 | InDepth: 1, 72 | KernelSize: 3, 73 | Stride: 1, 74 | Weights: []float32{ 75 | -0.0139077, 0.2473437, 0.1155401, -0.0495334, -0.0390755, 0.3136772, -0.2264613, -0.2643562, 0.2818739, 76 | }, 77 | } 78 | in := NewTensor(4, 4, 1) 79 | copy(in.Data, []float32{ 80 | 0.6017266, -2.0017490, 0.2513519, 0.5615258, 0.1457711, 0.8694047, 0.3224155, 0.6138996, -0.2135582, 0.8488631, 0.8838763, 0.8060158, 0.7435452, 0.0199154, 1.2212700, -0.4557744, 81 | }) 82 | 83 | expected := NewTensor(2, 2, 1) 84 | copy(expected.Data, []float32{ 85 | -0.3414040, 0.0930938, 0.6755218, -0.1410014, 86 | }) 87 | actual := c.Apply(in) 88 | if expected.Height != actual.Height || expected.Width != actual.Width || 89 | expected.Depth != actual.Depth { 90 | t.Fatal("incorrect output shape") 91 | } 92 | for i, x := range expected.Data { 93 | a := actual.Data[i] 94 | if math.Abs(float64(x-a)) > 1e-4 { 95 | t.Errorf("bad value at %d: expected %f but got %f", i, x, a) 96 | } 97 | } 98 | } 99 | 100 | func TestConvLarge(t *testing.T) { 101 | c := &Conv{ 102 | OutDepth: 4, 103 | InDepth: 5, 104 | KernelSize: 3, 105 | Stride: 2, 106 | Weights: []float32{ 107 | 0.0048753, 0.1400915, 0.0073967, -0.1249796, 0.0990686, 0.0990699, -0.1468822, 0.0789842, 0.1155828, 0.1138192, 0.1149694, 0.0828862, 0.0791527, -0.0879772, 0.0949459, 0.0085811, -0.0385565, 0.0788283, -0.0003407, -0.1072131, -0.1233412, 0.0257104, -0.1403926, -0.1191188, -0.0455806, 0.0049993, 0.1048242, 0.1100614, 0.1131917, -0.0771568, -0.0771224, -0.0381215, 0.0982061, -0.0779766, 0.0589884, 0.1120752, -0.0661897, -0.1146785, 0.0792114, 0.0506790, -0.0886449, 0.1148331, -0.1356669, 0.1354263, 0.0253897, -0.0787591, -0.0843297, -0.1268204, 0.0234138, 0.0515354, 0.0450980, -0.0963772, -0.0022710, 0.1145745, -0.0908253, 0.0116528, -0.0321431, -0.0787754, 0.0449269, -0.0024932, -0.0477832, 0.1000197, -0.0555633, -0.0737215, 0.1214098, 0.0158767, -0.0083296, -0.0655478, -0.1114864, 0.0336006, -0.0806804, -0.1058603, 0.0962375, -0.1000008, 0.0794527, 0.1352933, -0.0220469, 0.0106698, 0.0303797, 0.1366412, 0.1070330, 0.0979977, -0.1086167, 0.0448620, 0.1062138, -0.1036574, -0.1454835, -0.0667073, -0.1310376, 0.1209997, 0.0158482, -0.0984464, -0.1211329, -0.1013677, 0.0562877, 0.0543369, -0.0502296, -0.1218463, -0.0558607, -0.0743429, 0.0363706, -0.1145388, 0.1290013, -0.1090252, -0.1209797, 0.0011886, 0.0885055, -0.0460116, 0.0160562, -0.1270224, -0.0416963, -0.1314727, -0.0775326, 0.0971618, 0.0775764, -0.1064220, -0.0264593, -0.0582033, -0.0048981, -0.0312180, -0.0059659, 0.1070529, 0.0789812, -0.0520511, -0.1221963, 0.1213690, 0.1360595, 0.1326346, -0.0381068, 0.0179752, 0.0817789, -0.1232072, -0.0619079, -0.0240207, -0.0806419, 0.0344195, -0.1122220, 0.0418564, 0.0423259, -0.0128046, 0.0986949, 0.1288407, 0.0492961, -0.0569584, 0.0703266, 0.1369467, 0.0336334, 0.0923058, -0.0318223, 0.0872521, 0.1001197, 0.0224605, -0.0315309, -0.0356362, -0.1288342, -0.0696220, 0.0201382, -0.0332683, -0.0427386, 0.0922373, -0.1347726, 0.0486960, 0.0391940, -0.0881831, -0.0382566, 0.0862630, 0.0826437, 0.0556344, -0.0267371, -0.0464863, -0.0745129, 0.1025151, 0.0037476, -0.0211752, -0.0202064, -0.0645246, 0.1358609, 0.0056850, -0.1205265, 0.0110337, 108 | }, 109 | } 110 | in := NewTensor(6, 7, 5) 111 | copy(in.Data, []float32{ 112 | 1.4725239, -0.0588857, -1.0511501, -0.8535296, -0.1226160, -2.0331368, 0.2033319, 0.1808803, 0.7840600, -0.7841823, -0.8234431, 0.2406522, 1.0564671, -0.1470382, 0.0237237, 1.3913455, -0.1125739, 0.0310295, -0.5059676, 0.6760480, -0.1207311, -2.0347879, 3.4141071, -0.8715456, -1.0442810, 1.0539747, 1.0902373, 1.0664262, 0.9359902, -1.5147089, 0.3542635, 0.6339781, -0.7948303, 1.0690100, -0.6837966, -0.9467762, -1.4691463, 0.0043499, 0.8494170, -0.4182343, -0.4738699, -1.9620832, 0.4696709, -1.2688903, -1.9464079, 0.6544799, 0.0915866, 0.3698487, -0.8195642, 0.7171361, -0.0951290, 0.9562339, -0.7151088, 0.4587850, 0.4896681, -1.2527742, -0.9197360, 0.1529317, 2.2528563, -0.7698087, 1.1912090, 0.0766383, -0.7134541, -1.5935307, 0.4763172, -0.1287121, -1.3450482, -0.7140349, -0.4916139, -0.3380890, 1.7034016, 0.5017734, 0.1614027, 1.0646375, 0.0937267, 1.3851739, -1.3037589, -2.0523577, -1.2393095, 0.2878762, -1.5465885, -0.1017815, 0.3366855, -1.0065933, -0.7175980, 1.2542903, -1.6045935, -0.0836585, 1.0936671, 0.8313980, -0.0118547, 0.0760615, 0.1303777, -0.0700363, 1.1094716, 0.5530361, 1.0745329, 0.2073360, -1.0503830, -0.6106652, 0.4007267, -1.6435132, 2.0033343, -0.3548864, 1.4754119, 0.0968325, 0.5541025, 1.5374726, -0.4410564, 1.1347706, -0.5501668, -1.5077022, 0.2906386, 0.6215913, 0.7636902, -0.4066638, 1.3752940, 1.0289661, -1.0468789, -0.6141735, 1.4556367, 1.7281872, -0.2950397, 0.5613296, -0.3341977, -0.0380274, 1.0071597, -0.0684543, 0.5041260, 1.5872769, -1.7071162, -0.6374162, 0.0035660, -0.6789651, 0.4903325, -1.6518031, -1.2705370, -0.4742535, -1.0952357, -0.2847594, -0.2438605, 1.6654413, 0.7398302, 0.8784127, -0.3128213, -1.1850022, 1.1968794, -0.2819678, -2.2011805, -1.5756040, 0.4870483, 0.9218722, 0.7067528, -0.8398420, -0.7508997, 0.6922061, 0.6504284, 1.3526452, -2.1302404, 0.3583336, -0.5170502, 1.2601988, -0.5421015, -0.1550326, 0.5573612, 1.2214118, 1.3208184, -0.2973142, -0.9441420, -0.5450425, -0.4168938, 1.7370363, -0.5848581, -1.8912510, -1.0972722, -0.7830400, -0.0830904, -1.9210624, 0.1816452, -2.3774185, 0.0361986, 0.4834803, 0.7642338, -0.5535203, -0.2283998, 0.3258312, -0.3148893, -0.9916397, -0.5933701, -0.9147130, 1.3224818, 0.6048592, -1.7784269, 1.8416162, -0.2276750, 0.9857835, 0.8796785, 0.0706747, -1.5941858, -0.1933369, -1.6960143, 0.8870867, -0.7864155, -1.5835406, 1.4838743, -1.7379398, -1.5213418, 0.6810994, 0.1092481, -0.5279986, 113 | }) 114 | 115 | expected := NewTensor(2, 3, 4) 116 | copy(expected.Data, []float32{ 117 | -0.5791513, -0.4299181, -0.2282083, 0.9621052, 0.0896676, 0.0673461, 0.0544239, -0.9143549, -0.0084198, 0.3577323, -0.4168170, -0.4523573, -0.2743906, -0.7193667, 0.8375948, 0.6181691, 0.6491780, -1.1584780, -0.3041711, -0.1271321, -0.8035574, -0.0559687, 0.0250076, 0.0924966, 118 | }) 119 | actual := c.Apply(in) 120 | if expected.Height != actual.Height || expected.Width != actual.Width || 121 | expected.Depth != actual.Depth { 122 | t.Fatal("incorrect output shape") 123 | } 124 | for i, x := range expected.Data { 125 | a := actual.Data[i] 126 | if math.Abs(float64(x-a)) > 1e-4 { 127 | t.Errorf("bad value at %d: expected %f but got %f", i, x, a) 128 | } 129 | } 130 | } 131 | 132 | func TestSpatialConvLarge(t *testing.T) { 133 | c := &SpatialConv{ 134 | Depth: 5, 135 | KernelSize: 3, 136 | Stride: 2, 137 | Weights: []float32{ 138 | -0.1911255, 0.0742732, 0.3244561, 0.2112495, -0.0780912, 0.2744516, 0.1766910, -0.3102708, -0.0744405, 0.2958856, 0.2830285, -0.1699406, -0.0270629, -0.0285498, -0.0274338, 0.0813831, 0.2177308, -0.1350431, 0.2850550, -0.2722039, -0.3297556, -0.1751623, 0.1043071, 0.0395478, 0.0902030, -0.1793895, 0.0326312, 0.2715265, 0.0960180, 0.2904015, -0.0761158, -0.0874449, -0.3238208, 0.2678641, 0.1489393, -0.2541793, -0.0893213, 0.0757578, -0.0567662, 0.1434719, 0.2722551, -0.0703747, 0.0763021, -0.0925830, -0.0667646, 139 | }, 140 | } 141 | in := NewTensor(6, 7, 5) 142 | copy(in.Data, []float32{ 143 | 1.1409054, 1.5649923, -0.1948104, 1.3556889, 0.6847361, 0.6819699, 0.9286522, -0.3399995, -1.8085134, 1.3648162, -0.9051613, 1.0117084, -0.0365825, -0.9432146, 1.2921587, 0.4738652, -1.0386302, 0.8244314, -1.4762123, -0.4245814, 0.3236750, 1.4356518, -0.9547286, -2.9507315, 0.7800845, -0.4020369, 0.3173794, 0.6547441, 1.6436012, -0.4779457, 0.3736600, 0.8305290, -0.1762423, 0.3623725, -1.8178300, -0.1546911, -1.0005274, -1.4735147, 0.3573847, 0.2560021, -0.3821970, -0.3495889, 0.1900006, -2.2389088, -0.9038950, -0.3378337, 0.7981479, 0.9964933, 0.4515326, 0.0229507, -1.6533415, 1.4932644, 1.0282696, -1.4741617, 0.6400365, 0.4244200, 1.4641294, 0.5905561, 0.8236988, -0.6341587, -0.0929454, -1.0886556, 0.2774341, 0.7295105, 0.6951755, 0.1826468, -0.8445777, -0.4030475, -0.7516201, -0.0158012, 0.6809133, 0.9969447, 0.3797884, 1.6414919, 1.8294731, -1.2760342, 1.0722744, -0.2493834, -0.7005353, -1.2985059, 0.2131938, 0.0640154, 0.1608759, 0.3719594, 0.0770728, -0.3376203, -1.2530504, 1.2432044, 0.1290189, -1.1323935, 1.6972812, -1.4169976, 0.6200145, 0.4643981, 0.0858503, -0.3666513, 0.0400702, 1.2037058, 1.3236476, 0.7106019, -0.7832468, 0.4510391, -0.5057965, -0.6557346, -0.3792492, 0.5938662, -0.4725557, -0.0708619, -0.8851832, 1.6350344, 0.9954666, 0.8817711, 0.6591703, 0.1176392, -1.8714072, 1.8101006, -0.7794924, -0.3896322, -1.9108096, -1.0304877, -0.7368749, -0.1825271, -0.4482952, 0.8978229, 0.7722555, 1.8434216, 2.1341970, -0.1730643, -0.8144822, 0.0329581, 0.0599070, -1.4132864, 0.1444537, 2.6583509, 0.2056222, -0.9210598, -1.1162794, -0.6732786, 1.8129425, -1.8235412, -0.8088601, 0.2537483, -1.4412017, 0.8429257, -3.2031205, 0.1626158, 0.4738091, 2.5309155, -1.6109036, 1.0201719, 1.2063761, 0.3259659, 0.4106356, -1.7807648, 0.8569935, -0.9382124, 0.0530952, -1.4398686, -0.4192069, -0.2597384, 0.1468918, 0.7345250, -0.4082018, 0.2757454, 0.4615452, -1.1347867, 1.6473992, -1.5824908, 1.7915055, 0.8639272, -0.1940993, 0.2653693, 0.6212785, 0.1528680, -0.2220691, 1.2923363, 0.3638144, -2.5262332, -0.4221732, -0.1805679, 1.2553589, -1.0108938, 1.2761886, -1.5684910, 0.0717819, -1.6453557, 0.8571983, 0.1580963, 0.1615494, -1.2693247, 0.8411252, -0.3765129, -1.0511090, 1.0487553, -0.2171007, 0.7409956, -1.2929105, 0.1197028, 0.5129985, 0.2798404, 0.8487886, 1.0983537, -1.5116918, 0.0438500, 0.0298962, 0.4987317, 0.1934926, 1.5776329, -1.1949311, -0.3078639, 144 | }) 145 | 146 | expected := NewTensor(2, 3, 5) 147 | copy(expected.Data, []float32{ 148 | -0.0562831, 0.8750818, 0.4506569, 0.1837227, 0.0125734, 0.5035170, -0.4192492, -0.1522258, -1.4261292, 0.1352992, 0.6485183, 0.2205980, -0.6592113, 0.0669617, 0.0627794, 0.1055823, 0.6763800, -0.3817674, 1.6008079, -0.8646283, 1.9240477, -0.1882490, -0.2003380, -0.0306023, 0.0211626, -0.0813035, -0.0887315, 0.3019635, -0.3929369, 0.2267936, 149 | }) 150 | actual := c.Apply(in) 151 | if expected.Height != actual.Height || expected.Width != actual.Width || 152 | expected.Depth != actual.Depth { 153 | t.Fatal("incorrect output shape") 154 | } 155 | for i, x := range expected.Data { 156 | a := actual.Data[i] 157 | if math.Abs(float64(x-a)) > 1e-4 { 158 | t.Errorf("bad value at %d: expected %f but got %f", i, x, a) 159 | } 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /polish/nn/deconv.go: -------------------------------------------------------------------------------- 1 | package nn 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | ) 7 | 8 | // Deconv is a 2D transposed convolution operator. 9 | // 10 | // It contains weights of the shape: 11 | // 12 | // [in_depth x out_depth x kernel_size x kernel_size] 13 | // 14 | type Deconv struct { 15 | OutDepth int 16 | InDepth int 17 | KernelSize int 18 | Stride int 19 | Weights []float32 20 | } 21 | 22 | // Apply applies the transposed convolution to a Tensor. 23 | // 24 | // The resulting Tensor's size is determined by 25 | // DeconvOutputSize(). 26 | func (d *Deconv) Apply(t *Tensor) *Tensor { 27 | if t.Depth != d.InDepth { 28 | panic("input Tensor does not have the correct number of channels") 29 | } 30 | outH, outW := DeconvOutputSize(t.Height, t.Width, d.KernelSize, d.Stride) 31 | features := d.transposedFeatures() 32 | 33 | out := NewTensor(outH, outW, d.OutDepth) 34 | lock := sync.Mutex{} 35 | 36 | d.iteratePatches(t, func(tmp *Tensor, x, y int, data []float32) { 37 | for i, scale := range data { 38 | feature := features[i] 39 | if i == 0 { 40 | for j, x := range feature.Data { 41 | tmp.Data[j] = x * scale 42 | } 43 | } else { 44 | for j, x := range feature.Data { 45 | tmp.Data[j] += x * scale 46 | } 47 | } 48 | } 49 | outX := x * d.Stride 50 | outY := y * d.Stride 51 | 52 | // Adding to the global image must be synchronized, 53 | // but it takes up so little of the computation time 54 | // that it doesn't seem to matter. 55 | lock.Lock() 56 | addPatch(out, tmp, outX, outY) 57 | lock.Unlock() 58 | }) 59 | 60 | return out 61 | } 62 | 63 | func (d *Deconv) transposedFeatures() []*Tensor { 64 | featureStride := d.KernelSize * d.KernelSize * d.OutDepth 65 | var featureIdx int 66 | var result []*Tensor 67 | for i := 0; i < d.InDepth; i++ { 68 | feature := d.Weights[featureIdx : featureIdx+featureStride] 69 | featureIdx += featureStride 70 | 71 | tensor := NewTensor(d.KernelSize, d.KernelSize, d.OutDepth) 72 | var idx int 73 | for y := 0; y < tensor.Height; y++ { 74 | for x := 0; x < tensor.Width; x++ { 75 | for z := 0; z < tensor.Depth; z++ { 76 | tensor.Data[idx] = feature[(y+z*d.KernelSize)*d.KernelSize+x] 77 | idx++ 78 | } 79 | } 80 | } 81 | result = append(result, tensor) 82 | } 83 | return result 84 | } 85 | 86 | func (d *Deconv) iteratePatches(t *Tensor, f func(tmp *Tensor, x, y int, data []float32)) { 87 | numGos := runtime.GOMAXPROCS(0) 88 | var wg sync.WaitGroup 89 | for i := 0; i < numGos; i++ { 90 | wg.Add(1) 91 | go func(goIdx int) { 92 | defer wg.Done() 93 | tmp := NewTensor(d.KernelSize, d.KernelSize, d.OutDepth) 94 | idx := goIdx * t.Width * t.Depth 95 | for y := goIdx; y < t.Height; y += numGos { 96 | for x := 0; x < t.Width; x++ { 97 | f(tmp, x, y, t.Data[idx:idx+t.Depth]) 98 | idx += t.Depth 99 | } 100 | idx += (numGos - 1) * t.Width * t.Depth 101 | } 102 | }(i) 103 | } 104 | wg.Wait() 105 | } 106 | 107 | func addPatch(dst, src *Tensor, outX, outY int) { 108 | var srcIdx int 109 | dstIdx := (outX + outY*dst.Width) * dst.Depth 110 | dstStride := dst.Width*dst.Depth - src.Width*src.Depth 111 | for y := 0; y < src.Height; y++ { 112 | for x := 0; x < src.Width; x++ { 113 | for z := 0; z < src.Depth; z++ { 114 | dst.Data[dstIdx] += src.Data[srcIdx] 115 | dstIdx++ 116 | srcIdx++ 117 | } 118 | } 119 | dstIdx += dstStride 120 | } 121 | } 122 | 123 | // DeconvOutputSize gets the output dimensions from a 124 | // transposed convolution operation. 125 | func DeconvOutputSize(height, width, kernelSize, stride int) (heightOut, widthOut int) { 126 | if height > 0 { 127 | heightOut = (height-1)*stride + kernelSize 128 | } 129 | if width > 0 { 130 | widthOut = (width-1)*stride + kernelSize 131 | } 132 | return 133 | } 134 | -------------------------------------------------------------------------------- /polish/nn/deconv_test.go: -------------------------------------------------------------------------------- 1 | package nn 2 | 3 | import ( 4 | "math" 5 | "runtime" 6 | "testing" 7 | ) 8 | 9 | func TestDeconvLarge(t *testing.T) { 10 | runTest := func(t *testing.T) { 11 | c := &Deconv{ 12 | OutDepth: 4, 13 | InDepth: 5, 14 | KernelSize: 3, 15 | Stride: 2, 16 | Weights: []float32{ 17 | -0.1219822, 0.0415145, -0.0003108, 0.0420820, -0.0970473, -0.1313528, 0.1042072, -0.0170022, -0.0938565, -0.0299630, 0.0318583, 0.0244425, 0.0210479, 0.0111827, -0.0148932, 0.0557383, 0.1224927, 0.0547441, 0.1582459, 0.0499690, 0.1207057, -0.0196894, -0.0755690, 0.0689451, 0.0779862, 0.0873665, 0.1419234, -0.0543128, 0.1608845, 0.0582658, 0.0228303, -0.0477583, 0.0354382, 0.0951089, 0.0891034, -0.0037552, -0.1537833, 0.1476649, -0.0221033, -0.0759088, -0.0557675, -0.1296122, 0.1325820, 0.0835043, -0.1161792, 0.1155595, 0.0505537, -0.0784579, -0.0735167, 0.0492661, 0.1260416, -0.1229414, 0.1009175, -0.1287802, 0.1222984, -0.0440249, -0.0013464, -0.1196105, -0.1063310, -0.1411955, 0.0688984, 0.1312972, -0.0141451, -0.0510540, 0.0973860, -0.1332392, -0.1191129, 0.0066505, 0.1188149, 0.0877433, 0.1113341, 0.1222921, -0.1393880, -0.0432469, -0.0526133, -0.0949420, -0.1076849, -0.0002497, -0.1306648, 0.0294515, -0.0090034, 0.0040909, -0.1298766, 0.0008129, -0.0666431, 0.0601226, 0.0552047, -0.1309684, -0.0444079, -0.1611355, 0.1431194, 0.1094252, 0.1139870, 0.0102238, -0.1103540, -0.1296468, -0.0996154, -0.0453108, 0.1056267, 0.1260485, 0.1171038, -0.1362305, 0.1390443, 0.0645312, -0.1069985, -0.1441645, -0.1028010, -0.0422021, 0.0047795, -0.0863843, 0.0737386, 0.0269379, -0.0824751, 0.1281872, -0.1497540, 0.0671954, -0.1099316, 0.0260724, 0.0106858, -0.1560881, -0.1535428, 0.0284854, 0.1236318, -0.1221088, -0.1611681, 0.0809807, 0.0805877, 0.0553429, 0.0253775, 0.0030661, -0.0974576, -0.1127261, -0.0760133, -0.0775680, 0.1187243, -0.0476503, -0.1262008, 0.0253365, 0.0707496, -0.0859080, 0.1214689, -0.0510921, 0.0854513, 0.1506919, 0.0251312, -0.1434479, 0.1403285, -0.0385780, 0.1050188, 0.0446120, 0.0138543, 0.1392026, 0.0407732, 0.1158056, -0.1064162, -0.0432423, -0.0038336, -0.1366038, 0.0687647, -0.1084612, -0.1086784, -0.1020633, 0.0336938, 0.1003439, 0.0182202, 0.0090614, 0.1004127, -0.0765266, -0.1150280, -0.1253591, 0.1291900, -0.0042546, 0.0653943, 0.0769058, -0.1206319, -0.0232126, 0.1008501, -0.1636609, -0.0687802, -0.0380300, 18 | }, 19 | } 20 | in := NewTensor(6, 7, 5) 21 | copy(in.Data, []float32{ 22 | -1.2988282, 0.2510391, 1.0101794, 0.0262716, -0.7962795, -1.3709655, 1.5619336, -4.1645885, 1.3380519, -1.3223622, 1.3869857, 0.1907596, 1.5335323, -1.2222977, 0.1216181, 0.3406470, 0.1903207, 0.6818309, 0.5130429, 1.0867368, 0.1619944, 0.1881683, -1.0403576, -0.5780251, 0.6729466, -0.1694447, -1.8407322, -0.3128998, 1.2927204, -1.9052949, 0.3897138, 1.0135902, 0.3811399, 2.2074044, -0.5668194, -1.3145959, -0.5424687, -1.2134125, -0.4866764, 1.9015125, -0.5780223, 0.6818247, -1.5994284, -0.9151881, 0.1030166, -0.7112994, 1.0685599, -0.3549566, -2.1857874, -1.0917655, -0.6021930, 1.1224357, 1.6376072, -0.5959309, 1.0260910, -0.0712967, -1.2713999, -0.0322425, -0.1124378, 0.9655091, 0.8144102, 0.0382287, -0.1795858, -0.2953408, -0.4305271, -0.2038272, -0.1020874, -0.2738957, 0.3701357, 0.7156716, 0.0504618, -0.9757811, 0.7032495, 0.6709840, -1.0611930, -0.3453544, -0.3327364, 0.7522176, -0.1924587, -0.3480925, -0.3838927, 0.4166580, -1.2638218, 0.1913680, 1.5650450, -0.8434668, -0.6515325, 0.4179985, 0.9464355, -0.8318047, 1.0476940, 1.9159533, 0.5978024, -1.3885217, -1.3461010, -0.7859114, 0.9567471, -1.6975802, 1.7899201, 0.9705070, 0.4316317, 0.0558143, 2.8513019, 2.0602522, 1.3202075, -2.0741122, -1.8953130, 0.8939016, 0.9133815, -2.2644036, -0.6281854, 0.3507041, 0.5025447, 0.4345118, 0.0053923, 1.8058343, 0.7382588, 1.8524241, -0.0612551, 1.2026122, 1.7526712, -0.1822277, 0.6281086, 2.0601828, -0.1985361, 0.6348800, 0.6708742, -1.0700990, -0.4328031, 0.7890400, 0.6280078, 0.4541491, -0.7496076, -1.4483989, -2.1997216, -0.4227039, -1.5337064, -1.1435885, 1.1559694, -0.4253571, -0.6400719, 1.2148696, 0.2517402, -1.5043896, -1.8011912, 0.8911669, 0.0587563, -0.3249962, 0.6687999, -0.0895100, -0.6761006, 0.4636458, -0.5522307, 0.9506739, 1.1748210, -0.3795080, 0.0771183, -0.0588746, -0.2903397, -0.4915347, -2.0633969, -0.2064304, -0.9264892, -0.1987959, 1.7669750, 0.7308894, 0.0915315, -0.2743991, -0.9268026, 2.4203701, 0.1530970, -0.6694155, 0.4007120, -0.1968893, -0.1985747, -0.3613594, 0.4583485, 0.4560544, 1.2247474, 1.3021513, 2.0714045, 0.3739655, 0.1966641, 0.9291840, 1.6853676, 0.6274776, -0.0975672, -0.9863145, 0.7114818, -0.5865821, 0.7559505, -1.4822845, 0.0133979, 2.5776844, -0.1955276, -1.3894904, 0.0986306, -0.6527658, -1.0126148, -1.3212727, -0.2197127, 1.0792760, 0.4495734, -0.0663339, 0.6242186, 1.8274466, 0.2306639, -0.5027097, 0.7396556, -1.5373755, 23 | }) 24 | 25 | expected := NewTensor(13, 15, 4) 26 | copy(expected.Data, []float32{ 27 | -0.0408645, -0.0194694, -0.0549688, 0.1871940, 0.0514174, -0.0748684, -0.0438619, -0.1216055, 0.3125930, 0.0659977, -0.6144946, -0.8956703, 0.4279370, 0.7311863, -0.6516200, -0.8114839, -0.3169491, -0.3421381, -0.2645444, 0.3775117, 0.1075694, -0.1713427, 0.1732734, 0.5835122, -0.2944925, 0.3595108, 0.5658865, -0.1464769, -0.1874506, -0.1782442, 0.2206931, 0.1595043, 0.2606104, -0.0576293, 0.0272467, -0.1266818, 0.0329028, 0.0720025, -0.0784942, 0.0395116, 0.4079916, -0.3463828, -0.3584284, 0.1734509, -0.1036744, 0.1587522, -0.0813099, -0.5309020, -0.3750744, 0.1195588, 0.3434209, 0.0365922, 0.0399913, 0.0980617, 0.0818439, -0.1096016, 0.0406522, -0.3897277, 0.1348119, -0.1519297, -0.1381952, -0.1140959, -0.0012608, 0.1788205, -0.0825243, 0.1681010, -0.1225363, 0.1451142, 0.4118539, -0.0112780, -0.4082728, -0.7436250, 0.2451778, 0.0299881, 0.1339143, -0.2771372, -0.0481255, 0.1620539, 0.1381359, 0.7327967, -0.1967984, 0.0656770, -0.1629940, 0.1361714, -0.4514995, -0.1899078, -0.0127284, -0.2433596, -0.0452810, -0.0796590, -0.0621000, -0.0403041, 0.0944393, 0.3397935, -0.2647886, -0.1517983, 0.2041608, -0.1598594, 0.2064631, -0.0395846, 0.1811487, -0.0998886, 0.3401464, 0.6302477, -0.1539173, 0.1857012, -0.0742397, -0.0911694, 0.3262306, -0.6464262, 0.1654890, -0.0606451, -0.3769718, 0.2175172, -0.4513310, -0.1637520, 0.0750152, 0.3769193, -0.3711143, 0.3044261, 0.2093513, 0.0268531, -0.5182128, -0.1569117, -0.3492104, -0.2155812, -0.0392161, -0.3552701, 0.9319751, 0.1550327, -0.2962227, 0.9153284, 0.1470551, 0.3057349, 0.0617516, 0.5934935, -0.1118675, 0.6566533, -1.1410300, 0.7201911, 0.4613485, 0.4751079, -0.1956451, 0.0268744, -0.7405378, -0.0367488, 0.0942594, -0.5576359, 0.1901441, -0.4608253, 0.0174982, 0.2970467, 0.4230424, -0.1541975, 0.3828651, -0.1474487, -0.2809403, -0.0612479, 0.1860174, -0.0150790, -0.3129278, 0.3389230, 0.0144943, 0.2549954, -0.2039080, -0.1000422, -0.1439215, 0.1764802, -0.1757951, 0.3836617, -0.1316161, 0.0943417, 0.0123362, -0.2093474, 0.1063277, 0.2612825, -0.2790885, -0.0199187, 0.2368225, 0.5640559, 0.0145945, 0.1605127, 0.0941011, -0.3979308, 0.5283298, -0.3879965, 0.5292951, -0.0214575, 0.3128785, 0.1392319, -0.0357783, -0.2425334, 0.2766051, -0.1091767, 0.2472210, 0.0051574, -0.2188763, 0.1701127, 0.0324907, -0.0849292, 0.1132794, 0.1102249, 0.0826966, 0.2312904, -0.6955340, -0.3363871, 0.0778619, -0.3215622, -0.0235915, -0.0101222, -0.0934486, 0.1692782, -0.0411102, 0.3454086, -0.2621068, -0.0112024, 0.1919640, -0.2004677, 0.2520424, -0.0198838, 0.2598956, -0.0351781, 0.0903174, -0.0463146, -0.0826847, 0.0505923, -0.0602382, -0.0148637, -0.1614131, -0.1214503, 0.1613176, -0.1173977, 0.0996005, -0.1109961, 0.0922735, -0.0570295, 0.1194476, 0.0700215, -0.0606208, 0.1270892, 0.0712810, -0.2112693, -0.1931393, -0.1757720, 0.0951650, -0.3090639, -0.2786337, -0.3837940, 0.5698888, 0.1031060, 0.1700851, 0.2232326, -0.0569422, 0.1148025, 0.2036915, 0.1051866, 0.6243874, 0.6185188, -0.0999638, 0.0934461, -0.3487709, 0.5158741, 0.3924620, -0.1105984, 0.5366999, -0.8700936, -1.0222739, -0.1621883, 0.1432491, -0.0641485, -0.0606568, -0.5425753, -0.7233461, -0.6093550, 0.2249727, -0.3911178, 0.6345437, -0.0414115, -0.4625136, 0.2917506, 0.2109767, 0.6412838, 0.3307448, -0.7391987, -0.2074914, 0.3656350, 0.0904820, -0.2948250, -0.1879081, -0.2023569, 0.3664604, 0.2634708, -0.3531348, -0.6327595, 0.4403827, 0.2161390, 0.2077463, -0.3671280, 0.5333624, -0.1885908, 0.0684396, -0.0730250, 0.1153513, 0.3906485, -0.1929944, 0.1588487, -0.1496136, 0.0034723, 0.1058906, -0.0769879, 0.1061978, -0.0501986, -0.0496141, 0.0670391, -0.0377283, 0.0874361, 0.0650878, -0.0328690, -0.0664496, -0.6001179, 0.2986904, -0.2680899, 0.2626746, -0.1132197, 0.1229487, -0.0200973, 0.0313885, 0.6359798, -0.0922342, 0.1241876, -0.0888708, 0.0009254, 0.1357672, 0.0730435, -0.3235241, -0.1779824, -0.2997456, 0.2863789, -0.3487119, 0.1518145, -0.5578908, -0.2576382, -0.1182949, -0.4619887, 0.1600166, -0.1453047, 0.0680032, -0.2419467, 0.0001772, -0.1798283, -0.2129710, 0.9689248, -0.3833165, 0.0573459, -0.4214272, -0.0438824, 0.2583529, 0.5035083, -0.6810591, 0.1002424, 0.0361443, -0.3277960, -0.4970536, 0.2453915, -0.3229753, -0.1435948, -0.1736451, -0.7230493, -0.4691697, 0.2895883, -0.1456939, -0.1936812, -0.0954810, -0.1039214, -0.0166284, -0.2001947, -0.3186257, -0.1116009, 0.5662885, -0.2154126, 0.1714665, -0.4241474, 0.2451154, 0.7165399, -0.5589361, -0.1572717, 0.8552569, -0.2934281, -0.2107465, -0.2415778, 0.1111134, 0.0122270, 0.5683089, 0.0708979, 0.7062442, 0.4443365, -0.0249985, 0.7697493, 0.5327728, 0.3675256, -0.2004680, -0.7901496, -0.4140352, 0.1204381, 0.8648459, 0.0406502, -0.4981539, 0.3936258, -0.7343534, -0.4858570, -0.9478843, -0.4858515, 0.1701055, -0.4284994, -0.4669115, -0.6881824, -0.1255757, -0.4638782, 0.6190758, 0.4772179, 0.0836801, -0.0954533, 0.2589582, 0.6404762, -0.1024121, 0.2725188, -0.0567672, 0.1182317, 0.4446644, -0.3406713, 0.0559832, -0.4679064, -0.0479808, 0.0521079, -0.0870822, 0.0273100, -0.1110922, 0.0004704, -0.3015858, 0.0804126, -0.2845514, 0.0019115, -0.2926601, 0.0155520, -0.1964415, -0.0635939, -0.3067963, 0.3861094, -0.4183314, 0.1341913, -0.4031016, -0.2167615, 0.0837665, 0.3254414, -0.2497691, -0.1365314, 0.1347667, -0.1443004, 0.1201873, -0.0760484, -0.0813915, 0.2773890, -0.0225355, 0.2626073, -0.1171071, 0.2435607, -0.0927482, 0.1001449, -0.1702099, -0.3196163, 0.5877363, -0.0942353, 0.1096913, -0.0580086, 0.1658542, -0.1532428, 0.3838004, -0.1364820, 0.2379148, 0.0226721, -0.9486727, 0.0471871, -0.2889495, -0.0846890, 0.1797559, -0.2022166, -0.5794636, -0.1370407, -0.4106274, 0.2817611, -0.4343900, -0.6964189, 0.0995082, -0.0442456, -0.0315741, -0.0110493, 0.2104852, -0.3327667, -0.0426008, -0.1433302, 0.0624215, 0.0513770, 0.0806343, -0.2067210, -0.1681530, -0.9116020, 0.3187425, 0.0514263, 0.1780499, -0.0933935, -0.1283691, 0.2050675, 0.3542832, 0.5504524, 0.1712812, 0.3673798, -0.2179071, 0.0309926, 0.0976007, -0.1699147, 0.6405470, 0.6613162, 0.3034024, 0.7905031, -0.6097742, 0.4246245, 0.7306086, 0.5807211, -0.0064128, 0.3656771, -0.4617210, 0.2678794, -0.2004292, -0.3578294, -0.1726816, 0.0554052, 0.0350157, 0.5915754, -0.0203585, 0.0759241, -0.1140942, 0.1183311, -0.1510688, -0.0134705, -0.0978259, 0.2710258, -0.1428374, 0.2259427, -0.2612808, -0.1995445, 0.0324642, -0.2433879, -0.1193245, 0.0245992, -0.1118952, -0.1185197, -0.0737830, -0.1092648, -0.0074915, -0.0763595, 0.1441956, -0.1513248, 0.0880497, -0.1092044, 0.1980333, 0.2730721, -0.2452524, 0.3368914, 0.0111948, 0.0548908, 0.0059155, 0.0511906, -0.0759984, -0.0003335, 0.1121711, -0.4613860, 0.5134895, -0.3359855, 0.4769215, 0.0134460, 0.2828873, 0.2104597, -0.1148628, -0.2343793, 0.2841360, -0.3608468, 0.2986756, -0.0285679, -0.0970565, 0.0940803, 0.0710430, 0.3466266, -0.0252922, 0.0143420, 0.0146390, 0.0356186, 0.0324571, -0.1025298, 0.0905135, -0.1609289, 0.2104131, 0.4094425, 0.5369654, 0.3708864, -0.4917074, 0.2981741, 0.6447821, 0.0181734, -0.0014421, -0.2475360, 0.1685373, -0.0719392, -0.1766563, -0.0786047, 0.3887443, 0.5683194, 0.1227366, -0.3889403, 0.2959919, -0.0507108, 0.3383470, -0.0792183, -0.2859499, -0.1036216, 0.1382751, -0.1976376, 0.2762082, 0.3267169, -0.4594756, 0.0598829, 0.2913103, -0.3694354, 0.3737148, -0.5183859, -0.7008716, -0.0395234, 0.4852855, -0.2171500, -0.7053072, -0.4991091, 0.2024119, 0.1592132, -0.4152100, -0.4256546, 0.3087745, -0.0806829, -0.0935526, 0.0294502, -0.1627689, -0.4609432, 0.5260309, -0.6456886, 0.2007331, 0.2933916, -0.1396879, -0.0415259, -0.0670704, 0.0115323, 0.1782231, -0.0767321, -0.1105404, -0.2647381, -0.0274911, -0.0698640, -0.0038630, -0.0970322, -0.0603650, -0.0857060, 0.2031547, 0.1831581, -0.4518503, 0.1654684, -0.1426963, -0.1503477, -0.1393246, -0.2026948, 0.0429422, 0.2370295, -0.1831069, 0.3895363, -0.0695245, 0.0433065, -0.0564393, -0.1417701, 0.2155443, -0.2995106, 0.3182778, 0.5450820, -0.2252713, 0.0363688, -0.1718405, -0.2620012, 0.4342357, 0.2754403, -0.0221255, 0.0961393, 0.1443960, 0.1017210, 0.1325648, 0.1425537, -0.1785801, -0.3250968, 0.0754474, -0.3749839, -0.0162527, -0.0094159, -0.0786251, 0.0378912, 0.0945599, 0.1461810, -0.3464882, 0.3093731, -0.3585347, 0.2326562, -0.3336061, -0.1460380, -0.2435825, -0.0401674, 0.1928711, 0.0807574, -0.2018481, -0.4270057, -0.2849124, -0.3355846, 0.3214097, -0.3571669, -0.2502927, -0.0129574, 0.0189537, -0.4385507, 0.2072561, 0.0513688, 0.2988456, -0.0501798, -0.0621900, 0.1694670, -0.2064065, 0.0896689, 0.7913690, 0.3585398, -0.0817096, 0.0598960, 0.1050461, 0.2475841, -0.6490353, 0.2523943, -0.2230127, -0.0029347, 0.0097553, -0.4517760, -0.3046168, 0.1346664, -0.1032889, 0.7216869, 0.6028582, 0.4497806, -0.2393321, 0.1755342, 0.1653123, -0.0413742, 0.2623122, -0.1941507, -0.6127819, -0.1472983, 0.1895358, 0.0048919, 0.0290338, 0.0057643, 0.0785637, -0.0767744, 0.4029086, 0.5604403, -0.1909199, 0.3173218, 0.3480718, 0.4091369, -0.3377852, 0.3681487, 0.0921967, 0.2124878, 28 | }) 29 | actual := c.Apply(in) 30 | if expected.Height != actual.Height || expected.Width != actual.Width || 31 | expected.Depth != actual.Depth { 32 | t.Fatal("incorrect output shape") 33 | } 34 | for i, x := range expected.Data { 35 | a := actual.Data[i] 36 | if math.Abs(float64(x-a)) > 1e-4 { 37 | t.Errorf("bad value at %d: expected %f but got %f", i, x, a) 38 | } 39 | } 40 | } 41 | 42 | p := runtime.GOMAXPROCS(0) 43 | defer runtime.GOMAXPROCS(p) 44 | 45 | runtime.GOMAXPROCS(1) 46 | t.Run("Proc1", runTest) 47 | 48 | runtime.GOMAXPROCS(2) 49 | t.Run("Proc2", runTest) 50 | } 51 | -------------------------------------------------------------------------------- /polish/nn/group_norm.go: -------------------------------------------------------------------------------- 1 | package nn 2 | 3 | import "math" 4 | 5 | // GroupNorm implements the normalization step of group 6 | // normalization. 7 | type GroupNorm struct { 8 | NumGroups int 9 | } 10 | 11 | // Apply applies the normalization step. 12 | func (g *GroupNorm) Apply(t *Tensor) *Tensor { 13 | if t.Depth%g.NumGroups != 0 { 14 | panic("number of groups must divide number of input channels") 15 | } 16 | sums := make([]float32, g.NumGroups) 17 | sqSums := make([]float32, g.NumGroups) 18 | Groups(t, g.NumGroups, func(group, idx int) { 19 | v := t.Data[idx] 20 | sums[group] += v 21 | sqSums[group] += v * v 22 | }) 23 | normalize := 1.0 / float32(t.Width*t.Height*t.Depth/g.NumGroups) 24 | 25 | biases := sums 26 | scales := sqSums 27 | for i, x := range biases { 28 | biases[i] = -x * normalize 29 | } 30 | for i, sqSum := range scales { 31 | b := biases[i] 32 | x := sqSum*normalize - b*b 33 | if x < 0 { 34 | x = 0 35 | } 36 | scales[i] = float32(1 / math.Sqrt(float64(x)+1e-5)) 37 | } 38 | 39 | res := NewTensor(t.Height, t.Width, t.Depth) 40 | Groups(t, g.NumGroups, func(group, idx int) { 41 | res.Data[idx] = (t.Data[idx] + biases[group]) * scales[group] 42 | }) 43 | return res 44 | } 45 | 46 | // Groups iterates over the entries of t in order, but 47 | // adds a groupIdx parameter indicating which group each 48 | // component belongs to for group normalization. 49 | func Groups(t *Tensor, numGroups int, f func(groupIdx, dataIdx int)) { 50 | if t.Depth%numGroups != 0 { 51 | panic("number of groups must divide number of input channels") 52 | } 53 | groupSize := t.Depth / numGroups 54 | var idx int 55 | for y := 0; y < t.Height; y++ { 56 | for x := 0; x < t.Width; x++ { 57 | for g := 0; g < numGroups; g++ { 58 | for z := 0; z < groupSize; z++ { 59 | f(g, idx) 60 | idx++ 61 | } 62 | } 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /polish/nn/group_norm_test.go: -------------------------------------------------------------------------------- 1 | package nn 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | ) 7 | 8 | func TestGroupNorm(t *testing.T) { 9 | in := NewTensor(4, 5, 12) 10 | copy(in.Data, []float32{ 11 | -1.0221001, -1.5796757, 1.7134480, -0.6383699, -0.8252002, -0.3276477, -1.0969015, -0.3216827, 0.9023801, -0.2086772, -0.0098957, 0.3430660, -0.9665561, -1.0619389, 0.7971905, 2.5056205, 1.7643378, 0.6906035, -0.4143253, -0.9578822, 0.9929895, -0.3113273, -0.0693117, 0.4030469, -0.1519197, -1.5732781, -0.2080776, 1.2209413, 0.4384030, -1.7938681, -2.1126060, 0.6378745, -0.6695310, -1.0245168, -0.9538743, 0.0238561, -0.0467072, 0.5612687, 1.0865034, 1.5901088, -1.3611679, 0.8314494, 2.0650454, -0.3145818, -0.5839299, 0.4105616, -1.1479204, -0.5899442, 0.5731580, -1.5805511, 1.0306307, -0.3374717, 0.8066963, 1.9507209, 0.2116757, 0.1386233, -0.8553388, 0.4269760, 0.7171491, 0.4923056, -0.3509787, 0.3014634, 0.7269417, 1.6744821, 1.0522863, 0.2112580, -0.9874152, 0.5523420, 0.0158053, -0.3125320, 1.9377381, 0.1459422, 0.5488644, 0.6677971, -0.9905274, 1.6344807, 0.1424473, -0.0821457, 1.1207696, 0.0182036, 1.4424065, -0.0713405, 0.4244196, 0.0350529, 1.0896454, 2.0629742, 0.7795575, 1.0992870, -0.9697220, 0.2713334, -1.3391825, -0.5634347, -0.6239132, -0.4311570, -0.4288743, -1.9728240, 2.0597763, -0.7501921, -0.6738955, -0.6277537, -0.7750289, -1.4438819, 0.2871203, -0.9755278, 0.3835825, -0.6603342, 2.0187373, -0.9440562, 0.0718082, -0.1985949, 1.2051098, -0.1503626, 1.3310804, 0.0989851, 0.0927790, -1.0305945, -1.0669045, 0.8813666, 0.0706334, 0.7924773, -0.5513872, 0.5201599, -1.3582357, 0.0590850, 0.0187143, -0.4046591, 1.2982638, -0.3689553, 0.3967363, 0.5261109, 0.2245125, -0.7741958, -0.4043697, -0.6617534, -0.6659707, -0.4774512, 1.8953596, 2.0029273, 1.4613733, -1.8399051, 0.0611079, -1.4877330, -0.3591553, -1.1237507, 0.8407897, 0.9940514, -0.1909862, 0.3209904, 0.5886565, 1.5291401, 0.1317610, 0.6872989, -0.9559830, 1.4065583, 0.1249584, 0.4124418, -0.2684685, 1.4394376, -0.0221158, 0.5195087, -0.9137533, 0.9451744, -0.4042242, -0.6771299, -0.9027446, 1.0519906, -0.7811508, -0.0777123, -0.3535839, 1.1632098, -0.4913792, 0.3651035, -0.8887430, -1.0435823, -0.0712068, 0.1240887, -0.6695143, -1.3721397, 0.1611713, 1.6129134, 1.4333644, -1.5874431, 0.4289396, -0.4726154, -2.1013939, -1.2848064, 2.4758003, 0.7925473, 2.1013570, 0.1712477, 0.9803172, 0.4901890, -0.0817213, -0.6183630, -3.1022520, 0.6858925, -2.0902109, -1.6194723, -0.5491874, 0.5372946, 0.8821929, -0.7812951, -0.5972003, -0.3235330, 0.1639364, -0.8202379, 0.8938593, -0.3348892, 0.0765916, -0.3791303, -0.2271143, -0.7200960, 0.0506867, -1.0809671, -0.3619950, 0.0645623, 1.5365840, -1.4482405, 0.5266994, -0.2942330, 0.2541175, -0.6751279, -0.3904678, 1.1792732, -0.6438045, -0.7042071, 1.7414868, -0.8441688, 1.1386329, -1.0414594, -1.1076441, -0.8434691, -2.3673568, -1.6799883, 0.8815190, -1.1306807, -0.2620730, 0.2116996, 1.2455457, -0.1427833, 12 | }) 13 | expected := NewTensor(4, 5, 12) 14 | copy(expected.Data, []float32{ 15 | -1.1425461, 0.9038563, -1.4868827, -0.4609960, 1.4304428, -0.8965625, -0.6618699, -0.0191097, -0.7013139, 1.5647557, 1.0326166, 0.4262294, -1.7087859, -0.4207285, 0.7398060, 0.2015845, -0.1581028, -1.5984653, 0.4626254, -1.6444832, 1.0934234, -1.3517406, -0.5734295, -0.4709284, 1.6355102, -0.1749521, 1.9925710, 0.6336744, 0.3621542, 0.2180706, -1.5085871, -0.4601394, -0.2093793, 0.5950147, -0.3956917, -0.1960979, -0.7528531, 0.3047465, -0.4240335, 1.5959388, -0.0464522, -1.1069686, -0.0212326, -1.2625155, -0.4728613, -0.2754089, -0.1314744, 1.3194387, -0.9425865, -0.2588438, -0.6975671, 0.9640745, 1.0602508, 0.3192993, -0.0635981, 0.7990984, -0.6906853, -1.8479443, 0.3391623, -0.4406866, -0.4373025, -1.7022889, 0.3123779, 0.1099774, 2.0816746, -0.7762004, -0.5078916, 0.9599332, 1.1965512, -1.0595541, -0.6110276, -0.4990035, -1.2185099, -0.3158744, -1.2703215, -1.1073222, 0.7348406, 2.0352519, 1.2791775, -0.2836605, -0.5732903, 2.5711954, 1.0438803, 1.8622384, -0.4312448, 1.1353503, -0.7036749, 0.4563615, 1.0703688, -1.0739418, -0.4704235, 0.2536142, 0.1058579, 0.9460666, -0.1424384, -0.6341322, 0.8118389, 0.3406520, 0.4775009, -0.0885126, -1.1008759, -0.0078807, 0.3331030, 0.5345066, -0.1604876, 2.2096820, 0.2548335, 1.2802017, -0.3164833, -1.9263067, -1.7096750, -0.4219519, 0.2015036, -0.2916452, 0.4688705, 1.5214622, 1.3039299, 0.3462211, -0.1851519, -0.8246102, -0.1146129, -2.2499976, 0.9420825, 1.8632855, -1.4885924, 1.1814206, 0.1523692, 0.0550345, -0.2935247, 1.1273528, -0.0383851, -0.8885095, 0.2438335, 0.5432231, -0.4472791, 0.0436466, -0.6745127, -0.2410295, -0.8956881, 0.6380231, 0.5333828, 0.6541491, -0.5143437, -0.6334567, -1.0861390, -0.7844984, 0.7146683, 0.4528299, -0.7379797, 1.3136158, -0.5075879, -1.0864580, -0.6771672, 0.1019873, 0.2298232, -2.1047232, -1.1830039, -1.1450003, 1.8764701, 0.5736107, -0.5356988, 0.0206390, -0.7776896, 1.3928232, -0.8266598, -0.4161237, -0.8627536, -1.4410900, 0.7050148, -1.0732602, 0.1104015, -1.1104828, -0.5333033, 0.0141263, -0.7821154, 0.0478958, 0.1121387, -2.8142419, -0.1686083, 1.0319661, 2.4399924, -0.0803367, 0.0362140, 1.5553157, -2.1535442, -1.1647562, -0.5842806, 0.3495846, 0.3006905, 0.8430947, 0.2432196, -0.9107506, 1.6871907, -0.1519964, -0.9731934, 0.0400974, 2.0783186, -1.2028605, 1.9057776, -0.3649714, -0.4655087, -1.8371475, 1.6644112, -0.0721367, 0.5967715, 0.4654269, 0.3290475, -0.1879856, -0.8704984, 0.8416802, 2.0186605, 1.4273272, -1.1438719, -1.3826638, -1.2173449, 0.3852761, -0.5253270, 0.9988233, 0.6237296, 1.0336220, -0.7904317, -0.0091135, 1.4503468, -0.1064457, 0.3364926, -0.3493368, 0.6893988, 1.3834227, -1.0773302, 1.5102543, 0.3953922, -0.0860769, -0.7420099, 0.7483987, -2.0140572, 0.4619420, 1.7381047, 0.6996281, -0.1031862, 0.0430338, 16 | }) 17 | actual := (&GroupNorm{NumGroups: 3}).Apply(expected) 18 | if len(expected.Data) != len(actual.Data) { 19 | t.Fatal("incorrect output shape") 20 | } 21 | for i, x := range expected.Data { 22 | a := actual.Data[i] 23 | if math.Abs(float64(x-a)) > 1e-4 { 24 | t.Errorf("data index %d: expected %f but got %f", i, x, a) 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /polish/nn/nn.go: -------------------------------------------------------------------------------- 1 | // Package nn implements a small collection of neural 2 | // network layers for denoising auto-encoders. 3 | // 4 | // It is designed for inference, not for training. 5 | package nn 6 | 7 | // A Layer is a Tensor operation. 8 | type Layer interface { 9 | Apply(t *Tensor) *Tensor 10 | } 11 | 12 | // An NN is a special Layer that composes multiple other 13 | // Layers. 14 | type NN []Layer 15 | 16 | // Apply applies all the layers in order. 17 | func (n NN) Apply(t *Tensor) *Tensor { 18 | res := t 19 | for _, l := range n { 20 | res = l.Apply(res) 21 | } 22 | return res 23 | } 24 | 25 | // Residual is a special Layer that composes multiple 26 | // other Layers and adds the output to the input. 27 | type Residual []Layer 28 | 29 | // Apply applies the layers in order and adds the output 30 | // to the original input. 31 | func (r Residual) Apply(t *Tensor) *Tensor { 32 | t1 := NN(r).Apply(t) 33 | if t1.Width != t.Width || t1.Height != t.Height || t1.Depth != t.Depth { 34 | panic("dimensions must match for residual connection") 35 | } 36 | res := NewTensor(t.Height, t.Width, t.Depth) 37 | for i, x := range t.Data { 38 | res.Data[i] = x + t1.Data[i] 39 | } 40 | return res 41 | } 42 | -------------------------------------------------------------------------------- /polish/nn/simple_ops.go: -------------------------------------------------------------------------------- 1 | package nn 2 | 3 | // A ReLU layer applies the rectified linear unit. 4 | type ReLU struct{} 5 | 6 | // Apply applies the rectified linear unit. 7 | func (r ReLU) Apply(t *Tensor) *Tensor { 8 | res := NewTensor(t.Height, t.Width, t.Depth) 9 | for i, x := range t.Data { 10 | if x > 0 { 11 | res.Data[i] = x 12 | } 13 | } 14 | return res 15 | } 16 | 17 | // A Pad layer pads input Tensors. 18 | type Pad struct { 19 | Top int 20 | Right int 21 | Bottom int 22 | Left int 23 | } 24 | 25 | // NewPad creates a Pad with the given values. 26 | func NewPad(t, r, b, l int) *Pad { 27 | return &Pad{t, r, b, l} 28 | } 29 | 30 | // Apply pads the Tensor. 31 | func (p *Pad) Apply(t *Tensor) *Tensor { 32 | return t.Pad(p.Top, p.Right, p.Bottom, p.Left) 33 | } 34 | 35 | // An Unpad layer unpads (crops) input Tensors. 36 | type Unpad struct { 37 | Top int 38 | Right int 39 | Bottom int 40 | Left int 41 | } 42 | 43 | // NewUnpad creates an Unpad with the given values. 44 | func NewUnpad(t, r, b, l int) *Unpad { 45 | return &Unpad{t, r, b, l} 46 | } 47 | 48 | // Apply unpads (crops) the Tensor. 49 | func (u *Unpad) Apply(t *Tensor) *Tensor { 50 | return t.Unpad(u.Top, u.Right, u.Bottom, u.Left) 51 | } 52 | -------------------------------------------------------------------------------- /polish/nn/tensor.go: -------------------------------------------------------------------------------- 1 | package nn 2 | 3 | import ( 4 | "image" 5 | "image/color" 6 | ) 7 | 8 | // Tensor is a 3D array of numbers. 9 | // 10 | // It is arranged as [height x width x depth], with the 11 | // outer dimension being height. 12 | type Tensor struct { 13 | Height int 14 | Width int 15 | Depth int 16 | 17 | Data []float32 18 | } 19 | 20 | // NewTensorRGB creates an RGB Tensor from an image. 21 | func NewTensorRGB(img image.Image) *Tensor { 22 | b := img.Bounds() 23 | res := NewTensor(b.Dy(), b.Dx(), 3) 24 | var idx int 25 | for y := 0; y < res.Height; y++ { 26 | for x := 0; x < res.Width; x++ { 27 | red, green, blue, _ := img.At(x+b.Min.X, y+b.Min.Y).RGBA() 28 | for _, c := range []uint32{red, green, blue} { 29 | res.Data[idx] = float32(c) / 0xffff 30 | idx++ 31 | } 32 | } 33 | } 34 | return res 35 | } 36 | 37 | // NewTensor creates a zero tensor. 38 | func NewTensor(height, width, depth int) *Tensor { 39 | return &Tensor{ 40 | Height: height, 41 | Width: width, 42 | Depth: depth, 43 | Data: make([]float32, width*height*depth), 44 | } 45 | } 46 | 47 | // At gets a pointer to the given coordinate. 48 | func (t *Tensor) At(y, x, z int) *float32 { 49 | return &t.Data[z+t.Depth*(x+y*t.Width)] 50 | } 51 | 52 | // Add adds a scalar to every entry. 53 | func (t *Tensor) Add(s float32) *Tensor { 54 | res := NewTensor(t.Height, t.Width, t.Depth) 55 | for i, x := range t.Data { 56 | res.Data[i] = x + s 57 | } 58 | return res 59 | } 60 | 61 | // Pad creates a zero-padded version of the Tensor. 62 | func (t *Tensor) Pad(top, right, bottom, left int) *Tensor { 63 | res := NewTensor(t.Height+top+bottom, t.Width+left+right, t.Depth) 64 | rowSize := t.Depth * t.Width 65 | for i := 0; i < t.Height; i++ { 66 | start := t.Depth * t.Width * i 67 | copy(res.Data[res.Depth*(left+(i+top)*res.Width):], t.Data[start:start+rowSize]) 68 | } 69 | return res 70 | } 71 | 72 | // Unpad cuts out the edges of the Tensor, effectively 73 | // inverting the operation done by Pad. 74 | func (t *Tensor) Unpad(top, right, bottom, left int) *Tensor { 75 | res := NewTensor(t.Height-(top+bottom), t.Width-(left+right), t.Depth) 76 | rowSize := t.Depth * (t.Width - (left + right)) 77 | for i := top; i < t.Height-bottom; i++ { 78 | start := t.Depth * (left + t.Width*i) 79 | copy(res.Data[res.Depth*res.Width*(i-top):], t.Data[start:start+rowSize]) 80 | } 81 | return res 82 | } 83 | 84 | // RGB creates an RGB image out of the Tensor. 85 | // 86 | // If the tensor does not have three channels, this will 87 | // panic(). 88 | func (t *Tensor) RGB() image.Image { 89 | if t.Depth != 3 { 90 | panic("expected exactly 3 output channels") 91 | } 92 | res := image.NewRGBA(image.Rect(0, 0, t.Width, t.Height)) 93 | var idx int 94 | for y := 0; y < t.Height; y++ { 95 | for x := 0; x < t.Width; x++ { 96 | var colors [3]uint8 97 | for i := 0; i < 3; i++ { 98 | x := t.Data[idx+i] 99 | if x < 0 { 100 | x = 0 101 | } else if x > 1 { 102 | x = 1 103 | } 104 | colors[i] = uint8(x * 255.999) 105 | } 106 | idx += 3 107 | res.SetRGBA(x, y, color.RGBA{ 108 | R: colors[0], 109 | G: colors[1], 110 | B: colors[2], 111 | A: 0xff, 112 | }) 113 | } 114 | } 115 | return res 116 | } 117 | -------------------------------------------------------------------------------- /polish/nn/tensor_test.go: -------------------------------------------------------------------------------- 1 | package nn 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestPadUnpad(t *testing.T) { 11 | tensor := NewTensor(5, 10, 3) 12 | for i := range tensor.Data { 13 | tensor.Data[i] = float32(rand.NormFloat64()) 14 | } 15 | 16 | runShape := func(top, r, b, l int) { 17 | t.Run(fmt.Sprintf("%d,%d,%d,%d", top, r, b, l), func(t *testing.T) { 18 | padded := tensor.Pad(top, r, b, l) 19 | unpadded := padded.Unpad(top, r, b, l) 20 | if !reflect.DeepEqual(tensor, unpadded) { 21 | t.Error("unexpected pad->unpad") 22 | } 23 | }) 24 | } 25 | 26 | runShape(0, 0, 0, 0) 27 | runShape(1, 0, 0, 0) 28 | runShape(0, 1, 0, 0) 29 | runShape(0, 0, 1, 0) 30 | runShape(0, 0, 0, 1) 31 | runShape(1, 1, 1, 1) 32 | runShape(1, 2, 3, 4) 33 | } 34 | -------------------------------------------------------------------------------- /polish/polish.go: -------------------------------------------------------------------------------- 1 | package polish 2 | 3 | import ( 4 | "image" 5 | 6 | "github.com/unixpickle/essentials" 7 | "github.com/unixpickle/polish/polish/nn" 8 | ) 9 | 10 | // PolishImage applies a denoising network to an image. 11 | // 12 | // A model should be used which does not expect any extra 13 | // feature channels besides RGB colors. 14 | func PolishImage(t ModelType, img image.Image) image.Image { 15 | patchSize := essentials.MaxInt(img.Bounds().Dx(), img.Bounds().Dy()) 16 | return PolishImagePatches(t, img, patchSize, 0) 17 | } 18 | 19 | // PolishImagePatches is like PolishImage, but it applies 20 | // the operation to patches of the image at a time to save 21 | // memory. 22 | // 23 | // It is useful for very large images, where a neural 24 | // network forward pass is too costly. 25 | // 26 | // The patchSize argument specifies how large the output 27 | // patches should be (they are always square). 28 | // 29 | // The border argument specifies how many extra pixels are 30 | // included on the side of each patch before it is fed 31 | // into the network. 32 | // A value of -1 will use a reasonable default border. 33 | // Larger border values ensure more accuracy at the cost 34 | // of redundant computation, while lower values may cause 35 | // checkerboarding artifacts. 36 | func PolishImagePatches(t ModelType, img image.Image, patchSize, border int) image.Image { 37 | if t.Aux() { 38 | panic("model requires auxiliary features") 39 | } 40 | inTensor := nn.NewTensorRGB(img) 41 | outTensor := operatePatches(inTensor, patchSize, border, func(in *nn.Tensor) *nn.Tensor { 42 | pad, unpad := padAndUnpad(t, in) 43 | outTensor := pad.Apply(in) 44 | outTensor = t.Layer().Apply(outTensor) 45 | outTensor = unpad.Apply(outTensor) 46 | return outTensor 47 | }) 48 | return outTensor.RGB() 49 | } 50 | 51 | // PolishAux applies a denoising network to an image with 52 | // auxiliary feature channels. 53 | // 54 | // The Tensor may be created via CreateAuxTensor(). 55 | // 56 | // This should be used with a model that expects auxiliary 57 | // features. 58 | func PolishAux(t ModelType, auxImage *nn.Tensor) image.Image { 59 | patchSize := essentials.MaxInt(auxImage.Width, auxImage.Height) 60 | return PolishAuxPatches(t, auxImage, patchSize, 0) 61 | } 62 | 63 | // PolishAuxPatches is like PolishAux, but it applies the 64 | // operation to patches of the image at a time to save 65 | // memory. 66 | // 67 | // See PolishImagePatches for more information. 68 | func PolishAuxPatches(t ModelType, auxImage *nn.Tensor, patchSize, border int) image.Image { 69 | if !t.Aux() { 70 | panic("model does not support auxiliary features") 71 | } 72 | outTensor := operatePatches(auxImage, patchSize, border, func(in *nn.Tensor) *nn.Tensor { 73 | pad, unpad := padAndUnpad(t, in) 74 | outTensor := pad.Apply(in) 75 | outTensor = t.Layer().Apply(outTensor) 76 | outTensor = unpad.Apply(outTensor) 77 | return outTensor 78 | }) 79 | return outTensor.RGB() 80 | } 81 | 82 | func padAndUnpad(t ModelType, in *nn.Tensor) (pad, unpad nn.Layer) { 83 | lcd := t.LCD() 84 | rightPad := (lcd - in.Width%lcd) % lcd 85 | bottomPad := (lcd - in.Height%lcd) % lcd 86 | return nn.NewPad(0, rightPad, bottomPad, 0), nn.NewUnpad(0, rightPad, bottomPad, 0) 87 | } 88 | 89 | func operatePatches(t *nn.Tensor, patchSize, border int, f func(*nn.Tensor) *nn.Tensor) *nn.Tensor { 90 | if patchSize >= t.Width && patchSize >= t.Height { 91 | // Special case when the patch fills the image. 92 | // This is utilized by PolishImage(). 93 | return f(t) 94 | } 95 | 96 | if border == -1 { 97 | border = patchSize / 2 98 | } 99 | var output *nn.Tensor 100 | for y := 0; y < t.Height; y += patchSize { 101 | patchHeight := essentials.MinInt(patchSize, t.Height-y) 102 | extraTop := essentials.MinInt(y, border) 103 | extraBottom := essentials.MinInt(t.Height-(y+patchHeight), border) 104 | for x := 0; x < t.Width; x += patchSize { 105 | patchWidth := essentials.MinInt(patchSize, t.Width-x) 106 | extraLeft := essentials.MinInt(x, border) 107 | extraRight := essentials.MinInt(t.Width-(x+patchWidth), border) 108 | 109 | patch := nn.NewTensor(patchHeight+extraTop+extraBottom, 110 | patchWidth+extraLeft+extraRight, t.Depth) 111 | for subY := 0; subY < patch.Height; subY++ { 112 | for subX := 0; subX < patch.Width; subX++ { 113 | destIdx := (subX + subY*patch.Width) * patch.Depth 114 | dest := patch.Data[destIdx : destIdx+patch.Depth] 115 | sourceIdx := ((subX + x - extraLeft) + (subY+y-extraTop)*t.Width) * t.Depth 116 | source := t.Data[sourceIdx : sourceIdx+t.Depth] 117 | copy(dest, source) 118 | } 119 | } 120 | 121 | patchOut := f(patch) 122 | patchOut = patchOut.Unpad(extraTop, extraRight, extraBottom, extraLeft) 123 | if output == nil { 124 | output = nn.NewTensor(t.Height, t.Width, patchOut.Depth) 125 | } 126 | for subY := 0; subY < patchHeight; subY++ { 127 | for subX := 0; subX < patchWidth; subX++ { 128 | destIdx := ((subX + x) + (subY+y)*t.Width) * output.Depth 129 | dest := output.Data[destIdx : destIdx+output.Depth] 130 | sourceIdx := (subX + subY*patchOut.Width) * patchOut.Depth 131 | source := patchOut.Data[sourceIdx : sourceIdx+patchOut.Depth] 132 | copy(dest, source) 133 | } 134 | } 135 | } 136 | } 137 | return output 138 | } 139 | -------------------------------------------------------------------------------- /polish/polish_test.go: -------------------------------------------------------------------------------- 1 | package polish 2 | 3 | import ( 4 | "image" 5 | "image/color" 6 | "math/rand" 7 | "testing" 8 | 9 | "github.com/unixpickle/essentials" 10 | ) 11 | 12 | func TestPatchEquivalence(t *testing.T) { 13 | img := image.NewRGBA(image.Rect(0, 0, 213, 192)) 14 | for y := 0; y < img.Bounds().Dy(); y++ { 15 | for x := 0; x < img.Bounds().Dx(); x++ { 16 | img.SetRGBA(x, y, color.RGBA{ 17 | R: uint8(rand.Intn(256)), 18 | G: uint8(rand.Intn(256)), 19 | B: uint8(rand.Intn(256)), 20 | A: 0xff, 21 | }) 22 | } 23 | } 24 | 25 | // Use shallow model, since it has no global norms 26 | // like the deep model (which uses group norm). 27 | // Thus, the shallow model has a finite receptive 28 | // field. 29 | expected := PolishImage(ModelTypeShallow, img) 30 | 31 | actual := []image.Image{ 32 | PolishImagePatches(ModelTypeShallow, img, 100, 50), 33 | PolishImagePatches(ModelTypeShallow, img, 100, 20), 34 | PolishImagePatches(ModelTypeShallow, img, 55, 18), 35 | } 36 | 37 | CaseLoop: 38 | for i, a := range actual { 39 | for y := 0; y < a.Bounds().Dy(); y++ { 40 | for x := 0; x < a.Bounds().Dx(); x++ { 41 | r1, g1, b1, a1 := a.At(x, y).RGBA() 42 | r2, g2, b2, a2 := expected.At(x, y).RGBA() 43 | // Allow for small rounding errors. 44 | threshold := 0x200 45 | if essentials.AbsInt(int(r1-r2)) > threshold || 46 | essentials.AbsInt(int(g1-g2)) > threshold || 47 | essentials.AbsInt(int(b1-b2)) > threshold || 48 | essentials.AbsInt(int(a1-a2)) > threshold { 49 | t.Errorf("case %d: mismatch at (%d, %d)", i, x, y) 50 | continue CaseLoop 51 | } 52 | } 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /training/compute_rf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from polish.models import all_models 7 | 8 | 9 | def main(): 10 | args = arg_parser().parse_args() 11 | 12 | models = all_models() 13 | if args.model_type not in models: 14 | raise ValueError('unknown model: ' + args.model_type) 15 | model = models[args.model_type] 16 | # Get running stats going. 17 | model.train() 18 | for i in range(3): 19 | model(torch.randn(4, 3, 64, 64)) 20 | model.eval() 21 | 22 | # Set all biases and scales high to avoid 23 | # hitting ReLUs. 24 | for p in model.parameters(): 25 | if len(p.shape) == 1: 26 | p.data.detach().fill_(100.0) 27 | 28 | image = nn.Parameter(torch.randn(3, 512, 512)) 29 | out = model(image[None]) 30 | px = torch.sum(out[0, :, 256, 256]) 31 | px.backward() 32 | 33 | bits = torch.abs(image.grad) > 1e-8 34 | bitseq = torch.sum(bits.long(), dim=(0, 1)) 35 | print('radius: %.1f' % bitseq_radius(bitseq.numpy())) 36 | 37 | 38 | def bitseq_radius(seq): 39 | min_idx = -1 40 | max_idx = 0 41 | for i, x in enumerate(seq): 42 | if x: 43 | if min_idx == -1: 44 | min_idx = i 45 | max_idx = i 46 | return (max_idx-min_idx) / 2 47 | 48 | 49 | def arg_parser(): 50 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | parser.add_argument('--model-type', default='shallow') 52 | return parser 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /training/dump_params.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "import tempfile\n", 13 | "from zipfile import ZipFile\n", 14 | "\n", 15 | "import torch\n", 16 | "import polish.models" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "model = polish.models.DeepDenoiser(conv2d=polish.models.SepConv2d)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "text/plain": [ 38 | "" 39 | ] 40 | }, 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "output_type": "execute_result" 44 | } 45 | ], 46 | "source": [ 47 | "# Insert your trained model here.\n", 48 | "model.load_state_dict(torch.load('model.pt', map_location='cpu'))" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 4, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/plain": [ 59 | "odict_keys(['conv1.weight', 'conv1.bias', 'conv2.spatial.weight', 'conv2.spatial.bias', 'conv2.depthwise.weight', 'conv2.depthwise.bias', 'residuals.0.0.weight', 'residuals.0.0.bias', 'residuals.0.2.spatial.weight', 'residuals.0.2.spatial.bias', 'residuals.0.2.depthwise.weight', 'residuals.0.2.depthwise.bias', 'residuals.0.4.spatial.weight', 'residuals.0.4.spatial.bias', 'residuals.0.4.depthwise.weight', 'residuals.0.4.depthwise.bias', 'residuals.1.0.weight', 'residuals.1.0.bias', 'residuals.1.2.spatial.weight', 'residuals.1.2.spatial.bias', 'residuals.1.2.depthwise.weight', 'residuals.1.2.depthwise.bias', 'residuals.1.4.spatial.weight', 'residuals.1.4.spatial.bias', 'residuals.1.4.depthwise.weight', 'residuals.1.4.depthwise.bias', 'residuals.2.0.weight', 'residuals.2.0.bias', 'residuals.2.2.spatial.weight', 'residuals.2.2.spatial.bias', 'residuals.2.2.depthwise.weight', 'residuals.2.2.depthwise.bias', 'residuals.2.4.spatial.weight', 'residuals.2.4.spatial.bias', 'residuals.2.4.depthwise.weight', 'residuals.2.4.depthwise.bias', 'residuals.3.0.weight', 'residuals.3.0.bias', 'residuals.3.2.spatial.weight', 'residuals.3.2.spatial.bias', 'residuals.3.2.depthwise.weight', 'residuals.3.2.depthwise.bias', 'residuals.3.4.spatial.weight', 'residuals.3.4.spatial.bias', 'residuals.3.4.depthwise.weight', 'residuals.3.4.depthwise.bias', 'deconv1.weight', 'deconv1.bias', 'deconv2.weight', 'deconv2.bias', 'conv3.weight', 'conv3.bias'])" 60 | ] 61 | }, 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "output_type": "execute_result" 65 | } 66 | ], 67 | "source": [ 68 | "# We will export the parameters using these keys.\n", 69 | "#\n", 70 | "# Seeing the list may be helpful for implementing the\n", 71 | "# model in the Go API.\n", 72 | "model.state_dict().keys()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 5, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Created zip file of 1845282 bytes.\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "# Create a zip_data variable containing the parameters\n", 90 | "# as a zip file, with one file per array.\n", 91 | "with tempfile.TemporaryDirectory() as temp_dir:\n", 92 | " zip_path = os.path.join(temp_dir, 'params.zip')\n", 93 | " with ZipFile(zip_path, 'w') as f:\n", 94 | " for k, v in model.state_dict().items():\n", 95 | " arr = v.detach().cpu().numpy().flatten()\n", 96 | " f.writestr('%s' % k, arr.tobytes())\n", 97 | " with open(zip_path, 'rb') as f:\n", 98 | " zip_data = f.read()\n", 99 | "print('Created zip file of %d bytes.' % len(zip_data))" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "metadata": { 106 | "collapsed": true 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "def byte_str(b):\n", 111 | " \"\"\"Convert a byte into an escape sequence for a string.\"\"\"\n", 112 | " if b >= 32 and b <= 126 and b != ord('\\\\') and b != ord('\"'):\n", 113 | " return chr(b)\n", 114 | " return '\\\\x%02x' % b" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 8, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "Created code of length 5225861.\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "variable = 'deepModelZipData'\n", 132 | "go_code = 'package polish\\n\\nconst %s = \"' % variable\n", 133 | "go_code += ''.join(byte_str(x) for x in zip_data)\n", 134 | "go_code += '\"\\n'\n", 135 | "print('Created code of length %d.' % len(go_code))\n", 136 | "with open('model_data_deep.go', 'wt+') as f:\n", 137 | " f.write(go_code)" 138 | ] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "display_name": "Python 3", 144 | "language": "python", 145 | "name": "python3" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.6.1" 158 | } 159 | }, 160 | "nbformat": 4, 161 | "nbformat_minor": 2 162 | } 163 | -------------------------------------------------------------------------------- /training/polish/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/polish/4170538ad0dd695389fb90e50ebd9a6d555bf3c5/training/polish/__init__.py -------------------------------------------------------------------------------- /training/polish/baseline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Baseline methods for denoising images. 3 | """ 4 | 5 | import torch 6 | 7 | 8 | def identity_baseline(dataset, num_samples=200): 9 | i = 0 10 | total_loss = 0 11 | count = 0 12 | for inputs, outputs in dataset: 13 | total_loss += torch.mean(torch.abs(inputs[:, :3] - outputs)).item() 14 | count += 1 15 | i += inputs.shape[0] 16 | if i >= num_samples: 17 | break 18 | return total_loss / count 19 | -------------------------------------------------------------------------------- /training/polish/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | 9 | 10 | class PolishDataset(data.IterableDataset): 11 | def __init__(self, data_dir, train=True, aux=False, crop_size=192, samples=(128, 512), 12 | extra_aug=False): 13 | self.crop_size = crop_size 14 | self.samples = samples 15 | self.aux = aux 16 | self.extra_aug = extra_aug 17 | all_dirs = [x for x in os.listdir(data_dir) 18 | if os.path.isdir(os.path.join(data_dir, x)) and not x.startswith('.')] 19 | test_prefixes = 'abcd' 20 | if train: 21 | dirs = [x for x in all_dirs if not any(x.startswith(c) for c in test_prefixes)] 22 | else: 23 | dirs = [x for x in all_dirs if any(x.startswith(c) for c in test_prefixes)] 24 | self.dirs = [os.path.join(data_dir, x) for x in dirs] 25 | if not len(self.dirs): 26 | raise RuntimeError('missing data: %s (%s)' % (data_dir, str(train))) 27 | 28 | def __iter__(self): 29 | paths = self.dirs.copy() 30 | while True: 31 | random.shuffle(paths) 32 | for path in paths: 33 | sample = self.get_sample(path) 34 | yield sample 35 | 36 | def get_sample(self, path): 37 | input_case = random.choice(self.samples) 38 | outputs = np.array(Image.open(os.path.join(path, 'target.png'))) 39 | inputs = np.array(Image.open(os.path.join(path, 'input_%d.png' % input_case))) 40 | aug = Augmentation(inputs.shape[0], self.crop_size, self.extra_aug) 41 | inputs = aug(inputs) 42 | outputs = aug(outputs) 43 | if self.aux: 44 | incident = np.array(Image.open(os.path.join(path, 'incidence.png')))[..., None] 45 | albedo = np.array(Image.open(os.path.join(path, 'albedo.png'))) 46 | inputs = torch.cat([inputs, aug(albedo), aug(incident)], dim=0) 47 | return inputs, outputs 48 | 49 | 50 | class Augmentation: 51 | """ 52 | An augmentation consistently augments input and output 53 | samples. 54 | """ 55 | 56 | def __init__(self, img_size, crop_size, extra_aug): 57 | self.img_size = img_size 58 | self.crop_size = crop_size 59 | self.x = random.randrange(img_size - crop_size) 60 | self.y = random.randrange(img_size - crop_size) 61 | self.flip_x = random.random() < 0.5 62 | self.channel_perm = [0, 1, 2] 63 | if extra_aug: 64 | self.flip_y = random.random() < 0.5 65 | self.rotation = random.randrange(4) 66 | random.shuffle(self.channel_perm) 67 | else: 68 | self.flip_y = False 69 | self.rotation = 0 70 | self.mask = (np.random.uniform(size=(img_size, img_size, 1)) > 0.5).astype('float32') 71 | 72 | def __call__(self, x): 73 | x = x.astype('float32') / 255.0 74 | if x.shape[1] > x.shape[0]: 75 | # Mix up the two samples to generate an 76 | # almost infinite amount of unbiased training 77 | # data. 78 | x = x[:, :self.img_size]*self.mask + x[:, self.img_size:]*(1-self.mask) 79 | x = x[self.y:, self.x:][:self.crop_size, :self.crop_size] 80 | for i in range(self.rotation): 81 | x = np.transpose(x, axes=[1, 0, 2]) 82 | x = x[::-1] 83 | if self.flip_x: 84 | x = x[:, ::-1] 85 | if self.flip_y: 86 | x = x[::-1] 87 | 88 | if x.shape[2] == 3: 89 | x_copy = np.array(x) 90 | for i, p in enumerate(self.channel_perm): 91 | x[..., i] = x_copy[..., self.channel_perm[i]] 92 | 93 | return torch.from_numpy(np.array(x)).permute(2, 0, 1).contiguous() 94 | -------------------------------------------------------------------------------- /training/polish/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Machine learning models for denoising. 3 | """ 4 | 5 | from abc import abstractproperty 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | AUX_FEATURE_CHANNELS = 4 13 | 14 | 15 | def all_models(**kwargs): 16 | """ 17 | Get a dict of supported models. 18 | """ 19 | return { 20 | 'linear': LinearDenoiser(**kwargs), 21 | 'shallow': ShallowDenoiser(**kwargs), 22 | 'deep': DeepDenoiser(**kwargs), 23 | 'bilateral': BilateralDenoiser(**kwargs), 24 | } 25 | 26 | 27 | class Denoiser(nn.Module): 28 | def loss(self, images, targets): 29 | """ 30 | Compute the reconstruction loss using a noisy but 31 | unbiased estimate of the target errors. 32 | """ 33 | return torch.mean(torch.abs(self(images) - targets)) 34 | 35 | @abstractproperty 36 | def dim_lcd(self): 37 | """ 38 | Get a factor that must divide both the width and 39 | height of input images. 40 | """ 41 | pass 42 | 43 | 44 | class LinearDenoiser(Denoiser): 45 | """ 46 | This is the simplest possible denoiser, consisting of 47 | one convolutional filter. 48 | """ 49 | 50 | def __init__(self, aux=False, kernel_size=7): 51 | super().__init__() 52 | if not kernel_size % 2: 53 | raise ValueError('kernel_size must be odd') 54 | self.conv = nn.Conv2d(3 + AUX_FEATURE_CHANNELS if aux else 3, 55 | 3, kernel_size, padding=kernel_size//2) 56 | 57 | @property 58 | def dim_lcd(self): 59 | return 1 60 | 61 | def forward(self, x): 62 | return self.conv(x) 63 | 64 | 65 | class ShallowDenoiser(Denoiser): 66 | """ 67 | A denoiser that has one hidden layer and doesn't 68 | require any spatial LCD. 69 | """ 70 | 71 | def __init__(self, aux=False, kernel_size=5, hidden_size=32): 72 | super().__init__() 73 | if not kernel_size % 2: 74 | raise ValueError('kernel_size must be odd') 75 | self.conv1 = nn.Conv2d(3 + AUX_FEATURE_CHANNELS if aux else 3, 76 | hidden_size, kernel_size, 77 | padding=kernel_size//2) 78 | self.conv2 = nn.Conv2d(hidden_size, 3, kernel_size, padding=kernel_size//2) 79 | 80 | @property 81 | def dim_lcd(self): 82 | return 1 83 | 84 | def forward(self, x): 85 | x = self.conv1(x) 86 | x = F.relu(x) 87 | x = self.conv2(x) 88 | return x 89 | 90 | 91 | class SepConv2d(nn.Module): 92 | def __init__(self, depth_in, depth_out, kernel_size, stride=1, padding=0): 93 | super().__init__() 94 | self.spatial = nn.Conv2d(depth_in, depth_in, kernel_size, 95 | stride=stride, padding=padding, groups=depth_in) 96 | self.depthwise = nn.Conv2d(depth_in, depth_out, 1) 97 | 98 | def forward(self, x): 99 | x = self.spatial(x) 100 | x = F.relu(x) 101 | x = self.depthwise(x) 102 | return x 103 | 104 | 105 | class DeepDenoiser(Denoiser): 106 | """ 107 | A denoiser that has multiple hidden layers. 108 | """ 109 | 110 | def __init__(self, aux=False, conv2d=SepConv2d, batch_norm=True): 111 | super().__init__() 112 | self.conv1 = nn.Conv2d(3 + AUX_FEATURE_CHANNELS if aux else 3, 113 | 64, 5, padding=2, stride=2) 114 | self.conv2 = conv2d(64, 128, 5, padding=2, stride=2) 115 | 116 | def create_norm(): 117 | if batch_norm: 118 | return nn.BatchNorm2d(128) 119 | else: 120 | return nn.GroupNorm(8, 128) 121 | 122 | self.residuals = nn.ModuleList([nn.Sequential( 123 | create_norm(), 124 | nn.ReLU(), 125 | conv2d(128, 256, 3, padding=1), 126 | nn.ReLU(), 127 | conv2d(256, 128, 3, padding=1), 128 | ) for _ in range(4)]) 129 | 130 | self.deconv1 = nn.ConvTranspose2d(128, 64, 4, padding=1, stride=2) 131 | self.deconv2 = nn.ConvTranspose2d(64, 32, 4, padding=1, stride=2) 132 | self.conv3 = nn.Conv2d(32, 3, 3, padding=1) 133 | 134 | @property 135 | def dim_lcd(self): 136 | return 4 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = F.relu(x) 141 | x = self.conv2(x) 142 | 143 | for r in self.residuals: 144 | x = x + r(x) 145 | 146 | x = self.deconv1(x) 147 | x = F.relu(x) 148 | x = self.deconv2(x) 149 | x = F.relu(x) 150 | x = self.conv3(x) 151 | return x 152 | 153 | 154 | class BilateralDenoiser(Denoiser): 155 | """ 156 | A denoiser that uses a bilateral filter with learned 157 | parameters. 158 | """ 159 | 160 | def __init__(self, filter_size=15, aux=False): 161 | super().__init__() 162 | 163 | if not filter_size % 2: 164 | raise ValueError('filter size must be odd') 165 | 166 | self.aux = aux 167 | self.filter_size = filter_size 168 | 169 | distances = np.zeros([filter_size]*2, dtype=np.float32) 170 | middle = filter_size // 2 171 | for i in range(filter_size): 172 | for j in range(filter_size): 173 | distances[i, j] = (i-middle)*(i-middle) + (j-middle)*(j-middle) 174 | 175 | self.register_buffer('distances', torch.from_numpy(distances).view(-1)) 176 | self.blur_sigma = nn.Parameter(torch.Tensor([5.0])[0]) 177 | self.diff_sigma = nn.Parameter(torch.Tensor([1.0])[0]) 178 | 179 | @property 180 | def dim_lcd(self): 181 | return 1 182 | 183 | def forward(self, x): 184 | if self.aux: 185 | x = x[:, :3] 186 | 187 | # Create patches tensor: [N x C x K^2 x H x W] 188 | padding = self.filter_size // 2 189 | # Pad with huge negative values instead of zeros 190 | # so that the filter does not include the padding. 191 | padded = F.pad(x + 100, [padding] * 4) - 100 192 | patches = F.unfold(padded, self.filter_size) 193 | patches = patches.view(*x.shape[:2], self.filter_size**2, *x.shape[2:]) 194 | 195 | diffs = torch.pow(patches - x[:, :, None], 2) 196 | diffs = diffs / torch.pow(self.diff_sigma, 2) 197 | 198 | blurs = torch.zeros_like(patches) + self.distances[None, None, :, None, None] 199 | blurs = blurs / torch.pow(self.blur_sigma, 2) 200 | 201 | probs = torch.exp(-(diffs + blurs)) 202 | probs = probs / torch.sum(probs, dim=2, keepdim=True) 203 | 204 | return torch.sum(patches * probs, dim=2) 205 | -------------------------------------------------------------------------------- /training/run_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | 7 | from polish.models import all_models 8 | 9 | 10 | def main(): 11 | args = arg_parser().parse_args() 12 | 13 | models = all_models() 14 | if args.model_type not in models: 15 | raise ValueError('unknown model: ' + args.model_type) 16 | model = models[args.model_type] 17 | model.load_state_dict(torch.load(args.model_path, map_location='cpu')) 18 | model.eval() 19 | 20 | img_in = torch.from_numpy(np.array(Image.open(args.image_in))).float() / 255.0 21 | img_in = img_in.permute(2, 0, 1) 22 | with torch.no_grad(): 23 | img_out = model(img_in[None], torch.Tensor([0.0])) 24 | img_out = img_out[0].permute(1, 2, 0) 25 | img_out = img_out.clamp(0, 1) * 255 26 | pil_img = Image.fromarray(img_out.detach().cpu().numpy().astype('uint8')) 27 | pil_img.save(args.image_out) 28 | 29 | 30 | def arg_parser(): 31 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 32 | parser.add_argument('--model-path', default='model.pt') 33 | parser.add_argument('--model-type', default='shallow') 34 | parser.add_argument('image_in') 35 | parser.add_argument('image_out') 36 | return parser 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | import torch.optim as optim 8 | 9 | from polish.baseline import identity_baseline 10 | from polish.dataset import PolishDataset 11 | from polish.models import all_models 12 | 13 | 14 | def main(): 15 | use_cuda = torch.cuda.is_available() 16 | device = torch.device('cuda') if use_cuda else torch.device('cpu') 17 | 18 | args = arg_parser().parse_args() 19 | 20 | models = all_models(aux=args.aux) 21 | if args.model_type not in models: 22 | raise ValueError('unknown model: ' + args.model_type) 23 | model = models[args.model_type] 24 | if os.path.exists(args.model_path): 25 | model.load_state_dict(torch.load(args.model_path, map_location='cpu')) 26 | model.to(device) 27 | 28 | trains, tests = create_datasets(args.data, args.batch, aux=args.aux) 29 | print('baseline: train %f, test %f' % (identity_baseline(trains), identity_baseline(tests))) 30 | 31 | opt = optim.Adam(model.parameters(), lr=args.lr) 32 | 33 | i = 0 34 | for (train_in, train_out), (test_in, test_out) in zip(trains, tests): 35 | train_in, train_out = train_in.to(device), train_out.to(device) 36 | test_in, test_out = test_in.to(device), test_out.to(device) 37 | 38 | train_loss = model.loss(train_in, train_out) 39 | with torch.no_grad(): 40 | model.eval() 41 | test_loss = model.loss(test_in, test_out) 42 | model.train() 43 | 44 | opt.zero_grad() 45 | train_loss.backward() 46 | opt.step() 47 | 48 | if not i % args.save_interval: 49 | torch.save(model.state_dict(), args.model_path) 50 | with torch.no_grad(): 51 | model.eval() 52 | test_pred = model(test_in) 53 | model.train() 54 | save_rendering(test_in, test_pred, test_out) 55 | 56 | print('step %d: train=%f test=%f' % (i, train_loss.item(), test_loss.item())) 57 | i += 1 58 | 59 | 60 | def create_datasets(data_dir, batch, **kwargs): 61 | dl_kwargs = {'num_workers': 8, 'pin_memory': True, 'batch_size': batch} 62 | train_loader = torch.utils.data.DataLoader(PolishDataset(data_dir, extra_aug=True, **kwargs), 63 | **dl_kwargs) 64 | test_loader = torch.utils.data.DataLoader(PolishDataset(data_dir, train=False, **kwargs), 65 | **dl_kwargs) 66 | return train_loader, test_loader 67 | 68 | 69 | def save_rendering(inputs, outputs, targets): 70 | joined = torch.cat([inputs[:, :3], outputs, targets], dim=-1).permute(0, 2, 3, 1).contiguous() 71 | joined = joined.view(-1, *joined.shape[2:]) 72 | arr = joined.detach().cpu().numpy() 73 | arr = np.clip(arr, 0, 1) 74 | arr = (arr * 255).astype('uint8') 75 | Image.fromarray(arr).save('samples.png') 76 | 77 | 78 | def arg_parser(): 79 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 80 | parser.add_argument('--data', default='../data') 81 | parser.add_argument('--aux', action='store_true') 82 | parser.add_argument('--model-path', default='model.pt') 83 | parser.add_argument('--model-type', default='shallow') 84 | parser.add_argument('--save-interval', default=10, type=int) 85 | parser.add_argument('--lr', default=0.001, type=float) 86 | parser.add_argument('--batch', default=4, type=int) 87 | return parser 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | --------------------------------------------------------------------------------