Skip to content
Snippets Groups Projects
Verified Commit 0d8fcc5a authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

speechdataset iterate over tuple of recording, annotation

parent 0b47f281
No related branches found
No related tags found
No related merge requests found
......@@ -4,14 +4,12 @@ struct SpeechDataset <: MLUtils.AbstractDataContainer
idxs::Vector{AbstractString}
annotations::Dict{AbstractString, Annotation}
recordings::Dict{AbstractString, Recording}
partition::Symbol
end
"""
dataset(manifestroot, partition)
dataset(manifestroot)
Load `SpeechDataset` from manifest files stored in `manifestroot`.
Partition is specified by `partition`, e.g. `:train`, `:test`.
Each item of the dataset is a nested tuple `((samples, sampling_rate), Annotation.data)`.
......@@ -32,25 +30,20 @@ julia> ds[1]
)
```
"""
function dataset(manifestroot, partition)
function dataset(manifestroot::AbstractString, partition)
annot_path = joinpath(manifestroot, "annotations-$(partition).jsonl")
rec_path = joinpath(manifestroot, "recordings.jsonl")
annotations = load(Annotation, annot_path)
recordings = load(Recording, rec_path)
dataset(annotations, recordings, partition)
dataset(annotations, recordings)
end
function dataset(annotations, recordings, partition)
function dataset(annotations::AbstractDict, recordings::AbstractDict)
idxs = collect(keys(annotations))
SpeechDataset(idxs, annotations, recordings, Symbol(partition))
SpeechDataset(idxs, annotations, recordings)
end
function Base.getindex(d::SpeechDataset, key::AbstractString)
ann = d.annotations[key]
rec = d.recordings[ann.recording_id]
samples, sr = load(rec, ann)
(samples=samples, sampling_rate=sr), ann.data
end
Base.getindex(d::SpeechDataset, key::AbstractString) = d.recordings[key], d.annotations[key]
Base.getindex(d::SpeechDataset, idx::Integer) = getindex(d, d.idxs[idx])
# Fix1 -> partial funcion with fixed 1st argument
Base.getindex(d::SpeechDataset, idxs::AbstractVector) = map(Base.Fix1(getindex, d), idxs)
......@@ -58,9 +51,19 @@ Base.getindex(d::SpeechDataset, idxs::AbstractVector) = map(Base.Fix1(getindex,
Base.length(d::SpeechDataset) = length(d.idxs)
function Base.filter(fn, d::SpeechDataset)
fidxs = filter(fn, d.idxs)
fannotations = filter(k_v -> fn(first(k_v)), d.annotations)
frecs = filter(k_v -> fn(first(k_v)), d.recordings)
SpeechDataset(fidxs, fannotations, frecs, :custom)
fidxs = filter(d.idxs) do i
fn((d.recordings[i], d.annotations[i]))
end
idset = Set(fidxs)
fannotations = filter(d.annotations) do (k, v)
k idset
end
frecs = filter(d.recordings) do (k, v)
k idset
end
SpeechDataset(fidxs, fannotations, frecs)
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment