Skip to content
Snippets Groups Projects
dataset.jl 1.69 KiB
Newer Older
Martin Kocour's avatar
Martin Kocour committed
# SPDX-License-Identifier: CECILL-2.1

struct SpeechDataset <: MLUtils.AbstractDataContainer
    idxs::Vector{AbstractString}
    supervisions::Dict{AbstractString, Supervision}
Martin Kocour's avatar
Martin Kocour committed
    recordings::Dict{AbstractString, Recording}
    partition::Symbol
end

"""
Martin Kocour's avatar
Martin Kocour committed
dataset(manifestroot, partition)
Martin Kocour's avatar
Martin Kocour committed

Load `SpeechDataset` from manifest files stored in `manifestroot`.
Martin Kocour's avatar
Martin Kocour committed
Partition is specified by `partition`, e.g. `:train`, `:test`.
Martin Kocour's avatar
Martin Kocour committed

Each item of the dataset is a nested tuple `((samples, sampling_rate), Supervision.data)`.

See also [`Supervision`](@ref).

# Examples
```julia-repl
julia> ds = dataset("./manifests", :train)
SpeechDataset(
    ...
)
julia> ds[1]
(
    (samples=[...], sampling_rate=16_000),
    Dict(
        "text" => "Supervision text here"
    )
)
```
"""
Martin Kocour's avatar
Martin Kocour committed
function dataset(manifestroot::AbstractString, partition)
    sup_path = joinpath(manifestroot, "supervisions-$(subset).jsonl")
    rec_path = joinpath(manifestroot, "recordings.jsonl")
    supervisions = load(Supervision, sup_path)
    recordings = load(Recording, rec_path)
Martin Kocour's avatar
Martin Kocour committed
    dataset(supervisions, recordings)
end

function dataset(supervisions, recordings, partition)
Martin Kocour's avatar
Martin Kocour committed
    idxs = collect(keys(supervisions))
Martin Kocour's avatar
Martin Kocour committed
    SpeechDataset(idxs, supervisions, recordings, Symbol(partition))
Martin Kocour's avatar
Martin Kocour committed
end

function Base.getindex(d::SpeechDataset, key::AbstractString)
    sup = d.supervisions[key]
Martin Kocour's avatar
Martin Kocour committed
    rec = d.recordings[sup.recording_id]
    samples, sr = load(rec, sup)
    (samples=samples, sampling_rate=sr), sup.data
end
Base.getindex(d::SpeechDataset, idx::Integer) = getindex(d, d.idxs[idx])
# Fix1 -> partial funcion with fixed 1st argument
Base.getindex(d::SpeechDataset, idxs::AbstractVector) = map(Base.Fix1(getindex, d), idxs)

Base.length(d::SpeechDataset) = length(d.idxs)