├── .gitignore ├── lang-detection ├── .gitignore ├── scrape.jl └── model.jl ├── test └── test.jl ├── treebank ├── recursive.jl └── data.jl ├── mnist ├── mlp.jl ├── conv.jl ├── autoencoder.jl └── vae.jl ├── phonemes ├── 0-data.jl └── 1-model.jl ├── LICENSE.md ├── housing └── housing.jl ├── README.md ├── char-rnn └── char-rnn.jl └── dqn ├── dqn.jl └── ddqn.jl /.gitignore: -------------------------------------------------------------------------------- 1 | *.jls 2 | -------------------------------------------------------------------------------- /lang-detection/.gitignore: -------------------------------------------------------------------------------- 1 | corpus 2 | -------------------------------------------------------------------------------- /test/test.jl: -------------------------------------------------------------------------------- 1 | using Flux 2 | 3 | info("Hooking train loop") 4 | 5 | function Flux.train!(loss, data, opt; cb = () -> ()) 6 | loss(first(data)...) 7 | opt() 8 | cb() 9 | end 10 | 11 | file(x) = joinpath(@__DIR__, "..", x) 12 | 13 | models = [ 14 | ("MNIST MLP","mnist/mlp.jl"), 15 | ("MNIST Conv","mnist/conv.jl"), 16 | ("MNIST Autoencoder","mnist/autoencoder.jl")] 17 | 18 | info("Testing CPU models") 19 | for (name, p) in models 20 | info(name) 21 | include(file(p)) 22 | end 23 | info("MNIST VAE") 24 | info("mnist/vae.jl") 25 | 26 | if Base.find_in_path("CuArrays") != nothing 27 | using CuArrays 28 | info("Testing GPU models") 29 | for (name, p) in models 30 | info(name) 31 | include(file(p)) 32 | end 33 | end 34 | -------------------------------------------------------------------------------- /treebank/recursive.jl: -------------------------------------------------------------------------------- 1 | using Flux 2 | using Flux: crossentropy, throttle 3 | using Flux.Data: Tree, children, isleaf 4 | 5 | include("data.jl") 6 | 7 | N = 300 8 | 9 | embedding = param(randn(N, length(alphabet))) 10 | 11 | W = Dense(2N, N, tanh) 12 | combine(a, b) = W([a; b]) 13 | 14 | sentiment = Chain(Dense(N, 5), softmax) 15 | 16 | function forward(tree) 17 | if isleaf(tree) 18 | token, sent = tree.value 19 | phrase = embedding * token 20 | phrase, crossentropy(sentiment(phrase), sent) 21 | else 22 | _, sent = tree.value 23 | c1, l1 = forward(tree[1]) 24 | c2, l2 = forward(tree[2]) 25 | phrase = combine(c1, c2) 26 | phrase, l1 + l2 + crossentropy(sentiment(phrase), sent) 27 | end 28 | end 29 | 30 | loss(tree) = forward(tree)[2] 31 | 32 | opt = ADAM(params(embedding, W, sentiment)) 33 | evalcb = () -> @show loss(train[1]) 34 | 35 | Flux.train!(loss, zip(train), opt, 36 | cb = throttle(evalcb, 10)) 37 | -------------------------------------------------------------------------------- /mnist/mlp.jl: -------------------------------------------------------------------------------- 1 | using Flux, Flux.Data.MNIST 2 | using Flux: onehotbatch, argmax, crossentropy, throttle 3 | using Base.Iterators: repeated 4 | # using CuArrays 5 | 6 | # Classify MNIST digits with a simple multi-layer-perceptron 7 | 8 | imgs = MNIST.images() 9 | # Stack images into one large batch 10 | X = hcat(float.(reshape.(imgs, :))...) |> gpu 11 | 12 | labels = MNIST.labels() 13 | # One-hot-encode the labels 14 | Y = onehotbatch(labels, 0:9) |> gpu 15 | 16 | m = Chain( 17 | Dense(28^2, 32, relu), 18 | Dense(32, 10), 19 | softmax) |> gpu 20 | 21 | loss(x, y) = crossentropy(m(x), y) 22 | 23 | accuracy(x, y) = mean(argmax(m(x)) .== argmax(y)) 24 | 25 | dataset = repeated((X, Y), 200) 26 | evalcb = () -> @show(loss(X, Y)) 27 | opt = ADAM(params(m)) 28 | 29 | Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10)) 30 | 31 | accuracy(X, Y) 32 | 33 | # Test set accuracy 34 | tX = hcat(float.(reshape.(MNIST.images(:test), :))...) |> gpu 35 | tY = onehotbatch(MNIST.labels(:test), 0:9) |> gpu 36 | 37 | accuracy(tX, tY) 38 | -------------------------------------------------------------------------------- /phonemes/0-data.jl: -------------------------------------------------------------------------------- 1 | using Flux, Flux.Data.CMUDict 2 | using Flux: onehot, batchseq 3 | using Base.Iterators: partition 4 | 5 | dict = cmudict() 6 | alphabet = [:end, CMUDict.alphabet()...] 7 | phones = [:start, :end, CMUDict.symbols()...] 8 | 9 | tokenise(s, α) = [onehot(c, α) for c in s] 10 | 11 | # Turn a word into a sequence of vectors 12 | tokenise("PHYLOGENY", alphabet) 13 | # Same for phoneme lists 14 | tokenise(dict["PHYLOGENY"], phones) 15 | 16 | words = sort(collect(keys(dict)), by = length) 17 | 18 | # Finally, create iterators for our inputs and outputs. 19 | batches(xs, p) = [batchseq(b, p) for b in partition(xs, 50)] 20 | 21 | Xs = batches([tokenise(word, alphabet) for word in words], 22 | onehot(:end, alphabet)) 23 | 24 | Ys = batches([tokenise([dict[word]..., :end], phones) for word in words], 25 | onehot(:end, phones)) 26 | 27 | Yo = batches([tokenise([:start, dict[word]...], phones) for word in words], 28 | onehot(:end, phones)) 29 | 30 | data = collect(zip(Xs, Yo, Ys)) 31 | -------------------------------------------------------------------------------- /treebank/data.jl: -------------------------------------------------------------------------------- 1 | using Flux 2 | using Flux: onehot 3 | using Flux.Data.Sentiment 4 | using Flux.Data: Tree, leaves 5 | 6 | traintrees = Sentiment.train() 7 | 8 | # Get the raw labels and phrases as separate trees. 9 | labels = map.(x -> x[1], traintrees) 10 | phrases = map.(x -> x[2], traintrees) 11 | 12 | # All tokens in the training set. 13 | tokens = vcat(map(leaves, phrases)...) 14 | 15 | # Count how many times each token appears. 16 | freqs = Dict{String,Int}() 17 | for t in tokens 18 | freqs[t] = get(freqs, t, 0) + 1 19 | end 20 | 21 | # Replace singleton tokens with an "unknown" marker. 22 | # This roughly cuts our "alphabet" of tokens in half. 23 | phrases = map.(t -> get(freqs, t, 0) == 1 ? "UNK" : t, phrases) 24 | 25 | # Our alphabet of tokens. 26 | alphabet = unique(vcat(map(leaves, phrases)...)) 27 | 28 | # One-hot-encode our training data with respect to the alphabet. 29 | phrases_e = map.(t -> t == nothing ? t : onehot(t, alphabet), phrases) 30 | labels_e = map.(t -> onehot(t, 0:4), labels) 31 | 32 | train = map.(tuple, phrases_e, labels_e) 33 | -------------------------------------------------------------------------------- /mnist/conv.jl: -------------------------------------------------------------------------------- 1 | using Flux, Flux.Data.MNIST 2 | using Flux: onehotbatch, argmax, crossentropy, throttle 3 | using Base.Iterators: repeated, partition 4 | # using CuArrays 5 | 6 | # Classify MNIST digits with a convolutional network 7 | 8 | imgs = MNIST.images() 9 | 10 | labels = onehotbatch(MNIST.labels(), 0:9) 11 | 12 | # Partition into batches of size 1,000 13 | train = [(cat(4, float.(imgs[i])...), labels[:,i]) 14 | for i in partition(1:60_000, 1000)] 15 | 16 | train = gpu.(train) 17 | 18 | # Prepare test set (first 1,000 images) 19 | tX = cat(4, float.(MNIST.images(:test)[1:1000])...) |> gpu 20 | tY = onehotbatch(MNIST.labels(:test)[1:1000], 0:9) |> gpu 21 | 22 | m = Chain( 23 | Conv((2,2), 1=>16, relu), 24 | x -> maxpool(x, (2,2)), 25 | Conv((2,2), 16=>8, relu), 26 | x -> maxpool(x, (2,2)), 27 | x -> reshape(x, :, size(x, 4)), 28 | Dense(288, 10), softmax) |> gpu 29 | 30 | m(train[1][1]) 31 | 32 | loss(x, y) = crossentropy(m(x), y) 33 | 34 | accuracy(x, y) = mean(argmax(m(x)) .== argmax(y)) 35 | 36 | evalcb = throttle(() -> @show(accuracy(tX, tY)), 10) 37 | opt = ADAM(params(m)) 38 | 39 | Flux.train!(loss, train, opt, cb = evalcb) 40 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | These examples are licensed under the MIT "Expat" License: 2 | 3 | > Copyright (c) 2017: Mike Innes & contributors. 4 | > 5 | > Permission is hereby granted, free of charge, to any person obtaining a copy 6 | > of this software and associated documentation files (the "Software"), to deal 7 | > in the Software without restriction, including without limitation the rights 8 | > to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | > copies of the Software, and to permit persons to whom the Software is 10 | > furnished to do so, subject to the following conditions: 11 | > 12 | > The above copyright notice and this permission notice shall be included in all 13 | > copies or substantial portions of the Software. 14 | > 15 | > THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | > IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | > FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | > AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | > LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | > OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | > SOFTWARE. 22 | > 23 | -------------------------------------------------------------------------------- /housing/housing.jl: -------------------------------------------------------------------------------- 1 | using Flux.Tracker 2 | 3 | # This replicates the housing data example from the Knet.jl readme. Although we 4 | # could have reused more of Flux (see the mnist example), the library's 5 | # abstractions are very lightweight and don't force you into any particular 6 | # strategy. 7 | 8 | cd(@__DIR__) 9 | 10 | isfile("housing.data") || 11 | download("https://raw.githubusercontent.com/MikeInnes/notebooks/master/housing.data", 12 | "housing.data") 13 | 14 | rawdata = readdlm("housing.data")' 15 | 16 | # The last feature is our target -- the price of the house. 17 | 18 | x = rawdata[1:13,:] 19 | y = rawdata[14:14,:] 20 | 21 | # Normalise the data 22 | x = (x .- mean(x,2)) ./ std(x,2) 23 | 24 | # The model 25 | 26 | W = param(randn(1,13)/10) 27 | b = param([0.]) 28 | 29 | # using CuArrays 30 | # W, b, x, y = cu.((W, b, x, y)) 31 | 32 | predict(x) = W*x .+ b 33 | meansquarederror(ŷ, y) = sum((ŷ .- y).^2)/size(y, 2) 34 | loss(x, y) = meansquarederror(predict(x), y) 35 | 36 | function update!(ps, η = .1) 37 | for w in ps 38 | w.data .-= w.grad .* η 39 | w.grad .= 0 40 | end 41 | end 42 | 43 | for i = 1:10 44 | back!(loss(x, y)) 45 | update!((W, b)) 46 | @show loss(x, y) 47 | end 48 | 49 | predict(x[:,1]) / y[1] 50 | -------------------------------------------------------------------------------- /lang-detection/scrape.jl: -------------------------------------------------------------------------------- 1 | using Cascadia, Gumbo, HTTP, AbstractTrees 2 | 3 | pages = Dict( 4 | :en => ["Wikipedia", "Osama_bin_Laden_(elephant)", "List_of_lists_of_lists", "Josephine_Butler", "Canadian_football", "Judaism"], 5 | :it => ["Wikipedia", "Ludovico_Einaudi", "Filosofia_della_scienza", "Pizza", "Effie_Gray", "Galeazzo_Maria_Sforza", "Ebraismo"], 6 | :fr => ["Wikipedia", "Philosophie_des_sciences", "Seconde_Guerre_mondiale", "Eric_Hakonsson"], 7 | :es => ["Wikipedia", "Chorizo", "Historia_de_Barcelona", "Espana", "Las_Vegas_Strip", "Judaismo"], 8 | :da => ["Wikipedia", "H.C._Andersen", "L.A._Ring", "Jiangxi", "NATO", "Thomas_Edison", "Bangladesh"]) 9 | 10 | rawpage(url) = parsehtml(String(HTTP.get(url).body)).root 11 | 12 | function innerText(dom) 13 | text = IOBuffer() 14 | for elem in PreOrderDFS(dom) 15 | elem isa HTMLText && print(text, elem.text) 16 | end 17 | return String(text) 18 | end 19 | 20 | content(url) = join(innerText.(matchall(sel".mw-parser-output > p", rawpage(url))), "\n") 21 | 22 | cd(@__DIR__) 23 | mkpath("corpus") 24 | 25 | for (lang, ps) in pages 26 | open("corpus/$lang.txt", "w") do io 27 | for p in ps 28 | write(io, content("https://$lang.wikipedia.org/wiki/$p")) 29 | end 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /lang-detection/model.jl: -------------------------------------------------------------------------------- 1 | using Flux 2 | using Flux: onehot, onehotbatch, crossentropy, reset!, throttle 3 | 4 | corpora = Dict() 5 | 6 | cd(@__DIR__) 7 | for file in readdir("corpus") 8 | lang = Symbol(match(r"(.*)\.txt", file).captures[1]) 9 | corpus = split(String(read("corpus/$file")), ".") 10 | corpus = strip.(normalize_string.(corpus, casefold=true, stripmark=true)) 11 | corpus = filter(!isempty, corpus) 12 | corpora[lang] = corpus 13 | end 14 | 15 | langs = collect(keys(corpora)) 16 | alphabet = ['a':'z'; '0':'9'; ' '; '\n'; '_'] 17 | 18 | # See which chars will be represented as "unknown" 19 | unique(filter(x -> x ∉ alphabet, join(vcat(values(corpora)...)))) 20 | 21 | dataset = [(onehotbatch(s, alphabet, '_'), onehot(l, langs)) 22 | for l in langs for s in corpora[l]] |> shuffle 23 | 24 | train, test = dataset[1:end-100], dataset[end-99:end] 25 | 26 | N = 15 27 | 28 | scanner = Chain(Dense(length(alphabet), N, σ), LSTM(N, N)) 29 | encoder = Dense(N, length(langs)) 30 | 31 | function model(x) 32 | state = scanner.(x.data)[end] 33 | reset!(scanner) 34 | softmax(encoder(state)) 35 | end 36 | 37 | loss(x, y) = crossentropy(model(x), y) 38 | 39 | testloss() = mean(loss(t...) for t in test) 40 | 41 | opt = ADAM(params(scanner, encoder)) 42 | evalcb = () -> @show testloss() 43 | 44 | Flux.train!(loss, train, opt, cb = throttle(evalcb, 10)) 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flux Model Zoo 2 | 3 | This repository contains various demonstrations of the [Flux](http://fluxml.github.io/) machine learning library. Any of these may freely be used as a starting point for your own models. 4 | 5 | - **housing** implements the most basic model possible (a linear regression) on the [UCI housing data set](https://archive.ics.uci.edu/ml/machine-learning-databases/housing/). It's bare-bones and illustrates how to build a model from scratch. 6 | - **mnist** classifies digits from the [MNIST data set](https://en.wikipedia.org/wiki/MNIST_database), using a simple multi-layer perceptron and a convolutional network, as well as showing a simple autoencoder. 7 | - **char-rnn** implements a [character-level language model](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). It comes with a Shakespeare dataset but can work with any text. 8 | - **phonemes** implements a [sequence to sequence model with attention](https://arxiv.org/abs/1409.0473), using the [CMU pronouncing dictionary](http://www.speech.cs.cmu.edu/cgi-bin/cmudict) to predict the pronunciations of unknown words. 9 | - **lang-detection** implements a simple sequence-to-classification model, which recognises language (English, Danish etc.) from input characters. 10 | - **treebank** shows a recursive neural network with the Stanford Sentiment Treebank 11 | 12 | Note that these models are best run line-by-line, either in the REPL or Juno. 13 | -------------------------------------------------------------------------------- /char-rnn/char-rnn.jl: -------------------------------------------------------------------------------- 1 | using Flux 2 | using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy 3 | using StatsBase: wsample 4 | using Base.Iterators: partition 5 | 6 | cd(@__DIR__) 7 | 8 | isfile("input.txt") || 9 | download("http://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", 10 | "input.txt") 11 | 12 | text = collect(readstring("input.txt")) 13 | alphabet = [unique(text)..., '_'] 14 | text = map(ch -> onehot(ch, alphabet), text) 15 | stop = onehot('_', alphabet) 16 | 17 | N = length(alphabet) 18 | seqlen = 50 19 | nbatch = 50 20 | 21 | Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen)) 22 | Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen)) 23 | 24 | m = Chain( 25 | LSTM(N, 128), 26 | LSTM(128, 128), 27 | Dense(128, N), 28 | softmax) 29 | 30 | function loss(xs, ys) 31 | l = sum(crossentropy.(m.(xs), ys)) 32 | Flux.truncate!(m) 33 | return l 34 | end 35 | 36 | opt = ADAM(params(m), 0.01) 37 | evalcb = () -> @show loss(Xs[5], Ys[5]) 38 | 39 | Flux.train!(loss, zip(Xs, Ys), opt, 40 | cb = throttle(evalcb, 30)) 41 | 42 | # Sampling 43 | 44 | function sample(m, alphabet, len; temp = 1) 45 | Flux.reset!(m) 46 | buf = IOBuffer() 47 | c = rand(alphabet) 48 | for i = 1:len 49 | write(buf, c) 50 | c = wsample(alphabet, m(onehot(c, alphabet)).data) 51 | end 52 | return String(take!(buf)) 53 | end 54 | 55 | sample(m, alphabet, 1000) |> println 56 | 57 | # evalcb = function () 58 | # @show loss(Xs[5], Ys[5]) 59 | # println(sample(deepcopy(m), alphabet, 100)) 60 | # end 61 | -------------------------------------------------------------------------------- /mnist/autoencoder.jl: -------------------------------------------------------------------------------- 1 | using Flux, Flux.Data.MNIST 2 | using Flux: @epochs, onehotbatch, argmax, mse, throttle 3 | using Base.Iterators: partition 4 | using Juno: @progress 5 | # using CuArrays 6 | 7 | # Encode MNIST images as compressed vectors that can later be decoded back into 8 | # images. 9 | 10 | imgs = MNIST.images() 11 | 12 | # Partition into batches of size 1000 13 | data = [float(hcat(vec.(imgs)...)) for imgs in partition(imgs, 1000)] 14 | data = gpu.(data) 15 | 16 | N = 32 # Size of the encoding 17 | 18 | # You can try to make the encoder/decoder network larger 19 | # Also, the output of encoder is a coding of the given input. 20 | # In this case, the input dimension is 28^2 and the output dimension of 21 | # encoder is 32. This implies that the coding is a compressed representation. 22 | # We can make lossy compression via this `encoder`. 23 | encoder = Dense(28^2, N, relu) |> gpu 24 | decoder = Dense(N, 28^2, relu) |> gpu 25 | 26 | m = Chain(encoder, decoder) 27 | 28 | loss(x) = mse(m(x), x) 29 | 30 | evalcb = throttle(() -> @show(loss(data[1])), 5) 31 | opt = ADAM(params(m)) 32 | 33 | @epochs 10 Flux.train!(loss, zip(data), opt, cb = evalcb) 34 | 35 | # Sample output 36 | 37 | using Images 38 | 39 | img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28)) 40 | 41 | function sample() 42 | # 20 random digits 43 | before = [imgs[i] for i in rand(1:length(imgs), 20)] 44 | # Before and after images 45 | after = img.(map(x -> cpu(m)(float(vec(x))).data, before)) 46 | # Stack them all together 47 | hcat(vcat.(before, after)...) 48 | end 49 | 50 | cd(@__DIR__) 51 | 52 | save("sample.png", sample()) 53 | -------------------------------------------------------------------------------- /phonemes/1-model.jl: -------------------------------------------------------------------------------- 1 | # Based on https://arxiv.org/abs/1409.0473 2 | 3 | using Flux: combine, flip, crossentropy, reset!, throttle 4 | 5 | include("0-data.jl") 6 | 7 | Nin = length(alphabet) 8 | Nh = 30 # size of hidden layer 9 | 10 | # A recurrent model which takes a token and returns a context-dependent 11 | # annotation. 12 | 13 | forward = LSTM(Nin, Nh÷2) 14 | backward = LSTM(Nin, Nh÷2) 15 | encode(tokens) = vcat.(forward.(tokens), flip(backward, tokens)) 16 | 17 | alignnet = Dense(2Nh, 1) 18 | align(s, t) = alignnet(combine(t, s)) 19 | 20 | # A recurrent model which takes a sequence of annotations, attends, and returns 21 | # a predicted output token. 22 | 23 | recur = LSTM(Nh+length(phones), Nh) 24 | toalpha = Dense(Nh, length(phones)) 25 | 26 | function asoftmax(xs) 27 | xs = [exp.(x) for x in xs] 28 | s = sum(xs) 29 | return [x ./ s for x in xs] 30 | end 31 | 32 | function decode1(tokens, phone) 33 | weights = asoftmax([align(recur.state[2], t) for t in tokens]) 34 | context = sum(map((a, b) -> a .* b, weights, tokens)) 35 | y = recur(vcat(phone, context)) 36 | return softmax(toalpha(y)) 37 | end 38 | 39 | decode(tokens, phones) = [decode1(tokens, phone) for phone in phones] 40 | 41 | # The full model 42 | 43 | state = (forward, backward, alignnet, recur, toalpha) 44 | 45 | function model(x, y) 46 | ŷ = decode(encode(x), y) 47 | reset!(state) 48 | return ŷ 49 | end 50 | 51 | loss(x, yo, y) = sum(crossentropy.(model(x, yo), y)) 52 | 53 | evalcb = () -> @show loss(data[500]...) 54 | opt = ADAM(params(state)) 55 | 56 | Flux.train!(loss, data, opt, cb = throttle(evalcb, 10)) 57 | 58 | # Prediction 59 | 60 | using StatsBase: wsample 61 | 62 | function predict(s) 63 | ts = encode(tokenise(s, alphabet)) 64 | ps = Any[:start] 65 | for i = 1:50 66 | dist = decode1(ts, onehot(ps[end], phones)) 67 | next = wsample(phones, Flux.Tracker.value(dist)) 68 | next == :end && break 69 | push!(ps, next) 70 | end 71 | return ps[2:end] 72 | end 73 | 74 | predict("PHYLOGENY") 75 | -------------------------------------------------------------------------------- /mnist/vae.jl: -------------------------------------------------------------------------------- 1 | using Flux, Flux.Data.MNIST 2 | using Flux: throttle, params 3 | using Juno: @progress 4 | 5 | # Extend distributions slightly to have a numerically stable logpdf for `p` close to 1 or 0. 6 | using Distributions 7 | import Distributions: logpdf 8 | logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps()) + (1 - y) * log(1 - b.p + eps()) 9 | 10 | # Load data, binarise it, and partition into mini-batches of M. 11 | X = float.(hcat(vec.(MNIST.images())...)) .> 0.5 12 | N, M = size(X, 2), 100 13 | data = [X[:,i] for i in Iterators.partition(1:N,M)] 14 | 15 | 16 | ################################# Define Model ################################# 17 | 18 | # Latent dimensionality, # hidden units. 19 | Dz, Dh = 5, 500 20 | 21 | # Components of recognition model / "encoder" MLP. 22 | A, μ, logσ = Dense(28^2, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz) 23 | g(X) = (h = A(X); (μ(h), logσ(h))) 24 | z(μ, logσ) = μ + exp(logσ) * randn() 25 | 26 | # Generative model / "decoder" MLP. 27 | f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ)) 28 | 29 | 30 | ####################### Define ways of doing things with the model. ####################### 31 | 32 | # KL-divergence between approximation posterior and N(0, 1) prior. 33 | kl_q_p(μ, logσ) = 0.5 * sum(exp.(2 .* logσ) + μ.^2 - 1 .+ logσ.^2) 34 | 35 | # logp(x|z) - conditional probability of data given latents. 36 | logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x)) 37 | 38 | # Monte Carlo estimator of mean ELBO using M samples. 39 | L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - kl_q_p(μ̂, logσ̂)) / M) 40 | 41 | loss(X) = -L̄(X) + 0.01 * sum(x->sum(x.^2), params(f)) 42 | 43 | # Sample from the learned model. 44 | modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz))))) 45 | 46 | 47 | ################################# Learn Parameters ############################## 48 | 49 | evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30) 50 | opt = ADAM(params(A, μ, logσ, f)) 51 | @progress for i = 1:20 52 | info("Epoch $i") 53 | Flux.train!(loss, zip(data), opt, cb=evalcb) 54 | end 55 | 56 | 57 | ################################# Sample Output ############################## 58 | 59 | using Images 60 | 61 | img(x) = Gray.(reshape(x, 28, 28)) 62 | 63 | cd(@__DIR__) 64 | sample = hcat(img.([modelsample() for i = 1:10])...) 65 | save("sample.png", sample) 66 | -------------------------------------------------------------------------------- /dqn/dqn.jl: -------------------------------------------------------------------------------- 1 | using Reinforce:CartPole, actions, reset!, Episode, finished 2 | import Reinforce.action 3 | using Flux, StatsBase, Plots 4 | 5 | gr() 6 | 7 | #Define custom policy for choosing action 8 | mutable struct CartPolePolicy <: Reinforce.AbstractPolicy end 9 | 10 | #Load game environment 11 | env = CartPole() 12 | 13 | #Parameters 14 | EPISODES = 3000 15 | STATE_SIZE = length(env.state) 16 | ACTION_SIZE = length(actions(env, env.state)) 17 | MEM_SIZE = 2000 18 | BATCH_SIZE = 32 19 | γ = 0.95 # discount rate 20 | ϵ = 1.0 # exploration rate 21 | ϵ_min = 0.01 22 | ϵ_decay = 0.995 23 | η = 0.001 #learning rate 24 | 25 | memory = [] #used to remember past results 26 | 27 | #Model Architecture 28 | model = Chain(Dense(STATE_SIZE, 24, σ), Dense(24, 24, σ), Dense(24, ACTION_SIZE)) 29 | loss(x, y) = Flux.mse(model(x), y) 30 | opt = ADAM(params(model), η) 31 | fit_model(dataset) = Flux.train!(loss, dataset, opt) 32 | 33 | function remember(state, action, reward, next_state, done) 34 | if length(memory) == MEM_SIZE 35 | deleteat!(memory, 1) 36 | end 37 | push!(memory, (state, action, reward, next_state, done)) 38 | end 39 | 40 | function action(policy::CartPolePolicy, reward, state, action) 41 | if rand() <= ϵ 42 | return rand(1:ACTION_SIZE) 43 | end 44 | act_values = model(state) 45 | return Flux.argmax(act_values)[1] # returns action 46 | end 47 | 48 | function replay() 49 | global ϵ 50 | minibatch = sample(memory, BATCH_SIZE, replace = false) 51 | 52 | for (state, action, reward, next_state, done) in minibatch 53 | target = reward 54 | if !done 55 | target += γ * maximum(model(next_state)) 56 | end 57 | target_f = model(state).data 58 | target_f[action, 1] = target 59 | dataset = zip(state, target_f) 60 | fit_model(dataset) 61 | end 62 | 63 | if ϵ > ϵ_min 64 | ϵ *= ϵ_decay 65 | end 66 | end 67 | 68 | #Render the environment 69 | on_step(env::CartPole, niter, sars) = gui(plot(env)) 70 | 71 | function episode!(env, policy = RandomPolicy(); stepfunc = on_step, kw...) 72 | ep = Episode(env, policy) 73 | for sars in ep 74 | stepfunc(ep.env, ep.niter, sars) 75 | state, action, reward, next_state = sars 76 | state = reshape(state, STATE_SIZE, 1) 77 | next_state = reshape(next_state, STATE_SIZE, 1) 78 | done = finished(ep.env, next_state) #check of game is over) 79 | reward = !done ? reward : -1 #Penalty of -10 if game is over 80 | remember(state, action, reward, next_state, done) 81 | end 82 | ep.total_reward 83 | end 84 | 85 | for e=1:EPISODES 86 | reset!(env) 87 | total_reward = episode!(env, CartPolePolicy()) 88 | println("Episode: $e/$EPISODES | Score: $total_reward | ϵ: $ϵ") 89 | if length(memory) >= BATCH_SIZE 90 | replay() 91 | end 92 | end 93 | -------------------------------------------------------------------------------- /dqn/ddqn.jl: -------------------------------------------------------------------------------- 1 | using Reinforce:CartPole, actions, reset!, Episode, finished 2 | import Reinforce.action 3 | using Flux, StatsBase, Plots 4 | 5 | gr() 6 | 7 | #Define custom policy for choosing action 8 | mutable struct CartPolePolicy <: Reinforce.AbstractPolicy end 9 | 10 | #Load game environment 11 | env = CartPole() 12 | 13 | #Parameters 14 | EPISODES = 3000 15 | STATE_SIZE = length(env.state) 16 | ACTION_SIZE = length(actions(env, env.state)) 17 | MEM_SIZE = 2000 18 | BATCH_SIZE = 32 19 | γ = 0.95 # discount rate 20 | ϵ = 1.0 # exploration rate 21 | ϵ_min = 0.01 22 | ϵ_decay = 0.995 23 | η = 0.001 #learning rate 24 | 25 | memory = [] #used to remember past results 26 | 27 | huber_loss(x, y) = mean(sqrt.(1+(x-y).^2)-1) 28 | 29 | #Model Architecture 30 | 31 | model = Chain(Dense(STATE_SIZE, 24, σ), Dense(24, 24, σ), Dense(24, ACTION_SIZE)) 32 | loss(x, y) = huber_loss(model(x), y) 33 | opt = ADAM(params(model), η) 34 | fit_model(dataset) = Flux.train!(loss, dataset, opt) 35 | 36 | #Target model Architecture 37 | target_model = Chain(Dense(STATE_SIZE, 24, σ), Dense(24, 24, σ), Dense(24, ACTION_SIZE)) 38 | 39 | function remember(state, action, reward, next_state, done) 40 | if length(memory) == MEM_SIZE 41 | deleteat!(memory, 1) 42 | end 43 | push!(memory, (state, action, reward, next_state, done)) 44 | end 45 | 46 | function update_target() 47 | for i in eachindex(params(target_model)) 48 | for j in eachindex(params(target_model)[i].data) 49 | params(target_model)[i].data[j] = params(model)[i].data[j] 50 | end 51 | end 52 | end 53 | 54 | function action(policy::CartPolePolicy, reward, state, action) 55 | if rand() <= ϵ 56 | return rand(1:ACTION_SIZE) 57 | end 58 | act_values = model(state) 59 | return Flux.argmax(act_values)[1] # returns action 60 | end 61 | 62 | function replay() 63 | global ϵ 64 | minibatch = sample(memory, BATCH_SIZE, replace = false) 65 | 66 | for (state, action, reward, next_state, done) in minibatch 67 | target = model(state).data 68 | 69 | if done 70 | target[action, 1] = reward 71 | else 72 | a = model(next_state)[:, 1] 73 | t = target_model(next_state)[:, 1] 74 | target[action, 1] = reward + γ * t.data[Flux.argmax(a)] 75 | end 76 | 77 | dataset = zip(state, target) 78 | fit_model(dataset) 79 | end 80 | 81 | if ϵ > ϵ_min 82 | ϵ *= ϵ_decay 83 | end 84 | end 85 | 86 | #Render the environment 87 | on_step(env::CartPole, niter, sars) = gui(plot(env)) 88 | 89 | function episode!(env, policy = RandomPolicy(); stepfunc = on_step, kw...) 90 | ep = Episode(env, policy) 91 | for sars in ep 92 | stepfunc(ep.env, ep.niter, sars) 93 | state, action, reward, next_state = sars 94 | state = reshape(state, STATE_SIZE, 1) 95 | next_state = reshape(next_state, STATE_SIZE, 1) 96 | done = finished(ep.env, next_state) #check of game is over) 97 | reward = !done ? reward : -1 #Penalty of -10 if game is over 98 | remember(state, action, reward, next_state, done) 99 | end 100 | ep.total_reward 101 | end 102 | 103 | update_target() 104 | 105 | for e=1:EPISODES 106 | reset!(env) 107 | total_reward = episode!(env, CartPolePolicy()) 108 | update_target() 109 | println("Episode: $e/$EPISODES | Score: $total_reward | ϵ: $ϵ") 110 | if length(memory) >= BATCH_SIZE 111 | replay() 112 | end 113 | end 114 | --------------------------------------------------------------------------------