├── README_CN.md ├── README.md ├── GCN_and_Cora.ipynb ├── queuepool.jl ├── dataset.jl ├── VGG_and_Cifar.ipynb ├── MLP_and_MNIST.ipynb ├── DCGAN_and_Fashion.ipynb ├── Julia_quickstart.ipynb └── ResNet_and_ImageNet.ipynb /README_CN.md: -------------------------------------------------------------------------------- 1 | ## Julia 中的深度学习 2 | 3 | [English Document](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/README.md) 4 | 5 | 该项目编写、收集和整理了 Julia 中的深度学习方法。模型的结构和训练方法可能和论文略有出入,但复现了其核心思想。我们希望该项目能够为相关科研工作者使用模板快速开始项目提供帮助。非常欢迎其他深度学习方法的实现和建议。 6 | 7 | ## 实现列表 8 | 9 | - Julia 教程 10 | - [开始教程](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/Julia_quickstart.ipynb) 11 | 12 | - 卷积神经网络 13 | - [MLP+MNIST](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/MLP_and_MNIST.ipynb) 14 | - [VGG+Cifar10](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/VGG_and_Cifar.ipynb) 15 | - [ResNet+ImageNet](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/ResNet_and_ImageNet.ipynb) 16 | - [UNet+ISBI](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/UNet_and_ISBI.ipynb) 17 | 18 | - 生成对抗网络 19 | - [DCGAN+Fashion](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/DCGAN_and_Fashion.ipynb) 20 | 21 | - 图卷积网络 22 | - [GCN+Cora](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/GCN_and_Cora.ipynb) 23 | 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Julia-Deeplearning 2 | 3 | [中文文档](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/README_CN.md) 4 | 5 | Collection of Julia implementations of Deep Learning methods varieties presented in research papers. Model architectures will not always mirror the ones proposed in the papers, but I have chosen to focus on getting the core ideas covered instead of getting every layer configuration right. Contributions and suggestions of deeplearning methods to implement are very welcomed. 6 | 7 | ## Table of Contents 8 | 9 | - Julia Tutorials 10 | - [Julia Quickstart](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/Julia_quickstart.ipynb) 11 | 12 | - Convolutional Neural Networks 13 | - [MLP+MNIST](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/MLP_and_MNIST.ipynb) 14 | - [VGG+Cifar10](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/VGG_and_Cifar.ipynb) 15 | - [ResNet+ImageNet](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/ResNet_and_ImageNet.ipynb) 16 | - [UNet+ISBI](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/UNet_and_ISBI.ipynb) 17 | 18 | - Generative Adversarial Networks 19 | - [DCGAN+Fashion](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/DCGAN_and_Fashion.ipynb) 20 | 21 | - Graph Convolutional Networks 22 | - [GCN+Cora](https://github.com/tczhangzhi/Julia-Deeplearning/blob/master/GCN_and_Cora.ipynb) 23 | -------------------------------------------------------------------------------- /GCN_and_Cora.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### GCN and Cora\n", 8 | "\n", 9 | "在该实现中您可以看到如下功能:\n", 10 | "1. 下载 JID2 数据集\n", 11 | "2. 使用 GeometricFlux 定义模型并进行训练\n", 12 | "\n", 13 | "In this template you can finish the following functions:\n", 14 | "1. Download the JID2 data set\n", 15 | "2. Define and train the model using GeometricFlux" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | ";wget \"https://github.com/tczhangzhi/Julia-Deeplearning/releases/download/v0.0.1/cora_features.jld2\"" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | ";wget \"https://github.com/tczhangzhi/Julia-Deeplearning/releases/download/v0.0.1/cora_graph.jld2\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | ";wget \"https://github.com/tczhangzhi/Julia-Deeplearning/releases/download/v0.0.1/cora_labels.jld2\"" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stderr", 52 | "output_type": "stream", 53 | "text": [ 54 | "┌ Warning: Package GeometricFlux does not have LightGraphs in its dependencies:\n", 55 | "│ - If you have GeometricFlux checked out for development and have\n", 56 | "│ added LightGraphs as a dependency but haven't updated your primary\n", 57 | "│ environment's manifest file, try `Pkg.resolve()`.\n", 58 | "│ - Otherwise you may need to report an issue with GeometricFlux\n", 59 | "│ Loading LightGraphs into GeometricFlux from project dependency, future warnings for GeometricFlux are suppressed.\n", 60 | "└ @ nothing nothing:909\n" 61 | ] 62 | }, 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "accuracy(train_X, train_y) = 0.21565731166912852\n", 68 | "accuracy(train_X, train_y) = 0.3068685376661743\n", 69 | "accuracy(train_X, train_y) = 0.404357459379616\n", 70 | "accuracy(train_X, train_y) = 0.49741506646971934\n", 71 | "accuracy(train_X, train_y) = 0.569423929098966\n", 72 | "accuracy(train_X, train_y) = 0.6347858197932054\n", 73 | "accuracy(train_X, train_y) = 0.681314623338257\n", 74 | "accuracy(train_X, train_y) = 0.7127031019202363\n", 75 | "accuracy(train_X, train_y) = 0.7341211225997046\n", 76 | "accuracy(train_X, train_y) = 0.7573855243722304\n", 77 | "accuracy(train_X, train_y) = 0.7784342688330872\n", 78 | "accuracy(train_X, train_y) = 0.7913589364844904\n", 79 | "accuracy(train_X, train_y) = 0.8035450516986706\n", 80 | "accuracy(train_X, train_y) = 0.8157311669128509\n", 81 | "accuracy(train_X, train_y) = 0.8282865583456426\n", 82 | "accuracy(train_X, train_y) = 0.8338257016248154\n", 83 | "accuracy(train_X, train_y) = 0.8415805022156573\n", 84 | "accuracy(train_X, train_y) = 0.8508124076809453\n", 85 | "accuracy(train_X, train_y) = 0.8589364844903988\n", 86 | "accuracy(train_X, train_y) = 0.8670605612998523\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "using GeometricFlux\n", 92 | "using Flux\n", 93 | "using Flux: onehotbatch, onecold, crossentropy, throttle\n", 94 | "using JLD2 # use v0.1.2\n", 95 | "using Statistics: mean\n", 96 | "using SparseArrays\n", 97 | "using LightGraphs.SimpleGraphs\n", 98 | "using LightGraphs: adjacency_matrix\n", 99 | "using CuArrays\n", 100 | "\n", 101 | "@load \"cora_features.jld2\" features\n", 102 | "@load \"cora_labels.jld2\" labels\n", 103 | "@load \"cora_graph.jld2\" g\n", 104 | "\n", 105 | "num_nodes = 2708\n", 106 | "num_features = 1433\n", 107 | "hidden = 16\n", 108 | "target_catg = 7\n", 109 | "epochs = 20\n", 110 | "\n", 111 | "## Preprocessing data\n", 112 | "train_X = Float32.(features) # dim: num_features * num_nodes\n", 113 | "train_y = Float32.(labels) # dim: target_catg * num_nodes\n", 114 | "\n", 115 | "adj_mat = Matrix{Float32}(adjacency_matrix(g))\n", 116 | "\n", 117 | "## Model\n", 118 | "model = Chain(GCNConv(adj_mat, num_features=>hidden, relu),\n", 119 | " Dropout(0.5),\n", 120 | " GCNConv(adj_mat, hidden=>target_catg),\n", 121 | " softmax)\n", 122 | "\n", 123 | "## Loss\n", 124 | "loss(x, y) = crossentropy(model(x), y)\n", 125 | "accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))\n", 126 | "\n", 127 | "## Training\n", 128 | "ps = Flux.params(model)\n", 129 | "train_data = [(train_X, train_y)]\n", 130 | "opt = ADAM(0.0001)\n", 131 | "evalcb() = @show(accuracy(train_X, train_y))\n", 132 | "\n", 133 | "for i = 1:epochs\n", 134 | " Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))\n", 135 | "end" 136 | ] 137 | } 138 | ], 139 | "metadata": { 140 | "kernelspec": { 141 | "display_name": "Julia 1.4.1", 142 | "language": "julia", 143 | "name": "julia-1.4" 144 | }, 145 | "language_info": { 146 | "file_extension": ".jl", 147 | "mimetype": "application/julia", 148 | "name": "julia", 149 | "version": "1.4.1" 150 | } 151 | }, 152 | "nbformat": 4, 153 | "nbformat_minor": 2 154 | } 155 | -------------------------------------------------------------------------------- /queuepool.jl: -------------------------------------------------------------------------------- 1 | using Distributed, Serialization 2 | 3 | mutable struct QueuePool 4 | # The worker PIDs 5 | workers::Vector{Int} 6 | 7 | # Channels for communication 8 | queued_jobs::RemoteChannel 9 | results::RemoteChannel 10 | kill_switch::RemoteChannel 11 | 12 | # The ID of the next job to be submitted 13 | next_job::Int 14 | 15 | # Buffer space where we store results for out-of-order execution 16 | results_buffer::Dict{Int,Any} 17 | end 18 | 19 | function QueuePool(num_workers::Int, proc_func::Function, setup::Expr = :nothing, queue_size=128) 20 | workers = addprocs(num_workers) #; topology = :master_worker) 21 | 22 | # Tell the workers to include this file and whatever other setup the need, 23 | # so that they can communicate with us and complete their tasks. 24 | Distributed.remotecall_eval(Main, workers, quote 25 | include($(@__FILE__)) 26 | Core.eval(Main, $(setup)) 27 | end) 28 | 29 | # Create our QueuePool 30 | qp = QueuePool( 31 | workers, 32 | RemoteChannel(() -> Channel{Tuple}(Inf)), 33 | RemoteChannel(() -> Channel{Tuple}(queue_size)), 34 | RemoteChannel(() -> Channel{Bool}(1)), 35 | 0, 36 | Dict{Int,Any}(), 37 | ) 38 | 39 | # immediately add a finalizer to it to flip the kill switch and wait for the 40 | # workers to finish. EDIT: This doesn't work because we can't switch tasks 41 | # in a finalizer, apparently. :( 42 | #= 43 | finalizer(close, qp) 44 | =# 45 | 46 | # Launch workers, running the `worker_task` with a handle to this QueuePool object 47 | # and the processing function that will be called within the worker loop. 48 | for id in workers 49 | Distributed.remote_do(worker_task, id, qp, proc_func) 50 | end 51 | 52 | # Return QP 53 | return qp 54 | end 55 | 56 | function close(qp::QueuePool) 57 | # Tell the worker processes to die 58 | close(qp.queued_jobs) 59 | put!(qp.kill_switch, true) 60 | 61 | # Wait for the workers to descend into the long, dark sleep 62 | rmprocs(qp.workers...; waitfor=10) 63 | end 64 | 65 | function worker_task(qp::QueuePool, proc_func) 66 | # Loop unless we're burning this whole queue pool down 67 | while !isready(qp.kill_switch) 68 | # Grab the next queued job from the master 69 | job_id, x = take!(qp.queued_jobs) 70 | 71 | local y 72 | try 73 | # Push x through proc_func to get y 74 | y = proc_func(x) 75 | catch e 76 | if isa(e, InterruptException) 77 | println(e) 78 | rethrow(e) 79 | end 80 | # Just skip bad processing runs 81 | @warn("Failed to run worker task $(proc_func) on $(x): $(e)") 82 | continue 83 | end 84 | 85 | # Push the result onto qp.results 86 | put!(qp.results, (job_id, y)) 87 | end 88 | end 89 | 90 | 91 | """ 92 | try_buffer_result!(qp::QueuePool) 93 | Does a nonblocking read of the next result from the QueuePool into our result 94 | buffer. If no result is available, returns `nothing` immediately. 95 | """ 96 | function try_buffer_result!(qp::QueuePool) 97 | if isready(qp.results) 98 | job_id, result = take!(qp.results) 99 | qp.results_buffer[job_id] = result 100 | return job_id 101 | end 102 | return 103 | end 104 | 105 | # Check to see if it's `nothing` and `yield()` if it is. 106 | function try_buffer_result!(qp::QueuePool, t_start::Float64, timeout::Nothing) 107 | if try_buffer_result!(qp) == nothing 108 | # No new results available, so just yield 109 | yield() 110 | end 111 | end 112 | 113 | # Check to see if we've broken through our timeout 114 | function try_buffer_result!(qp::QueuePool, t_start::Float64, timeout::Float64) 115 | try_buffer_result!(qp, t_start, nothing) 116 | 117 | if (time() - t_start) > timeout 118 | error("timeout within fetch_result") 119 | end 120 | end 121 | 122 | 123 | 124 | """ 125 | push_job!(qp::QueuePool, value) 126 | Push a new job onto the QueuePool, returning the associated job id with this job, 127 | for future usage with `fetch_result(qp, job_id)` 128 | """ 129 | function push_job!(qp::QueuePool, value) 130 | job_id = qp.next_job 131 | qp.next_job += 1 132 | 133 | put!(qp.queued_jobs, (job_id, value)) 134 | return job_id 135 | end 136 | 137 | """ 138 | fetch_result(qp::QueuePool; timeout = nothing) 139 | Return a result from the QueuePool, regardless of order. By default, will wait 140 | for forever; set `timeout` to a value in seconds to time out and throw an error 141 | if a value does not arrive. 142 | """ 143 | function fetch_result(qp::QueuePool; timeout = nothing) 144 | # If we don't have any results buffered, then pull one in 145 | t_start = time() 146 | while isempty(qp.results_buffer) 147 | try_buffer_result!(qp, t_start, timeout) 148 | end 149 | return pop!(qp.results_buffer).second 150 | end 151 | 152 | """ 153 | fetch_result(qp::QueuePool, job_id::Int; timeout = nothing) 154 | Return a result from the QueuePool, in specific order. By default, will wait 155 | for forever; set `timeout` to a value in seconds to time out and throw an error 156 | if a value does not arrive. 157 | """ 158 | function fetch_result(qp::QueuePool, job_id::Int; timeout=nothing) 159 | # Keep accumulating results until we get the job_id we're interested in. 160 | t_start = time() 161 | while !haskey(qp.results_buffer, job_id) 162 | try_buffer_result!(qp, t_start, timeout) 163 | end 164 | return pop!(qp.results_buffer, job_id) 165 | end -------------------------------------------------------------------------------- /dataset.jl: -------------------------------------------------------------------------------- 1 | using Flux, Images, Metalhead 2 | using Distributed, Random, Printf 3 | 4 | include("queuepool.jl") 5 | 6 | """ 7 | load_img(filename::String) 8 | Thin wrapper around `Images.load()` that immediately converts the resultant 9 | array to a homogenous Float32 tensor. 10 | """ 11 | function load_img(filename::AbstractString) 12 | # Load the image 13 | im = load(filename) 14 | 15 | # Permute dimensions to get (R, G, B), then expand to four dimensions to 16 | # get a singleton batch axis, as the rest of Metalhead expects. 17 | im = permutedims(channelview(RGB.(im)), (3, 2, 1))[:,:,:,:] 18 | 19 | # Return this as a Float32, (This should no longer be necessary once Flux 20 | # does the conversion for us, but until then we'll frontload it.) 21 | return Float32.(im) 22 | end 23 | 24 | # Resize an image such that its smallest dimension is the given length 25 | function resize_smallest_dimension(im::AbstractArray{T, 4}, len) where {T} 26 | # Images.jl doesn't like our batch axis, so drop that temporarily 27 | im = im[:,:,:,1] 28 | 29 | reduction_factor = len/minimum(size(im)[1:2]) 30 | new_size = size(im) 31 | new_size = ( 32 | round(Int, size(im,1)*reduction_factor), 33 | round(Int, size(im,2)*reduction_factor), 34 | new_size[3], # number of channels 35 | ) 36 | if reduction_factor < 1.0 37 | # Use restrict() to quarter our size each step, which is much faster 38 | # than a single large Gaussian imfilter(). 39 | while reduction_factor < 0.5 40 | im = cat((restrict(im[:,:,cidx]) for cidx in 1:size(im, 3))..., dims=3) 41 | reduction_factor *= 2 42 | end 43 | # low-pass filter 44 | im = imfilter(im, KernelFactors.gaussian(0.75/reduction_factor), Inner()) 45 | end 46 | 47 | # Expand the result back up to a 4d tensor 48 | return imresize(im, new_size)[:,:,:,:] 49 | end 50 | 51 | 52 | """ 53 | center_crop(im, len) 54 | Extracts the `len`-by-`len` square of pixels centered within `im`. 55 | """ 56 | function center_crop(im::AbstractArray{T, 4}, len::Integer) where {T} 57 | l2 = div(len,2) 58 | adjust = len % 2 == 0 ? 1 : 0 59 | return im[ 60 | div(end,2)-l2 : div(end,2)+l2 - adjust, 61 | div(end,2)-l2 : div(end,2)+l2 - adjust, 62 | :, # across all channels 63 | :, # across all batches 64 | ] 65 | end 66 | 67 | 68 | """ 69 | channel_normalize(im) 70 | Normalizes the channels of `im` according to the standard ImageNet training 71 | coefficiients, yielding roughly unit normal distribution outputs across the 72 | ImageNet corpus. (These values gratefully taken from PyTorch) 73 | """ 74 | function channel_normalize(im::AbstractArray{T, 4}) where {T} 75 | # Convert our channel normalization arrays (in R, G, B) order 76 | # to 1x1x3x1 tensors so that we can use dot-operators to directly 77 | # subtract and divide to normalize. 78 | μ = reshape([0.485, 0.456, 0.406], (1, 1, 3, 1)) 79 | σ = reshape([0.229, 0.224, 0.225], (1, 1, 3, 1)) 80 | return (im .- μ)./σ 81 | end 82 | 83 | 84 | 85 | """ 86 | imagenet_val_preprocess(im) 87 | Perform the typical ImageNet preprocessing steps for validation of a resize, 88 | center crop, and normalization. 89 | """ 90 | function imagenet_val_preprocess(im) 91 | # Make sure that `im` is loaded 92 | t_0 = time() 93 | im = load_img(im) 94 | t_1 = time() 95 | 96 | # Resize such that smallest edge is 256 pixels long, center-crop to 97 | # 224x224, then normalize channels and return 98 | im = resize_smallest_dimension(im, 256) 99 | t_2 = time() 100 | im = center_crop(im, 224) 101 | t_3 = time() 102 | return (channel_normalize(im), t_1 - t_0, t_2 - t_1, t_3 - t_2) 103 | end 104 | 105 | 106 | """ 107 | imagenet_train_preprocess(im) 108 | Perform the typical ImageNet preprocessing steps for training of a random crop, 109 | resize, random flip, and normalization. 110 | """ 111 | function imagenet_train_preprocess(im) 112 | # TODO: random crop 113 | return imagenet_val_preprocess(im) 114 | end 115 | 116 | function recursive_readdir(root::String) 117 | ret = String[] 118 | for (r, dirs, files) in walkdir(root) 119 | for f in files 120 | push!(ret, joinpath(r, f)[length(root)+2:end]) 121 | end 122 | end 123 | return ret 124 | end 125 | 126 | 127 | """ 128 | imagenet_train_data_loader(filename) 129 | Worker thread data loading routine; loads a filename, figures out its label, 130 | and returns the (x, y) pair for later collation. This is used for training, 131 | and expects data pathnames to look something like `train/nXXX/nXXX_YYYYY.JPEG` 132 | """ 133 | function imagenet_train_data_loader(filename::String) 134 | t_start = time() 135 | synset_mapping = Metalhead.ImageNet.synset_mapping 136 | 137 | # Load image file and preprocess it to get x 138 | x, dt0, dt1, dt2 = imagenet_train_preprocess(filename) 139 | 140 | # Map directory name to class label, then one-hot that 141 | label = split(basename(filename), "_")[1] 142 | y = Flux.onehot(synset_mapping[label], 1:length(synset_mapping))[:,:] 143 | 144 | #println(@sprintf("[%.03fs, %.03fs, %.03fs]: %s", dt0, dt1, dt2, filename)) 145 | return (x, y) 146 | end 147 | 148 | """ 149 | imagenet_val_data_loader(filename) 150 | Worker thread data loading routine; loads a filename, figures out its label, 151 | and returns the (x, y) pair for later collation. This is used for validation, 152 | and expects data basenames to look something like `test_XXX.JPEG`. 153 | """ 154 | function imagenet_val_data_loader(filename::String) 155 | t_start = time() 156 | synset_mapping = Metalhead.ImageNet.synset_mapping 157 | 158 | # Load image file and preprocess it to get x 159 | x = Metalhead.imagenet_val_preprocess(filename) 160 | 161 | # Map filename to class index, then one-hot that 162 | test_idx = parse(Int, split(splitext(basename(filename))[1], "_")[end]) 163 | label = Metalhead.ImageNet.imagenet_val_labels[test_idx] 164 | y = Flux.onehot(synset_mapping[label], 1:length(synset_mapping))[:,:] 165 | 166 | println(@sprintf("%s: %.3fs", filename, time() - t_start)) 167 | return (x, y) 168 | end 169 | 170 | struct ImagenetDataset 171 | # Data we're initialized with 172 | dataset_root::String 173 | batch_size::Int 174 | data_loader::Function 175 | 176 | # Data we calculate once, at startup 177 | filenames::Vector{String} 178 | queue_pool::QueuePool 179 | 180 | function ImagenetDataset(dataset_root::String, num_workers::Int, batch_size::Int, 181 | data_loader::Function = imagenet_val_data_loader) 182 | # Scan dataset_root for files 183 | filenames = filter(f -> endswith(f, ".JPEG"), recursive_readdir(dataset_root)) 184 | 185 | @assert !isempty(filenames) "Empty dataset folder!" 186 | @assert num_workers >= 1 "Must have nonnegative integer number of workers!" 187 | @assert batch_size >= 1 "Must have nonnegative integer batch size!" 188 | 189 | # Start our worker pool 190 | @info("Adding $(num_workers) new data workers...") 191 | queue_pool = QueuePool(num_workers, data_loader, quote 192 | # The workers need to be able to load images and preprocess them via Metalhead 193 | using Flux, Images, Metalhead 194 | include($(@__FILE__)) 195 | end) 196 | 197 | return new(dataset_root, batch_size, data_loader, filenames, queue_pool) 198 | end 199 | end 200 | 201 | # Serialize the arguments needed to recreate this ImagenetDataset 202 | function freeze_args(id::ImagenetDataset) 203 | return (id.dataset_root, length(id.queue_pool.workers), id.batch_size, id.data_loader) 204 | end 205 | Base.length(id::ImagenetDataset) = div(length(id.filenames),id.batch_size) 206 | 207 | mutable struct ImagenetIteratorState 208 | batch_idx::Int 209 | job_offset::Int 210 | 211 | function ImagenetIteratorState(id::ImagenetDataset) 212 | @info("Creating IIS with $(length(id.filenames)) images") 213 | 214 | # Build permutation for this iteration 215 | permutation = shuffle(1:length(id.filenames)) 216 | 217 | # Push first job, save value to get job_offset (we know that all jobs 218 | # within this iteration will be consequtive, so we only save the offset 219 | # of the first one, and can use that to determine the job ids of every 220 | # subsequent job: 221 | filename = joinpath(id.dataset_root, id.filenames[permutation[1]]) 222 | job_offset = push_job!(id.queue_pool, filename) 223 | 224 | # Next, push every other job 225 | for pidx in permutation[2:end] 226 | filename = joinpath(id.dataset_root, id.filenames[pidx]) 227 | push_job!(id.queue_pool, filename) 228 | end 229 | return new( 230 | 0, 231 | job_offset, 232 | ) 233 | end 234 | end 235 | 236 | function Base.iterate(id::ImagenetDataset, state=ImagenetIteratorState(id)) 237 | # If we're at the end of this epoch, give up the ghost 238 | if state.batch_idx > length(id) 239 | return nothing 240 | end 241 | 242 | # Otherwise, wait for the next batch worth of jobs to finish on our queue pool 243 | next_batch_job_ids = state.job_offset .+ (0:(id.batch_size-1)) .+ id.batch_size*state.batch_idx 244 | # Next, wait for the currently-being-worked-on batch to be done. 245 | pairs = fetch_result.(Ref(id.queue_pool), next_batch_job_ids) 246 | state.batch_idx += 1 247 | 248 | # Collate X's and Y's into big tensors: 249 | X = cat((p[1] for p in pairs)...; dims=ndims(pairs[1][1])) 250 | Y = cat((p[2] for p in pairs)...; dims=ndims(pairs[1][2])) 251 | 252 | # Return the fruit of our labor 253 | return (X, Y), state 254 | end -------------------------------------------------------------------------------- /VGG_and_Cifar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### VGG and Cifar\n", 8 | "\n", 9 | "在该实现中您可以看到如下功能:\n", 10 | "1. 设置训练过程中的参数\n", 11 | "2. 读取 Cifar10 数据集并创建训练集和测试集\n", 12 | "3. 使用图像增广\n", 13 | "4. 使用预训练的 VGG19 模型\n", 14 | "5. 训练、测试和保存模型\n", 15 | "\n", 16 | "In this template you can finish the following functions:\n", 17 | "1. Set the parameters during training\n", 18 | "2. Read Cifar10 data set and create training set and test set\n", 19 | "3. Use image augmentation\n", 20 | "4. Use the pre-trained VGG19 model\n", 21 | "5. Train, test and save the model" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "using Flux, Metalhead, Statistics\n", 31 | "using Flux: onehotbatch, onecold, logitcrossentropy, throttle, flatten\n", 32 | "using Metalhead: trainimgs\n", 33 | "using Parameters: @with_kw\n", 34 | "using Images: channelview\n", 35 | "using Statistics: mean\n", 36 | "using Base.Iterators: partition\n", 37 | "using CUDAapi" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stderr", 47 | "output_type": "stream", 48 | "text": [ 49 | "┌ Info: Training on GPU-3\n", 50 | "└ @ Main In[2]:7\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "using CUDAapi, CUDAdrv, CUDAnative\n", 56 | "gpu_id = 3 ## set < 0 for no cuda, >= 0 for using a specific device (if available)\n", 57 | "\n", 58 | "if has_cuda_gpu() && gpu_id >=0\n", 59 | " device!(gpu_id)\n", 60 | " device = Flux.gpu\n", 61 | " @info \"Training on GPU-$(gpu_id)\"\n", 62 | "else\n", 63 | " device = Flux.cpu\n", 64 | " @info \"Training on CPU\"\n", 65 | "end" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "为了便于调整参数和记录试验结果,我们需要使用 parameters 将参数记录和封装。\n", 73 | "\n", 74 | "In order to easily adjust the parameters and record the test results, we need to use parameters to record and encapsulate the parameters." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "Args" 86 | ] 87 | }, 88 | "execution_count": 3, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "using Parameters: @with_kw\n", 95 | "@with_kw mutable struct Args\n", 96 | " batchsize::Int = 128\n", 97 | " throttle::Int = 10\n", 98 | " lr::Float64 = 5e-5\n", 99 | " epochs::Int = 10\n", 100 | " splitr_::Float64 = 0.1\n", 101 | "end" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "参照pytorch实现图像增广的预处理过程。\n", 109 | "\n", 110 | "Refer to pytorch to realize the preprocessing process of image augmentation." 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 4, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "data": { 120 | "text/plain": [ 121 | "preprocess (generic function with 1 method)" 122 | ] 123 | }, 124 | "execution_count": 4, 125 | "metadata": {}, 126 | "output_type": "execute_result" 127 | } 128 | ], 129 | "source": [ 130 | "# without augmentation\n", 131 | "function preprocess(X)\n", 132 | " Float32.(permutedims(channelview(X), (2, 3, 1)))\n", 133 | "end\n", 134 | "\n", 135 | "# # with augmentation\n", 136 | "# function resize_smallest_dimension(im, len)\n", 137 | "# reduction_factor = len/minimum(size(im)[1:2])\n", 138 | "# new_size = size(im)\n", 139 | "# new_size = (\n", 140 | "# round(Int, size(im,1)*reduction_factor),\n", 141 | "# round(Int, size(im,2)*reduction_factor),\n", 142 | "# )\n", 143 | "# if reduction_factor < 1.0\n", 144 | "# # Images.jl's imresize() needs to first lowpass the image, it won't do it for us\n", 145 | "# im = imfilter(im, KernelFactors.gaussian(0.75/reduction_factor), Inner())\n", 146 | "# end\n", 147 | "# return imresize(im, new_size)\n", 148 | "# end\n", 149 | "\n", 150 | "# # Take the len-by-len square of pixels at the center of image `im`\n", 151 | "# function center_crop(im, len)\n", 152 | "# l2 = div(len,2)\n", 153 | "# adjust = len % 2 == 0 ? 1 : 0\n", 154 | "# return im[div(end,2)-l2:div(end,2)+l2-adjust,div(end,2)-l2:div(end,2)+l2-adjust]\n", 155 | "# end\n", 156 | "\n", 157 | "# function preprocess(im)\n", 158 | "# # Resize such that smallest edge is 256 pixels long\n", 159 | "# im = resize_smallest_dimension(im, 256)\n", 160 | "\n", 161 | "# # Center-crop to 224x224\n", 162 | "# im = center_crop(im, 224)\n", 163 | "\n", 164 | "# # Convert to channel view and normalize (these coefficients taken\n", 165 | "# # from PyTorch's ImageNet normalization code)\n", 166 | "# μ = [0.485, 0.456, 0.406]\n", 167 | "# # the sigma numbers are suspect: they cause the image to go outside of 0..1\n", 168 | "# # 1/0.225 = 4.4 effective scale\n", 169 | "# σ = [0.229, 0.224, 0.225]\n", 170 | "# #im = (channelview(im) .- μ)./σ\n", 171 | "# im = (channelview(im) .- μ)\n", 172 | "\n", 173 | "# # Convert from CHW (Image.jl's channel ordering) to WHCN (Flux.jl's ordering)\n", 174 | "# # and enforce Float32, as that seems important to Flux\n", 175 | "# # result is (224, 224, 3, 1)\n", 176 | "# #return Float32.(permutedims(im, (3, 2, 1))[:,:,:,:].*255) # why\n", 177 | "# return Float32.(permutedims(im, (3, 2, 1))[:,:,:,:])\n", 178 | "# end" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "构建训练集合、验证集合和测试集合。\n", 186 | "\n", 187 | "Build training set, validation set and test set." 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 5, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/plain": [ 198 | "get_test_data (generic function with 1 method)" 199 | ] 200 | }, 201 | "execution_count": 5, 202 | "metadata": {}, 203 | "output_type": "execute_result" 204 | } 205 | ], 206 | "source": [ 207 | "using Metalhead: trainimgs\n", 208 | "using Images, ImageMagick\n", 209 | "\n", 210 | "function get_processed_data(args)\n", 211 | " # Fetching the train and validation data and getting them into proper shape\t\n", 212 | " X = trainimgs(CIFAR10)\n", 213 | " imgs = [preprocess(X[i].img) for i in 1:40000]\n", 214 | " #onehot encode labels of batch\n", 215 | " \n", 216 | " labels = onehotbatch([X[i].ground_truth.class for i in 1:40000],1:10)\n", 217 | "\n", 218 | " train_pop = Int((1-args.splitr_)* 40000)\n", 219 | " train = device.([(cat(imgs[i]..., dims = 4), labels[:,i]) for i in partition(1:train_pop, args.batchsize)])\n", 220 | " valset = collect(train_pop+1:40000)\n", 221 | " valX = cat(imgs[valset]..., dims = 4) |> device\n", 222 | " valY = labels[:, valset] |> device\n", 223 | "\n", 224 | " val = (valX,valY)\n", 225 | " return train, val\n", 226 | "end\n", 227 | "\n", 228 | "function get_test_data()\n", 229 | " # Fetch the test data from Metalhead and get it into proper shape.\n", 230 | " test = valimgs(CIFAR10)\n", 231 | "\n", 232 | " # CIFAR-10 does not specify a validation set so valimgs fetch the testdata instead of testimgs\n", 233 | " testimgs = [preprocess(test[i].img) for i in 1:1000]\n", 234 | " testY = onehotbatch([test[i].ground_truth.class for i in 1:1000], 1:10) |> device\n", 235 | " testX = cat(testimgs..., dims = 4) |> device\n", 236 | "\n", 237 | " test = (testX,testY)\n", 238 | " return test\n", 239 | "end" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "使用Metalhead中提供的模型结构和预训练参数。\n", 247 | "\n", 248 | "Use the model structure and pre-training parameters provided in Metalhead.\n", 249 | "\n", 250 | "在源码中可以找到预训练权重的下载地址,例如[github](https://github.com/FluxML/Metalhead.jl/blob/fd4687a0f91a188f099a43d6464000162b20aa60/src/utils.jl),我们从提供的[vgg19](https://github.com/FluxML/Metalhead.jl/releases/download/Models/vgg19.bson)下载地址中下载VGG19的权重文件,放在 deps 文件夹中。我的 deps 文件夹的路径是:\"~/.juliapro/JuliaPro_v1.4.1-1/packages/Metalhead/RZn9O/deps\"。\n", 251 | "\n", 252 | "The download address of the pre-training weights can be found in the source code, such as [github](https://github.com/FluxML/Metalhead.jl/blob/fd4687a0f91a188f099a43d6464000162b20aa60/src/utils.jl). We download the weight file of VGG19 from the provided [vgg19](https://github.com/FluxML/Metalhead.jl/releases/download/Models/vgg19.bson) download address and place it in the deps folder. The path of my deps folder is: \"~/.juliapro/JuliaPro_v1.4.1-1/packages/Metalhead/RZn9O/deps\".\n", 253 | "\n", 254 | "此外,从 Metalhead 导入的模型默认是测试模式的,如果要 Finetue 需要先将它设置为训练模式。\n", 255 | "Also Metalhead models are by default loaded in testmode set as true . So lets set that to false." 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 6, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "text/plain": [ 266 | "Chain(Chain(Conv((3, 3), 3=>64, relu), Conv((3, 3), 64=>64, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 64=>128, relu), Conv((3, 3), 128=>128, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 128=>256, relu), Conv((3, 3), 256=>256, relu), Conv((3, 3), 256=>256, relu), Conv((3, 3), 256=>256, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 256=>512, relu), Conv((3, 3), 512=>512, relu), Conv((3, 3), 512=>512, relu), Conv((3, 3), 512=>512, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 512=>512, relu), Conv((3, 3), 512=>512, relu), Conv((3, 3), 512=>512, relu), Conv((3, 3), 512=>512, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), #44), Dense(512, 4096, relu), Dropout(0.5), Dense(4096, 4096, relu), Dropout(0.5), Dense(4096, 10))" 267 | ] 268 | }, 269 | "execution_count": 6, 270 | "metadata": {}, 271 | "output_type": "execute_result" 272 | } 273 | ], 274 | "source": [ 275 | "using Metalhead\n", 276 | "\n", 277 | "# function VGG19()\n", 278 | "# return Chain(\n", 279 | "# Conv((3, 3), 3 => 64, relu, pad=(1, 1), stride=(1, 1)),\n", 280 | "# BatchNorm(64),\n", 281 | "# Conv((3, 3), 64 => 64, relu, pad=(1, 1), stride=(1, 1)),\n", 282 | "# BatchNorm(64),\n", 283 | "# MaxPool((2,2)),\n", 284 | "# Conv((3, 3), 64 => 128, relu, pad=(1, 1), stride=(1, 1)),\n", 285 | "# BatchNorm(128),\n", 286 | "# Conv((3, 3), 128 => 128, relu, pad=(1, 1), stride=(1, 1)),\n", 287 | "# BatchNorm(128),\n", 288 | "# MaxPool((2,2)),\n", 289 | "# Conv((3, 3), 128 => 256, relu, pad=(1, 1), stride=(1, 1)),\n", 290 | "# BatchNorm(256),\n", 291 | "# Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),\n", 292 | "# BatchNorm(256),\n", 293 | "# Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),\n", 294 | "# BatchNorm(256),\n", 295 | "# Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),\n", 296 | "# MaxPool((2,2)),\n", 297 | "# Conv((3, 3), 256 => 512, relu, pad=(1, 1), stride=(1, 1)),\n", 298 | "# BatchNorm(512),\n", 299 | "# Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),\n", 300 | "# BatchNorm(512),\n", 301 | "# Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),\n", 302 | "# BatchNorm(512),\n", 303 | "# Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),\n", 304 | "# MaxPool((2,2)),\n", 305 | "# Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),\n", 306 | "# BatchNorm(512),\n", 307 | "# Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),\n", 308 | "# BatchNorm(512),\n", 309 | "# Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),\n", 310 | "# BatchNorm(512),\n", 311 | "# Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),\n", 312 | "# MaxPool((2,2)),\n", 313 | "# flatten,\n", 314 | "# Dense(512, 4096, relu),\n", 315 | "# Dropout(0.5),\n", 316 | "# Dense(4096, 4096, relu),\n", 317 | "# Dropout(0.5),\n", 318 | "# Dense(4096, 10))\n", 319 | "# end\n", 320 | "# model = VGG19() |> device\n", 321 | "\n", 322 | "# Finetune MetalHead VGG19 without augmentation\n", 323 | "vgg = VGG19()\n", 324 | "model = Chain(vgg.layers[1:end-6],\n", 325 | " Dense(512, 4096, relu),\n", 326 | " Dropout(0.5),\n", 327 | " Dense(4096, 4096, relu),\n", 328 | " Dropout(0.5),\n", 329 | " Dense(4096, 10)) |> device\n", 330 | "Flux.trainmode!(model, true)\n", 331 | "\n", 332 | "# # Finetune MetalHead VGG19 with augmentation, images are resized to 224*224\n", 333 | "# vgg = VGG19()\n", 334 | "# model = Chain(vgg.layers[1:end-2],\n", 335 | "# Dense(4096,10),\n", 336 | "# softmax) |> device\n", 337 | "\n", 338 | "# # Finetune your trained models\n", 339 | "# function vgg19()\n", 340 | "# ws = weights(\"vgg19.bson\")\n", 341 | "# return Chain(\n", 342 | "# Conv(ws[:conv1_1_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv1_1_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 343 | "# Conv(ws[:conv1_2_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv1_2_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 344 | "# MaxPool((2,2)),\n", 345 | "# Conv(ws[:conv2_1_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv2_1_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 346 | "# Conv(ws[:conv2_2_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv2_2_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 347 | "# MaxPool((2,2)),\n", 348 | "# Conv(ws[:conv3_1_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv3_1_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 349 | "# Conv(ws[:conv3_2_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv3_2_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 350 | "# Conv(ws[:conv3_3_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv3_3_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 351 | "# Conv(ws[:conv3_4_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv3_4_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 352 | "# MaxPool((2,2)),\n", 353 | "# Conv(ws[:conv4_1_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv4_1_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 354 | "# Conv(ws[:conv4_2_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv4_2_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 355 | "# Conv(ws[:conv4_3_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv4_3_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 356 | "# Conv(ws[:conv4_4_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv4_4_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 357 | "# MaxPool((2,2)),\n", 358 | "# Conv(ws[:conv5_1_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv5_1_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 359 | "# Conv(ws[:conv5_2_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv5_2_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 360 | "# Conv(ws[:conv5_3_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv5_3_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 361 | "# Conv(ws[:conv5_4_w_0][end:-1:1,:,:,:][:,end:-1:1,:,:], ws[:conv5_4_b_0], relu, pad = (1,1), stride = (1,1), dilation = (1,1)),\n", 362 | "# MaxPool((2,2)),\n", 363 | "# x -> reshape(x, :, size(x, 4)),\n", 364 | "# Dense(ws[:fc6_w_0]', ws[:fc6_b_0], relu),\n", 365 | "# Dropout(0.5f0),\n", 366 | "# Dense(ws[:fc7_w_0]', ws[:fc7_b_0], relu),\n", 367 | "# Dropout(0.5f0),\n", 368 | "# Dense(ws[:fc8_w_0]', ws[:fc8_b_0]),\n", 369 | "# softmax)\n", 370 | "# end\n", 371 | "# model = vgg19() |> device" 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": {}, 377 | "source": [ 378 | "训练模型并微调参数。\n", 379 | "\n", 380 | "Train the model and fine-tune the parameters." 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 7, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "data": { 390 | "text/plain": [ 391 | "train (generic function with 1 method)" 392 | ] 393 | }, 394 | "execution_count": 7, 395 | "metadata": {}, 396 | "output_type": "execute_result" 397 | } 398 | ], 399 | "source": [ 400 | "function train(model; kws...)\n", 401 | " # Initialize the hyperparameters\n", 402 | " args = Args(; kws...)\n", 403 | " \n", 404 | " # Load the train, validation data \n", 405 | " train, val = get_processed_data(args)\n", 406 | "\n", 407 | " @info(\"Constructing Model\")\n", 408 | " # Defining the loss and accuracy functions\n", 409 | "\n", 410 | " loss(x, y) = logitcrossentropy(model(x), y)\n", 411 | "\n", 412 | " ## Training\n", 413 | " # Defining the callback and the optimizer\n", 414 | " evalcb = throttle(() -> @show(loss(val...)), args.throttle)\n", 415 | " opt = ADAM(args.lr)\n", 416 | " @info(\"Training....\")\n", 417 | " # Starting to train models\n", 418 | " Flux.@epochs args.epochs Flux.train!(loss, params(model), train, opt, cb=evalcb)\n", 419 | "end" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "metadata": {}, 425 | "source": [ 426 | "需要耐心等待几分钟,正在下载数据集和数据增广。\n", 427 | "\n", 428 | "Need to wait patiently for a few minutes, the dataset is being downloaded and image augmentation is in progress." 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 8, 434 | "metadata": {}, 435 | "outputs": [ 436 | { 437 | "name": "stderr", 438 | "output_type": "stream", 439 | "text": [ 440 | "┌ Info: Constructing Model\n", 441 | "└ @ Main In[7]:8\n", 442 | "┌ Info: Training....\n", 443 | "└ @ Main In[7]:17\n", 444 | "┌ Info: Epoch 1\n", 445 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 446 | ] 447 | }, 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "loss(val...) = 2.3092935f0\n", 453 | "loss(val...) = 1.1011595f0\n", 454 | "loss(val...) = 0.8892075f0\n", 455 | "loss(val...) = 0.857905f0\n" 456 | ] 457 | }, 458 | { 459 | "name": "stderr", 460 | "output_type": "stream", 461 | "text": [ 462 | "┌ Info: Epoch 2\n", 463 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 464 | ] 465 | }, 466 | { 467 | "name": "stdout", 468 | "output_type": "stream", 469 | "text": [ 470 | "loss(val...) = 0.81348246f0\n", 471 | "loss(val...) = 0.72334063f0\n", 472 | "loss(val...) = 0.66348326f0\n" 473 | ] 474 | }, 475 | { 476 | "name": "stderr", 477 | "output_type": "stream", 478 | "text": [ 479 | "┌ Info: Epoch 3\n", 480 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 481 | ] 482 | }, 483 | { 484 | "name": "stdout", 485 | "output_type": "stream", 486 | "text": [ 487 | "loss(val...) = 0.68518114f0\n", 488 | "loss(val...) = 0.6231706f0\n", 489 | "loss(val...) = 0.62917036f0\n" 490 | ] 491 | }, 492 | { 493 | "name": "stderr", 494 | "output_type": "stream", 495 | "text": [ 496 | "┌ Info: Epoch 4\n", 497 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 498 | ] 499 | }, 500 | { 501 | "name": "stdout", 502 | "output_type": "stream", 503 | "text": [ 504 | "loss(val...) = 0.60531837f0\n", 505 | "loss(val...) = 0.6573223f0\n", 506 | "loss(val...) = 0.81194323f0\n" 507 | ] 508 | }, 509 | { 510 | "name": "stderr", 511 | "output_type": "stream", 512 | "text": [ 513 | "┌ Info: Epoch 5\n", 514 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 515 | ] 516 | }, 517 | { 518 | "name": "stdout", 519 | "output_type": "stream", 520 | "text": [ 521 | "loss(val...) = 0.7699521f0\n", 522 | "loss(val...) = 0.64679813f0\n", 523 | "loss(val...) = 0.78103256f0\n", 524 | "loss(val...) = 0.6638566f0\n" 525 | ] 526 | }, 527 | { 528 | "name": "stderr", 529 | "output_type": "stream", 530 | "text": [ 531 | "┌ Info: Epoch 6\n", 532 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 533 | ] 534 | }, 535 | { 536 | "name": "stdout", 537 | "output_type": "stream", 538 | "text": [ 539 | "loss(val...) = 0.6445502f0\n", 540 | "loss(val...) = 0.7403376f0\n", 541 | "loss(val...) = 0.9013936f0\n" 542 | ] 543 | }, 544 | { 545 | "name": "stderr", 546 | "output_type": "stream", 547 | "text": [ 548 | "┌ Info: Epoch 7\n", 549 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 550 | ] 551 | }, 552 | { 553 | "name": "stdout", 554 | "output_type": "stream", 555 | "text": [ 556 | "loss(val...) = 0.6655385f0\n", 557 | "loss(val...) = 0.6930324f0\n", 558 | "loss(val...) = 0.7701569f0\n" 559 | ] 560 | }, 561 | { 562 | "name": "stderr", 563 | "output_type": "stream", 564 | "text": [ 565 | "┌ Info: Epoch 8\n", 566 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 567 | ] 568 | }, 569 | { 570 | "name": "stdout", 571 | "output_type": "stream", 572 | "text": [ 573 | "loss(val...) = 0.79056406f0\n", 574 | "loss(val...) = 0.7227654f0\n", 575 | "loss(val...) = 0.6939518f0\n" 576 | ] 577 | }, 578 | { 579 | "name": "stderr", 580 | "output_type": "stream", 581 | "text": [ 582 | "┌ Info: Epoch 9\n", 583 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 584 | ] 585 | }, 586 | { 587 | "name": "stdout", 588 | "output_type": "stream", 589 | "text": [ 590 | "loss(val...) = 0.714453f0\n", 591 | "loss(val...) = 0.76212484f0\n", 592 | "loss(val...) = 0.8016011f0\n" 593 | ] 594 | }, 595 | { 596 | "name": "stderr", 597 | "output_type": "stream", 598 | "text": [ 599 | "┌ Info: Epoch 10\n", 600 | "└ @ Main /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/Flux/Fj3bt/src/optimise/train.jl:121\n" 601 | ] 602 | }, 603 | { 604 | "name": "stdout", 605 | "output_type": "stream", 606 | "text": [ 607 | "loss(val...) = 0.76166624f0\n", 608 | "loss(val...) = 0.87577313f0\n", 609 | "loss(val...) = 0.88311064f0\n" 610 | ] 611 | } 612 | ], 613 | "source": [ 614 | "train(model)" 615 | ] 616 | }, 617 | { 618 | "cell_type": "markdown", 619 | "metadata": {}, 620 | "source": [ 621 | "测试模型在测试集上的准确率.\n", 622 | "\n", 623 | "Test the accuracy of the model on the test set." 624 | ] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "execution_count": 9, 629 | "metadata": {}, 630 | "outputs": [ 631 | { 632 | "data": { 633 | "text/plain": [ 634 | "accuracy (generic function with 1 method)" 635 | ] 636 | }, 637 | "execution_count": 9, 638 | "metadata": {}, 639 | "output_type": "execute_result" 640 | } 641 | ], 642 | "source": [ 643 | "accuracy(x, y, m) = mean(onecold(cpu(m(x)), 1:10) .== onecold(cpu(y), 1:10))" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 10, 649 | "metadata": {}, 650 | "outputs": [ 651 | { 652 | "data": { 653 | "text/plain": [ 654 | "test (generic function with 1 method)" 655 | ] 656 | }, 657 | "execution_count": 10, 658 | "metadata": {}, 659 | "output_type": "execute_result" 660 | } 661 | ], 662 | "source": [ 663 | "function test(model)\n", 664 | " test_data = get_test_data() |> device\n", 665 | " # Print the final accuracy\n", 666 | " @show(accuracy(test_data..., model))\n", 667 | "end" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 11, 673 | "metadata": {}, 674 | "outputs": [ 675 | { 676 | "name": "stdout", 677 | "output_type": "stream", 678 | "text": [ 679 | "accuracy(test_data..., model) = 0.807\n" 680 | ] 681 | }, 682 | { 683 | "data": { 684 | "text/plain": [ 685 | "0.807" 686 | ] 687 | }, 688 | "execution_count": 11, 689 | "metadata": {}, 690 | "output_type": "execute_result" 691 | } 692 | ], 693 | "source": [ 694 | "test(model)" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": 12, 700 | "metadata": {}, 701 | "outputs": [], 702 | "source": [ 703 | "using Tracker\n", 704 | "using BSON: @load, @save\n", 705 | "\n", 706 | "pretrained = model |> cpu\n", 707 | "weights = Tracker.data.(params(pretrained))\n", 708 | "@save \"weights.bson\" weights" 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 13, 714 | "metadata": {}, 715 | "outputs": [], 716 | "source": [ 717 | "# # load weights\n", 718 | "# weights = BSON.load(filename)\n", 719 | "# Flux.loadparams!(model, weights)" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": null, 725 | "metadata": {}, 726 | "outputs": [], 727 | "source": [] 728 | } 729 | ], 730 | "metadata": { 731 | "kernelspec": { 732 | "display_name": "Julia 1.4.1", 733 | "language": "julia", 734 | "name": "julia-1.4" 735 | }, 736 | "language_info": { 737 | "file_extension": ".jl", 738 | "mimetype": "application/julia", 739 | "name": "julia", 740 | "version": "1.4.1" 741 | } 742 | }, 743 | "nbformat": 4, 744 | "nbformat_minor": 2 745 | } 746 | -------------------------------------------------------------------------------- /MLP_and_MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### MLP and MNIST\n", 8 | "\n", 9 | "在该实现中您可以看到如下功能:\n", 10 | "1. 从 Flux 中导入标准数据集\n", 11 | "2. 对数据集进行切分\n", 12 | "3. 指定使用的设备/GPU\n", 13 | "4. 定义模型\n", 14 | "5. 定义损失函数\n", 15 | "6. 定义评估方法\n", 16 | "7. 使用模型进行训练和推断\n", 17 | "\n", 18 | "In this template you can finish the following functions:\n", 19 | "1. Import standard dataset from Flux\n", 20 | "2. Split the data set\n", 21 | "3. Specify the device / GPU used\n", 22 | "4. Define the model\n", 23 | "5. Define the loss function\n", 24 | "6. Define the evaluation method\n", 25 | "7. Use the model for training and inference" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "Flux 是 Julia 中的深度学习库,其完全由 Julia 实现,结构轻量化,是 Julia 中的 PyTorch。\n", 33 | "\n", 34 | "Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 1, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "using Flux, Flux.Data.MNIST, Statistics\n", 44 | "using Flux: onehotbatch, onecold, crossentropy, throttle, params\n", 45 | "using Base.Iterators: repeated, partition\n", 46 | "using CuArrays" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "尽管 Flux 中目前已经实现了 gpu 方法,但功能有限。所幸 Flux 在 GPU 上的功能基于 CuArrays 实现,我们可以使用 CUDAapi, CUDAdrv, CUDAnative 来设置 Flux 使用哪个 GPU,或是只使用 CPU。\n", 54 | "\n", 55 | "Although the gpu method has been implemented in Flux, it has limited functionality. Fortunately, the function of Flux on the GPU is based on CuArrays. We can use CUDAapi, CUDAdrv, CUDAnative to set which GPU Flux uses, or only the CPU." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "┌ Info: Training on GPU-1\n", 68 | "└ @ Main In[2]:7\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "using CUDAapi, CUDAdrv, CUDAnative\n", 74 | "gpu_id = 1 ## set < 0 for no cuda, >= 0 for using a specific device (if available)\n", 75 | "\n", 76 | "if has_cuda_gpu() && gpu_id >=0\n", 77 | " device!(gpu_id)\n", 78 | " device = Flux.gpu\n", 79 | " @info \"Training on GPU-$(gpu_id)\"\n", 80 | "else\n", 81 | " device = Flux.cpu\n", 82 | " @info \"Training on CPU\"\n", 83 | "end" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "加载数据集和对应的label,对label进行onehot编码。\n", 91 | "\n", 92 | "Load the data set and the corresponding label, and use onehot to encode the label." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "10×60000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:\n", 104 | " 0 1 0 0 0 0 0 0 0 0 0 0 0 … 0 0 0 0 0 0 0 0 0 0 0 0\n", 105 | " 0 0 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0\n", 106 | " 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0\n", 107 | " 0 0 0 0 0 0 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 1 0 0 0\n", 108 | " 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 109 | " 1 0 0 0 0 0 0 0 0 0 0 1 0 … 0 0 0 0 0 1 0 0 0 1 0 0\n", 110 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0\n", 111 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0\n", 112 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 1\n", 113 | " 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0" 114 | ] 115 | }, 116 | "execution_count": 3, 117 | "metadata": {}, 118 | "output_type": "execute_result" 119 | } 120 | ], 121 | "source": [ 122 | "imgs = MNIST.images()\n", 123 | "labels = onehotbatch(MNIST.labels(), 0:9)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "准备训练数据集,将每 1000 张图像分为一个 batch,并全部图像迁移到 GPU 中。\n", 131 | "\n", 132 | "Prepare a training data set, divide every 1000 images into a batch, and migrate all images to the GPU." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "data": { 142 | "text/plain": [ 143 | "60-element Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1}:\n", 144 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 145 | "\n", 146 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 147 | "\n", 148 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 149 | "\n", 150 | "...\n", 151 | "\n", 152 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 153 | "\n", 154 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 155 | "\n", 156 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])\n", 157 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 158 | "\n", 159 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 160 | "\n", 161 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 162 | "\n", 163 | "...\n", 164 | "\n", 165 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 166 | "\n", 167 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 168 | "\n", 169 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [1 0 … 0 1; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])\n", 170 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 171 | "\n", 172 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 173 | "\n", 174 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 175 | "\n", 176 | "...\n", 177 | "\n", 178 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 179 | "\n", 180 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 181 | "\n", 182 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 1 0; … ; 0 1 … 0 0; 0 0 … 0 0])\n", 183 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 184 | "\n", 185 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 186 | "\n", 187 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 188 | "\n", 189 | "...\n", 190 | "\n", 191 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 192 | "\n", 193 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 194 | "\n", 195 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 1 0 … 1 0])\n", 196 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 197 | "\n", 198 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 199 | "\n", 200 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 201 | "\n", 202 | "...\n", 203 | "\n", 204 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 205 | "\n", 206 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 207 | "\n", 208 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 1 0; … ; 0 0 … 0 0; 0 0 … 0 0])\n", 209 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 210 | "\n", 211 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 212 | "\n", 213 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 214 | "\n", 215 | "...\n", 216 | "\n", 217 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 218 | "\n", 219 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 220 | "\n", 221 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 1])\n", 222 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 223 | "\n", 224 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 225 | "\n", 226 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 227 | "\n", 228 | "...\n", 229 | "\n", 230 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 231 | "\n", 232 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 233 | "\n", 234 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 1 … 0 0; 0 0 … 1 0; … ; 0 0 … 0 0; 0 0 … 0 1])\n", 235 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 236 | "\n", 237 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 238 | "\n", 239 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 240 | "\n", 241 | "...\n", 242 | "\n", 243 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 244 | "\n", 245 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 246 | "\n", 247 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 1 0 … 0 0; 0 0 … 0 0])\n", 248 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 249 | "\n", 250 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 251 | "\n", 252 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 253 | "\n", 254 | "...\n", 255 | "\n", 256 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 257 | "\n", 258 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 259 | "\n", 260 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [1 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1])\n", 261 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 262 | "\n", 263 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 264 | "\n", 265 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 266 | "\n", 267 | "...\n", 268 | "\n", 269 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 270 | "\n", 271 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 272 | "\n", 273 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 0])\n", 274 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 275 | "\n", 276 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 277 | "\n", 278 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 279 | "\n", 280 | "...\n", 281 | "\n", 282 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 283 | "\n", 284 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 285 | "\n", 286 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 1 … 0 0; 0 0 … 0 0])\n", 287 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 288 | "\n", 289 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 290 | "\n", 291 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 292 | "\n", 293 | "...\n", 294 | "\n", 295 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 296 | "\n", 297 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 298 | "\n", 299 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 1 … 0 0])\n", 300 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 301 | "\n", 302 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 303 | "\n", 304 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 305 | "\n", 306 | "...\n", 307 | "\n", 308 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 309 | "\n", 310 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 311 | "\n", 312 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 1])\n", 313 | " ⋮\n", 314 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 315 | "\n", 316 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 317 | "\n", 318 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 319 | "\n", 320 | "...\n", 321 | "\n", 322 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 323 | "\n", 324 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 325 | "\n", 326 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])\n", 327 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 328 | "\n", 329 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 330 | "\n", 331 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 332 | "\n", 333 | "...\n", 334 | "\n", 335 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 336 | "\n", 337 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 338 | "\n", 339 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 0])\n", 340 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 341 | "\n", 342 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 343 | "\n", 344 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 345 | "\n", 346 | "...\n", 347 | "\n", 348 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 349 | "\n", 350 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 351 | "\n", 352 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 1 … 0 0; 0 0 … 0 0])\n", 353 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 354 | "\n", 355 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 356 | "\n", 357 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 358 | "\n", 359 | "...\n", 360 | "\n", 361 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 362 | "\n", 363 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 364 | "\n", 365 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 1; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])\n", 366 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 367 | "\n", 368 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 369 | "\n", 370 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 371 | "\n", 372 | "...\n", 373 | "\n", 374 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 375 | "\n", 376 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 377 | "\n", 378 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])\n", 379 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 380 | "\n", 381 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 382 | "\n", 383 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 384 | "\n", 385 | "...\n", 386 | "\n", 387 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 388 | "\n", 389 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 390 | "\n", 391 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 1; 0 0 … 1 0])\n", 392 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 393 | "\n", 394 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 395 | "\n", 396 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 397 | "\n", 398 | "...\n", 399 | "\n", 400 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 401 | "\n", 402 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 403 | "\n", 404 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 1; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])\n", 405 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 406 | "\n", 407 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 408 | "\n", 409 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 410 | "\n", 411 | "...\n", 412 | "\n", 413 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 414 | "\n", 415 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 416 | "\n", 417 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])\n", 418 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 419 | "\n", 420 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 421 | "\n", 422 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 423 | "\n", 424 | "...\n", 425 | "\n", 426 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 427 | "\n", 428 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 429 | "\n", 430 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 1 0 … 0 0; … ; 0 1 … 0 0; 0 0 … 0 0])\n", 431 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 432 | "\n", 433 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 434 | "\n", 435 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 436 | "\n", 437 | "...\n", 438 | "\n", 439 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 440 | "\n", 441 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 442 | "\n", 443 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [1 0 … 0 0; 0 0 … 1 0; … ; 0 1 … 0 0; 0 0 … 0 0])\n", 444 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 445 | "\n", 446 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 447 | "\n", 448 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 449 | "\n", 450 | "...\n", 451 | "\n", 452 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 453 | "\n", 454 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 455 | "\n", 456 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])\n", 457 | " ([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 458 | "\n", 459 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 460 | "\n", 461 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 462 | "\n", 463 | "...\n", 464 | "\n", 465 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 466 | "\n", 467 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", 468 | "\n", 469 | "[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 0])" 470 | ] 471 | }, 472 | "execution_count": 4, 473 | "metadata": {}, 474 | "output_type": "execute_result" 475 | }, 476 | { 477 | "name": "stderr", 478 | "output_type": "stream", 479 | "text": [ 480 | "┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`\n", 481 | "└ @ GPUArrays /home/zhangzhi/.juliapro/JuliaPro_v1.4.1-1/packages/GPUArrays/WZupy/src/host/indexing.jl:43\n" 482 | ] 483 | } 484 | ], 485 | "source": [ 486 | "train = [(cat(float.(imgs[i])..., dims = 4), labels[:,i])\n", 487 | " for i in partition(1:60_000, 1000)] |> device" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "选择前1000张图片作为测试数据集,也迁移到 GPU 中。\n", 495 | "\n", 496 | "Select the first 1000 pictures as the test data set, and also migrate to the GPU." 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 5, 502 | "metadata": {}, 503 | "outputs": [ 504 | { 505 | "data": { 506 | "text/plain": [ 507 | "10×1000 Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}:\n", 508 | " 0 0 0 1 0 0 0 0 0 0 1 0 0 … 0 0 0 0 0 1 0 0 0 1 0 0\n", 509 | " 0 0 1 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0\n", 510 | " 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 0 0 0\n", 511 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0\n", 512 | " 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 513 | " 0 0 0 0 0 0 0 0 1 0 0 0 0 … 0 0 0 0 0 0 0 0 0 0 0 0\n", 514 | " 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0\n", 515 | " 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 516 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0\n", 517 | " 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 0 1" 518 | ] 519 | }, 520 | "execution_count": 5, 521 | "metadata": {}, 522 | "output_type": "execute_result" 523 | } 524 | ], 525 | "source": [ 526 | "test_X = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> device\n", 527 | "test_y = onehotbatch(MNIST.labels(:test)[1:1000], 0:9) |> device" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "metadata": {}, 533 | "source": [ 534 | "定义模型、损失函数和评估方法。\n", 535 | "\n", 536 | "Define models, loss functions and evaluation methods." 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 6, 542 | "metadata": {}, 543 | "outputs": [ 544 | { 545 | "data": { 546 | "text/plain": [ 547 | "accuracy (generic function with 1 method)" 548 | ] 549 | }, 550 | "execution_count": 6, 551 | "metadata": {}, 552 | "output_type": "execute_result" 553 | } 554 | ], 555 | "source": [ 556 | "model = Chain(\n", 557 | " Conv((2,2), 1=>16, relu),\n", 558 | " MaxPool((2, 2)),\n", 559 | " Conv((2,2), 16=>8, relu),\n", 560 | " MaxPool((2, 2)),\n", 561 | " x -> reshape(x, :, size(x, 4)),\n", 562 | " Dense(288, 10), softmax\n", 563 | ") |> device\n", 564 | "\n", 565 | "\n", 566 | "loss(x, y) = crossentropy(model(x), y)\n", 567 | "accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))" 568 | ] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "metadata": {}, 573 | "source": [ 574 | "训练并打印测试集的准确率。\n", 575 | "\n", 576 | "Train and print the accuracy of the test set." 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 7, 582 | "metadata": {}, 583 | "outputs": [ 584 | { 585 | "name": "stdout", 586 | "output_type": "stream", 587 | "text": [ 588 | "accuracy(test_X, test_y) = 0.192\n", 589 | "accuracy(test_X, test_y) = 0.919\n", 590 | "accuracy(test_X, test_y) = 0.951\n", 591 | "accuracy(test_X, test_y) = 0.962\n", 592 | "accuracy(test_X, test_y) = 0.967\n" 593 | ] 594 | } 595 | ], 596 | "source": [ 597 | "opt = ADAM(0.01)\n", 598 | "evalcb() = @show(accuracy(test_X, test_y))\n", 599 | "\n", 600 | "epochs = 5\n", 601 | "\n", 602 | "for i = 1:epochs\n", 603 | " Flux.train!(loss, Flux.params(model), train, opt, cb=throttle(evalcb, 10))\n", 604 | "end" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": {}, 610 | "source": [ 611 | "针对单独的图片进行推断。\n", 612 | "\n", 613 | "Infer for individual pictures." 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": 8, 619 | "metadata": {}, 620 | "outputs": [ 621 | { 622 | "name": "stdout", 623 | "output_type": "stream", 624 | "text": [ 625 | "Predicted: [4]\n" 626 | ] 627 | } 628 | ], 629 | "source": [ 630 | "using Colors, FileIO, ImageShow\n", 631 | "\n", 632 | "img = test_X[:, :, 1:1, 7:7]\n", 633 | "\n", 634 | "println(\"Predicted: \", Flux.onecold(model(img |> device)) .- 1)\n", 635 | "save(\"outputs.jpg\", collect(test_X[:, :, 1, 7]))" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": null, 641 | "metadata": {}, 642 | "outputs": [], 643 | "source": [] 644 | } 645 | ], 646 | "metadata": { 647 | "kernelspec": { 648 | "display_name": "Julia 1.4.1", 649 | "language": "julia", 650 | "name": "julia-1.4" 651 | }, 652 | "language_info": { 653 | "file_extension": ".jl", 654 | "mimetype": "application/julia", 655 | "name": "julia", 656 | "version": "1.4.1" 657 | } 658 | }, 659 | "nbformat": 4, 660 | "nbformat_minor": 2 661 | } -------------------------------------------------------------------------------- /DCGAN_and_Fashion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### DCGAN and Fashion_MNIST\n", 8 | "\n", 9 | "在该实现中您可以看到如下功能:\n", 10 | "1. GAN 的定义\n", 11 | "2. GAN 的对抗训练\n", 12 | "3. 生成图像的可视化\n", 13 | "\n", 14 | "In this template you can finish the following functions:\n", 15 | "1. Definition of GAN\n", 16 | "2. GAN's adversarial training\n", 17 | "3. Visualization of generated images" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "using Base.Iterators: partition\n", 27 | "using Flux\n", 28 | "using Flux.Optimise: update!\n", 29 | "using Flux: logitbinarycrossentropy\n", 30 | "using Images\n", 31 | "using MLDatasets\n", 32 | "using Statistics\n", 33 | "using Parameters: @with_kw\n", 34 | "using Printf\n", 35 | "using Random" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stderr", 45 | "output_type": "stream", 46 | "text": [ 47 | "┌ Info: Training on GPU-3\n", 48 | "└ @ Main In[2]:7\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "using CUDAapi, CUDAdrv, CUDAnative\n", 54 | "gpu_id = 3 ## set < 0 for no cuda, >= 0 for using a specific device (if available)\n", 55 | "\n", 56 | "if has_cuda_gpu() && gpu_id >=0\n", 57 | " device!(gpu_id)\n", 58 | " device = Flux.gpu\n", 59 | " @info \"Training on GPU-$(gpu_id)\"\n", 60 | "else\n", 61 | " device = Flux.cpu\n", 62 | " @info \"Training on CPU\"\n", 63 | "end" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "Args" 75 | ] 76 | }, 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "output_type": "execute_result" 80 | } 81 | ], 82 | "source": [ 83 | "using Parameters: @with_kw\n", 84 | "@with_kw mutable struct Args\n", 85 | " batch_size::Int = 128\n", 86 | " latent_dim::Int = 100\n", 87 | " epochs::Int = 20\n", 88 | " verbose_freq::Int = 1000\n", 89 | " output_x::Int = 6\n", 90 | " output_y::Int = 6\n", 91 | " lr_dscr::Float64 = 0.00005\n", 92 | " lr_gen::Float64 = 0.00005\n", 93 | "end" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "create_output_image (generic function with 1 method)" 105 | ] 106 | }, 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "function create_output_image(gen, fixed_noise, args)\n", 114 | " @eval Flux.istraining() = false\n", 115 | " fake_images = @. cpu(gen(fixed_noise))\n", 116 | " @eval Flux.istraining() = true\n", 117 | " image_array = permutedims(dropdims(reduce(vcat, reduce.(hcat, partition(fake_images, args.output_y))); dims=(3, 4)), (2, 1))\n", 118 | " image_array = @. Gray(image_array + 1f0) / 2f0\n", 119 | " return image_array\n", 120 | "end" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "Generator (generic function with 1 method)" 132 | ] 133 | }, 134 | "execution_count": 5, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "function Discriminator()\n", 141 | " return Chain(\n", 142 | " Conv((4, 4), 1 => 64; stride = 2, pad = 1),\n", 143 | " x->leakyrelu.(x, 0.2f0),\n", 144 | " Dropout(0.25),\n", 145 | " Conv((4, 4), 64 => 128; stride = 2, pad = 1),\n", 146 | " x->leakyrelu.(x, 0.2f0),\n", 147 | " Dropout(0.25), \n", 148 | " x->reshape(x, 7 * 7 * 128, :),\n", 149 | " Dense(7 * 7 * 128, 1))\n", 150 | "end\n", 151 | "\n", 152 | "function Generator(latent_dim)\n", 153 | " return Chain(\n", 154 | " Dense(latent_dim, 7 * 7 * 256),\n", 155 | " BatchNorm(7 * 7 * 256, relu),\n", 156 | " x->reshape(x, 7, 7, 256, :),\n", 157 | " ConvTranspose((5, 5), 256 => 128; stride = 1, pad = 2),\n", 158 | " BatchNorm(128, relu),\n", 159 | " ConvTranspose((4, 4), 128 => 64; stride = 2, pad = 1),\n", 160 | " BatchNorm(64, relu),\n", 161 | " ConvTranspose((4, 4), 64 => 1, tanh; stride = 2, pad = 1),\n", 162 | " )\n", 163 | "end" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 6, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "generator_loss (generic function with 1 method)" 175 | ] 176 | }, 177 | "execution_count": 6, 178 | "metadata": {}, 179 | "output_type": "execute_result" 180 | } 181 | ], 182 | "source": [ 183 | "function discriminator_loss(real_output, fake_output)\n", 184 | " real_loss = mean(logitbinarycrossentropy.(real_output, 1f0))\n", 185 | " fake_loss = mean(logitbinarycrossentropy.(fake_output, 0f0))\n", 186 | " return real_loss + fake_loss\n", 187 | "end\n", 188 | "\n", 189 | "generator_loss(fake_output) = mean(logitbinarycrossentropy.(fake_output, 1f0))" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 7, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "data": { 199 | "text/plain": [ 200 | "train_generator! (generic function with 1 method)" 201 | ] 202 | }, 203 | "execution_count": 7, 204 | "metadata": {}, 205 | "output_type": "execute_result" 206 | } 207 | ], 208 | "source": [ 209 | "function train_discriminator!(gen, dscr, x, opt_dscr, args)\n", 210 | " noise = randn!(similar(x, (args.latent_dim, args.batch_size))) \n", 211 | " fake_input = gen(noise)\n", 212 | " ps = Flux.params(dscr)\n", 213 | " # Taking gradient\n", 214 | " loss, back = Flux.pullback(ps) do\n", 215 | " discriminator_loss(dscr(x), dscr(fake_input))\n", 216 | " end\n", 217 | " grad = back(1f0)\n", 218 | " update!(opt_dscr, ps, grad)\n", 219 | " return loss\n", 220 | "end\n", 221 | "\n", 222 | "function train_generator!(gen, dscr, x, opt_gen, args)\n", 223 | " noise = randn!(similar(x, (args.latent_dim, args.batch_size))) \n", 224 | " ps = Flux.params(gen)\n", 225 | " # Taking gradient\n", 226 | " loss, back = Flux.pullback(ps) do\n", 227 | " generator_loss(dscr(gen(noise)))\n", 228 | " end\n", 229 | " grad = back(1f0)\n", 230 | " update!(opt_gen, ps, grad)\n", 231 | " return loss\n", 232 | "end" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 8, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "data": { 242 | "text/plain": [ 243 | "train (generic function with 1 method)" 244 | ] 245 | }, 246 | "execution_count": 8, 247 | "metadata": {}, 248 | "output_type": "execute_result" 249 | } 250 | ], 251 | "source": [ 252 | "using MLDatasets\n", 253 | "using Images\n", 254 | "using Printf\n", 255 | "\n", 256 | "function train(; kws...)\n", 257 | " # Model Parameters\n", 258 | " args = Args(; kws...)\n", 259 | "\n", 260 | " # Load FashionMNIST dataset\n", 261 | " images, _ = MLDatasets.FashionMNIST.traindata(Float32)\n", 262 | " # Normalize to [-1, 1]\n", 263 | " image_tensor = reshape(@.(2f0 * images - 1f0), 28, 28, 1, :)\n", 264 | " # Partition into batches\n", 265 | " data = [image_tensor[:, :, :, r] |> device for r in partition(1:60000, args.batch_size)]\n", 266 | "\n", 267 | " fixed_noise = [randn(args.latent_dim, 1) |> device for _=1:args.output_x*args.output_y]\n", 268 | "\n", 269 | " # Discriminator\n", 270 | " d_model = Discriminator() |> device\n", 271 | "\n", 272 | " # Generator\n", 273 | " g_model = Generator(args.latent_dim) |> device\n", 274 | "\n", 275 | " # Optimizers\n", 276 | " opt_dscr = ADAM(args.lr_dscr)\n", 277 | " opt_gen = ADAM(args.lr_gen)\n", 278 | "\n", 279 | " # Training\n", 280 | " train_steps = 0\n", 281 | " for ep in 1:args.epochs\n", 282 | " @info \"Epoch $ep\"\n", 283 | " for x in data\n", 284 | " # Update discriminator and generator\n", 285 | " loss_dscr = train_discriminator!(g_model, d_model, x, opt_dscr, args)\n", 286 | " loss_gen = train_generator!(g_model, d_model, x, opt_gen, args)\n", 287 | "\n", 288 | " if train_steps % args.verbose_freq == 0\n", 289 | " @info(\"Train step $(train_steps), Discriminator loss = $(loss_dscr), Generator loss = $(loss_gen)\")\n", 290 | " # Save generated fake image\n", 291 | "# output_image = create_output_image(g_model, fixed_noise, args)\n", 292 | "# display(output_image)\n", 293 | "# save(@sprintf(\"dcgan_steps_%06d.png\", train_steps), output_image)\n", 294 | " end\n", 295 | " train_steps += 1\n", 296 | " end\n", 297 | " end\n", 298 | "\n", 299 | " output_image = create_output_image(g_model, fixed_noise, args)\n", 300 | " display(output_image)\n", 301 | "# save(@sprintf(\"dcgan_steps_%06d.png\", train_steps), output_image)\n", 302 | "end" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 9, 308 | "metadata": {}, 309 | "outputs": [ 310 | { 311 | "name": "stderr", 312 | "output_type": "stream", 313 | "text": [ 314 | "┌ Info: Epoch 1\n", 315 | "└ @ Main In[8]:31\n", 316 | "┌ Info: Train step 0, Discriminator loss = 1.4381738, Generator loss = 0.70620656\n", 317 | "└ @ Main In[8]:38\n", 318 | "┌ Info: Epoch 2\n", 319 | "└ @ Main In[8]:31\n", 320 | "┌ Info: Epoch 3\n", 321 | "└ @ Main In[8]:31\n", 322 | "┌ Info: Train step 1000, Discriminator loss = 1.203705, Generator loss = 0.8882565\n", 323 | "└ @ Main In[8]:38\n", 324 | "┌ Info: Epoch 4\n", 325 | "└ @ Main In[8]:31\n", 326 | "┌ Info: Epoch 5\n", 327 | "└ @ Main In[8]:31\n", 328 | "┌ Info: Train step 2000, Discriminator loss = 1.3487449, Generator loss = 0.754305\n", 329 | "└ @ Main In[8]:38\n", 330 | "┌ Info: Epoch 6\n", 331 | "└ @ Main In[8]:31\n", 332 | "┌ Info: Epoch 7\n", 333 | "└ @ Main In[8]:31\n", 334 | "┌ Info: Train step 3000, Discriminator loss = 1.450408, Generator loss = 0.6517055\n", 335 | "└ @ Main In[8]:38\n", 336 | "┌ Info: Epoch 8\n", 337 | "└ @ Main In[8]:31\n", 338 | "┌ Info: Epoch 9\n", 339 | "└ @ Main In[8]:31\n", 340 | "┌ Info: Train step 4000, Discriminator loss = 1.3254168, Generator loss = 0.8234496\n", 341 | "└ @ Main In[8]:38\n", 342 | "┌ Info: Epoch 10\n", 343 | "└ @ Main In[8]:31\n", 344 | "┌ Info: Epoch 11\n", 345 | "└ @ Main In[8]:31\n", 346 | "┌ Info: Train step 5000, Discriminator loss = 1.3110251, Generator loss = 0.80107075\n", 347 | "└ @ Main In[8]:38\n", 348 | "┌ Info: Epoch 12\n", 349 | "└ @ Main In[8]:31\n", 350 | "┌ Info: Epoch 13\n", 351 | "└ @ Main In[8]:31\n", 352 | "┌ Info: Train step 6000, Discriminator loss = 1.303761, Generator loss = 0.7783703\n", 353 | "└ @ Main In[8]:38\n", 354 | "┌ Info: Epoch 14\n", 355 | "└ @ Main In[8]:31\n", 356 | "┌ Info: Epoch 15\n", 357 | "└ @ Main In[8]:31\n", 358 | "┌ Info: Train step 7000, Discriminator loss = 1.4186232, Generator loss = 0.688388\n", 359 | "└ @ Main In[8]:38\n", 360 | "┌ Info: Epoch 16\n", 361 | "└ @ Main In[8]:31\n", 362 | "┌ Info: Epoch 17\n", 363 | "└ @ Main In[8]:31\n", 364 | "┌ Info: Epoch 18\n", 365 | "└ @ Main In[8]:31\n", 366 | "┌ Info: Train step 8000, Discriminator loss = 1.1227105, Generator loss = 0.9515403\n", 367 | "└ @ Main In[8]:38\n", 368 | "┌ Info: Epoch 19\n", 369 | "└ @ Main In[8]:31\n", 370 | "┌ Info: Epoch 20\n", 371 | "└ @ Main In[8]:31\n", 372 | "┌ Info: Train step 9000, Discriminator loss = 1.3346882, Generator loss = 0.80039674\n", 373 | "└ @ Main In[8]:38\n" 374 | ] 375 | }, 376 | { 377 | "data": { 378 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKgAAACoCAAAAABRIPpoAAAABGdBTUEAALGPC/xhBQAAACBjSFJNAAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAAAmJLR0QA/4ePzL8AAD9rSURBVHjalb13mFxHlT58KtzcOUyOGo1yzrJsS5ZzZsFgzGIwmIyXsB9p2W9hdwnLwkZ2yZiwBozBGDDYOGDLQU6ycpZG0uTQPdO5++aq+v3RPTM90tgy9fh5LPW9VfXek+qcU6dKCC7asKxh22UCsMJdXvsRienH0a8vfOT3U72XKo/0WRf2lcAVF5/igkFR03/19FVOftOdeUbfQPd4Y2t+IOOjeMyfKLHqj7Ozf/dNpOOebENQ/dzEWvOCvpRR/w0jnX0RX9/YW3FHgpnZkS7aO7no2uFLxotbuk/0PrVuz8Fz0l5vDpSTybIUNMnJxgzvL9+338J+3dOgtPXr3/pzH7xxqgIAwNuv2vBXeZN/auunz06PdnGKNl8duDIV0DpG0ZHV2Yb3BYa+PTDKZyd+WzsSMqvkW33o6J3oDUw+9i0229kRq374qI/4RWeZ2y7bfHDFMdnr3mUbhdpP5KKdfGf18Twkz50Ybz86gDomhw61D80CReHrkJlLM/noKIoV+sor800P1pGP41PP5V3/opOc14YPqc9b5NpTLwzhaSl9baAIVyfccDNWtaSjxlqKaqi9Qru23HRmYmZqfG9X0EOG4RC1qSRECyHoifIs0k5Fxc3Y/As5D1s+/WdW+UCOD1bI/EADd1DzPXsBEFKDnzxdAQBAb149Be1WTlfSeTkxNRqMI8e3R6enljZvxBlq+BZO5G0SNFNj9EezBMTHStZ78CXr3/H4xcFJdYK85JrJ0meX8dHAWW9aavCclysPNF12DAAI6bhtT7b6m+UZkK6Ey3ka8UrQWExnSfvSGSXcdm2GBXBZUMkKxEhxRCs9MmtSgBul2MH9i9seeB2SxhHCSXnbP94Snv1tv0xbfv/S0cKQP4PvPNb7fX2DAMB5fl9/VSdQsh05DkXAOMUOYaioJ0nqlemZv/areLstSo4OFaSgkmOMDR6oQ6V3P/TSpPf0kdcBagssTJZ+5kSdFXY3ymf7nl+z/5yH56coAAiZEMAUIypXf1ijY4yYzwTYLkUSlUUFJae/r+FW57dUoWSowhx33Anmz736RD2oD2VKCjo4HK4xAF0wHciPf2RH6H1P/cf2bz44+2N4dQ9Pl343KE0vAWge8/Tcsd88sfak13xp6mkBAJIm+RwxIFgI4CCoo40Fcc0CNf2L9NOrESORCLIDZDApW+WXy/WD6UGTihxc9UdLAAC0uFlW/zh888TunfEXrv8EX3WXnCYzz97cuPf3pkCM1CbC4gKgKCzOfutNG6yPd0kPPMMAQAqrtsQZ5QSAG6C7smSUGmoU+uxNuG2A635uIoEURyjKn9cPzVkyWY9WFlE46lT/OgGzFhULCT3cERhBwXs8guIAgGeALiW9ncN2kk2K6mB8HoOf3wpPD01mu15p+HV1vFHDFw72iQeY2MSSCj7RjtfATO67+pyW0aglQunJ0edWiswr3hySbbtiwT7AE6na+7PPEGViydFrbclD4IM90uMUbpn9wq/pLNaPlg1N1i2qcwVmRQ/EAv6OLw+tvfamfToAgN8RkjzgnPlAEADyCRbWtNJvX42BNWCCK6IcP37QC+GFfM4q9OdovBEvu1ATlDW3/Uf07sCf/7oBbXpCEK0Xk4aXgjOPr+hmRuf13Yt78MwSX09RtPzwkW0gdPzlblxcRuxkCQACls0ZcMGpsKksPOrRVKymZ3BpAQQA0JZmpeiaJxSjY3AOQSF06R3SGgMaJ/g0vwUAkCd+s+vxvcXPW//8ah7/rJlhAWWNiZbcjGjHSg1rEn2L3/12ex6KXnnggPibG/GSy55aJHkSZaUGAACbyH4FEEYSoYJy4mGpiffWurGzAD6owHZDWBkmCbxH5nMsEe9sDtpfDWwPLqySQiEA0dt+1tn763/8ZevzO+6+5E/HuqiEhYWBl7UZisWIvk0+sraDoNg8FH3q2izjAHvAMivHVg4deWoSAICMhyrMwTJyEEEWdWhZoMDJWpfdExuAWaCYRd+sCHuknR2Zy+Il35Lf4b/NfKKykPoAGDkK587hfeHT7S/+JmcvxlpnBj583wM/GVvR+tysUyZ+vNjLvJpV+4dmbes0UMK7PnHmu90DQrTdIbHENa6546o3nQUA2pGAMwsUGxBSLAUTW1aIY9Y4GV2qADI8Ilq8WPDWYv+Sdes/Mwfoq4Oi/aUxZdnkB/8ZKkl75c1fHe6+c8e5a+zWy3VMVkm3+P8TeLmTATwLNOjMdEuuPDXWrDWNLD+z/U+zrMdIuSGYBO2r/8sQFcpDiX8e9ks/tc6MFgAAEkWByp6FACHL45yBa2dtkwkAALwiCoAEF6ElPmldwR9FrS/PSjwAgG/hRrA3ffWac2seelPrLZ9a1yuG/uXan/PU42niTFQk73c/fNCqirWfnyEfSmqjbnfMXLaRDtZRFMfxMbao1BFZ0dc+3Mq//yP1F4lxTiJ5BwBAj4ayjoY5UIRzFNS80DUeqgqiQvJBECAh39W8Cr7uNI6+46E5FFV7gpf+74p3Jx6NHHzaRM98B1NgDHZ1Id0OcyX6vb/R3XMzDJ/pZWQLrXa6vD2zRnQcnwXqpxGHF/3H8ff3PeyMiu+BV0oBQKZKk9wRfRLbRHMYRczUPFYhIelgTVVO5U0QFYoLZsktjDz7ylcCQzBn3v/8EN0N5tnJ4XPcBwGM1QRRVCALkNoO8zVk9Z0oskNaMf1L+1C9jAoB4INY9Y8T3pyYQQAAGKtafFU2gIAXLftYATkQ92tGIxEMRwEMLuSIRqXyB3dJt6lz5/xixlFRID/o/yU+vsy1xbnBRe1d+NL7lozPyuh0Oz3izdPLsBQ7Y3quSSLmCBDLJGQyE63a4XhURUBsiwd6mK+3mWU/xOcGYVzjtynFgTv/Is9Z5UZgrKDr+TKmM9FdHdDKPKMhYOtNQ11i615gKuahfHPRz7axmnWLhScpxKjBrFLQc9gGcPzI3P5T5OA/2PiJD/5FQdNObNCtUmhiNV2qvK+e9a/bMk+25caLAdXNxFChkrQyhQ43m6o+O/TQHyYgW9Gdc+XCpMUs/m9/V5nbu+fooV3ilBj4S3DCsH/uWPBg5tah0z8oHJuHovM1Gcu3r12srd26wF6yKFHp2iGPLF3WfLpZAgCAxivv6IBYkJSTCVJsDrVDUJbm9tfCpxDwvzAIdVBTb5tz3fJ4ZufiZW+Qolh0BtuLZ05377738e1De1duSba2LMmNTlUlcYex6mn4BwxNwsGbVPMI9OoFya3vr5mbH8qhhJ/9S4BOkm0LzI+1h9y7Wq4ceSNAEQhE899c+43M/08mS6WTgv30cTqpDZ487XLZBQLaF0YexvvTwjxTDDaM9j1NTqcHWwc5wEx2ZoG1YMvunvfs+r2YdtXrFAFh4LJfF3pX/4TsSGmPivolyTWxXItFyKySIqWxJzRwlfzbPn8WaO0PKPnmu72JluI3nnHhvAkxkXzGAVRiXpiXQj0faO47siL5xK7zImbJCPjy6ncWjqj9v53rb83wkiIi2T5H1U9EMxQl6rp3JXpKTfnOn71SQzMztqpe+vFYsExLfw+7vPOecUCEY0jEQ0enkdZh+uKVqWtHo4WNG34yMAcG7UlY2o2Nbec2npUvTK0BAIDBwj7YnFfHE7UojURu/eXqpdlwX0mJr+xoWD88G1DICzaEf6LAFx5+8L4Fzes6XuYAgOuwCAwf//WNH39i2yn3AqUJfMk0WX92rGGy8dV6kt50w+qRe5j87z996jfiTae8+axs4uudRsdblY72jxdyHpoOl9G1d/0tskc7LMg2Z/C6a5Zle/unpeq69y99c4JXepcng5c5Rtf4OQ6I1A+N9A/bu1OtIXO7XTiPjZG/H9a9VL6VFa8f62czQaj2uXC5ncdHW17yvNIRW5on5YOW9OZJey7cu6IS6RxzRA0o1j9x9sDJxnPjcvh4H2ss9u9uPepPw6Cr9kwVldEoa3u1Yq3NPCmqi24dH8+eG+o8bNs35nKluQuTv5ocKQZaTg5ObR08ZktaWxEEAESCp+hohXfuOuOD73J1nhUR8S63bTyQW5HOFlYOFz1UBZr4llZOLmBSMJmXtQ4H97Zf8niphuJ9V7soHuW66CoSb0X8+GNwXrIyqG1euTgTktbj5k4llK7/Bvr9E1IwNFqoUB071vKtb2k76QLgjZe1oogTy0t7hQCY186iltVLU5IcCjuRxS0sVRJVZryjdagIwzmpya5I0YnjZRF3IjU2BTvMMaFaWO5CiuwcPvr4eeoC+MNGSxyVxZLMJOxIyqj+I3zi005FVttemdTvoOcOw/MEAIkOWRDcJYVhAwEAmDfZt+aapC8FR546PTp4VumZTkBMOk2DFRp0ciJmYaWlVJxsmY6yqQ/hXME23JTvTbQ3RifPl3z+Z8jEZEkfdIXt5Zfvq3/GCA9nDzFTyTjlPcM0/cUSACCtWXBWwTSgziue1UaORfX8kWPCDTduT8mitoQePOOiIue0aDNa5g6klFhXLdJkZ8qYF5AgOYa0qVzcvWDIowKfHFTLrg1k1AnPTS6fCqWefbiBprMJKe9ZXs7zARA/NVKkkl1kmKvkNVLeaDSfN6wzUxOTg6NpMjENdHuX7mqy8Cl4QqYUwhqjNQl3kgkA6nNAlQyW9PzwBWMmPBw5W5o6mT8ndyzZMdd5WNK5gu07furJya9GEIAQIABQBDX15mxaybLG9a2GNB9UpHRLti4lDQyiSY2sTNAq65MKBDxJKMznCDPH8FC+UpNwjn3VsBQusEBgwLG5S7ktAN7/0q3CxilsZhexzgzx55quE3nLKET79r288+HaA5yf8G2d246JG4iTdE+7AgDQJafr8lJGJGJ46dCSSCq/YG2LFN6bqwKNa4TKvmLLWLIVpnlNo5FsbVgeC9hTIa5aVEJ+INCh1SXB8MZ93IKzCfTMyQ2afYyEw1Y8nKkXO6wu/JMnrXjYhPsi07rt0r1mAbaumpRFe2QZBP4kntAXXP7utrueYQAIQACgbW1LM2VFammklW7FJO0BpQr0yd4mPceobIMvm5zreU7E9Hr4QqU9UOGKUhFMDJsL670Ysc9SAX6/ZteftMTJcs46prcaRq7uBXigV/kj3PTNlIDJqZkffXQYaCATACaKwp+QPtoqu+/af+B5DgACCQDAZ1eeFoAwos3DadKQsjJ+dd7t7XK/KiSEBeVEohC1ZGuaoqt7NYYtghSf0S4arIcBMd75InzqmVHqSSkUcJ5e2GEE5ijTbfsUIl58Toc5Jk0A+H5bigDxCE10x9/aK4Vps/IRt/YaartbOheThRstN2ohIimNazNVoCEMlPtIU0UeY+boPCgi006XVdFJQfUlTCu2ARNzklU7jkgAl/5xUU7g1KLQmVJqudN4ao5StPGFE8tJ+1kAMmd5Fcd6FC7yJosFM/v3rDswub7lOQ4AgEGAvKXnmKwpZclmp1MSkbqtFlGdl0BAxIoRV+bcUoRmo0K0c0/VvUKRUJMdcIijCSEiTsO1986SRuytLAT45OhvR/JDU16B+yc3dS98xanDgxVe5Cd7tj5zwdbb1Y3DhMmOqoxah195sEIeSWYYAGCEOd5yeZFrZWLY3D+bRbImul+qrfV0mZEtebJSAZtbXMhTpZbje2sUjQfa3CmbyBXfc9NkofmHutkKxUEX0laxwEuO5XA+vjXx0sF6yq0qOV8WI+fGh+H8zbsOmxDHQhhOHTqYt5jPSv60O0dR+HTUQihTOnSg7+S5HOSOnahUKbq+M5TtQQZDboD5OODGM7ZW8xMZXRYYjDtBwB4mTHWDwWIdAxUZgDIfCeCICOZ0xJrmUO6WnIXgzPo7X+IACNd9grwzkxVeAAMZSzliRnQBGJEMFZrtSqBknz5pe8IHsBw9U7XPQSoY5SaSDc9FwlQBjUxv7UEj9TgQy6cKK4Oba1pax9jmRgYQwB26hPGijSoo3YX2OfssXGQBvHdcKwNgue4JgnxFuLICvNA+VyhIIGGnhyYms+APHC8T0Fasg0TY41WgMlMdSYsSTY7Kqh5xcdZoqS2hYtSKFIJKgxY04uFEWIMr65T+hvVXAV6y6cObQtvbL1/0D8BJPNxQPy02mgzYsBUQAsZ21D1480LVxiDRZHnKCSA068qwSobqkm8oBIc8z4s22Cezxw+MsSrrH+jSCxMRCSzVE4WA0MYyU6emc71PrVrkj6mulFeYn9IF/Gp2Nu/FqAJYJOxy+9LBdP8pqEw6Y4V6oE92GyU49VjYFgDsxTqZCZw6GwBA1MueYw716wItxolBfJmLbMYok5g+XiFe3q95T63LA+MrEgTAjwpbUolRVlunGeVeE+7b3CFhTrBRiRN9djFH0a0L+5/p/tf9A9th3UimXX0ZevLBeqWHnZUB2Vm43ZBdAXTNCzMKRddVooT4Gmp8ASHDGGfUrcmowGpcSizwPb0xcQpLUiHn0LKAmj8axO5x8EtlLWilJDZZJomhk8q00CB70LcLeV12C7I1VUnMkqWrcWlKXKLbdMViN78kcQQELc1seSMAALN0RkAxHFYFQAeZ/cbkqEcMcKheckNqdNHK9lgwVJU1ElzK5PYFLLFkUdppah7Ormh3EZqO67uQUdBZMi1NJc4IrykVO7wjMCP6PDhg+u3D3Go+qbot+zbPcBCdvuUPx8joq+R46sw1r57OAoDe8dgc1QgY8eTo5ysJJHmgfnnnjNp7l78ScCIpv7CyEh2OlMvDwGpfyMuHlfKudG5I/KKUTgskRms8AACAX1+pWKejnpcJ+aX8Ats/OHZyZv84dagpt6/dLI51mObpTtef3QkU1r2ZSf6y03kKsVcrZcZB9OdT0wGQAAC4//JTE/Dj952o+AB9X5sNOXJTk26cD+MFqdQUGyuXfTQdhgkfPPTs4Y7x0QFP1FnfKtCWRdS/rNXJr/HE+LVxOL6Cu1N6rjouvU6pbF9m9a/S7MH1QfVcYJZeFs9yYeYjwwQSngAA3Nxa0Up1FL28cW+g2NvYEC9Y0Pbx2c1wPVZYGHe9ha1BU7eUCX9uXla4kyUbzY1oq0A36ei97Zq2TULsLU2KfwldmTYpqQKNhdBHNwW1WxVauT4SRI111lCNKi4I2Of70gkhkACkFZc/NC2iAgAk5/J9w8mFckMW0KesluFpPAsDK1q78we7VqWiAa80T67EPi/RWgVKfndV5v/GruhDC2By/4bo0MGWVybLCGkWACB+KPeH9E0ZK6I52VDAeVlFAAioC5ismmAI0GSYCA6SBwiW2Wk5VhEcAAcrTEDMhva1LZ3+wUBswpCSV/3SqrlH8iupsXFXjKJsR3CQESYAEJBZT1ZgLNAcDxzV/RXfvHW/Ebx3Jk2EZhkiX3Jt5nSpcsg5/8u19U0nBiwOgDGbSUfNZMSw5CMGxu04nTuTrnf9sRaV8qbPRX3WbKYfIYGVH0mlGvFXUzPEnrPgjba+lC7Y9jxhNhfitlB0JHVBqgApHxzMm4Akmc7Qo36nFQMmK//63Ttv6O536vedjea3Mq3CMK0DOqs4hti0cOnAwhFvNltfJwlI8pAQyRsyB8fx+YCQFLwU7Y/e/JW5PhACXWorFd7xwWOrPnT1N8zzvxC1VMyW/3w5/jNqyJ/uTT/wfb8WcSny4vb3ptzY1uGOb255ly3qdUkXkrn6OAPOQF0wUZpOTNU7CgIwQotatMvlCYbEXGFmbnkyLz12XhiOEE2EQrG3nXnp1V5nOG6zud8hrfRjt0PRDDXt6Flftq+9LrG/yspItHlthatrRo8fuGz8SMKuowvu8nSWcxjjAH7G0ear0hEI642xzv5Vx7KWjOS59CE2R7kLRBRUEol1Hv3V3vJxv3ydnZ/LCU12YWT0QPtoQy5+NKMsZOofPQEAyPc6iyYYz3zlT5nn9oyvKtX5B5JVMr0Zx08Ep8t05sgoCn1yyUq8gIQj3rLObe3j9biIMIoX1AMiKbGyYVFxPM0htPzaxV1RHKpLgqPWpCZ3tKzJL4sscmXWLOPAuQEGAKhpxYJeqvjHznrlTLx79YszEyG5G1cEnxFYJCHvAqAYpP+55bCn9AUTraN6fOvIxGxkjFDLg+aRC+xdeFls7fXKS4dynqP38mx8Z8UcmdEOhNbecHDQD0HQ6LYKuvCHSejBUQ4A5Kb1XnwLfeGJgm86qPS4qfozGOKt41zIQlQdP4I36NXwtS5WE9A2/sopGjrXSomrTb58qC7XIFDs48cuzLeqSyJLhotxQDjWkVpsyY654YA9k5MM3XE4g4QUUc2M05DXRIeXH6t+ghWQAoXysqcQlcU482Em6cxJatRDs6Wc3DleE4v6oJKUnvtBTlB/e+zspJV+1axXJzTjn9Y3ll5ayhQmFJM2EHIqvp8c6yCzruWqvYdM3/MnWn1L53TKC4Umato2PtFRGtfyzZax9mz/HAX1SwxgtuhQ8ClxAdAmUwyXTAbKnqttquRYfbaWCBddSFBoXI2KDEeDor2jCTdHU4lAYLCEBnwAANR5+8llTrlFjQGzFVIMgpfiGvUBgCcXVbDK1UX+OiybE7VKsEAJAIQPuD62mgExK6P0id3jXoEBDmU2K38acuYsGhvybB6c+MpFvusCtLevjmgxViBR0rPunVJVZtru1sco4u2dEVQJ6QIztaLDvkkfAMjmXsvhCCeWrY6Kt54p+0gOX7dr7e+qGz9JMZvZRed5TwAALdLSyggD4Dnp/vwEiDmLgTD0qQuBEjXACqrfJSgDkLzOvBxokNt7omUAgOaVkpamzTGwsXAtJYfk5rKhyDYAINXIIyIMFmLRwNWt676ybkmyAy+tpsjFZGzW2sxSZ4aiyPn9mSkHCBKC9Ve9nDqKjkvmXFtf1clmg3g+RpghwTyi6mUtUTRfOGABACDUlEoVEJFBMJ2xkqvK4cruSQ8AcCKmOY6gKFco+Dl9+fZmv/VM/j8O16Zg82V38czEUsx2kaxxNFNHVcfspD3H1qO7YlUuNdCywMAREEklnqtz3yc9YQAA1LTWL6uGZhBQNM4hLYg9OZKucinZwEARBOtqU8LQhguo1fMbP91YG/3C7RxEaxsqGIEo95UBN0gkjAUE4LwUzNVzliky+d//AgAQkIlT8REACIy1IMbM75NIXAAA0JV61vFVCRBliIDncFYmiu0DAKg6820LymXMvYymlm3NHYxaYMzhdwADEIKAalLgzatIVUY5AAAjFBW27O88QCFermWHp1tcL87+LdR7au+PAADUEJV9DgJABl91jHw5lPT9kg4AkIyWPcegCPmAhI8itinHCnq10AUHheCGSX2syFj2EJgYMwSNx+qIgSwOALDleK9/azR6vz6rTIJRueKdLR7lAk+dT/zv1KMu7ttWpbfIZxwhmCwYw9Sy0GCxU6MT4xUAAF6eAgKIcAABlE1l5CLVz7Bqxi47gYQnY0qVYKHixyun9KbJ0+5Y/YyCAQCDV+AAWja4dh+us6Ob80UGjUO6jziG86zmp+6dnIu8Wk6kBH1XNzAQ1plxA9ohupb4IlMBAJRQNbtiyC4nAIKS9h8ay4r8aNV5FoFkrqJgLGhx7a/3W1MZvzUxXE5PXKhCKOh33NRzbM3Vn5kFSj7731DiI1KzZ4mbHpi7rKNLvgXztCQKTgoJ/DJTXy71qvmiQXKhEHUAQCyQwxUJMcIEICHLpXK2YAYqVfsdwbGSiwn4cuibfyC+JxR/X9ETDM6nTzKkbX9r7NhHAsWGOtbjvyo8eHJN/6qMnE8BnWMixL+ppXmAOr6hBKhvoZDfPqz7IVoOGYFQRgAANk1NASKEjRW5IHk9zGkh+Wi1X5knRyxkuz46e8ICwMLzmj2PCQCjfbIAuGaf0A2ff/Fqwyp0qpJTJ6P8ydiJU4Jo5wpCvHi+r/7cvAUIpYkJC7tFJLtqWYGMnNVS4WJh3AcAMZrImIK7fjmAPYS8k87YIMbjVYK5Q+OW5zolJ3PiHACiMg2o7WcrgBkvhcvSjrEDHBDW/2byG881u1xfPjb0/GQdgDvO/lFBbWpsSDsDUt0pjSpN59uq9hY2CS+kUORrqw+UFxfAXVJmkvABANJyo1P2sCjxmOq5UquArQOl/mo/d0HDFHKxW6hwtQK63qTtpGV8tgibb1swvqFVK4RKHoup4qOn3K+uW7shMdyx8eFZoEgZs72gVe6KDQHVzyNpwJtvrddp6wjTk2W9TRsfWZbMlRoiY8F4XACAENxwKpqi5RoTphWLcMFKGV5zglFBczMNgeBkZ6/ygryZhhdtPtbSef9ecbQkfThMXu22+18iHz7y/lMgjp94cuzRNGaFOopOrn92orjM7s40FLC39tk5kOwlR+E8uQUAsxlrwt4X9yL/e2Q0YYaKfiZAAlW/1Q5Ho5NqOSCmpuJFb+gypMU2vajX2HNZJEAaUgoOZdYea0Fr+7Qj8UznqOCV4+gjhHAfOMDf1d5N/TsA1Dsl4pub+grkwND9xQwUyZ7zmDwB9RXK0/B30WOHp8Zk8pMKFxNQEuWf7XSnpjgAgH9wQ64YLLJxQzPztveimBo0y7VjRPyX+nAxnTMmkzlemcgYx3h75MRwvsaJ+erZ5tQ9XX6pbQQ3LLi2oQWI3D33NbRtvs5WS/HlF08WJieKTACAAPir1dATxAAAbG13D1dta/BkzhnZVTgBzSuWDLTWaPKuxa1k8fKW7NmpvXam8MzgcKs3EA6+7pGlWYri1X6jt0JbGes6Bvra8/ZmyX6Y3lWr/8rWy/97LpkbM9k+pep6yQusCc17uGA+XRgt2EMgwuVmrzoClVrbfTX++JlhOSOYKaA/f2D4+dc/oDXr5sU2xyKysXJp9ixOb1mwZt8cBNH3ZIoifslElS0IVf9rbXCfn3M+gPSOnHo6i1wAiPV07xsdPpnxUSbvFQdzsPWKYwdPlkAAQHOs4/Hi+NE9GVFhAAIQjI+VphzJe22c9fUKaNmLclo/tXMmpkP1C7zizV+4FqXVYJ7qnlOzYbWaKjTfubDa/h2OvntbSfrfA27kHUv/NJ3WmFO8pSgOYc7ML/ViEba54MLfd/GzWHOajAAApJXPf+am8FwpQ+eTYrZdNWpz33v6Xbf8OuvkHpmnQBAprXps3q18yGTZ1retbrR4+bYLP+L1W0vxRLH31oj+hYdWoZmOBGNkyPqiLzSGwptkhOS6Dt/pd3+duv/tfenJ0nf6nvnp+wPnDbg00RZapCsLLw+iaRjT/9v0mxZXyijp1DZgXL6AE6/b0H/d+sslr3TtWnxP4c8v/sblVTtGBQFJUT598hF0fffjcXFgcPaE5llQbHVItNv5caNhYAF66J45A+pXjWXVSZ9uNUajhzJlEDCjTPT9Lu1vfL470FkmueA//UW8R3jvj56IPb+seOWEa1/9PFSqMio4IObGd+0uVsYPjd40tbeuvqHbij4RDqJyXz7ylB2IWS8/M3fA4TEyadkhm+WuqGTMOoom3n13arNkM0XmPgWK3ygtARD1EejrNi1cIjX3B6WIOvTrX0x6AAASll2G9YIAIEtu+tt99z1WmD6H8ZHPBqWMJlngoFJObrEUskjUVeghqU3NO35MjXcub5sYeCE9IWoURQ8+OVDejAFpCDm4pP7zG8bZcFl+46jvjlvj3eu8lBIDVlL+YAsAQNsbmddUqtYnoPFd+fdlx10KAIDQtst95IeRp3AnJ+KFbCDyI382eERI3uzlYhviXlObOd6wpGhPQe08k/joSHf3RxHBAoBg9Ma1CC+PtD0rALzgzh7GFphChJSWWuLoULjsjboIAABLiZv7+9zvPHX6xIQAEkxSoQgPJNcnjRnPkH1ZdM6e3RCY6skIdQIrcSUpWan2ruyEU1uZ+uF0uhTDzPd04nHA7I0CffX2/QAAcOzal7YqskglafR4DWixRB0GAgDJUePAxECisfH2I0cQgKicmwpxZMlCKlNbmQr7chYtrTtkIgytsmYiMtWSM4qJPEGdS/iImFlCCxEMgJAgSIM3ihN6I5+6FwCAXr1O0GAAhwBxv8oQ4UF15UArGnqpflouf6mXrem6FUAIYigYiM/Btx2EsSqiO99dtxeMtAaqtfq9IiZkQBqGJYGMXec9IcGJbQBBr72STVus2v4FfuLbewCQWPeElO1LMFfIsqcn6hckTOSPX95vLJ3cQFZLhlB2A4DwEwpGAoEvG8xzda+QRLc21k1COhJnc1JYRowKl3hyMan9adZ7IhIggiSEoWK+Fk5yA6YAOPw3/70CAwDE4zvejqAn9hQpSAlwgjHKNdFRvyprV33uZqmt2VO2rAtTV8BZASCQoIAEgGSbDPuCNodwe7yO84idfn7gHHIBC0AcGMgJd5b1vqMgxAXi3h8vey2g/DH+gR/JldVP7j2CAQAKv3/1ciQK9DsdkeVpYzIwzlrAGKlzr6nS/ZI7uRq7verUKb93f/KgAAB+vEwVz5V8xRTcKDChh7sidcuLO3K6NNSc5w0AgLgrS6VX+SxQpCABQAXI21vmYzpohlP5+i/+zv02Zn1ducU2AOD2JUTRXn51TQumUqhZ/c2GhMfqNrtxfEmvvOgSzYvrUnAN0VDgn94pAFBAlbEpY+6UkYl9PfbzBQqvU2BSbnPVqBoSCCPwNFV1kpjPABUCVYvIUXwec5+wvFdvnoJHT6xy7lIfe0sxHR8FADWoC+cn7SoixEU68x5czB3dmN4Y0vxlX3TNk4VyY9DghVKg6EXklQAAopXIkO1AQIK47Med8V+/PwSzNdkIVl+WneiMojJwDKKMKxGEZx3n6zAgEIAA1AvtKL1yovzZfsR3u61n7ypf9ctgtbhpZbyA4qv9Bo6EFAWuJl1Z4c1Vimo7bpi60TidXJUMIMPww4QHw5xGJRcAfOTJYx1Ed4RPiyaPmaanqDNA6YJPltPdujFpYy4MzJHkbrpvFugTrgxCMCzAl8j59sl/gDAAAQ6cEefgfgAhAQCkf9XaWLwfv6NsiYXZ9qLy+76dmJSqifF4YNcVe88duilnBpIsG+iuFJO2JYoMAKBvssHLlgJcYoirfj8Z/sFXyKxz5ftPV/wFebucdJhCs+WGiHto9qSttESq+o4ISWAUzwNKlOkVrhYbEQFAeVO7oa34oGKUhCRTGlX4uMZBIwxwMrl+Q1QtrWldOY6VnlyjywT4WIaoYgIACStMKIh5nkeo3tMjThK0abrwCzet1qNEInrUYTIGFNEltOGpgWnJv+x0dVF+DX132NxHggP0wjWkxRxQKoW0RHIDkiSEyiwhS4CC72rtXGT2vTg4MYya20rjNEbVjpiLQjEMAFB0CEpLnFFNLxZiYS6GEJ+uUkLXvTmeK02kC9aUB8IrRmS1GIkmZyj657fYek2UxYVXNnBEGCB51nHgAPRz34/3FOkO3YtEhJ2MIVtHFWwIQoD0Xrc6BFmleclg21TreCtPur7PSyFM3UgZAGKS8AOYqaYvYhWU64Q0I9NhbssmFhAFImOmCSYRydeE7jc4s2nH320+AgACCSTQBXXMQijYF3PzD+KFEy/AFahFGNRDASwEFqLguYplAx36RWZRs9ULrLcJikjTLRQMZhTJz4yZAABPPfmmvGcJEL7KIGP3Q1p4OQAgDIGaVhIVoWGBBCAuGDhE5CZy0+YJIwgAAGBAgEC5AClyxfkJCIOub7oyEPwnTSAZAchIARAKAszBNU7dFu3Mb+5qdZoalVIiY1CgjVz1w81Vvd5CCfU8yczSALVpGwAWCgCgxcK6emUYpZCfQPlyws7ospJ11uWiI9+rAW0GMZNGRfMUxuMLQ9DywTMbEoY/EhGMKcDHWoUDAnxEAPjg+OFb35M4empT+UDf5anv/cnYcOhf1gRHnuzqywMAlPPthbyP7WxEL+WWmT4IYTsAIFLB/NJlwdFchFgs3jaVDkULSqN1VPeytZI34AAdMzydxyGdx5/i+8nR7wR+vi9U4CKZEUL2rwUA3RcCgNvOzx/hyiT1qlmL0i74YN23Bns4CNLfXvEnOir89JVAkNZAGLBCWR3cVNGt3r6RcjwfOaitDI0MPOu5UoXVMGE2XTQpQOBABd5AQyDh8w/cWABW/KKBjPTjSw/ser9LD/O1ZP+Soe51UPGPX1qdX21uvvZQS+F4Kdia6jd2KvbDWRBiJqXzAYBGAIEAPBmBc7GZap904eEBldmK5L5+PyToNfriXy1w4aVbdL9od7hAVeflxupegz0wMjShIMcnp0ipzLUnM7WsRrXzlwHtrVpRGQCSbwgoApAu+BHLwrhoT0PVmptAFU/vsccKYvQxQIism9ll8kb9iukJBhZHzS8PzO2a3I2We77nWvaUa5rV0vA5V2AghIBIdb9gRImknIeA8MlDtwcuEnNhaPvP4U8uf9vtq7TWztZ4UpKgzTz898HzXkMII6m5sXZGo05vjC9uyz24f+el5Nu7LtB6uW0RpCauWLF717TrMFvnIHW+5cb0K9fJu/5wtsguMA5k0yXPDWZf4+gKEoBkNHNrxOtmPGaAhn5wk1IhNpUk9+Bl5/VQd965/lxgbCNYD339PDcAr11xw1VUrigMlV/8Rp91vnYFv3L9r8a//9pSi7DkcngDrZYpQaEt/9/AkK2Y7ATNk33p+r4o+sUvco+2cEgztX1165nZh8r1/wJfOD1x35lHHxb5vPKWe9664ACtQ4Vu/CaVXs7dvOaV+VF+xNQL4g3hnKbo+i9eRTNcodlCTKpEJt5yuo4RPf/4V/xsrsudKDe6ma5g6r2jM9U/b/tEJBTsT4WVk+V1WUVPOCyWfu7jdZd8vCv+psPPxzf+7775ZzcYcWXPYJWLx7018xTeeVbywjmmRU4ivdEpQt0muHXV8ECm+WQhFNhTaXZHR3PhGaCn9t34UI8T6msJxF/BRiSb6W3mSn09wv1/81+n7Zt+eug15K8ClISsKw+Pl+Bircp69ZmAFAwBlSQTy90gnh+qO9j6cLvX2IyUWCKvhbq5tCRw1aPTX/mvVwa09kaiQaKsaAmmagZixUfqwu3trSK79A+nX/scmBxdelfvwmV28XWyzXUU/VSEcdlhFDOuyrkg/+DuWQI0rvCRmikEgxUPJTLZDnlleZo8oZV+2qYlFlIdbICdl31rJGTXy+itaNknxvdOvfb0XuiT4QMNl+aUPRduEVRb3BRw1fO1faYzXIISk3yEdEsKGbh+CS3LQhTcoJsTwSmkJE02Mb1jggKOEvZNTl0H9EyAtDrIjpQG68gndq8JfNkdf22z41AyEBbCUm87ac9bB4QaluZXXum8WKPovtEOnxHAZZnQHKX+rrpXreOLeJFhKU+5lNcsnG5YVPMvhHNko49KKpNykkUrMkOFsBaYk7r/3dq+/VSyXxMo+Ke/37o11C+1Gmxe5uOpZ0Xmj45ZA7qjRargGGfYIlgVrr+s7tXkYrWkysyVHI4IMlg4N3Ogz18YJEgW1JN9jBHiEPDLmUfq6dJY2BdYkHk9VfFP+l1dIdTc1X9h3SrS27Wh8NsC+9b3/bS6KN4A0IcBqG7ILoGitKbu7bUUgGCuqjr3Za9MacOMzyKrgCWKmaIqAisEyVh1Juacx1x+zm9clpgJwOdbXNnU42dIMteaPO8l3BsLnP7mtWvu3YbueceC2mnwHoyjDsiW7ynUdwOB+n26hTLWfS7ZDIJFggPYGJ85PUfCEhVYENdFRoUI2Y5VEo/M2ZZuWigVltx79vC/VY8vLtx/ocEU5VibaQZ6oycAAJDAUpclRhG9+Yv5xkDyiwcbvW0qxzWgx7SOJ7YGDdUEEGcStqjn1e7RRrdcVlQHm4qFirw5EZp+5B7vjtAsplKZVrSy4FqZITJnuXzoA+k0kMAzaOVhAZDfP59hdyZ/smmDFj/Eq6E4dyvNp6B9cs+mphTzjz59+OwPdwe/W8s9bY1Ix69psEG4gje5g3JLnUZ8plFFRmRKONwFL9CwaytMX8KEyOKAklUY8YWLbBdhszEDJ+pzgej6tkxHz6B1VepLd65/oXls44ncHM1GslwJdF73QN+bnXBTrShMMlaENj5723Pvqhx+6upWtOHGpLay47FNuCq0RDzvcpCCYdnT5WNMxrMOXSMW5wjiyIi6XpDkTqAAr7mhQkWYEUptoqk2k6GCiEHnEDT5XssLV86Uc5XPlF9wB9XIdXMLa3+8551tL/7OS7mq3dqIBAAC5Yb1B05OSa+k/vyT+8hl1+tLeoJTy6KtQQoAgKWSgQiWGCrLOJvV4sRYtXt6tHYBQHEI+7ZR9mjcF1ZIs2pKqTAtJ2HZwyJYISTg0qn4nOTqnVpjqVhapo0lv/B+ypKtAVWpX6Ri/fdaLWcOLIRjuQ7nln8FgNjGJbGza7RL//DWlxTtY70rcglb5JPUFz1V1o+F/cln1ghiWuBm0/vWbBDpmdH2XBU8rMd8xUYOKfH808/dQGv1wcI+0yGRIiJyUbKkMrjasC/Ncfrve/WKtJ8pFp4dPtc1ZcZKxwfnLKaZfwKA2zlxxn99c3CfACQqA4M0dXvmkXOZqebIMiV96HpcrkT88XT1zB0OBwhZrnpCYKTGAvGRS27+8HQ2F20wSCYeBi7bBoubC0yBRdNA9Vk+kQCu+5qve2EGIuI0n5lTJIM2vm0iKBqcQzbsbmSi+2Tv0vqjhdWm9XT8Fn0oWdoquYKE78i+KbfBTzrBqyhXS0rAJ+EYcRb2VA/OaoqC/ITjyToqekln8g9v3TlbHu1SdELzXEOjGRZX/UoKuctqtzp5PghPgrIqkSmMkKkiOTcnfnH7M41N4sVBwY6uuenB5/yi5wIQhutqacgPtBO/VfyzC7HqAkv/U9uqzI+2LR85u+KGie8+3L2f3/Gh5CsdHA5TAAAZuzKVCGZcGKUS7TTQ4jXTA0lxjhSEgw6HcBbKbdi0jYUzSsspV4jqEC9eIsgwleH2+giWv2AkxsJmcdRik489IgBMAGTc2ennHyiWqlWt6td2PP+Rz2mtvor0IoAQQ3cD+iMWALIrYBTgxz8RSBBe1W5nv2t5T5Yct8xMMzdy9NFibMbg+c/ZZj5v2sxFDOWdYff/uDWdVHF35V2v7DnIkRyWc01pnLF8vf0xd6Um+8fOWUy40/UNknPwucHfxb64LSQlV67f9sKa/pEKLx8+nh6Z8QgEcCHE7CFMAUzUMiUdOpU6dc4BAghXNoDUSKZdGbJKYTFX49x0dcByMzJUPD2k39mIHINqyLYiilB1j1pEg/qmA2qMuIfrwPt3Xb5c+g+kflgGLjAg22l3AW0/kaKvfz8tBQBgLgLoKiER8s1CKGBmuDJTgcHHu7xx4K4Xg0xmAXIFFcXpKJkVfGciBqYfQMW0ijOouXvMm+Os3e6ihfHnMGVimkX8p/fDtn+PnFHX5/Z614/+8EcLCpPCd3Oy/vr3D1VltIPLgU6qZexSc8rC7Rj7M9spZIWgcaOgmf5U8zga70UdIpqsPcS9ODqp5NqyfqFpnJmN56SBq+NzgO69ro8c2/EbxuMz1ZLMgj+vrnslDQB8uKvivH4eCQMAVJ4ulcwTjmMOqaY1Vpmyn8PlaV/Ge6yUmypQuzTWgLw0H+BP+qXpkJn9rmSRKSoqU2HhTJLK5BFlsjyHgafY0G/PPe36PP+6KODuQ5P9r3+TLgUAkHYYjrIqbg/2GN7Z9UhCYTSzZU9uDJhkofCGNxJ7alHAR2sM78j0BRHrWvjJ2xGMb5HskSsMZ/89JFCg9Q5wQ4d6qUQFIPn1Y6KPbZlIaOXXE1IKABBQKF4Yl8ilBLEbDFFBzXo2WAs1NUVGyyRMo0Ll65MkAy0K2zx9a2C3JO4OIrgsQIJvNmT32mQs1TlZBwm9L371Gu0kJhcrqLmGNprNfRelqPyDtzyc+e9lU1OdGJdJ5oTyzLUnG8tV8Zf+7y0/eKSsHzjdC1YlnAaMsweOTPu2//qdY99aGe47eY1aSq1Wxl9y2SvDCM2eYVUPlLrOVUJtoVHDqtUYzUs1mQ9kdyH99aR05swdkgPUL3L9smWVMwer91vMDoqDXSF/fGvo+MFp9ryRyhh505tj5v+c9QNXtR/cY1+84wyUIHfOlxRMp/uJiIRkDxbetsP1vvSIWdtPqrXQx3Zme0qtZbTne/trZcoXx0mu/E4Lg8Qv05/bHOFTNx6rzv06HWegFBe3v2CL2g0b1VZ3+hID/O17BiP/vint3hD46QOluk9P7vOlQsNI2NQtpxD+8aMD5x0uhWBZXEAq7atvfX7j7sk3R8vRn1yd8jpj3/3qG0pkAwCg6/uebbSVz7x3oz2j1LPfI5C25kShxWhY1OO1bx/JzIBRDsa88YSTMUOOGWvXmu5UCnNTCnS+crj4Pxhdhrw0KUbL8QRpa5Dj19//hiup/m79DjBhZ/j055+uLoNozpk7oRWbDviFRHqUrsxMZaep3bnZHZAGx1Wj/zg1iuPFllf65wLl8wFwF658mhE53y/Cu4q6WwS+8t/f6KXj2Ogafe9X336o8tJ70if88ygKAI3XbipHIoudsNYc7mYHa8Pit98WTPpSLGqjZIPAYTUsHxi9OG22fDkQSkQkNRoRyZaYY4Ql9dvmRXtVW/gXyZHt25d3Nd6C1jdHzrH6y9MAAFo20sl4x3jWQDlTXTLN+o4reWgiE0yYtiaVio1ESbstdTjVmouCsOD129K3Sy4POJ4q3AoJ2ZZOa1vS57fpguSZu2YAyWny0Xt9KfDQRu40fGKg+adqls/xBJZZGaikK11cpeFK5zSYShO2cnFu8hAIuUlUcpjXO53T0VxgwY1NQYTQdOq/QdLCnqsxhuMqUdtl7IrMfECnxWE2rR540infI2NtywLL9lim/x13Xy3XH8EAdGgxMry0YnpTiUDw2LSIZuRhQ2NyQca0LAPkA4F4R12v2uH2aNfxLQ7HMUEyVZvw+NYYr1CBy6pDi6rwXFWNpN4Q45dfod4zsbxf/dpEwgk4KSm+VDrvCu9BKUqzCZrMac0Jb2O0Nmxv97CMmUk9YFS4yPCs4v7zRXTxpnXgTHZIH2+aWjl+qwkAcE9S8iQMDtgq04hRph5//XTttNxI/7qTv+e/FhU+Mdw42JjVke62SOsPzLlsKEv0KUnzTeRbqYZpNUOf14yg6RngOpgyLgGYxwbreiEA/NkrsudaTmY/vy3IeeBglfdhjFxVMJmCi1XmY0LFvB7n/jvOMAAU/KNzvQ8A8JXHygx1fzqhmtl0Z2G8rYIHcGMjmUNRS8RUJ1RsyWgRhcZqvjreTkJ2WbOAKSWJSCyaCj1afxQnwDuu25ZSOkuLbmqQmKX6v3cBAFAUIY0BcSyiOIA0RDx64Y4guePDa14NvKL+z+e6sakXAQDfdeWaqCYMmq5gVIaFNO9E06GG8ByguJxSrAnF5UN2kA/XLlIS3/uHCe4iahJLqlCklsvh+n0jhHD7bntoYwUarizszS9+0vkjAwAQY62BEiEYC5lrjNkTBiEXXpDH9u644eiZTewDPS77dnW6F9BGLh1RxjKpjQ5HXC6FeDJz1pwD9JLrYtlxhUapoUVN+eYfCgCAhg/RUiToe8KRPVVSndArjWP1RW1bw43lRUvUKXUBL2w8p020NgwBAKBWFftU8n1HyEyi+omeORne6obF4i9t2PoLSSC6mYP2tz8pAoAUXCWaV67MFlFAqErQzlqq4+Sn6u3oXYEvRAsWjfg2wigVINf9EAAAPq+jYisyVYUXmQKOrsNQfX5JutodaQApE11gnbmv/EwR1wwJ4QTGF4GgIckSKnjP30ig5fQMG4oAAK3Hhz+XzagkgoVHhd99AgDC2zzUd0m5Px6SLb8LhoshbSLB0Oy/iQDB7389LkxJ9klQDWgNQaOjqk4OgrBrBRHlIUk1IlR+uL7KAH1sU2hjT0AOUAW3NgxUJL92iYBCfcEJUhFyqePaki84qm13ItzZcB0AWnAK9n8s4NghJJhAwvcAAIhzDmil1BuumATMSTs7SHtKMF5X7egv+NEP2g/8tRaIBMvjTSoj41U8P7981bGuJCgW8VBWB3W870S9pfjZ2comymSHnsunnjviEzmYRwIA+FgrnixrDBzFQ5bPfv6WK2BbbV9MDPEY4QKvF32yu7n0vNg//k7x5P5zAADFE67pnuhbGNVNUvEG/nxo4QKIHip4NaBIIne+t+PWS4biFUXWvaCmCdFSvf/jzkXS+KZ25hMniHRZ8xvdxjqPjjYvJprEFL0UIZORgC13Bdf8tAgAtFHFagiQjSxDBJ24PQWwCgCAKMn8Hddtz669Bv287RumyMIlAsvfdIKkCAAg2wkVtW8Z1s4gkgqQS78/+dlhf19FUADd6vrOH+gPXn6ysmpxxDXbCKcJfSpSu9sGNlM48FbX0ahSMYO+aYD85Ox5Awwxxgpx3S45yEn8VXwqtHFb5eEiAGAXwwgBn+rILCUk8CdArAIAYNZQ7M6WgTNf+tm30V0VDgAcmGWBU/38tvZmzO9CRnBFcXJoi37aLniuioSg6J4PZLuj6aV68em3rlfNhniorNomimuSXT270eAj7JMQA65nJZlzuVC3s452vGmqqDEbNA+phF/ywpbF+mMuAEBcBU7BJwwx1URWAAEn1YIFAZnt8tK+DQDwvXqbU/18KafKaM+OFEkXIjTvLIViaOVYEIDSI+iEO1Vo1Zaufst4A5GnytnO8HgAOVbtH/L53S3a0RMRCXzJMiuNEEhXZiv/BV+Vz8sYMY4rqkxje1y+Z+qRAgDA+N5tjmsD95nieCZy2Dc+wmfMr++/xi4uwNDxpoIwnx67xAtatj9o83Il6wwJoNqjuNsNeMrfqgKHkCCEJhiNRTVHWlrNty1rpFI4IjjzIiU/iKSAsfZPM1UFDdG82wygOc0DpctBPrJX2b/AcwEA1OWYKsJXmFVKBH2QAGMowsUbY/45nGo79myT0p4cFh0Arah4CoCaE4mTnUhmAoFvB8TI3Tuij3wuKqWm2oeroaiK+JTh+UiVzZFmqxyXw6dmKCrtRKYeqPiBeOZUl+cYy174FT9T9fedXNwZ45JnBmWvkAAAVbjWGwBK9EBOtGshq6vBPrsTKiBYxfQAKLv8ztiyzc+a97/pP05z6gM8CfBYtBWTJS8KAAB5kU8ggUImKkeoQZsrCrrlN9NIvaduOK5kQuPdk62jzpIlh7/rTLt9EOpgCEmpBPenWpgYX4QSoEbxxYsI1E3F05SOdB3vH+geHj+yA5C0YG8IgIqRr1XF+FGYcWLZVCboHq8abu/p7bp7PCbZtu6bx5aJAhT3zsroVHGwsjLfB4v2pirF7Kt1XkfmpbWls3aEl6xmXCkqttjLOX8DRxGK/9P1RzZy3OgjC6eOPbH/N8AKaasIgOc/8imKds0nIauaKd0YJVNhCqdXJkPpeKjOWxMYRZrDZJHSBkbX6tnDxwikS8JxOxmNlheH1Sk9wWCphJ6JXwwmgPzR7hYpuWpFvslLpu4bPQI02jG0+A3c3k5aJXRlk0SXYxDvDhO+Rn8nm+WgnG1ad+k43RAdjt99KapP4sZkCYeB+vfCEvvgFlqBAMJbQ+mLTQetsW2/9uORY4fDrQ/kUdYGIKxXQuK1gYoqHvrPHx58/guXjeJFzN63PpQ+fFX/Gerz2uJE94k9FqYDp7OryUuJMp69Z5z94vr//P0IHB94iAH7dkUC3zr09aHXBogEIAEIrnqyaSyyexS5+56xgfkCFjjDA8MS+39zAsOMa1aIDAAAAABJRU5ErkJggg==", 379 | "text/plain": [ 380 | "168×168 Array{Gray{Float32},2} with eltype Gray{Float32}:\n", 381 | " Gray{Float32}(0.00140151) … Gray{Float32}(0.000945836)\n", 382 | " Gray{Float32}(7.24196f-6) Gray{Float32}(4.94123f-5)\n", 383 | " Gray{Float32}(1.00136f-5) Gray{Float32}(4.27365f-5)\n", 384 | " Gray{Float32}(1.66893f-6) Gray{Float32}(3.92795f-5)\n", 385 | " Gray{Float32}(7.83801f-6) Gray{Float32}(4.20809f-5)\n", 386 | " Gray{Float32}(2.98023f-7) … Gray{Float32}(0.000131667)\n", 387 | " Gray{Float32}(5.66244f-7) Gray{Float32}(0.00031516)\n", 388 | " Gray{Float32}(5.66244f-7) Gray{Float32}(0.00034821)\n", 389 | " Gray{Float32}(3.42727f-6) Gray{Float32}(0.00037089)\n", 390 | " Gray{Float32}(6.85453f-7) Gray{Float32}(0.000832379)\n", 391 | " Gray{Float32}(4.61936f-6) … Gray{Float32}(0.00180849)\n", 392 | " Gray{Float32}(5.126f-6) Gray{Float32}(0.000150561)\n", 393 | " Gray{Float32}(3.66569f-6) Gray{Float32}(6.19888f-6)\n", 394 | " ⋮ ⋱ \n", 395 | " Gray{Float32}(4.23193f-6) Gray{Float32}(0.0191346)\n", 396 | " Gray{Float32}(1.27256f-5) Gray{Float32}(0.00387678)\n", 397 | " Gray{Float32}(0.000157267) Gray{Float32}(0.00180429)\n", 398 | " Gray{Float32}(0.00412166) Gray{Float32}(0.00066939)\n", 399 | " Gray{Float32}(0.0183611) … Gray{Float32}(0.000279278)\n", 400 | " Gray{Float32}(0.0533536) Gray{Float32}(5.22435f-5)\n", 401 | " Gray{Float32}(0.0142402) Gray{Float32}(2.33948f-5)\n", 402 | " Gray{Float32}(0.00237703) Gray{Float32}(1.27256f-5)\n", 403 | " Gray{Float32}(0.00011003) Gray{Float32}(6.25849f-6)\n", 404 | " Gray{Float32}(6.19888f-6) … Gray{Float32}(7.10487f-5)\n", 405 | " Gray{Float32}(7.45058f-7) Gray{Float32}(0.000135839)\n", 406 | " Gray{Float32}(0.00011459) Gray{Float32}(0.0109034)" 407 | ] 408 | }, 409 | "metadata": {}, 410 | "output_type": "display_data" 411 | } 412 | ], 413 | "source": [ 414 | "train()" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [] 423 | } 424 | ], 425 | "metadata": { 426 | "kernelspec": { 427 | "display_name": "Julia 1.4.1", 428 | "language": "julia", 429 | "name": "julia-1.4" 430 | }, 431 | "language_info": { 432 | "file_extension": ".jl", 433 | "mimetype": "application/julia", 434 | "name": "julia", 435 | "version": "1.4.1" 436 | } 437 | }, 438 | "nbformat": 4, 439 | "nbformat_minor": 2 440 | } 441 | -------------------------------------------------------------------------------- /Julia_quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 开始教程\n", 8 | "\n", 9 | "该教程粗略的摘录了 Julia 的基本语法,不熟悉 Julia 的同学可以先粗略地通读该教程,以大致熟悉基本操作。\n", 10 | "\n", 11 | "如果想要了解详细内容可以查阅[官方文档](https://docs.julialang.org/en/v1/),Julia 中文社区提供了中文版本的[文档](https://docs.juliacn.com/latest/)。视频学习可以在 [B 站](https://www.bilibili.com/video/BV1Cb411W7Sr?p=1)进行。\n", 12 | "\n", 13 | "\n", 14 | "Julia Quickstart\n", 15 | "\n", 16 | "This tutorial provides a rough excerpt from Julia's basic grammar. Students who are not familiar with Julia can read this tutorial roughly to familiarize themselves with basic operations.\n", 17 | "\n", 18 | "If you want to know more details, you can refer to [Official Document](https://docs.julialang.org/en/v1/)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## 一、注释\n", 26 | "\n", 27 | "comment" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "### 1. 单行注释\n", 35 | "\n", 36 | "Single-line comment" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 1, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# 单行注释只需要一个井号" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "### 2. 多行注释\n", 53 | "\n", 54 | "Multi-line comments" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "#= 多行注释\n", 64 | " 只需要以 '#=' 开始 '=#' 结束\n", 65 | " 还可以嵌套.\n", 66 | "=#" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## 二、原始类型与操作符\n", 74 | "\n", 75 | "Data types and operators" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "### 1. 数字" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "data": { 92 | "text/plain": [ 93 | "4" 94 | ] 95 | }, 96 | "execution_count": 3, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "# Julia 中一切皆是表达式。\n", 103 | "\n", 104 | "# 这是一些基本数字类型.\n", 105 | "3 # => 3 (Int64)\n", 106 | "3.2 # => 3.2 (Float64)\n", 107 | "2 + 1im # => 2 + 1im (Complex{Int64})\n", 108 | "2//3 # => 2//3 (Rational{Int64})\n", 109 | "\n", 110 | "# 支持所有的普通中缀操作符。\n", 111 | "1 + 1 # => 2\n", 112 | "8 - 1 # => 7\n", 113 | "10 * 2 # => 20\n", 114 | "35 / 5 # => 7.0\n", 115 | "5 / 2 # => 2.5 # 用 Int 除 Int 永远返回 Float\n", 116 | "div(5, 2) # => 2 # 使用 div 截断小数点\n", 117 | "5 \\ 35 # => 7.0\n", 118 | "2 ^ 2 # => 4 # 次方, 不是二进制 xor\n", 119 | "12 % 10 # => 2\n", 120 | "\n", 121 | "# 用括号提高优先级\n", 122 | "(1 + 3) * 2 # => 8\n", 123 | "\n", 124 | "# 二进制操作符\n", 125 | "~2 # => -3 # 非\n", 126 | "3 & 5 # => 1 # 与\n", 127 | "2 | 4 # => 6 # 或\n", 128 | "2 >>> 1 # => 1 # 逻辑右移\n", 129 | "2 >> 1 # => 1 # 算术右移\n", 130 | "2 << 1 # => 4 # 逻辑/算术 右移" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "### 2. 布尔值\n", 138 | "\n", 139 | "Bool" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 4, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "text/plain": [ 150 | "false" 151 | ] 152 | }, 153 | "execution_count": 4, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | } 157 | ], 158 | "source": [ 159 | "# 布尔值是原始类型\n", 160 | "true\n", 161 | "false\n", 162 | "\n", 163 | "# 布尔操作符\n", 164 | "!true # => false\n", 165 | "!false # => true\n", 166 | "1 == 1 # => true\n", 167 | "2 == 1 # => false\n", 168 | "1 != 1 # => false\n", 169 | "2 != 1 # => true\n", 170 | "1 < 10 # => true\n", 171 | "1 > 10 # => false\n", 172 | "2 <= 2 # => true\n", 173 | "2 >= 2 # => true\n", 174 | "\n", 175 | "# 比较可以串联\n", 176 | "1 < 2 < 3 # => true\n", 177 | "2 < 3 < 2 # => false" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "### 3. 字符串\n", 185 | "\n", 186 | "String" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 5, 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "name": "stdout", 196 | "output_type": "stream", 197 | "text": [ 198 | "5 is less than 5.300000I'm Julia. Nice to meet you!\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "using Printf\n", 204 | "\n", 205 | "# 字符串可以由 \" 创建\n", 206 | "\"This is a string.\"\n", 207 | "\n", 208 | "# 字符字面量可用 ' 创建\n", 209 | "'a'\n", 210 | "\n", 211 | "# 可以像取数组取值一样用 index 取出对应字符\n", 212 | "\"This is a string\"[1] # => 'T' # Julia 的 index 从 1 开始 :(\n", 213 | "# 但是对 UTF-8 无效,\n", 214 | "# 因此建议使用遍历器 (map, for loops, 等).\n", 215 | "\n", 216 | "# $ 可用于字符插值:\n", 217 | "\"2 + 2 = $(2 + 2)\" # => \"2 + 2 = 4\"\n", 218 | "# 可以将任何 Julia 表达式放入括号。\n", 219 | "\n", 220 | "# 另一种格式化字符串的方式是 printf 宏.\n", 221 | "@printf \"%d is less than %f\" 4.5 5.3 # 5 is less than 5.300000\n", 222 | "\n", 223 | "# 打印字符串很容易\n", 224 | "println(\"I'm Julia. Nice to meet you!\")" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "## 二、变量\n", 232 | "\n", 233 | "Variable\n", 234 | "\n", 235 | "注意 Julia 的命名规约:\n", 236 | "\n", 237 | "* 变量名为小写,单词之间以下划线连接('\\_')。\n", 238 | "* 类型名以大写字母开头,单词以 CamelCase 方式连接。\n", 239 | "* 函数与宏的名字小写,无下划线。\n", 240 | "* 会改变输入的函数名末位为 !,这类函数有时被称为 mutating functions 或 in-place functions.\n", 241 | "\n", 242 | "Julia's naming convention:\n", 243 | "\n", 244 | "* The variable name is lowercase, and the words are underlined ('\\ _').\n", 245 | "* Type names start with capital letters, and words are connected in CamelCase.\n", 246 | "* The names of functions and macros are lowercase without underscores.\n", 247 | "* Will change the end of the input function name to!, Such functions are sometimes called mutating functions or in-place functions." 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 6, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "name": "stdout", 257 | "output_type": "stream", 258 | "text": [ 259 | "UndefVarError(:some_other_var)\n" 260 | ] 261 | }, 262 | { 263 | "data": { 264 | "text/plain": [ 265 | "6.283185307179586" 266 | ] 267 | }, 268 | "execution_count": 6, 269 | "metadata": {}, 270 | "output_type": "execute_result" 271 | } 272 | ], 273 | "source": [ 274 | "# 给变量赋值就是声明变量\n", 275 | "some_var = 5 # => 5\n", 276 | "some_var # => 5\n", 277 | "\n", 278 | "# 访问未声明变量会抛出异常\n", 279 | "try\n", 280 | " some_other_var # => ERROR: some_other_var not defined\n", 281 | "catch e\n", 282 | " println(e)\n", 283 | "end\n", 284 | "\n", 285 | "# 变量名需要以字母开头.\n", 286 | "# 之后任何字母,数字,下划线,叹号都是合法的。\n", 287 | "SomeOtherVar123! = 6 # => 6\n", 288 | "\n", 289 | "# 用数学符号非常方便\n", 290 | "2 * π # => 6.283185307179586" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": {}, 296 | "source": [ 297 | "## 三、复杂数据类型\n", 298 | "\n", 299 | "Complex data types" 300 | ] 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "### 数组\n", 307 | "\n", 308 | "Array" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 7, 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "data": { 318 | "text/plain": [ 319 | "2×2 Array{Int64,2}:\n", 320 | " 1 2\n", 321 | " 3 4" 322 | ] 323 | }, 324 | "execution_count": 7, 325 | "metadata": {}, 326 | "output_type": "execute_result" 327 | } 328 | ], 329 | "source": [ 330 | "# 数组存储一列值,index 从 1 开始。\n", 331 | "a = Int64[] # => 0-element Int64 Array\n", 332 | "\n", 333 | "# 一维数组可以以逗号分隔值的方式声明。\n", 334 | "b = [4, 5, 6] # => 包含 3 个 Int64 类型元素的数组: [4, 5, 6]\n", 335 | "b[1] # => 4\n", 336 | "b[end] # => 6\n", 337 | "\n", 338 | "# 二维数组以分号分隔维度。\n", 339 | "matrix = [1 2; 3 4] # => 2x2 Int64 数组: [1 2; 3 4]" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 8, 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "data": { 349 | "text/plain": [ 350 | "3-element Array{Int64,1}:\n", 351 | " 4\n", 352 | " 5\n", 353 | " 6" 354 | ] 355 | }, 356 | "execution_count": 8, 357 | "metadata": {}, 358 | "output_type": "execute_result" 359 | } 360 | ], 361 | "source": [ 362 | "# 使用 push! 和 append! 往数组末尾添加元素\n", 363 | "push!(a,1) # => [1]\n", 364 | "push!(a,2) # => [1,2]\n", 365 | "push!(a,4) # => [1,2,4]\n", 366 | "push!(a,3) # => [1,2,4,3]\n", 367 | "append!(a,b) # => [1,2,4,3,4,5,6]\n", 368 | "\n", 369 | "# 用 pop 弹出末尾元素\n", 370 | "pop!(b) # => 6 and b is now [4,5]\n", 371 | "\n", 372 | "# 可以再放回去\n", 373 | "push!(b,6) # b 又变成了 [4,5,6]." 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 9, 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "data": { 383 | "text/plain": [ 384 | "6" 385 | ] 386 | }, 387 | "execution_count": 9, 388 | "metadata": {}, 389 | "output_type": "execute_result" 390 | } 391 | ], 392 | "source": [ 393 | "a[1] # => 1 # 永远记住 Julia 的 index 从 1 开始!\n", 394 | "\n", 395 | "# 用 end 可以直接取到最后索引. 可用作任何索引表达式\n", 396 | "a[end] # => 6" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 10, 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "data": { 406 | "text/plain": [ 407 | "3-element Array{Int64,1}:\n", 408 | " 4\n", 409 | " 5\n", 410 | " 6" 411 | ] 412 | }, 413 | "execution_count": 10, 414 | "metadata": {}, 415 | "output_type": "execute_result" 416 | } 417 | ], 418 | "source": [ 419 | "# 以叹号结尾的函数名表示它会改变参数的值\n", 420 | "arr = [5,4,6] # => 包含三个 Int64 元素的数组: [5,4,6]\n", 421 | "sort(arr) # => [4,5,6]; arr 还是 [5,4,6]\n", 422 | "sort!(arr) # => [4,5,6]; arr 现在是 [4,5,6]" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 11, 428 | "metadata": {}, 429 | "outputs": [ 430 | { 431 | "name": "stdout", 432 | "output_type": "stream", 433 | "text": [ 434 | "BoundsError([1, 2, 4, 3, 4, 5, 6], (0,))\n" 435 | ] 436 | } 437 | ], 438 | "source": [ 439 | "# 越界会抛出 BoundsError 异常\n", 440 | "try\n", 441 | " a[0] # => ERROR: BoundsError() in getindex at array.jl:270\n", 442 | " a[end+1] # => ERROR: BoundsError() in getindex at array.jl:270\n", 443 | "catch e\n", 444 | " println(e)\n", 445 | "end" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 12, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "data": { 455 | "text/plain": [ 456 | "8-element Array{Int64,1}:\n", 457 | " 1\n", 458 | " 2\n", 459 | " 3\n", 460 | " 4\n", 461 | " 5\n", 462 | " 1\n", 463 | " 2\n", 464 | " 3" 465 | ] 466 | }, 467 | "execution_count": 12, 468 | "metadata": {}, 469 | "output_type": "execute_result" 470 | } 471 | ], 472 | "source": [ 473 | "# 可以用 range 初始化数组\n", 474 | "a = collect(1:5) # => 5-element Int64 Array: [1,2,3,4,5]\n", 475 | "\n", 476 | "# 可以切割数组\n", 477 | "a[1:3] # => [1, 2, 3]\n", 478 | "a[2:end] # => [2, 3, 4, 5]\n", 479 | "\n", 480 | "# 用 splice! 切割原数组\n", 481 | "arr = [3,4,5]\n", 482 | "splice!(arr,2) # => 4 ; arr 变成了 [3,5]\n", 483 | "\n", 484 | "# 用 append! 连接数组\n", 485 | "b = [1,2,3]\n", 486 | "append!(a,b) # a 变成了 [1, 2, 3, 4, 5, 1, 2, 3]" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 13, 492 | "metadata": {}, 493 | "outputs": [ 494 | { 495 | "data": { 496 | "text/plain": [ 497 | "8" 498 | ] 499 | }, 500 | "execution_count": 13, 501 | "metadata": {}, 502 | "output_type": "execute_result" 503 | } 504 | ], 505 | "source": [ 506 | "# 检查元素是否在数组中\n", 507 | "in(1, a) # => true\n", 508 | "\n", 509 | "# 用 length 获得数组长度\n", 510 | "length(a) # => 8" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "### 元组\n", 518 | "\n", 519 | "Tuple" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 14, 525 | "metadata": {}, 526 | "outputs": [ 527 | { 528 | "name": "stdout", 529 | "output_type": "stream", 530 | "text": [ 531 | "MethodError(setindex!, ((1, 2, 3), 3, 1), 0x0000000000006a14)\n" 532 | ] 533 | }, 534 | { 535 | "data": { 536 | "text/plain": [ 537 | "(4, 5)" 538 | ] 539 | }, 540 | "execution_count": 14, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "# Tuples 是 immutable 的\n", 547 | "tup = (1,2,3) # => (1,2,3) # an (Int64,Int64,Int64) tuple.\n", 548 | "tup[1] # => 1\n", 549 | "try\n", 550 | " tup[1] = 3 # => ERROR: no method setindex!((Int64,Int64,Int64),Int64,Int64)\n", 551 | "catch e\n", 552 | " println(e)\n", 553 | "end\n", 554 | "\n", 555 | "# 大多数组的函数同样支持 tuples\n", 556 | "length(tup) # => 3\n", 557 | "tup[1:2] # => (1,2)\n", 558 | "in(2, tup) # => true\n", 559 | "\n", 560 | "# 可以将 tuples 元素分别赋给变量\n", 561 | "a, b, c = (1, 2, 3) # => (1,2,3) # a is now 1, b is now 2 and c is now 3\n", 562 | "\n", 563 | "# 不用括号也可以\n", 564 | "d, e, f = 4, 5, 6 # => (4,5,6)\n", 565 | "\n", 566 | "# 单元素 tuple 不等于其元素值\n", 567 | "(1,) == 1 # => false\n", 568 | "(1) == 1 # => true\n", 569 | "\n", 570 | "# 交换值\n", 571 | "e, d = d, e # => (5,4) # d is now 5 and e is now 4" 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": {}, 577 | "source": [ 578 | "### 字典\n", 579 | "\n", 580 | "Dict" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 15, 586 | "metadata": {}, 587 | "outputs": [ 588 | { 589 | "data": { 590 | "text/plain": [ 591 | "Base.ValueIterator for a Dict{String,Int64} with 3 entries. Values:\n", 592 | " 2\n", 593 | " 1\n", 594 | " 3" 595 | ] 596 | }, 597 | "execution_count": 15, 598 | "metadata": {}, 599 | "output_type": "execute_result" 600 | } 601 | ], 602 | "source": [ 603 | "# 字典Dictionaries store mappings\n", 604 | "empty_dict = Dict() # => Dict{Any,Any}()\n", 605 | "\n", 606 | "# 也可以用字面量创建字典\n", 607 | "filled_dict = Dict(\"one\"=> 1, \"two\"=> 2, \"three\"=> 3)\n", 608 | "# => Dict{ASCIIString,Int64}\n", 609 | "\n", 610 | "# 用 [] 获得键值\n", 611 | "filled_dict[\"one\"] # => 1\n", 612 | "\n", 613 | "# 获得所有键\n", 614 | "keys(filled_dict)\n", 615 | "# => KeyIterator{Dict{ASCIIString,Int64}}([\"three\"=>3,\"one\"=>1,\"two\"=>2])\n", 616 | "# 注意,键的顺序不是插入时的顺序\n", 617 | "\n", 618 | "# 获得所有值\n", 619 | "values(filled_dict)\n", 620 | "# => ValueIterator{Dict{ASCIIString,Int64}}([\"three\"=>3,\"one\"=>1,\"two\"=>2])\n", 621 | "# 注意,值的顺序也一样" 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": 16, 627 | "metadata": {}, 628 | "outputs": [ 629 | { 630 | "data": { 631 | "text/plain": [ 632 | "false" 633 | ] 634 | }, 635 | "execution_count": 16, 636 | "metadata": {}, 637 | "output_type": "execute_result" 638 | } 639 | ], 640 | "source": [ 641 | "# 用 haskey 检查键是否存在\n", 642 | "haskey(filled_dict, \"one\") # => true\n", 643 | "haskey(filled_dict, 1) # => false" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 17, 649 | "metadata": {}, 650 | "outputs": [ 651 | { 652 | "name": "stdout", 653 | "output_type": "stream", 654 | "text": [ 655 | "KeyError(\"four\")\n" 656 | ] 657 | }, 658 | { 659 | "data": { 660 | "text/plain": [ 661 | "4" 662 | ] 663 | }, 664 | "execution_count": 17, 665 | "metadata": {}, 666 | "output_type": "execute_result" 667 | } 668 | ], 669 | "source": [ 670 | "# 获取不存在的键的值会抛出异常\n", 671 | "try\n", 672 | " filled_dict[\"four\"] # => ERROR: key not found: four in getindex at dict.jl:489\n", 673 | "catch e\n", 674 | " println(e)\n", 675 | "end\n", 676 | "\n", 677 | "# 使用 get 可以提供默认值来避免异常\n", 678 | "# get(dictionary,key,default_value)\n", 679 | "get(filled_dict,\"one\",4) # => 1\n", 680 | "get(filled_dict,\"four\",4) # => 4" 681 | ] 682 | }, 683 | { 684 | "cell_type": "markdown", 685 | "metadata": {}, 686 | "source": [ 687 | "### 集合\n", 688 | "\n", 689 | "Set" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 18, 695 | "metadata": {}, 696 | "outputs": [ 697 | { 698 | "data": { 699 | "text/plain": [ 700 | "Set{Int64} with 2 elements:\n", 701 | " 4\n", 702 | " 1" 703 | ] 704 | }, 705 | "execution_count": 18, 706 | "metadata": {}, 707 | "output_type": "execute_result" 708 | } 709 | ], 710 | "source": [ 711 | "# 用 Sets 表示无序不可重复的值的集合\n", 712 | "empty_set = Set() # => Set{Any}()\n", 713 | "# 初始化一个 Set 并定义其值\n", 714 | "filled_set = Set([1,2,2,3,4]) # => Set{Int64}(1,2,3,4)\n", 715 | "\n", 716 | "# 添加值\n", 717 | "push!(filled_set,5) # => Set{Int64}(5,4,2,3,1)\n", 718 | "\n", 719 | "# 检查是否存在某值\n", 720 | "in(2, filled_set) # => true\n", 721 | "in(10, filled_set) # => false\n", 722 | "\n", 723 | "# 交集,并集,差集\n", 724 | "other_set = Set([3, 4, 5, 6]) # => Set{Int64}(6,4,5,3)\n", 725 | "intersect(filled_set, other_set) # => Set{Int64}(3,4,5)\n", 726 | "union(filled_set, other_set) # => Set{Int64}(1,2,3,4,5,6)\n", 727 | "setdiff(Set([1,2,3,4]),Set([2,3,5])) # => Set{Int64}(1,4)" 728 | ] 729 | }, 730 | { 731 | "cell_type": "markdown", 732 | "metadata": {}, 733 | "source": [ 734 | "## 四、控制流\n", 735 | "\n", 736 | "Control flow" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": 19, 742 | "metadata": {}, 743 | "outputs": [ 744 | { 745 | "name": "stdout", 746 | "output_type": "stream", 747 | "text": [ 748 | "some_var is smaller than 10.\n" 749 | ] 750 | } 751 | ], 752 | "source": [ 753 | "# 声明一个变量\n", 754 | "some_var = 5\n", 755 | "\n", 756 | "# 这是一个 if 语句,缩进不是必要的\n", 757 | "if some_var > 10\n", 758 | " println(\"some_var is totally bigger than 10.\")\n", 759 | "elseif some_var < 10 # elseif 是可选的.\n", 760 | " println(\"some_var is smaller than 10.\")\n", 761 | "else # else 也是可选的.\n", 762 | " println(\"some_var is indeed 10.\")\n", 763 | "end\n", 764 | "# => prints \"some var is smaller than 10\"" 765 | ] 766 | }, 767 | { 768 | "cell_type": "code", 769 | "execution_count": 20, 770 | "metadata": {}, 771 | "outputs": [ 772 | { 773 | "name": "stdout", 774 | "output_type": "stream", 775 | "text": [ 776 | "dog is a mammal\n", 777 | "cat is a mammal\n", 778 | "mouse is a mammal\n", 779 | "dog is a mammal\n", 780 | "cat is a mammal\n", 781 | "mouse is a mammal\n", 782 | "dog is a mammal\n", 783 | "cat is a mammal\n", 784 | "mouse is a mammal\n", 785 | "dog is a mammal\n", 786 | "cat is a mammal\n", 787 | "mouse is a mammal\n" 788 | ] 789 | } 790 | ], 791 | "source": [ 792 | "# For 循环遍历\n", 793 | "# Iterable 类型包括 Range, Array, Set, Dict, 以及 String.\n", 794 | "for animal=[\"dog\", \"cat\", \"mouse\"]\n", 795 | " println(\"$animal is a mammal\")\n", 796 | " # 可用 $ 将 variables 或 expression 转换为字符串into strings\n", 797 | "end\n", 798 | "# You can use 'in' instead of '='.\n", 799 | "for animal in [\"dog\", \"cat\", \"mouse\"]\n", 800 | " println(\"$animal is a mammal\")\n", 801 | "end\n", 802 | "\n", 803 | "for a in [\"dog\"=>\"mammal\",\"cat\"=>\"mammal\",\"mouse\"=>\"mammal\"]\n", 804 | " println(\"$(a[1]) is a $(a[2])\")\n", 805 | "end\n", 806 | "\n", 807 | "for (k,v) in [\"dog\"=>\"mammal\",\"cat\"=>\"mammal\",\"mouse\"=>\"mammal\"]\n", 808 | " println(\"$k is a $v\")\n", 809 | "end" 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": 21, 815 | "metadata": {}, 816 | "outputs": [ 817 | { 818 | "name": "stdout", 819 | "output_type": "stream", 820 | "text": [ 821 | "0\n", 822 | "1\n", 823 | "2\n", 824 | "3\n" 825 | ] 826 | } 827 | ], 828 | "source": [ 829 | "# While 循环\n", 830 | "x = 0\n", 831 | "while x < 4\n", 832 | " println(x)\n", 833 | " x += 1 # x = x + 1\n", 834 | "end" 835 | ] 836 | }, 837 | { 838 | "cell_type": "code", 839 | "execution_count": 22, 840 | "metadata": {}, 841 | "outputs": [ 842 | { 843 | "name": "stdout", 844 | "output_type": "stream", 845 | "text": [ 846 | "caught it ErrorException(\"help\")\n" 847 | ] 848 | } 849 | ], 850 | "source": [ 851 | "# 用 try/catch 处理异常\n", 852 | "try\n", 853 | " error(\"help\")\n", 854 | "catch e\n", 855 | " println(\"caught it $e\")\n", 856 | "end" 857 | ] 858 | }, 859 | { 860 | "cell_type": "markdown", 861 | "metadata": {}, 862 | "source": [ 863 | "## 五、函数\n", 864 | "\n", 865 | "Function" 866 | ] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "execution_count": 23, 871 | "metadata": {}, 872 | "outputs": [ 873 | { 874 | "name": "stdout", 875 | "output_type": "stream", 876 | "text": [ 877 | "x is 5 and y is 6\n" 878 | ] 879 | }, 880 | { 881 | "data": { 882 | "text/plain": [ 883 | "(1, 2, 3)" 884 | ] 885 | }, 886 | "execution_count": 23, 887 | "metadata": {}, 888 | "output_type": "execute_result" 889 | } 890 | ], 891 | "source": [ 892 | "# 用关键字 'function' 可创建一个新函数\n", 893 | "#function name(arglist)\n", 894 | "# body...\n", 895 | "#end\n", 896 | "function add(x, y)\n", 897 | " println(\"x is $x and y is $y\")\n", 898 | "\n", 899 | " # 最后一行语句的值为返回\n", 900 | " x + y\n", 901 | "end\n", 902 | "\n", 903 | "add(5, 6) # => 在 \"x is 5 and y is 6\" 后会打印 11\n", 904 | "\n", 905 | "# 还可以定义接收可变长参数的函数\n", 906 | "function varargs(args...)\n", 907 | " return args\n", 908 | " # 关键字 return 可在函数内部任何地方返回\n", 909 | "end\n", 910 | "# => varargs (generic function with 1 method)\n", 911 | "\n", 912 | "varargs(1,2,3) # => (1,2,3)" 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "execution_count": 24, 918 | "metadata": {}, 919 | "outputs": [ 920 | { 921 | "name": "stdout", 922 | "output_type": "stream", 923 | "text": [ 924 | "x is 1 and y is 2\n" 925 | ] 926 | }, 927 | { 928 | "data": { 929 | "text/plain": [ 930 | "3" 931 | ] 932 | }, 933 | "execution_count": 24, 934 | "metadata": {}, 935 | "output_type": "execute_result" 936 | } 937 | ], 938 | "source": [ 939 | "# 省略号 ... 被称为 splat.\n", 940 | "# 刚刚用在了函数定义中\n", 941 | "# 还可以用在函数的调用\n", 942 | "# Array 或者 Tuple 的内容会变成参数列表\n", 943 | "add([1,2]...)" 944 | ] 945 | }, 946 | { 947 | "cell_type": "code", 948 | "execution_count": 25, 949 | "metadata": {}, 950 | "outputs": [ 951 | { 952 | "name": "stdout", 953 | "output_type": "stream", 954 | "text": [ 955 | "MethodError(defaults, ('h',), 0x0000000000006a19)\n" 956 | ] 957 | } 958 | ], 959 | "source": [ 960 | "# 可定义可选参数的函数\n", 961 | "function defaults(a,b,x=5,y=6)\n", 962 | " return \"$a $b and $x $y\"\n", 963 | "end\n", 964 | "\n", 965 | "defaults('h','g') # => \"h g and 5 6\"\n", 966 | "defaults('h','g','j') # => \"h g and j 6\"\n", 967 | "defaults('h','g','j','k') # => \"h g and j k\"\n", 968 | "try\n", 969 | " defaults('h') # => ERROR: no method defaults(Char,)\n", 970 | " defaults() # => ERROR: no methods defaults()\n", 971 | "catch e\n", 972 | " println(e)\n", 973 | "end" 974 | ] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "execution_count": 26, 979 | "metadata": {}, 980 | "outputs": [ 981 | { 982 | "data": { 983 | "text/plain": [ 984 | "2-element Array{Pair{String,Any},1}:\n", 985 | " \"k1\" => 4\n", 986 | " \"name2\" => \"hello\"" 987 | ] 988 | }, 989 | "execution_count": 26, 990 | "metadata": {}, 991 | "output_type": "execute_result" 992 | } 993 | ], 994 | "source": [ 995 | "# 还可以定义键值对的参数\n", 996 | "function keyword_args(;k1=4,name2=\"hello\") # note the ;\n", 997 | " return [\"k1\"=>k1,\"name2\"=>name2]\n", 998 | "end\n", 999 | "\n", 1000 | "keyword_args(name2=\"ness\") # => [\"name2\"=>\"ness\",\"k1\"=>4]\n", 1001 | "keyword_args(k1=\"mine\") # => [\"k1\"=>\"mine\",\"name2\"=>\"hello\"]\n", 1002 | "keyword_args() # => [\"name2\"=>\"hello\",\"k1\"=>4]" 1003 | ] 1004 | }, 1005 | { 1006 | "cell_type": "code", 1007 | "execution_count": 27, 1008 | "metadata": {}, 1009 | "outputs": [ 1010 | { 1011 | "name": "stdout", 1012 | "output_type": "stream", 1013 | "text": [ 1014 | "normal arg: 1\n", 1015 | "optional arg: 3\n", 1016 | "keyword arg: 4\n" 1017 | ] 1018 | } 1019 | ], 1020 | "source": [ 1021 | "# 可以组合各种类型的参数在同一个函数的参数列表中\n", 1022 | "function all_the_args(normal_arg, optional_positional_arg=2; keyword_arg=\"foo\")\n", 1023 | " println(\"normal arg: $normal_arg\")\n", 1024 | " println(\"optional arg: $optional_positional_arg\")\n", 1025 | " println(\"keyword arg: $keyword_arg\")\n", 1026 | "end\n", 1027 | "\n", 1028 | "all_the_args(1, 3, keyword_arg=4)" 1029 | ] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "execution_count": 28, 1034 | "metadata": {}, 1035 | "outputs": [ 1036 | { 1037 | "data": { 1038 | "text/plain": [ 1039 | "create_adder (generic function with 1 method)" 1040 | ] 1041 | }, 1042 | "execution_count": 28, 1043 | "metadata": {}, 1044 | "output_type": "execute_result" 1045 | } 1046 | ], 1047 | "source": [ 1048 | "# Julia 有一等函数\n", 1049 | "function create_adder(x)\n", 1050 | " adder = function (y)\n", 1051 | " return x + y\n", 1052 | " end\n", 1053 | " return adder\n", 1054 | "end" 1055 | ] 1056 | }, 1057 | { 1058 | "cell_type": "code", 1059 | "execution_count": 29, 1060 | "metadata": {}, 1061 | "outputs": [ 1062 | { 1063 | "data": { 1064 | "text/plain": [ 1065 | "create_adder (generic function with 1 method)" 1066 | ] 1067 | }, 1068 | "execution_count": 29, 1069 | "metadata": {}, 1070 | "output_type": "execute_result" 1071 | } 1072 | ], 1073 | "source": [ 1074 | "# 这是用 \"stabby lambda syntax\" 创建的匿名函数\n", 1075 | "(x -> x > 2)(3) # => true\n", 1076 | "\n", 1077 | "# 这个函数和上面的 create_adder 一模一样\n", 1078 | "function create_adder(x)\n", 1079 | " y -> x + y\n", 1080 | "end" 1081 | ] 1082 | }, 1083 | { 1084 | "cell_type": "code", 1085 | "execution_count": 30, 1086 | "metadata": {}, 1087 | "outputs": [ 1088 | { 1089 | "data": { 1090 | "text/plain": [ 1091 | "13" 1092 | ] 1093 | }, 1094 | "execution_count": 30, 1095 | "metadata": {}, 1096 | "output_type": "execute_result" 1097 | } 1098 | ], 1099 | "source": [ 1100 | "# 你也可以给内部函数起个名字\n", 1101 | "function create_adder(x)\n", 1102 | " function adder(y)\n", 1103 | " x + y\n", 1104 | " end\n", 1105 | " adder\n", 1106 | "end\n", 1107 | "\n", 1108 | "add_10 = create_adder(10)\n", 1109 | "add_10(3) # => 13" 1110 | ] 1111 | }, 1112 | { 1113 | "cell_type": "code", 1114 | "execution_count": 31, 1115 | "metadata": {}, 1116 | "outputs": [ 1117 | { 1118 | "data": { 1119 | "text/plain": [ 1120 | "3-element Array{Int64,1}:\n", 1121 | " 11\n", 1122 | " 12\n", 1123 | " 13" 1124 | ] 1125 | }, 1126 | "execution_count": 31, 1127 | "metadata": {}, 1128 | "output_type": "execute_result" 1129 | } 1130 | ], 1131 | "source": [ 1132 | "# 内置的高阶函数有\n", 1133 | "map(add_10, [1,2,3]) # => [11, 12, 13]\n", 1134 | "filter(x -> x > 5, [3, 4, 5, 6, 7]) # => [6, 7]\n", 1135 | "\n", 1136 | "# 还可以使用 list comprehensions 替代 map\n", 1137 | "[add_10(i) for i=[1, 2, 3]] # => [11, 12, 13]\n", 1138 | "[add_10(i) for i in [1, 2, 3]] # => [11, 12, 13]" 1139 | ] 1140 | }, 1141 | { 1142 | "cell_type": "markdown", 1143 | "metadata": {}, 1144 | "source": [ 1145 | "## 六、类型\n", 1146 | "\n", 1147 | "Type" 1148 | ] 1149 | }, 1150 | { 1151 | "cell_type": "code", 1152 | "execution_count": 32, 1153 | "metadata": {}, 1154 | "outputs": [ 1155 | { 1156 | "data": { 1157 | "text/plain": [ 1158 | "DataType" 1159 | ] 1160 | }, 1161 | "execution_count": 32, 1162 | "metadata": {}, 1163 | "output_type": "execute_result" 1164 | } 1165 | ], 1166 | "source": [ 1167 | "# Julia 有类型系统\n", 1168 | "# 所有的值都有类型;但变量本身没有类型\n", 1169 | "# 你可以用 `typeof` 函数获得值的类型\n", 1170 | "typeof(5) # => Int64\n", 1171 | "\n", 1172 | "# 类型是一等值\n", 1173 | "typeof(Int64) # => DataType\n", 1174 | "typeof(DataType) # => DataType\n", 1175 | "# DataType 是代表类型的类型,也代表他自己的类型\n", 1176 | "\n", 1177 | "# 类型可用作文档化,优化,以及调度\n", 1178 | "# 并不是静态检查类型" 1179 | ] 1180 | }, 1181 | { 1182 | "cell_type": "code", 1183 | "execution_count": 33, 1184 | "metadata": {}, 1185 | "outputs": [ 1186 | { 1187 | "data": { 1188 | "text/plain": [ 1189 | "Tiger(5.6, \"fire\")" 1190 | ] 1191 | }, 1192 | "execution_count": 33, 1193 | "metadata": {}, 1194 | "output_type": "execute_result" 1195 | } 1196 | ], 1197 | "source": [ 1198 | "# 用户还可以自定义类型\n", 1199 | "# 跟其他语言的 records 或 structs 一样\n", 1200 | "# 用 `type` 关键字定义新的类型\n", 1201 | "\n", 1202 | "# type Name\n", 1203 | "# field::OptionalType\n", 1204 | "# ...\n", 1205 | "# end\n", 1206 | "struct Tiger\n", 1207 | " taillength::Float64\n", 1208 | " coatcolor # 不附带类型标注的相当于 `::Any`\n", 1209 | "end\n", 1210 | "\n", 1211 | "# 构造函数参数是类型的属性\n", 1212 | "tigger = Tiger(3.5,\"orange\") # => Tiger(3.5,\"orange\")\n", 1213 | "\n", 1214 | "# 用新类型作为构造函数还会创建一个类型\n", 1215 | "sherekhan = typeof(tigger)(5.6,\"fire\") # => Tiger(5.6,\"fire\")" 1216 | ] 1217 | }, 1218 | { 1219 | "cell_type": "code", 1220 | "execution_count": 34, 1221 | "metadata": {}, 1222 | "outputs": [], 1223 | "source": [ 1224 | "# struct 类似的类型被称为具体类型\n", 1225 | "# 他们可被实例化但不能有子类型\n", 1226 | "# 另一种类型是抽象类型\n", 1227 | "\n", 1228 | "# abstract Name\n", 1229 | "abstract type Cat end # just a name and point in the type hierarchy\n", 1230 | "\n", 1231 | "# 抽象类型不能被实例化,但是可以有子类型\n", 1232 | "# 例如,Number 就是抽象类型\n", 1233 | "subtypes(Number) # => 6-element Array{Any,1}:\n", 1234 | " # Complex{Float16}\n", 1235 | " # Complex{Float32}\n", 1236 | " # Complex{Float64}\n", 1237 | " # Complex{T<:Real}\n", 1238 | " # ImaginaryUnit\n", 1239 | " # Real\n", 1240 | "subtypes(Cat) # => 0-element Array{Any,1}\n", 1241 | "\n", 1242 | "# 所有的类型都有父类型; 可以用函数 `super` 得到父类型.\n", 1243 | "typeof(5) # => Int64\n", 1244 | "supertype(Int64) # => Signed\n", 1245 | "supertype(Signed) # => Real\n", 1246 | "supertype(Real) # => Number\n", 1247 | "supertype(Number) # => Any\n", 1248 | "supertype(supertype(Signed)) # => Number\n", 1249 | "supertype(Any) # => Any\n", 1250 | "# 所有这些类型,除了 Int64, 都是抽象类型.\n", 1251 | "\n", 1252 | "# <: 是类型集成操作符\n", 1253 | "struct Lion <: Cat # Lion 是 Cat 的子类型\n", 1254 | " mane_color\n", 1255 | " roar::String\n", 1256 | "end\n", 1257 | "\n", 1258 | "# 可以继续为你的类型定义构造函数\n", 1259 | "# 只需要定义一个同名的函数\n", 1260 | "# 并调用已有的构造函数设置一个固定参数\n", 1261 | "Lion(roar::String) = Lion(\"green\",roar)\n", 1262 | "# 这是一个外部构造函数,因为他再类型定义之外\n", 1263 | "\n", 1264 | "struct Panther <: Cat # Panther 也是 Cat 的子类型\n", 1265 | " eye_color\n", 1266 | " Panther() = new(\"green\")\n", 1267 | " # Panthers 只有这个构造函数,没有默认构造函数\n", 1268 | "end\n", 1269 | "# 使用内置构造函数,如 Panther,可以让你控制\n", 1270 | "# 如何构造类型的值\n", 1271 | "# 应该尽可能使用外部构造函数而不是内部构造函数" 1272 | ] 1273 | }, 1274 | { 1275 | "cell_type": "markdown", 1276 | "metadata": {}, 1277 | "source": [ 1278 | "## 七、多分派\n", 1279 | "\n", 1280 | "Multiple dispatch" 1281 | ] 1282 | }, 1283 | { 1284 | "cell_type": "markdown", 1285 | "metadata": {}, 1286 | "source": [ 1287 | "在Julia中, 所有的具名函数都是类属函数。这意味着他们都是有小方法组成的。" 1288 | ] 1289 | }, 1290 | { 1291 | "cell_type": "code", 1292 | "execution_count": 35, 1293 | "metadata": {}, 1294 | "outputs": [ 1295 | { 1296 | "data": { 1297 | "text/plain": [ 1298 | "meow (generic function with 3 methods)" 1299 | ] 1300 | }, 1301 | "execution_count": 35, 1302 | "metadata": {}, 1303 | "output_type": "execute_result" 1304 | } 1305 | ], 1306 | "source": [ 1307 | "# 每个 Lion 的构造函数都是类属函数 Lion 的方法\n", 1308 | "# 我们来看一个非构造函数的例子\n", 1309 | "\n", 1310 | "# Lion, Panther, Tiger 的 meow 定义为\n", 1311 | "function meow(animal::Lion)\n", 1312 | " animal.roar # 使用点符号访问属性\n", 1313 | "end\n", 1314 | "\n", 1315 | "function meow(animal::Panther)\n", 1316 | " \"grrr\"\n", 1317 | "end\n", 1318 | "\n", 1319 | "function meow(animal::Tiger)\n", 1320 | " \"rawwwr\"\n", 1321 | "end" 1322 | ] 1323 | }, 1324 | { 1325 | "cell_type": "code", 1326 | "execution_count": 36, 1327 | "metadata": {}, 1328 | "outputs": [ 1329 | { 1330 | "data": { 1331 | "text/plain": [ 1332 | "\"grrr\"" 1333 | ] 1334 | }, 1335 | "execution_count": 36, 1336 | "metadata": {}, 1337 | "output_type": "execute_result" 1338 | } 1339 | ], 1340 | "source": [ 1341 | "# 试试 meow 函数\n", 1342 | "meow(tigger) # => \"rawwr\"\n", 1343 | "meow(Lion(\"brown\",\"ROAAR\")) # => \"ROAAR\"\n", 1344 | "meow(Panther()) # => \"grrr\"" 1345 | ] 1346 | }, 1347 | { 1348 | "cell_type": "code", 1349 | "execution_count": 37, 1350 | "metadata": {}, 1351 | "outputs": [ 1352 | { 1353 | "name": "stdout", 1354 | "output_type": "stream", 1355 | "text": [ 1356 | "The cat says 42\n", 1357 | "MethodError(pet_cat, (Tiger(3.5, \"orange\"),), 0x0000000000006a33)\n" 1358 | ] 1359 | } 1360 | ], 1361 | "source": [ 1362 | "# 定义一个接收 Cats 的函数\n", 1363 | "function pet_cat(cat::Cat)\n", 1364 | " println(\"The cat says $(meow(cat))\")\n", 1365 | "end\n", 1366 | "\n", 1367 | "pet_cat(Lion(\"42\")) # => prints \"The cat says 42\"\n", 1368 | "try\n", 1369 | " pet_cat(tigger) # => ERROR: no method pet_cat(Tiger,)\n", 1370 | "catch e\n", 1371 | " println(e)\n", 1372 | "end" 1373 | ] 1374 | }, 1375 | { 1376 | "cell_type": "code", 1377 | "execution_count": 38, 1378 | "metadata": {}, 1379 | "outputs": [ 1380 | { 1381 | "name": "stdout", 1382 | "output_type": "stream", 1383 | "text": [ 1384 | "The orange tiger wins!\n", 1385 | "The orange tiger wins!\n", 1386 | "The orange tiger wins!\n", 1387 | "The green-maned lion wins!\n", 1388 | "The victorious cat says grrr\n" 1389 | ] 1390 | } 1391 | ], 1392 | "source": [ 1393 | "# 在面向对象语言中,通常都是单分派\n", 1394 | "# 这意味着分派方法是通过第一个参数的类型决定的\n", 1395 | "# 在Julia中, 所有参数类型都会被考虑到\n", 1396 | "\n", 1397 | "# 让我们定义有多个参数的函数,好看看区别\n", 1398 | "function fight(t::Tiger,c::Cat)\n", 1399 | " println(\"The $(t.coatcolor) tiger wins!\")\n", 1400 | "end\n", 1401 | "# => fight (generic function with 1 method)\n", 1402 | "\n", 1403 | "fight(tigger,Panther()) # => prints The orange tiger wins!\n", 1404 | "fight(tigger,Lion(\"ROAR\")) # => prints The orange tiger wins!\n", 1405 | "\n", 1406 | "# 让我们修改一下传入具体为 Lion 类型时的行为\n", 1407 | "fight(t::Tiger,l::Lion) = println(\"The $(l.mane_color)-maned lion wins!\")\n", 1408 | "# => fight (generic function with 2 methods)\n", 1409 | "\n", 1410 | "fight(tigger,Panther()) # => prints The orange tiger wins!\n", 1411 | "fight(tigger,Lion(\"ROAR\")) # => prints The green-maned lion wins!\n", 1412 | "\n", 1413 | "# 把 Tiger 去掉\n", 1414 | "fight(l::Lion,c::Cat) = println(\"The victorious cat says $(meow(c))\")\n", 1415 | "# => fight (generic function with 3 methods)\n", 1416 | "\n", 1417 | "fight(Lion(\"balooga!\"),Panther()) # => prints The victorious cat says grrr\n", 1418 | "try\n", 1419 | " fight(Panther(),Lion(\"RAWR\")) # => ERROR: no method fight(Panther,Lion)\n", 1420 | "catch\n", 1421 | "end" 1422 | ] 1423 | }, 1424 | { 1425 | "cell_type": "code", 1426 | "execution_count": null, 1427 | "metadata": {}, 1428 | "outputs": [], 1429 | "source": [] 1430 | } 1431 | ], 1432 | "metadata": { 1433 | "kernelspec": { 1434 | "display_name": "Julia 1.4.1", 1435 | "language": "julia", 1436 | "name": "julia-1.4" 1437 | }, 1438 | "language_info": { 1439 | "file_extension": ".jl", 1440 | "mimetype": "application/julia", 1441 | "name": "julia", 1442 | "version": "1.4.1" 1443 | } 1444 | }, 1445 | "nbformat": 4, 1446 | "nbformat_minor": 2 1447 | } -------------------------------------------------------------------------------- /ResNet_and_ImageNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### ResNet and ImageNet\n", 8 | "\n", 9 | "在该实现中您可以看到如下功能:\n", 10 | "1. 读取 ImageFolder 并进行预处理,切分 batch\n", 11 | "2. 模型的读入和保存\n", 12 | "3. 对模型的训练和测试的封装\n", 13 | "\n", 14 | "In this template you can finish the following functions:\n", 15 | "1. Read ImageFolder and pre-process it, divide it into batches\n", 16 | "2. Reading and saving the model\n", 17 | "3. Encapsulation of model training and testing" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "using Flux, Metalhead, Statistics\n", 27 | "using Flux: onehotbatch, onecold, logitcrossentropy, throttle, flatten\n", 28 | "using Metalhead: trainimgs\n", 29 | "using Parameters: @with_kw\n", 30 | "using Images: channelview\n", 31 | "using Statistics: mean\n", 32 | "using Base.Iterators: partition\n", 33 | "using CUDAapi" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stderr", 43 | "output_type": "stream", 44 | "text": [ 45 | "┌ Info: Training on GPU-0\n", 46 | "└ @ Main In[2]:7\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "using CUDAapi, CUDAdrv, CUDAnative\n", 52 | "gpu_id = 0 ## set < 0 for no cuda, >= 0 for using a specific device (if available)\n", 53 | "\n", 54 | "if has_cuda_gpu() && gpu_id >=0\n", 55 | " device!(gpu_id)\n", 56 | " device = Flux.gpu\n", 57 | " @info \"Training on GPU-$(gpu_id)\"\n", 58 | "else\n", 59 | " device = Flux.cpu\n", 60 | " @info \"Training on CPU\"\n", 61 | "end" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "data": { 71 | "text/plain": [ 72 | "Args" 73 | ] 74 | }, 75 | "execution_count": 3, 76 | "metadata": {}, 77 | "output_type": "execute_result" 78 | } 79 | ], 80 | "source": [ 81 | "using Parameters: @with_kw\n", 82 | "@with_kw mutable struct Args\n", 83 | " batch_size::Int = 64\n", 84 | " lr::Float64 = 5e-5\n", 85 | " epochs::Int = 10\n", 86 | " patience::Int = 5\n", 87 | " data_workers::Int = 4\n", 88 | " train_data_dir::String = \"/home/zhangzhi/Data/ImageNet2012/train\"\n", 89 | " val_data_dir::String = \"/home/zhangzhi/Data/ImageNet2012/train\"\n", 90 | "end" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 4, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "Args\n", 102 | " batch_size: Int64 64\n", 103 | " lr: Float64 5.0e-5\n", 104 | " epochs: Int64 10\n", 105 | " patience: Int64 5\n", 106 | " data_workers: Int64 4\n", 107 | " train_data_dir: String \"/home/zhangzhi/Data/ImageNet2012/train\"\n", 108 | " val_data_dir: String \"/home/zhangzhi/Data/ImageNet2012/train\"\n" 109 | ] 110 | }, 111 | "execution_count": 4, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "args = Args()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "模仿 pytorch 使用多个 worker 读取数据集,进行预处理,并且分为n个batch。这里将其封装为单独的文件。\n", 125 | "\n", 126 | "Imitating pytorch, we uses multiple workers to read the data set, preprocess it, and divide it into n batches. Here it is packaged as a separate file." 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 5, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "include(\"dataset.jl\")" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 6, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "name": "stderr", 145 | "output_type": "stream", 146 | "text": [ 147 | "┌ Info: Adding 4 new data workers...\n", 148 | "└ @ Main /data/zhangzhi/julia/dataset.jl:190\n", 149 | "┌ Info: Adding 4 new data workers...\n", 150 | "└ @ Main /data/zhangzhi/julia/dataset.jl:190\n" 151 | ] 152 | }, 153 | { 154 | "data": { 155 | "text/plain": [ 156 | "ImagenetDataset(\"/home/zhangzhi/Data/ImageNet2012/train\", 64, imagenet_val_data_loader, [\"n01440764/n01440764_10026.JPEG\", \"n01440764/n01440764_10027.JPEG\", \"n01440764/n01440764_10029.JPEG\", \"n01440764/n01440764_10040.JPEG\", \"n01440764/n01440764_10042.JPEG\", \"n01440764/n01440764_10043.JPEG\", \"n01440764/n01440764_10048.JPEG\", \"n01440764/n01440764_10066.JPEG\", \"n01440764/n01440764_10074.JPEG\", \"n01440764/n01440764_1009.JPEG\" … \"n15075141/n15075141_9816.JPEG\", \"n15075141/n15075141_9819.JPEG\", \"n15075141/n15075141_9835.JPEG\", \"n15075141/n15075141_9855.JPEG\", \"n15075141/n15075141_9907.JPEG\", \"n15075141/n15075141_9915.JPEG\", \"n15075141/n15075141_9933.JPEG\", \"n15075141/n15075141_9942.JPEG\", \"n15075141/n15075141_999.JPEG\", \"n15075141/n15075141_9993.JPEG\"], QueuePool([6, 7, 8, 9], RemoteChannel{Channel{Tuple}}(1, 1, 29), RemoteChannel{Channel{Tuple}}(1, 1, 30), RemoteChannel{Channel{Bool}}(1, 1, 31), 0, Dict{Int64,Any}()))" 157 | ] 158 | }, 159 | "execution_count": 6, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "train_dataset = ImagenetDataset(args.train_data_dir, args.data_workers, args.batch_size, imagenet_train_data_loader)\n", 166 | "val_dataset = ImagenetDataset(args.val_data_dir, args.data_workers, args.batch_size, imagenet_val_data_loader)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "定义 ResNet。\n", 174 | "\n", 175 | "Define ResNet." 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 7, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/plain": [ 186 | "(Chain(Conv((7, 7), 3=>64), MaxPool((3, 3), pad = (1, 1), stride = (2, 2)), Metalhead.ResidualBlock((Conv((1, 1), 64=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), Chain(Conv((1, 1), 64=>256), BatchNorm(256))), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), Chain(Conv((1, 1), 256=>512), BatchNorm(512))), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), Chain(Conv((1, 1), 512=>1024), BatchNorm(1024))), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>512), Conv((3, 3), 512=>512), Conv((1, 1), 512=>2048)), (BatchNorm(512), BatchNorm(512), BatchNorm(2048)), Chain(Conv((1, 1), 1024=>2048), BatchNorm(2048))), Metalhead.ResidualBlock((Conv((1, 1), 2048=>512), Conv((3, 3), 512=>512), Conv((1, 1), 512=>2048)), (BatchNorm(512), BatchNorm(512), BatchNorm(2048)), identity), Metalhead.ResidualBlock((Conv((1, 1), 2048=>512), Conv((3, 3), 512=>512), Conv((1, 1), 512=>2048)), (BatchNorm(512), BatchNorm(512), BatchNorm(2048)), identity), MeanPool((7, 7), pad = (0, 0, 0, 0), stride = (7, 7)), #103, Dense(2048, 1000)),)" 187 | ] 188 | }, 189 | "execution_count": 7, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "using Metalhead\n", 196 | "\n", 197 | "resnet = ResNet()\n", 198 | "model = Chain(resnet.layers[1:end-1]) |> device\n", 199 | "Flux.trainmode!(model, true)\n", 200 | "opt = ADAM(args.lr)\n", 201 | "model.layers" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "对模型的训练和测试的封装。\n", 209 | "\n", 210 | "Encapsulation of model training and testing." 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 8, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "data": { 220 | "text/plain": [ 221 | "train (generic function with 1 method)" 222 | ] 223 | }, 224 | "execution_count": 8, 225 | "metadata": {}, 226 | "output_type": "execute_result" 227 | } 228 | ], 229 | "source": [ 230 | "using BSON\n", 231 | "using Tracker\n", 232 | "using Statistics, Printf\n", 233 | "using Flux.Optimise\n", 234 | "\n", 235 | "function save_model(model, filename)\n", 236 | " model_state = Dict(\n", 237 | " :weights => Tracker.data.(params(model))\n", 238 | " )\n", 239 | " open(filename, \"w\") do io\n", 240 | " BSON.bson(io, model_state)\n", 241 | " end\n", 242 | "end\n", 243 | "\n", 244 | "function load_model!(model, filename)\n", 245 | " weights = BSON.load(filename)[:weights]\n", 246 | " Flux.loadparams!(model, weights)\n", 247 | " return model\n", 248 | "end\n", 249 | "\n", 250 | "@with_kw mutable struct State\n", 251 | " epoch::Int = 1\n", 252 | " train_loss_history = []\n", 253 | " val_loss_history = []\n", 254 | "end\n", 255 | "\n", 256 | "state = State()\n", 257 | "\n", 258 | "process_minibatch = (model, opt, x, y) -> begin\n", 259 | " x = x |> device\n", 260 | " y = y |> device\n", 261 | " #@show model_to_host(y_hat)\n", 262 | " #@show model_to_host(y)\n", 263 | " loss(x, y) = logitcrossentropy(model(x), y)\n", 264 | " Flux.train!(loss, params(model), [(x, y)], opt)\n", 265 | " batch_loss = logitcrossentropy(model(x), y)\n", 266 | " @show batch_loss\n", 267 | " return Tracker.data(batch_loss |> cpu)\n", 268 | "end\n", 269 | "\n", 270 | "\n", 271 | "function train_epoch(model, opt)\n", 272 | " # Clear out any previous training loss history\n", 273 | " while length(state.train_loss_history) < state.epoch\n", 274 | " push!(state.train_loss_history, Float64[])\n", 275 | " end\n", 276 | " state.train_loss_history[state.epoch] = zeros(Float64, length(train_dataset))\n", 277 | "\n", 278 | " batch_idx = 1\n", 279 | " avg_batch_time = 0.0\n", 280 | " t_last = time()\n", 281 | " for (x, y) in train_dataset\n", 282 | " # Store training loss into loss history\n", 283 | " state.train_loss_history[state.epoch][batch_idx] = process_minibatch(model, opt, x, y)\n", 284 | "\n", 285 | " # Update average batch time\n", 286 | " t_now = time()\n", 287 | " avg_batch_time = .99*avg_batch_time + .01*(t_now - t_last)\n", 288 | " t_last = t_now\n", 289 | "\n", 290 | " # Calculate ETA\n", 291 | " time_left = avg_batch_time*(length(train_dataset) - batch_idx)\n", 292 | " hours = floor(Int,time_left/(60*60))\n", 293 | " minutes = floor(Int, (time_left - hours*60*60)/60)\n", 294 | " seconds = time_left - hours*60*60 - minutes*60\n", 295 | " eta = @sprintf(\"%dh%dm%ds\", hours, minutes, seconds)\n", 296 | "\n", 297 | " # Show a smoothed loss approximation per-minibatch\n", 298 | " smoothed_loss = mean(state.train_loss_history[state.epoch][max(batch_idx-50,1):batch_idx])\n", 299 | " println(@sprintf(\n", 300 | " \"[TRAIN %d - %d/%d]: avg loss: %.4f, avg time: %.2fs, ETA: %s \",\n", 301 | " state.epoch, batch_idx, length(train_dataset), smoothed_loss,\n", 302 | " avg_batch_time, eta,\n", 303 | " ))\n", 304 | "\n", 305 | " batch_idx += 1\n", 306 | " end\n", 307 | "end\n", 308 | "\n", 309 | "function validate(model)\n", 310 | " # Get the \"fast model\", \n", 311 | " fast_model = Flux.mapleaves(Tracker.data, model)\n", 312 | " Flux.testmode!(fast_model, true)\n", 313 | "\n", 314 | " avg_loss = 0\n", 315 | " batch_idx = 1\n", 316 | " for (x, y) in val_dataset\n", 317 | " # Push x through our fast model and calculate loss\n", 318 | " y_hat = fast_model(x)\n", 319 | " avg_loss += cpu(Flux.crossentropy(y_hat, y))\n", 320 | "\n", 321 | " print(@sprintf(\n", 322 | " \"\\r[VAL %d - %d/%d]: %.2f\",\n", 323 | " state.epoch, batch_idx, length(val_dataset), avg_loss/batch_idx,\n", 324 | " ))\n", 325 | " batch_idx += 1\n", 326 | " end\n", 327 | " avg_loss /= length(val_dataset)\n", 328 | " push!(state.val_loss_history, avg_loss)\n", 329 | "\n", 330 | " # Return the average loss for this epoch\n", 331 | " return avg_loss\n", 332 | "end\n", 333 | "\n", 334 | "\n", 335 | "function train(model, opt)\n", 336 | " # Initialize best_epoch to epoch 0, with Infinity loss\n", 337 | " best_epoch = (0, Inf)\n", 338 | "\n", 339 | " while state.epoch < args.epochs\n", 340 | " # Early-stop if we don't improve after `args.patience` epochs\n", 341 | " if state.epoch > best_epoch[1] + args.patience\n", 342 | " @info(\"Losing patience at epoch $(state.epoch)!\")\n", 343 | " break\n", 344 | " end\n", 345 | "\n", 346 | " # Train for an epoch\n", 347 | " train_epoch(model, opt)\n", 348 | " \n", 349 | " # Validate to see how much we've improved\n", 350 | " epoch_loss = validate(model)\n", 351 | "\n", 352 | " # Check to see if this epoch is the best we've seen so far\n", 353 | " if epoch_loss < best_epoch[2]\n", 354 | " best_epoch = (state.epoch, epoch_loss)\n", 355 | " end\n", 356 | "\n", 357 | " # Save our training state every epoch (but only save the model weights\n", 358 | " # if this was the best epoch yet)\n", 359 | " state.epoch += 1\n", 360 | " end\n", 361 | "end" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": {}, 368 | "outputs": [ 369 | { 370 | "name": "stderr", 371 | "output_type": "stream", 372 | "text": [ 373 | "┌ Info: Beginning training run...\n", 374 | "└ @ Main In[9]:2\n", 375 | "┌ Info: Creating IIS with 1281167 images\n", 376 | "└ @ Main /data/zhangzhi/julia/dataset.jl:212\n" 377 | ] 378 | }, 379 | { 380 | "name": "stdout", 381 | "output_type": "stream", 382 | "text": [ 383 | "batch_loss = 6.9622507f0\n", 384 | "[TRAIN 1 - 1/20018]: avg loss: 6.9623, avg time: 0.78s, ETA: 4h21m11s \n", 385 | "batch_loss = 6.9277363f0\n", 386 | "[TRAIN 1 - 2/20018]: avg loss: 6.9450, avg time: 0.81s, ETA: 4h30m20s \n", 387 | "batch_loss = 6.942004f0\n", 388 | "[TRAIN 1 - 3/20018]: avg loss: 6.9440, avg time: 0.84s, ETA: 4h38m37s \n", 389 | "batch_loss = 6.9768744f0\n", 390 | "[TRAIN 1 - 4/20018]: avg loss: 6.9522, avg time: 0.86s, ETA: 4h45m42s \n", 391 | "batch_loss = 6.912876f0\n", 392 | "[TRAIN 1 - 5/20018]: avg loss: 6.9443, avg time: 0.87s, ETA: 4h51m37s \n", 393 | "batch_loss = 7.0515375f0\n", 394 | "[TRAIN 1 - 6/20018]: avg loss: 6.9622, avg time: 0.89s, ETA: 4h58m27s \n", 395 | "batch_loss = 7.0313864f0\n", 396 | "[TRAIN 1 - 7/20018]: avg loss: 6.9721, avg time: 0.91s, ETA: 5h4m52s \n", 397 | "batch_loss = 7.167213f0\n", 398 | "[TRAIN 1 - 8/20018]: avg loss: 6.9965, avg time: 0.93s, ETA: 5h11m43s \n", 399 | "batch_loss = 7.07827f0\n", 400 | "[TRAIN 1 - 9/20018]: avg loss: 7.0056, avg time: 0.96s, ETA: 5h19m8s \n", 401 | "batch_loss = 7.0133095f0\n", 402 | "[TRAIN 1 - 10/20018]: avg loss: 7.0063, avg time: 0.97s, ETA: 5h24m56s \n", 403 | "batch_loss = 7.131474f0\n", 404 | "[TRAIN 1 - 11/20018]: avg loss: 7.0177, avg time: 0.99s, ETA: 5h31m22s \n", 405 | "batch_loss = 7.1165767f0\n", 406 | "[TRAIN 1 - 12/20018]: avg loss: 7.0260, avg time: 1.01s, ETA: 5h38m7s \n", 407 | "batch_loss = 7.0434966f0\n", 408 | "[TRAIN 1 - 13/20018]: avg loss: 7.0273, avg time: 1.03s, ETA: 5h44m36s \n", 409 | "batch_loss = 7.341138f0\n", 410 | "[TRAIN 1 - 14/20018]: avg loss: 7.0497, avg time: 1.05s, ETA: 5h50m54s \n", 411 | "batch_loss = 7.0767326f0\n", 412 | "[TRAIN 1 - 15/20018]: avg loss: 7.0515, avg time: 1.07s, ETA: 5h56m41s \n", 413 | "batch_loss = 7.1909633f0\n", 414 | "[TRAIN 1 - 16/20018]: avg loss: 7.0602, avg time: 1.09s, ETA: 6h3m7s \n", 415 | "batch_loss = 7.6535316f0\n", 416 | "[TRAIN 1 - 17/20018]: avg loss: 7.0951, avg time: 1.11s, ETA: 6h9m3s \n", 417 | "batch_loss = 7.26151f0\n", 418 | "[TRAIN 1 - 18/20018]: avg loss: 7.1044, avg time: 1.13s, ETA: 6h15m1s \n", 419 | "batch_loss = 7.31052f0\n", 420 | "[TRAIN 1 - 19/20018]: avg loss: 7.1152, avg time: 1.14s, ETA: 6h20m56s \n", 421 | "batch_loss = 7.09128f0\n", 422 | "[TRAIN 1 - 20/20018]: avg loss: 7.1140, avg time: 1.16s, ETA: 6h26m55s \n", 423 | "batch_loss = 7.5812626f0\n", 424 | "[TRAIN 1 - 21/20018]: avg loss: 7.1363, avg time: 1.18s, ETA: 6h32m59s \n", 425 | "batch_loss = 7.6248775f0\n", 426 | "[TRAIN 1 - 22/20018]: avg loss: 7.1585, avg time: 1.20s, ETA: 6h39m0s \n", 427 | "batch_loss = 7.38128f0\n", 428 | "[TRAIN 1 - 23/20018]: avg loss: 7.1682, avg time: 1.21s, ETA: 6h44m49s \n", 429 | "batch_loss = 7.5851283f0\n", 430 | "[TRAIN 1 - 24/20018]: avg loss: 7.1856, avg time: 1.23s, ETA: 6h50m31s \n", 431 | "batch_loss = 7.520513f0\n", 432 | "[TRAIN 1 - 25/20018]: avg loss: 7.1989, avg time: 1.25s, ETA: 6h56m15s \n", 433 | "batch_loss = 7.662198f0\n", 434 | "[TRAIN 1 - 26/20018]: avg loss: 7.2168, avg time: 1.27s, ETA: 7h1m56s \n", 435 | "batch_loss = 7.373073f0\n", 436 | "[TRAIN 1 - 27/20018]: avg loss: 7.2226, avg time: 1.28s, ETA: 7h7m20s \n", 437 | "batch_loss = 7.65544f0\n", 438 | "[TRAIN 1 - 28/20018]: avg loss: 7.2380, avg time: 1.30s, ETA: 7h12m51s \n", 439 | "batch_loss = 7.401843f0\n", 440 | "[TRAIN 1 - 29/20018]: avg loss: 7.2437, avg time: 1.32s, ETA: 7h18m59s \n", 441 | "batch_loss = 7.541626f0\n", 442 | "[TRAIN 1 - 30/20018]: avg loss: 7.2536, avg time: 1.33s, ETA: 7h24m20s \n", 443 | "batch_loss = 7.664078f0\n", 444 | "[TRAIN 1 - 31/20018]: avg loss: 7.2668, avg time: 1.35s, ETA: 7h29m23s \n", 445 | "batch_loss = 7.607012f0\n", 446 | "[TRAIN 1 - 32/20018]: avg loss: 7.2775, avg time: 1.37s, ETA: 7h34m48s \n", 447 | "batch_loss = 7.557599f0\n", 448 | "[TRAIN 1 - 33/20018]: avg loss: 7.2860, avg time: 1.38s, ETA: 7h39m22s \n", 449 | "batch_loss = 7.889821f0\n", 450 | "[TRAIN 1 - 34/20018]: avg loss: 7.3037, avg time: 1.40s, ETA: 7h45m28s \n", 451 | "batch_loss = 7.570457f0\n", 452 | "[TRAIN 1 - 35/20018]: avg loss: 7.3113, avg time: 1.41s, ETA: 7h49m53s \n", 453 | "batch_loss = 7.829129f0\n", 454 | "[TRAIN 1 - 36/20018]: avg loss: 7.3257, avg time: 1.43s, ETA: 7h54m51s \n", 455 | "batch_loss = 7.7865896f0\n", 456 | "[TRAIN 1 - 37/20018]: avg loss: 7.3382, avg time: 1.44s, ETA: 8h0m39s \n", 457 | "batch_loss = 7.33513f0\n", 458 | "[TRAIN 1 - 38/20018]: avg loss: 7.3381, avg time: 1.46s, ETA: 8h5m11s \n", 459 | "batch_loss = 7.655512f0\n", 460 | "[TRAIN 1 - 39/20018]: avg loss: 7.3462, avg time: 1.47s, ETA: 8h9m59s \n", 461 | "batch_loss = 7.372562f0\n", 462 | "[TRAIN 1 - 40/20018]: avg loss: 7.3469, avg time: 1.49s, ETA: 8h14m58s \n", 463 | "batch_loss = 7.5467167f0\n", 464 | "[TRAIN 1 - 41/20018]: avg loss: 7.3518, avg time: 1.50s, ETA: 8h20m30s \n", 465 | "batch_loss = 7.5951633f0\n", 466 | "[TRAIN 1 - 42/20018]: avg loss: 7.3576, avg time: 1.52s, ETA: 8h24m37s \n", 467 | "batch_loss = 7.5205526f0\n", 468 | "[TRAIN 1 - 43/20018]: avg loss: 7.3614, avg time: 1.53s, ETA: 8h29m14s \n", 469 | "batch_loss = 7.2018003f0\n", 470 | "[TRAIN 1 - 44/20018]: avg loss: 7.3577, avg time: 1.55s, ETA: 8h34m28s \n", 471 | "batch_loss = 7.783224f0\n", 472 | "[TRAIN 1 - 45/20018]: avg loss: 7.3672, avg time: 1.56s, ETA: 8h40m44s \n", 473 | "batch_loss = 7.7686825f0\n", 474 | "[TRAIN 1 - 46/20018]: avg loss: 7.3759, avg time: 1.58s, ETA: 8h44m51s \n", 475 | "batch_loss = 7.1020703f0\n", 476 | "[TRAIN 1 - 47/20018]: avg loss: 7.3701, avg time: 1.59s, ETA: 8h49m16s \n", 477 | "batch_loss = 7.523679f0\n", 478 | "[TRAIN 1 - 48/20018]: avg loss: 7.3733, avg time: 1.60s, ETA: 8h53m45s \n", 479 | "batch_loss = 7.7012f0\n", 480 | "[TRAIN 1 - 49/20018]: avg loss: 7.3800, avg time: 1.62s, ETA: 8h59m49s \n", 481 | "batch_loss = 7.3365755f0\n", 482 | "[TRAIN 1 - 50/20018]: avg loss: 7.3791, avg time: 1.63s, ETA: 9h3m13s \n", 483 | "batch_loss = 7.137593f0\n", 484 | "[TRAIN 1 - 51/20018]: avg loss: 7.3744, avg time: 1.64s, ETA: 9h7m4s \n", 485 | "batch_loss = 6.7748346f0\n", 486 | "[TRAIN 1 - 52/20018]: avg loss: 7.3707, avg time: 1.66s, ETA: 9h11m38s \n", 487 | "batch_loss = 7.6484156f0\n", 488 | "[TRAIN 1 - 53/20018]: avg loss: 7.3848, avg time: 1.67s, ETA: 9h15m22s \n", 489 | "batch_loss = 7.461233f0\n", 490 | "[TRAIN 1 - 54/20018]: avg loss: 7.3950, avg time: 1.69s, ETA: 9h20m60s \n", 491 | "batch_loss = 7.0180135f0\n", 492 | "[TRAIN 1 - 55/20018]: avg loss: 7.3958, avg time: 1.70s, ETA: 9h24m21s \n", 493 | "batch_loss = 7.2069693f0\n", 494 | "[TRAIN 1 - 56/20018]: avg loss: 7.4016, avg time: 1.71s, ETA: 9h29m58s \n", 495 | "batch_loss = 7.3503304f0\n", 496 | "[TRAIN 1 - 57/20018]: avg loss: 7.4074, avg time: 1.72s, ETA: 9h33m14s \n", 497 | "batch_loss = 7.1234055f0\n", 498 | "[TRAIN 1 - 58/20018]: avg loss: 7.4092, avg time: 1.74s, ETA: 9h37m28s \n", 499 | "batch_loss = 7.170315f0\n", 500 | "[TRAIN 1 - 59/20018]: avg loss: 7.4093, avg time: 1.75s, ETA: 9h41m1s \n", 501 | "batch_loss = 7.1656704f0\n", 502 | "[TRAIN 1 - 60/20018]: avg loss: 7.4110, avg time: 1.76s, ETA: 9h44m48s \n", 503 | "batch_loss = 6.9516582f0\n", 504 | "[TRAIN 1 - 61/20018]: avg loss: 7.4098, avg time: 1.77s, ETA: 9h48m26s \n", 505 | "batch_loss = 7.330106f0\n", 506 | "[TRAIN 1 - 62/20018]: avg loss: 7.4137, avg time: 1.78s, ETA: 9h52m29s \n", 507 | "batch_loss = 6.720178f0\n", 508 | "[TRAIN 1 - 63/20018]: avg loss: 7.4059, avg time: 1.79s, ETA: 9h56m1s \n", 509 | "batch_loss = 7.1089444f0\n", 510 | "[TRAIN 1 - 64/20018]: avg loss: 7.4072, avg time: 1.80s, ETA: 9h59m55s \n", 511 | "batch_loss = 6.7845163f0\n", 512 | "[TRAIN 1 - 65/20018]: avg loss: 7.3963, avg time: 1.81s, ETA: 10h3m14s \n", 513 | "batch_loss = 7.158451f0\n", 514 | "[TRAIN 1 - 66/20018]: avg loss: 7.3979, avg time: 1.83s, ETA: 10h7m1s \n", 515 | "batch_loss = 6.865749f0\n", 516 | "[TRAIN 1 - 67/20018]: avg loss: 7.3915, avg time: 1.84s, ETA: 10h10m22s \n", 517 | "batch_loss = 6.91348f0\n", 518 | "[TRAIN 1 - 68/20018]: avg loss: 7.3770, avg time: 1.85s, ETA: 10h14m7s \n", 519 | "batch_loss = 6.7564955f0\n", 520 | "[TRAIN 1 - 69/20018]: avg loss: 7.3671, avg time: 1.86s, ETA: 10h17m30s \n", 521 | "batch_loss = 7.2078257f0\n", 522 | "[TRAIN 1 - 70/20018]: avg loss: 7.3651, avg time: 1.87s, ETA: 10h20m52s \n", 523 | "batch_loss = 6.5676007f0\n", 524 | "[TRAIN 1 - 71/20018]: avg loss: 7.3548, avg time: 1.88s, ETA: 10h23m54s \n", 525 | "batch_loss = 6.5625496f0\n", 526 | "[TRAIN 1 - 72/20018]: avg loss: 7.3349, avg time: 1.89s, ETA: 10h27m41s \n", 527 | "batch_loss = 6.7674713f0\n", 528 | "[TRAIN 1 - 73/20018]: avg loss: 7.3180, avg time: 1.90s, ETA: 10h31m52s \n", 529 | "batch_loss = 6.7032466f0\n", 530 | "[TRAIN 1 - 74/20018]: avg loss: 7.3048, avg time: 1.91s, ETA: 10h34m51s \n", 531 | "batch_loss = 6.949676f0\n", 532 | "[TRAIN 1 - 75/20018]: avg loss: 7.2923, avg time: 1.92s, ETA: 10h37m46s \n", 533 | "batch_loss = 6.8355093f0\n", 534 | "[TRAIN 1 - 76/20018]: avg loss: 7.2789, avg time: 1.93s, ETA: 10h41m9s \n", 535 | "batch_loss = 6.761684f0\n", 536 | "[TRAIN 1 - 77/20018]: avg loss: 7.2612, avg time: 1.94s, ETA: 10h44m10s \n", 537 | "batch_loss = 6.742987f0\n", 538 | "[TRAIN 1 - 78/20018]: avg loss: 7.2489, avg time: 1.95s, ETA: 10h47m23s \n", 539 | "batch_loss = 6.8627687f0\n", 540 | "[TRAIN 1 - 79/20018]: avg loss: 7.2333, avg time: 1.96s, ETA: 10h50m39s \n", 541 | "batch_loss = 6.658802f0\n", 542 | "[TRAIN 1 - 80/20018]: avg loss: 7.2187, avg time: 1.97s, ETA: 10h53m57s \n", 543 | "batch_loss = 6.874971f0\n", 544 | "[TRAIN 1 - 81/20018]: avg loss: 7.2057, avg time: 1.98s, ETA: 10h57m41s \n", 545 | "batch_loss = 6.808718f0\n", 546 | "[TRAIN 1 - 82/20018]: avg loss: 7.1889, avg time: 1.99s, ETA: 11h0m38s \n", 547 | "batch_loss = 6.615429f0\n", 548 | "[TRAIN 1 - 83/20018]: avg loss: 7.1695, avg time: 2.00s, ETA: 11h3m33s \n", 549 | "batch_loss = 6.7408957f0\n", 550 | "[TRAIN 1 - 84/20018]: avg loss: 7.1534, avg time: 2.01s, ETA: 11h6m36s \n", 551 | "batch_loss = 6.48777f0\n", 552 | "[TRAIN 1 - 85/20018]: avg loss: 7.1259, avg time: 2.02s, ETA: 11h11m55s \n", 553 | "batch_loss = 6.7756104f0\n", 554 | "[TRAIN 1 - 86/20018]: avg loss: 7.1104, avg time: 2.03s, ETA: 11h14m22s \n", 555 | "batch_loss = 6.249303f0\n", 556 | "[TRAIN 1 - 87/20018]: avg loss: 7.0794, avg time: 2.04s, ETA: 11h17m3s \n", 557 | "batch_loss = 6.717087f0\n", 558 | "[TRAIN 1 - 88/20018]: avg loss: 7.0584, avg time: 2.05s, ETA: 11h19m48s \n", 559 | "batch_loss = 6.4166083f0\n", 560 | "[TRAIN 1 - 89/20018]: avg loss: 7.0404, avg time: 2.05s, ETA: 11h22m20s \n", 561 | "batch_loss = 6.710187f0\n", 562 | "[TRAIN 1 - 90/20018]: avg loss: 7.0219, avg time: 2.06s, ETA: 11h25m50s \n", 563 | "batch_loss = 6.438019f0\n", 564 | "[TRAIN 1 - 91/20018]: avg loss: 7.0035, avg time: 2.07s, ETA: 11h28m9s \n", 565 | "batch_loss = 6.7583857f0\n", 566 | "[TRAIN 1 - 92/20018]: avg loss: 6.9881, avg time: 2.08s, ETA: 11h31m18s \n", 567 | "batch_loss = 6.522828f0\n", 568 | "[TRAIN 1 - 93/20018]: avg loss: 6.9671, avg time: 2.09s, ETA: 11h33m55s \n", 569 | "batch_loss = 6.576439f0\n", 570 | "[TRAIN 1 - 94/20018]: avg loss: 6.9485, avg time: 2.10s, ETA: 11h36m43s \n", 571 | "batch_loss = 6.5601664f0\n", 572 | "[TRAIN 1 - 95/20018]: avg loss: 6.9360, avg time: 2.11s, ETA: 11h39m28s \n", 573 | "batch_loss = 6.541833f0\n", 574 | "[TRAIN 1 - 96/20018]: avg loss: 6.9116, avg time: 2.12s, ETA: 11h42m38s \n", 575 | "batch_loss = 6.7442126f0\n", 576 | "[TRAIN 1 - 97/20018]: avg loss: 6.8915, avg time: 2.13s, ETA: 11h45m54s \n", 577 | "batch_loss = 7.0219116f0\n", 578 | "[TRAIN 1 - 98/20018]: avg loss: 6.8900, avg time: 2.14s, ETA: 11h48m49s \n", 579 | "batch_loss = 6.8299356f0\n", 580 | "[TRAIN 1 - 99/20018]: avg loss: 6.8764, avg time: 2.14s, ETA: 11h51m58s \n", 581 | "batch_loss = 6.6847277f0\n", 582 | "[TRAIN 1 - 100/20018]: avg loss: 6.8564, avg time: 2.15s, ETA: 11h54m59s \n", 583 | "batch_loss = 6.6075153f0\n", 584 | "[TRAIN 1 - 101/20018]: avg loss: 6.8421, avg time: 2.16s, ETA: 11h57m14s \n", 585 | "batch_loss = 6.675604f0\n", 586 | "[TRAIN 1 - 102/20018]: avg loss: 6.8331, avg time: 2.17s, ETA: 11h59m33s \n", 587 | "batch_loss = 6.882872f0\n", 588 | "[TRAIN 1 - 103/20018]: avg loss: 6.8352, avg time: 2.18s, ETA: 12h1m57s \n", 589 | "batch_loss = 6.5191336f0\n", 590 | "[TRAIN 1 - 104/20018]: avg loss: 6.8131, avg time: 2.18s, ETA: 12h4m16s \n", 591 | "batch_loss = 6.5079465f0\n", 592 | "[TRAIN 1 - 105/20018]: avg loss: 6.7944, avg time: 2.19s, ETA: 12h6m27s \n", 593 | "batch_loss = 6.1698112f0\n", 594 | "[TRAIN 1 - 106/20018]: avg loss: 6.7777, avg time: 2.20s, ETA: 12h10m40s \n", 595 | "batch_loss = 6.469122f0\n", 596 | "[TRAIN 1 - 107/20018]: avg loss: 6.7633, avg time: 2.21s, ETA: 12h12m9s \n", 597 | "batch_loss = 6.3438797f0\n", 598 | "[TRAIN 1 - 108/20018]: avg loss: 6.7435, avg time: 2.21s, ETA: 12h14m53s \n", 599 | "batch_loss = 6.2818766f0\n", 600 | "[TRAIN 1 - 109/20018]: avg loss: 6.7270, avg time: 2.22s, ETA: 12h16m52s \n", 601 | "batch_loss = 6.815174f0\n", 602 | "[TRAIN 1 - 110/20018]: avg loss: 6.7201, avg time: 2.23s, ETA: 12h19m50s \n", 603 | "batch_loss = 6.5535192f0\n", 604 | "[TRAIN 1 - 111/20018]: avg loss: 6.7081, avg time: 2.24s, ETA: 12h22m45s \n", 605 | "batch_loss = 6.1814556f0\n", 606 | "[TRAIN 1 - 112/20018]: avg loss: 6.6930, avg time: 2.25s, ETA: 12h26m41s \n", 607 | "batch_loss = 6.302928f0\n", 608 | "[TRAIN 1 - 113/20018]: avg loss: 6.6728, avg time: 2.26s, ETA: 12h30m22s \n", 609 | "batch_loss = 6.178856f0\n", 610 | "[TRAIN 1 - 114/20018]: avg loss: 6.6622, avg time: 2.27s, ETA: 12h31m30s \n", 611 | "batch_loss = 6.2830343f0\n", 612 | "[TRAIN 1 - 115/20018]: avg loss: 6.6460, avg time: 2.27s, ETA: 12h33m31s \n", 613 | "batch_loss = 6.4876084f0\n", 614 | "[TRAIN 1 - 116/20018]: avg loss: 6.6402, avg time: 2.28s, ETA: 12h35m40s \n", 615 | "batch_loss = 6.1451254f0\n", 616 | "[TRAIN 1 - 117/20018]: avg loss: 6.6203, avg time: 2.29s, ETA: 12h39m5s \n", 617 | "batch_loss = 6.297739f0\n", 618 | "[TRAIN 1 - 118/20018]: avg loss: 6.6092, avg time: 2.29s, ETA: 12h40m36s \n", 619 | "batch_loss = 6.553295f0\n", 620 | "[TRAIN 1 - 119/20018]: avg loss: 6.6021, avg time: 2.30s, ETA: 12h42m14s \n", 621 | "batch_loss = 6.344301f0\n", 622 | "[TRAIN 1 - 120/20018]: avg loss: 6.5940, avg time: 2.31s, ETA: 12h44m49s \n", 623 | "batch_loss = 6.113551f0\n", 624 | "[TRAIN 1 - 121/20018]: avg loss: 6.5726, avg time: 2.31s, ETA: 12h46m12s \n", 625 | "batch_loss = 6.2233915f0\n", 626 | "[TRAIN 1 - 122/20018]: avg loss: 6.5658, avg time: 2.32s, ETA: 12h48m17s \n", 627 | "batch_loss = 6.2611136f0\n", 628 | "[TRAIN 1 - 123/20018]: avg loss: 6.5599, avg time: 2.33s, ETA: 12h51m3s \n", 629 | "batch_loss = 6.4684725f0\n", 630 | "[TRAIN 1 - 124/20018]: avg loss: 6.5541, avg time: 2.33s, ETA: 12h52m31s \n", 631 | "batch_loss = 6.402356f0\n", 632 | "[TRAIN 1 - 125/20018]: avg loss: 6.5482, avg time: 2.33s, ETA: 12h54m4s \n", 633 | "batch_loss = 6.0890493f0\n", 634 | "[TRAIN 1 - 126/20018]: avg loss: 6.5313, avg time: 2.34s, ETA: 12h56m16s \n", 635 | "batch_loss = 6.241959f0\n", 636 | "[TRAIN 1 - 127/20018]: avg loss: 6.5196, avg time: 2.35s, ETA: 12h57m48s \n", 637 | "batch_loss = 6.143773f0\n", 638 | "[TRAIN 1 - 128/20018]: avg loss: 6.5075, avg time: 2.36s, ETA: 13h2m30s \n", 639 | "batch_loss = 6.28998f0\n", 640 | "[TRAIN 1 - 129/20018]: avg loss: 6.4986, avg time: 2.36s, ETA: 13h2m27s \n", 641 | "batch_loss = 6.2962484f0\n", 642 | "[TRAIN 1 - 130/20018]: avg loss: 6.4875, avg time: 2.37s, ETA: 13h5m14s \n", 643 | "batch_loss = 6.232935f0\n", 644 | "[TRAIN 1 - 131/20018]: avg loss: 6.4792, avg time: 2.38s, ETA: 13h7m51s \n", 645 | "batch_loss = 6.422283f0\n", 646 | "[TRAIN 1 - 132/20018]: avg loss: 6.4703, avg time: 2.38s, ETA: 13h10m4s \n", 647 | "batch_loss = 6.2529445f0\n", 648 | "[TRAIN 1 - 133/20018]: avg loss: 6.4594, avg time: 2.39s, ETA: 13h12m41s \n", 649 | "batch_loss = 6.376482f0\n", 650 | "[TRAIN 1 - 134/20018]: avg loss: 6.4547, avg time: 2.40s, ETA: 13h14m32s \n", 651 | "batch_loss = 6.529128f0\n", 652 | "[TRAIN 1 - 135/20018]: avg loss: 6.4506, avg time: 2.41s, ETA: 13h17m16s \n", 653 | "batch_loss = 6.2505274f0\n", 654 | "[TRAIN 1 - 136/20018]: avg loss: 6.4459, avg time: 2.41s, ETA: 13h19m23s \n", 655 | "batch_loss = 6.1326075f0\n", 656 | "[TRAIN 1 - 137/20018]: avg loss: 6.4333, avg time: 2.42s, ETA: 13h20m57s \n", 657 | "batch_loss = 6.0665383f0\n", 658 | "[TRAIN 1 - 138/20018]: avg loss: 6.4297, avg time: 2.42s, ETA: 13h23m12s \n", 659 | "batch_loss = 6.494875f0\n", 660 | "[TRAIN 1 - 139/20018]: avg loss: 6.4254, avg time: 2.43s, ETA: 13h24m47s \n", 661 | "batch_loss = 6.2976084f0\n", 662 | "[TRAIN 1 - 140/20018]: avg loss: 6.4230, avg time: 2.44s, ETA: 13h26m57s \n", 663 | "batch_loss = 6.24177f0\n", 664 | "[TRAIN 1 - 141/20018]: avg loss: 6.4139, avg time: 2.44s, ETA: 13h29m2s \n", 665 | "batch_loss = 6.016088f0\n", 666 | "[TRAIN 1 - 142/20018]: avg loss: 6.4056, avg time: 2.45s, ETA: 13h31m12s \n", 667 | "batch_loss = 6.215699f0\n", 668 | "[TRAIN 1 - 143/20018]: avg loss: 6.3949, avg time: 2.45s, ETA: 13h32m54s \n", 669 | "batch_loss = 6.331113f0\n", 670 | "[TRAIN 1 - 144/20018]: avg loss: 6.3912, avg time: 2.46s, ETA: 13h35m18s \n", 671 | "batch_loss = 6.5089054f0\n", 672 | "[TRAIN 1 - 145/20018]: avg loss: 6.3899, avg time: 2.47s, ETA: 13h36m50s \n", 673 | "batch_loss = 5.943267f0\n", 674 | "[TRAIN 1 - 146/20018]: avg loss: 6.3778, avg time: 2.47s, ETA: 13h38m29s \n", 675 | "batch_loss = 6.3464794f0\n", 676 | "[TRAIN 1 - 147/20018]: avg loss: 6.3739, avg time: 2.48s, ETA: 13h39m48s \n", 677 | "batch_loss = 6.5075245f0\n", 678 | "[TRAIN 1 - 148/20018]: avg loss: 6.3693, avg time: 2.48s, ETA: 13h41m38s \n", 679 | "batch_loss = 5.843158f0\n", 680 | "[TRAIN 1 - 149/20018]: avg loss: 6.3462, avg time: 2.49s, ETA: 13h42m58s \n", 681 | "batch_loss = 6.015974f0\n", 682 | "[TRAIN 1 - 150/20018]: avg loss: 6.3302, avg time: 2.49s, ETA: 13h44m59s \n", 683 | "batch_loss = 6.2874823f0\n", 684 | "[TRAIN 1 - 151/20018]: avg loss: 6.3224, avg time: 2.50s, ETA: 13h47m46s \n", 685 | "batch_loss = 6.145458f0\n", 686 | "[TRAIN 1 - 152/20018]: avg loss: 6.3134, avg time: 2.51s, ETA: 13h49m56s \n", 687 | "batch_loss = 6.2152166f0\n", 688 | "[TRAIN 1 - 153/20018]: avg loss: 6.3043, avg time: 2.51s, ETA: 13h50m57s \n", 689 | "batch_loss = 6.4732656f0\n", 690 | "[TRAIN 1 - 154/20018]: avg loss: 6.2963, avg time: 2.52s, ETA: 13h53m8s \n", 691 | "batch_loss = 6.2842054f0\n", 692 | "[TRAIN 1 - 155/20018]: avg loss: 6.2917, avg time: 2.52s, ETA: 13h53m53s \n", 693 | "batch_loss = 6.099002f0\n", 694 | "[TRAIN 1 - 156/20018]: avg loss: 6.2837, avg time: 2.52s, ETA: 13h55m18s \n", 695 | "batch_loss = 6.0108767f0\n", 696 | "[TRAIN 1 - 157/20018]: avg loss: 6.2806, avg time: 2.53s, ETA: 13h56m17s \n", 697 | "batch_loss = 6.3169117f0\n", 698 | "[TRAIN 1 - 158/20018]: avg loss: 6.2776, avg time: 2.53s, ETA: 13h57m48s \n", 699 | "batch_loss = 6.130246f0\n", 700 | "[TRAIN 1 - 159/20018]: avg loss: 6.2734, avg time: 2.53s, ETA: 13h58m58s \n", 701 | "batch_loss = 6.179352f0\n", 702 | "[TRAIN 1 - 160/20018]: avg loss: 6.2714, avg time: 2.54s, ETA: 14h0m40s \n", 703 | "batch_loss = 6.1209974f0\n", 704 | "[TRAIN 1 - 161/20018]: avg loss: 6.2578, avg time: 2.54s, ETA: 14h1m52s \n", 705 | "batch_loss = 5.9963007f0\n", 706 | "[TRAIN 1 - 162/20018]: avg loss: 6.2469, avg time: 2.55s, ETA: 14h3m14s \n", 707 | "batch_loss = 6.0571094f0\n", 708 | "[TRAIN 1 - 163/20018]: avg loss: 6.2444, avg time: 2.55s, ETA: 14h5m8s \n", 709 | "batch_loss = 5.976975f0\n", 710 | "[TRAIN 1 - 164/20018]: avg loss: 6.2380, avg time: 2.56s, ETA: 14h6m28s \n", 711 | "batch_loss = 6.4343805f0\n", 712 | "[TRAIN 1 - 165/20018]: avg loss: 6.2430, avg time: 2.56s, ETA: 14h7m36s \n", 713 | "batch_loss = 6.3204055f0\n", 714 | "[TRAIN 1 - 166/20018]: avg loss: 6.2438, avg time: 2.57s, ETA: 14h9m1s \n", 715 | "batch_loss = 6.144717f0\n", 716 | "[TRAIN 1 - 167/20018]: avg loss: 6.2370, avg time: 2.57s, ETA: 14h10m15s \n", 717 | "batch_loss = 5.978823f0\n", 718 | "[TRAIN 1 - 168/20018]: avg loss: 6.2338, avg time: 2.58s, ETA: 14h12m2s \n", 719 | "batch_loss = 6.2657084f0\n", 720 | "[TRAIN 1 - 169/20018]: avg loss: 6.2332, avg time: 2.58s, ETA: 14h12m52s \n", 721 | "batch_loss = 6.108475f0\n", 722 | "[TRAIN 1 - 170/20018]: avg loss: 6.2244, avg time: 2.58s, ETA: 14h14m26s \n", 723 | "batch_loss = 6.2016907f0\n", 724 | "[TRAIN 1 - 171/20018]: avg loss: 6.2216, avg time: 2.59s, ETA: 14h15m34s \n", 725 | "batch_loss = 6.105574f0\n", 726 | "[TRAIN 1 - 172/20018]: avg loss: 6.2215, avg time: 2.60s, ETA: 14h19m12s \n", 727 | "batch_loss = 6.411059f0\n", 728 | "[TRAIN 1 - 173/20018]: avg loss: 6.2252, avg time: 2.60s, ETA: 14h18m48s \n", 729 | "batch_loss = 6.2743998f0\n", 730 | "[TRAIN 1 - 174/20018]: avg loss: 6.2254, avg time: 2.60s, ETA: 14h20m1s \n", 731 | "batch_loss = 6.207641f0\n", 732 | "[TRAIN 1 - 175/20018]: avg loss: 6.2203, avg time: 2.61s, ETA: 14h23m0s \n", 733 | "batch_loss = 6.2922597f0\n", 734 | "[TRAIN 1 - 176/20018]: avg loss: 6.2181, avg time: 2.61s, ETA: 14h24m6s \n", 735 | "batch_loss = 6.4356365f0\n", 736 | "[TRAIN 1 - 177/20018]: avg loss: 6.2249, avg time: 2.62s, ETA: 14h25m28s \n", 737 | "batch_loss = 6.0347633f0\n", 738 | "[TRAIN 1 - 178/20018]: avg loss: 6.2209, avg time: 2.62s, ETA: 14h27m38s \n", 739 | "batch_loss = 5.9667406f0\n", 740 | "[TRAIN 1 - 179/20018]: avg loss: 6.2174, avg time: 2.63s, ETA: 14h28m10s \n", 741 | "batch_loss = 5.7165823f0\n", 742 | "[TRAIN 1 - 180/20018]: avg loss: 6.2062, avg time: 2.63s, ETA: 14h29m41s \n", 743 | "batch_loss = 5.9902f0\n", 744 | "[TRAIN 1 - 181/20018]: avg loss: 6.2002, avg time: 2.63s, ETA: 14h30m57s \n", 745 | "batch_loss = 6.1820183f0\n", 746 | "[TRAIN 1 - 182/20018]: avg loss: 6.1992, avg time: 2.64s, ETA: 14h32m12s \n", 747 | "batch_loss = 5.860342f0\n", 748 | "[TRAIN 1 - 183/20018]: avg loss: 6.1881, avg time: 2.64s, ETA: 14h33m3s \n", 749 | "batch_loss = 5.9650106f0\n", 750 | "[TRAIN 1 - 184/20018]: avg loss: 6.1825, avg time: 2.64s, ETA: 14h34m7s \n", 751 | "batch_loss = 6.3115206f0\n", 752 | "[TRAIN 1 - 185/20018]: avg loss: 6.1812, avg time: 2.65s, ETA: 14h34m42s \n", 753 | "batch_loss = 6.3297844f0\n", 754 | "[TRAIN 1 - 186/20018]: avg loss: 6.1773, avg time: 2.65s, ETA: 14h36m2s \n", 755 | "batch_loss = 6.3033185f0\n", 756 | "[TRAIN 1 - 187/20018]: avg loss: 6.1784, avg time: 2.66s, ETA: 14h37m58s \n", 757 | "batch_loss = 5.835417f0\n", 758 | "[TRAIN 1 - 188/20018]: avg loss: 6.1725, avg time: 2.66s, ETA: 14h38m53s \n", 759 | "batch_loss = 6.0861673f0\n", 760 | "[TRAIN 1 - 189/20018]: avg loss: 6.1729, avg time: 2.66s, ETA: 14h39m37s \n" 761 | ] 762 | } 763 | ], 764 | "source": [ 765 | "# Train away, train away, train away |> 's/train/sail/ig'\n", 766 | "@info(\"Beginning training run...\")\n", 767 | "train(model, opt)" 768 | ] 769 | }, 770 | { 771 | "cell_type": "code", 772 | "execution_count": null, 773 | "metadata": {}, 774 | "outputs": [], 775 | "source": [] 776 | } 777 | ], 778 | "metadata": { 779 | "kernelspec": { 780 | "display_name": "Julia 1.4.1", 781 | "language": "julia", 782 | "name": "julia-1.4" 783 | }, 784 | "language_info": { 785 | "file_extension": ".jl", 786 | "mimetype": "application/julia", 787 | "name": "julia", 788 | "version": "1.4.1" 789 | } 790 | }, 791 | "nbformat": 4, 792 | "nbformat_minor": 2 793 | } 794 | --------------------------------------------------------------------------------