Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow type specification for hdf5 embeddings loading #10

Merged
merged 1 commit into from
Dec 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/files.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ function load_embeddings(filepath::AbstractString;
conceptnet = _load_hdf5_embeddings(filepath,
max_vocab_size,
keep_words,
languages=languages)
languages=languages,
data_type=data_type)
else
conceptnet = _load_gz_embeddings(filepath,
Noop(),
Expand Down Expand Up @@ -125,11 +126,10 @@ Load the ConceptNetNumberbatch embeddings from a HDF5 file.
function _load_hdf5_embeddings(filepath::S1,
max_vocab_size::Union{Nothing,Int},
keep_words::Vector{S2};
languages::Union{Nothing,
Languages.Language,
Vector{<:Languages.Language}
}=nothing) where
{S1<:AbstractString, S2<:AbstractString}
languages::Union{Nothing, Languages.Language,
Vector{<:Languages.Language}}=nothing,
data_type::Type{E}=Int8) where
{S1<:AbstractString, S2<:AbstractString, E<:Real}
local fuzzy_words
type_word = String
payload = h5open(read, filepath)["mat"]
Expand All @@ -142,7 +142,7 @@ function _load_hdf5_embeddings(filepath::S1,
max_vocab_size,
keep_words)
lang_embs, languages, type_lang, _ =
process_language_argument(languages, type_word, Int8)
process_language_argument(languages, type_word, E)
fuzzy_words = Dict{type_lang, Vector{type_word}}()
no_custom_words = length(keep_words)==0
cnt = 0
Expand All @@ -151,7 +151,7 @@ function _load_hdf5_embeddings(filepath::S1,
if haskey(LANGUAGES, lang) && LANGUAGES[lang] in languages # use only languages mapped in LANGUAGES
_llang = LANGUAGES[lang]
if !haskey(lang_embs, _llang)
push!(lang_embs, _llang=>Dict{type_word, Vector{Int8}}())
push!(lang_embs, _llang=>Dict{type_word, Vector{E}}())
push!(fuzzy_words, _llang=>type_word[])
end
occursin("#", word) && push!(fuzzy_words[_llang], word)
Expand All @@ -163,7 +163,7 @@ function _load_hdf5_embeddings(filepath::S1,
end
end
end
return ConceptNet{type_lang, type_word, Int8}(lang_embs, size(embeddings,1), fuzzy_words)
return ConceptNet{type_lang, type_word, E}(lang_embs, size(embeddings,1), fuzzy_words)
end


Expand Down
34 changes: 21 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ using Languages
using ConceptnetNumberbatch

# Test file with just 2 entriesa (test purposes only)
const DATA_TYPE = Float64
const CONCEPTNET_TEST_DATA = Dict( # filename => output type
const DATA_TYPE = Float32
const CONCEPTNET_TEST_DATA = Dict(
# filename => output type
(joinpath(string(@__DIR__), "data", "_test_file_en.txt.gz") =>
([Languages.English()],
["####_ish", "####_form", "####_metres"],
Expand All @@ -17,26 +18,29 @@ const CONCEPTNET_TEST_DATA = Dict( # filename => output type

(joinpath(string(@__DIR__), "data", "_test_file.txt") =>
(nothing,
["1_konings", "aaklig", "aak"],
["1_konings", "aaklig", "aak"],
ConceptNet{Languages.Language, String, DATA_TYPE})),

(joinpath(string(@__DIR__), "data", "_test_file.h5") =>
(nothing,
["1", "2", "2d"],
ConceptNet{Languages.Language, String, Int8}))
)
ConceptNet{Languages.Language, String, DATA_TYPE}))
)

@testset "Parser: (no arguments)" begin
for (filename, (languages, _, resulting_type)) in CONCEPTNET_TEST_DATA
conceptnet = load_embeddings(filename, languages=languages);
conceptnet = load_embeddings(filename,
languages=languages,
data_type=DATA_TYPE);
@test conceptnet isa resulting_type
end
end

max_vocab_size=5
@testset "Parser: max_vocab_size=5" begin
for (filename, (languages, _, _)) in CONCEPTNET_TEST_DATA
conceptnet = load_embeddings(filename, max_vocab_size=max_vocab_size,
conceptnet = load_embeddings(filename,
max_vocab_size=max_vocab_size,
languages=languages);
@test length(conceptnet) == max_vocab_size
end
Expand All @@ -45,8 +49,10 @@ end
max_vocab_size=5
@testset "Parser: max_vocab_size=5, 3 keep words" begin
for (filename, (languages, keep_words, _)) in CONCEPTNET_TEST_DATA
conceptnet = load_embeddings(filename, max_vocab_size=max_vocab_size,
keep_words=keep_words, languages=languages)
conceptnet = load_embeddings(filename,
max_vocab_size=max_vocab_size,
keep_words=keep_words,
languages=languages)
@test length(conceptnet) == length(keep_words)
for word in keep_words
@test word in conceptnet
Expand All @@ -63,7 +69,7 @@ end
# Test indexing
idx = 1
@test conceptnet[words[idx]] == conceptnet[:en, words[idx]] ==
conceptnet[Languages.English(), words[idx]]
conceptnet[Languages.English(), words[idx]]

# Test values
embeddings = conceptnet[words]
Expand All @@ -84,14 +90,14 @@ end
@test_throws MethodError conceptnet[words] # type of language is Language, cannot directly search
@test_throws KeyError conceptnet[:en, "word"] # English language not present
@test conceptnet[:nl, words[idx]] ==
conceptnet[Languages.Dutch(), words[idx]]
conceptnet[Languages.Dutch(), words[idx]]

# Test values
for (idx, word) in enumerate(words)
@test_throws KeyError conceptnet[Languages.English(), word]
if word in conceptnet
@test vec(conceptnet[Languages.Dutch(), word]) ==
conceptnet.embeddings[Languages.Dutch()][word]
conceptnet.embeddings[Languages.Dutch()][word]
else
@test iszero(conceptnet[Languages.Dutch(),word])
end
Expand All @@ -112,7 +118,9 @@ end

@testset "Document Embedding" begin
filepath = joinpath(string(@__DIR__), "data", "_test_file_en.txt.gz")
conceptnet = load_embeddings(filepath, languages=[Languages.English()])
conceptnet = load_embeddings(filepath,
languages=[Languages.English()],
data_type=DATA_TYPE)
# Document with no matchable words
doc = "a aaaaa b"
embedded_doc, missed = embed_document(conceptnet,
Expand Down