diff --git a/text/seq2seq/Manifest.toml b/text/seq2seq/Manifest.toml new file mode 100644 index 000000000..8b63dfb1d --- /dev/null +++ b/text/seq2seq/Manifest.toml @@ -0,0 +1,278 @@ +# This file is machine-generated - editing it directly is not advised + +[[AbstractTrees]] +deps = ["Markdown", "Test"] +git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.2.1" + +[[Adapt]] +deps = ["LinearAlgebra", "Test"] +git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "0.4.2" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BinDeps]] +deps = ["Compat", "Libdl", "SHA", "URIParser"] +git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9" +uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" +version = "0.8.10" + +[[BinaryProvider]] +deps = ["Libdl", "Pkg", "SHA", "Test"] +git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.3" + +[[CodecZlib]] +deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] +git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.5.1" + +[[ColorTypes]] +deps = ["FixedPointNumbers", "Random", "Test"] +git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.7.5" + +[[Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"] +git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.9.5" + +[[CommonSubexpressions]] +deps = ["Test"] +git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.2.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "1.5.1" + +[[DataStructures]] +deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] +git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.15.0" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["Compat", "StaticArrays"] +git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "0.0.4" + +[[DiffRules]] +deps = ["Random", "Test"] +git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "0.0.10" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[FixedPointNumbers]] +deps = ["Test"] +git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.5.3" + +[[Flux]] +deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"] +git-tree-sha1 = "28e6dbf663fed71ea607414bc5f2f099d2831c0c" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.7.3" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] +git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.3" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[Juno]] +deps = ["Base64", "Logging", "Media", "Profile", "Test"] +git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8" +uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" +version = "0.5.4" + +[[LibGit2]] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MacroTools]] +deps = ["Compat"] +git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.4.5" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Media]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" +uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" +version = "0.5.0" + +[[Missings]] +deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] +git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "0.4.0" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[NNlib]] +deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] +git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.4.3" + +[[NaNMath]] +deps = ["Compat"] +git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.2" + +[[OrderedCollections]] +deps = ["Random", "Serialization", "Test"] +git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.0.2" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + +[[Requires]] +deps = ["Test"] +git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "0.5.2" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures", "Random", "Test"] +git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "0.3.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"] +git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.7.2" + +[[StaticArrays]] +deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] +git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.10.2" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsBase]] +deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] +git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.27.0" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[TranscodingStreams]] +deps = ["Pkg", "Random", "Test"] +git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.8.1" + +[[URIParser]] +deps = ["Test", "Unicode"] +git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" +uuid = "30578b45-9adc-5946-b283-645ec420af67" +version = "0.4.0" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZipFile]] +deps = ["BinaryProvider", "Libdl", "Printf", "Test"] +git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.8.0" diff --git a/text/seq2seq/Project.toml b/text/seq2seq/Project.toml new file mode 100644 index 000000000..77df42abf --- /dev/null +++ b/text/seq2seq/Project.toml @@ -0,0 +1,2 @@ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/text/seq2seq/src/images/batching.png b/text/seq2seq/src/images/batching.png new file mode 100644 index 000000000..dbad1c5e1 Binary files /dev/null and b/text/seq2seq/src/images/batching.png differ diff --git a/text/seq2seq/src/images/graph.png b/text/seq2seq/src/images/graph.png new file mode 100644 index 000000000..1423db2e8 Binary files /dev/null and b/text/seq2seq/src/images/graph.png differ diff --git a/text/seq2seq/src/seq2seq translation.jl b/text/seq2seq/src/seq2seq translation.jl new file mode 100644 index 000000000..e14780c74 --- /dev/null +++ b/text/seq2seq/src/seq2seq translation.jl @@ -0,0 +1,377 @@ +# # Seq2seq translation in Flux +# In this notebook, I share the code I wrote to make a seq2seq nmt-model (neural machine translation model) to translate simple english sentences to french. +# The code is written in Julia using Flux. +# +# ### note: +# For some reason, when I train the model for some epochs, I get gibberish results. +# +# |Input (Eng)|Prediction (*Fr*)|Expected (Fr)| +# | - | - | - | +# |"You are too skinny"|"Vous êtes ' que . . . . . . . ."| "Vous êtes trop maigre" | +# |"He is painting a picture"|"Il est est de . . . . . . . ."|"Il est en train de peindre un tableau"| +# | ... | ... | ... | +# If you know what I'm doing wrong, please do let me know! + +# # The data +# The english-french sentence pairs dataset is found on this website: http://www.manythings.org/anki/fra-eng.zip. +# For the data preparation, I mainly follow the official Pytorch tutorial on seq2seq models: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html. +# +# We create a `Lang` struct which holds two dictionaries to convert words to indices and back. Every `Lang` instance gets instantiated with a SOS-(start of sentence), EOS(end of sentence)-, UNK(unknown word) and PAD(padding)-token. +# Padding is necessary because we will be training in batches of differently sized sentences. +# +# +# *Since the data is relatively small (a lot of sentences get filtered out), we keep all words instead of discarding scarcely used words. +# This means the `UNK` token will not be used.* +# +# The function `readLangs` takes the text file, splits up the sentences (which are tab-delimited) and intantiates two new languages (lang1 and lang2). and assigns them to two newly created languages. + +using CuArrays, Flux, Statistics, Random + +FILE = "D:/downloads/fra-eng/eng-fra.txt" + +mutable struct Lang + name + word2index + word2count + index2word + n_words +end + +Lang(name) = Lang( + name, + Dict{String, Int}(), + Dict{String, Int}(), + Dict{Int, String}(1=>"SOS", 2=>"EOS", 3=>"UNK", 4=>"PAD"), + 4) + +function (l::Lang)(sentence::String) + for word in split(sentence, " ") + if word ∉ keys(l.word2index) + l.word2index[word] = l.n_words + 1 + l.word2count[word] = 1 + l.index2word[l.n_words + 1] = word + l.n_words += 1 + else + l.word2count[word] += 1 + end + end +end + +function normalizeString(s) + s = strip(lowercase(s)) + s = replace(s, r"([.!?,])"=>s" \1") + s = replace(s, "'"=>" ' ") + return s +end + +function readLangs(lang1, lang2; rev=false) + println("Reading lines...") + lines = readlines(FILE) + pairs = [normalizeString.(pair) for pair in split.(lines, "\t")] + if rev + pairs = reverse.(pairs) + input_lang = Lang(lang2) + output_lang = Lang(lang1) + else + input_lang = Lang(lang1) + output_lang = Lang(lang2) + end + return(input_lang, output_lang, pairs) +end + +# As suggested in the Pytorch tutorial, we create a function to filter out sentences that don't start with `english_prefixes` ("i am", "i'm", "you are"...), as well as sentences that exceed the `MAX_LENGTH` (which is set to 10). +# +# The function `prepareData` takes the names of two languages and creates these language instances as well as the sentence pairs by calling `readLangs`. +# After the sentence pairs get filtered (with `filterPair`), every unique word in a sentence get's added to the corresponding language's vocabulary (`word2index`, `index2word`, `n_words`) while every additional instance of a word increments `n_words` by 1. +# +# Sentences from the input language are added to `xs`, target sentences are added to `ys`. Finally, inputs (`xs`) and targets (`ys`) are shuffled. + +MAX_LENGTH = 10 + +eng_prefixes = [ + "i am ", "i ' m ", + "he is ", "he ' s ", + "she is ", "she ' s ", + "you are ", "you ' re ", + "we are ", "we ' re ", + "they are ", "they ' re "] + +function filterPair(pair) + return(false ∉ (length.(split.(pair, " ")) .<= MAX_LENGTH) && true ∈ (startswith.(pair[1], eng_prefixes))) +end + +function prepareData(lang1, lang2; rev=false) + input_lang, output_lang, pairs = readLangs(lang1, lang2; rev=rev) + println("Read $(length(pairs)) sentence pairs.") + pairs = [pair for pair in pairs if filterPair(pair)] + println("Trimmed to $(length(pairs)) sentence pairs.\n") + xs = [] + ys = [] + for pair in pairs + push!(xs, pair[1]) + push!(ys, pair[2]) + end + println("Counting words...") + for pair in pairs + input_lang(pair[2]) + output_lang(pair[1]) + end + println("Counted words:") + println("• ", input_lang.name, ": ", input_lang.n_words) + println("• ", output_lang.name, ": ", output_lang.n_words) + return(input_lang, output_lang, xs, ys) +end + +fr, eng, xs, ys = prepareData("fr", "eng") +indices = shuffle([1:length(xs)...]) +xs = xs[indices] +ys = ys[indices]; + +# The function `indexesFromSentence` takes a language's `word2index` and maps all the words in a sentence to a index, later this index will get used to get the word's embedding. Note that, at the end of every sentence, the `EOS`-index (2) gets added, this is for the model to know when to stop predicting during inference. +# +# To make batches for mini-batch training, the data (`[indexesFromSentence.([eng], xs), indexesFromSentence.([fr], ys)]`) gets split in chunks of `BATCH_SIZE`. Since sentences in a chunk often have different lengths, the `PAD`-index (4), gets added to the end of sentences to make them as long as the longest sentence of the chunk. +# +# To be able to easily pass a chunk to an RNN, the nth word of every sentence in the chunk get placed next to each other in an array. Also, all the words get OneHot encoded. +# +# ![batching](./images/batching.png) + +BATCH_SIZE = 32 + +indexesFromSentence(lang, sentence) = append!(get.(Ref(lang.word2index), split(lowercase(sentence), " "), 3), 2) + +function batch(data, batch_size, voc_size; gpu=true) + chunks = Iterators.partition(data, batch_size) + batches = [] + for chunk in chunks + max_length = maximum(length.(chunk)) + chunk = map(sentence->append!(sentence, fill(4, max_length-length(sentence))), chunk) + chunk = hcat(reshape.(chunk, :, 1)...) + batch = [] + for i in 1:size(chunk, 1) + if gpu + push!(batch, cu(Flux.onehotbatch(chunk[i, :], [1:voc_size...]))) + else + push!(batch, Flux.onehotbatch(chunk[i, :], [1:voc_size...])) + end + end + push!(batches, batch) + end + return(batches) +end + +x, y = batch.([indexesFromSentence.([eng], xs), indexesFromSentence.([fr], ys)], [BATCH_SIZE], [eng.n_words, fr.n_words]; gpu=true); + +# # The Model +# +# For the model, we're using a **encoder-decoder** architecture. +# ![encoder-decoder](https://smerity.com/media/images/articles/2016/gnmt_arch_attn.svg) +# *image source: https://smerity.com/articles/2016/google_nmt_arch.html* +# +# ### High level overview +# The **encoder** takes the OneHot-encoded words and uses the embedding layer to get their embedding, a multidimensional-representation of that word. Next, the words get passed through a RNN (in our case a GRU). For each word, the RNN spits out a state-vector (encoder-outputs). +# +# The job of the **decoder** is to take the output of the encoder and mold it into a correct translation of the original sentence. The **attention** layer acts as a guide for the decoder. Every timestep (every time the decoder is to predict a word), it takes all the encoder-outputs and creates **one** state vector (the context vector) with the most relevant information for that particular timestep. + +## some constants to be used for the model: +EMB_size = 128 +HIDDEN = 128 +LEARNING_RATE = 0.005 +DROPOUT = 0.2; + +# For the encoder, we're using a bidirectional GRU, the input is read from front to back as well as from back to front. This should help for a more robust `encoder_output`. +# The `Flux.@treelike` macro makes sure all the parameters are recognized by the optimizer to optimise the values. + +struct Encoder + embedding + linear + rnn + out +end +Encoder(voc_size::Integer; h_size::Integer=HIDDEN) = Encoder( + param(Flux.glorot_uniform(EMB_size, voc_size)), + Dense(EMB_size, HIDDEN, relu), + GRU(h_size, h_size), + Dense(h_size, h_size)) +function (e::Encoder)(x; dropout=0) + x = map(x->Dropout(dropout)(e.embedding*x), x) + x = e.linear.(x) + enc_outputs = e.rnn.(x) + h = e.out(enc_outputs[end]) + return(enc_outputs, h) +end +Flux.@treelike Encoder + +# The decoder takes the word it predicted in the previous timestep as well the `encoder_outputs`. The context vector gets created by passing these `encoder_outputs` as well as the current state of the decoder's RNN to the attention layer. Finally, the context vector is concatenated with the word of the previous timestep to predict the word of the current timestep. +# +# *During the first timestep, the decoder doesn't have acces to a previously predicted word. To combat this, a `SOS`-token is provided* + +struct Decoder + embedding + linear + attention + rnn + output +end +Decoder(h_size, voc_size) = Decoder( + param(Flux.glorot_uniform(EMB_size, voc_size)), + Dense(EMB_size, HIDDEN), + Attention(h_size), + GRU(h_size*2, h_size), + Dense(h_size, voc_size, relu)) +function (d::Decoder)(x, enc_outputs; dropout=0) + x = d.embedding * x + x = Dropout(dropout)(x) + x = d.linear(x) + decoder_state = d.rnn.state + context = d.attention(enc_outputs, decoder_state) + x = d.rnn([x; context]) + x = softmax(d.output(x)) + return(x) +end +Flux.@treelike Decoder + +# For the attention mechanism, we follow the implementation from the paper "Grammar as a Foreign Language" (https://arxiv.org/pdf/1412.7449.pdf). +# +# Esentially, the encoder outputs and the hidden state of the decoder are used to a context vector which contains all the necessary information to decode into a translation during a particular timestep. +# The paper shows the following equations: +# +# $ u_i^t = v^T tanh(W_1'h_i+W_2'd_t) $ +# +# $ a_i^t = softmax(u_i^t) $ +# +# $ \sum\limits_{i=1}^{T_a} a_i^t h_i$ +# +# Where the encoder hidden states are denoted `(h1, . . . , hTA )` and we denote the hidden states of the decoder by `(d1, . . . , dTB )` + +struct Attention + W1 + W2 + v +end +Attention(h_size) = Attention( + Dense(h_size, h_size), + Dense(h_size, h_size), + param(Flux.glorot_uniform(1, h_size))) +function (a::Attention)(enc_outputs, d) + U = [a.v*tanh.(x) for x in a.W1.(enc_outputs).+[a.W2(d)]] + A = softmax(vcat(U...)) + out = sum([gpu(collect(A[i, :]')) .* h for (i, h) in enumerate(enc_outputs)]) +end +Flux.@treelike Attention + +testEncoder = Encoder(eng.n_words)|>gpu +testDecoder = Decoder(HIDDEN, fr.n_words)|>gpu; + +# The model function is made to return the loss when the input and the target are provided. +# The hidden states of the RNN from both the encoder as well as the decoder are reset, by doing this you make sure no information of previous sentences is remembered. +# +# The encoder_ouputs are made by passing the input through the encoder, the initial decoder input is made and the decoder's rnn state is initialized with the last encoder output. +# The decoder has to predict `max_length` words with `max_length` being the length of the longes sentence. +# +# First off, the model decides whether teacher forcing will be used this timestep. Teacher forcing means instead of using the decoder output as the next timestep's decoder input, the correct input is used. Teacher forcing is especially useful in the beginning of training since decoder outputs won't make sense. +# +# Every timestep, the decoder's prediction as well as the correct target are passed to a loss function. All the losses of all timesteps are summed up and returned. + +function model(encoder::Encoder, decoder::Decoder, x, y; teacher_forcing = 0.5, dropout=DROPOUT, voc_size=fr.n_words) + total_loss = 0 + max_length = length(y) + batch_size = size(x[1], 2) + Flux.reset!.([encoder, decoder]) + enc_outputs, h = encoder(x; dropout=dropout) + decoder_input = Flux.onehotbatch(ones(batch_size), [1:voc_size...]) + decoder.rnn.state = h + for i in 1:max_length + use_teacher_forcing = rand() < teacher_forcing + decoder_output = decoder(decoder_input, enc_outputs; dropout=dropout) + total_loss += loss(decoder_output, y[i]) + if use_teacher_forcing + decoder_input = y[i] + else + decoder_input = Flux.onehotbatch(Flux.onecold(decoder_output.data), [1:voc_size...]) + end + end + return(total_loss) +end + +model(x, y) = model(testEncoder, testDecoder, x, y; dropout = DROPOUT) + +# When the target is not provided to the `model` function, the model returns a prediction instead of a loss value. +# +# +# *Note that, when the model is trained, the loop could be set to run indefinitely because the loop will break when an `EOS`-token is predicted. +# I've set the loop to run for an arbitrary amount of timesteps (in this case 12) because the model doesn't seem to be able to learn to predict an `EOS token`* + +function model(encoder::Encoder, decoder::Decoder, x; reset=true, voc_size=fr.n_words) + result = [] + if reset Flux.reset!.([encoder, decoder]) end + enc_outputs, h = encoder(x) + decoder_input = Flux.onehot(1, [1:voc_size...]) + decoder.rnn.state = h + for i in 1:12 + decoder_output = Flux.onecold(decoder(decoder_input, enc_outputs)) + if decoder_output[1] == 2 break end + push!(result, decoder_output...) + end + return(result) +end + +# The `loss` function expects a probability distribution over all possible words in the vocabulary, this gets accounted for by the softmax layer in the decoder. The loss function itself is crossentropy (a.k.a. negative-log-likelihood). +# We pass an vector of ones, except for the `PAD`-index (4) as weight to the loss function. This way the model will disregard any predictions that should have been PAD, since padding only occurs after the sentence has ended. +# +# +# For the optimizer, we use ADAM. + +lossmask = ones(fr.n_words)|>gpu +lossmask[4] = 0 + +loss(logits, target) = Flux.crossentropy(logits, target; weight=lossmask) + +opt = ADAM(LEARNING_RATE) +ps = params(testEncoder, testDecoder) + +# The data (`x` and `y`) gets passed to `partitionTrainTest` to split the data in a train and a test set. +# +# Finally the model is trained for a number of epochs. Every epoch, the loss on the test set gets printed. + +function partitionTrainTest(x, y, at) + n = length(x) + idx = shuffle(1:n) + train_idx = view(idx, 1:floor(Int, at*n)) + test_idx = view(idx, (floor(Int, at*n)+1):n) + train_x, test_x = x[train_idx,:], x[test_idx,:] + train_y, test_y = y[train_idx,:], y[test_idx,:] + return(train_x, train_y, test_x, test_y) +end + +train_x, train_y, test_x, test_y = partitionTrainTest(x, y, 0.90); + +EPOCHS = 5 + +for i in 1:EPOCHS + Flux.train!(model, ps, zip(train_x, train_y), opt) + println("loss: ", mean(model.(test_x, test_y)).data) +end + +# The `predict` function takes an encoder, decoder and an english sentence. It converts the sentence into it's OneHot representation and passes it to the `model` function. The output gets converted back to a string and returned. + +function predict(encoder, decoder, sentence::String) + sentence = normalizeString(sentence) + input = append!(get.(Ref(eng.word2index), split(lowercase(sentence), " "), 3), 2) + input = [Flux.onehot(word, [1:eng.n_words...]) for word in input] + output = model(encoder, decoder, input) + output = get.(Ref(fr.index2word), output, "UNK") + println(output) +end + +predict(testEncoder, testDecoder, "she's doing her thing") +predict(testEncoder, testDecoder, "you're too skinny") +predict(testEncoder, testDecoder, "He is singing") + +# As you can see, when I run the model for 70 epochs, the results are quite underwhelming... Even though sentence subjects are mostly correct, most part of the translation does not make sense. +# +# If you have a look at the loss on the test set during these 70 epochs, you can clearly see the model seems to hit a barrier around 18. +# +# I'm not sure why this is happening and I'd love to know! If you've got an idea on how to improve/fix this model, definitely let me know. +# +# Thanks +# ![encoder-decoder](./images/graph.png)