Skip to content
Snippets Groups Projects
dataset.jl 2.49 KiB
Newer Older
# SPDX-License-Identifier: CECILL-C
Martin Kocour's avatar
Martin Kocour committed

const corpora_file = joinpath(@__DIR__, "corpora", "corpora.json")

@kwdef struct SpeechDatasetInfos
    name::AbstractString = ""
    lang::Union{AbstractString, Vector{AbstractString}} = ""
    license::AbstractString = ""
    source::AbstractString = ""
    authors::Vector{AbstractString} = []
    description::AbstractString = ""
end

function SpeechDatasetInfos(infos::AbstractDict)
    kwargs = NamedTuple()
    for key in fieldnames(SpeechDatasetInfos)
        val = get(infos, String(key), nothing)
        # merge new (key=val) if key was found
        kwargs = !isnothing(val) ? (; kwargs..., key=>val) : kwargs
    end
    SpeechDatasetInfos(kwargs...)
end

function SpeechDatasetInfos(name::AbstractString)
    corpora_infos = JSON.parsefile(corpora_file)
    infos = filter(x -> x["name"]==name, corpora_infos)[1]
    SpeechDatasetInfos(infos)
end

Martin Kocour's avatar
Martin Kocour committed
struct SpeechDataset <: MLUtils.AbstractDataContainer
Martin Kocour's avatar
Martin Kocour committed
    idxs::Vector{AbstractString}
    annotations::Dict{AbstractString, Annotation}
Martin Kocour's avatar
Martin Kocour committed
    recordings::Dict{AbstractString, Recording}
end

function SpeechDataset(infos::SpeechDatasetInfos, annotations::Dict{AbstractString, Annotation}, recordings::Dict{AbstractString, Recording})
    idxs = collect(keys(annotations))
    SpeechDataset(infos, idxs, annotations, recordings)
end
Martin Kocour's avatar
Martin Kocour committed

function SpeechDataset(infos::SpeechDatasetInfos, manifestroot::AbstractString, partition::AbstractString)
    partition_name = partition == "" ? "" : "-$(partition)"
    annot_path =  joinpath(manifestroot, "annotations$(partition_name).jsonl") 
    rec_path = joinpath(manifestroot, "recordings.jsonl")
    annotations = load_manifest(Annotation, annot_path)
    recordings = load_manifest(Recording, rec_path)
    SpeechDataset(infos, annotations, recordings)
Martin Kocour's avatar
Martin Kocour committed
end

Base.getindex(d::SpeechDataset, key::AbstractString) = d.recordings[key], d.annotations[key]
Martin Kocour's avatar
Martin Kocour committed
Base.getindex(d::SpeechDataset, idx::Integer) = getindex(d, d.idxs[idx])
# Fix1 -> partial function with fixed 1st argument
Martin Kocour's avatar
Martin Kocour committed
Base.getindex(d::SpeechDataset, idxs::AbstractVector) = map(Base.Fix1(getindex, d), idxs)

Base.length(d::SpeechDataset) = length(d.idxs)
Lucas Ondel Yang's avatar
Lucas Ondel Yang committed
function Base.filter(fn, d::SpeechDataset)
    fidxs = filter(d.idxs) do i
        fn((d.recordings[i], d.annotations[i]))
    end
    idset = Set(fidxs)

    fannotations = filter(d.annotations) do (k, v)
        k  idset
    end

    frecs = filter(d.recordings) do (k, v)
        k  idset
    end

    SpeechDataset(d.infos, fidxs, fannotations, frecs)
Lucas Ondel Yang's avatar
Lucas Ondel Yang committed
end