Skip to content
Snippets Groups Projects
SpeechRecognitionFSTs.jl 2.89 KiB
Newer Older
# SPDX-License-Identifier: CECILL-2.1

module SpeechRecognitionFSTs

using ..TensorFSTs
using ..TensorFSTs.LinearFSTs

export EmissionMapping, lexiconfst, tokenfst


"""
    lexiconfst(S, encoded_lexicon)

Create FST representation of lexicon.

All leafs of `encoded_lexicon` should be represented as `Integer`.

- `encoded_lexicon` - collection of `(token, [unit1, unit2, …])` pairs
- `S` - FST arc's weight type (e.g. Semiring)
# Example
```juliarepl
julia> lexicon = [1 => [1, 2, 3], 1 => [3,4,1], 2 => [1, 2]]
julia> lexiconfst(LogSemiring{Float32, 1.0}, lexicon)
TensorFST(

)
```
"""
function lexiconfst(S, lexicon)
    fst = nothing
    for (token, prons) in lexicon
        for pron in prons
            isyms = pron
            osyms = vcat([token], ones(Int, length(pron) -1))
            pronfst = linearfst(ones(S, length(pron)), isyms, osyms)
            fst = isnothing(fst) ? pronfst : union(fst, pronfst)
        end
    end
    fst
end


"""
    struct EmissionMapping <: AbstractMatrix{Int}
        numstates::Int
        numtokens::Int
    end

Default emission mapping which assumes the same number of emissions per token.
"""
struct EmissionMapping <: AbstractMatrix{Int}
    numstates::Int
    numtokens::Int
end

Base.size(x::EmissionMapping) = (x.numstates, x.numtokens)
Base.getindex(x::EmissionMapping, i, j) = (j - 1) * x.numstates + i +  1
Base.maximum(x::EmissionMapping) = (x.numstates * x.num_emission) + 1


"""
    tokenfst(S, topo, initweights, finalweights, tokens, mapping)

Create a FST composed of a set of smaller FSTs. Each token share the same FST
topology specified with `topo = [(src, dest, weight), ...]`. `S` is a
`Semiring` of output FST, `tokens` is a list of token IDs, and
`mapping[tokens[i]]` is the output symbol .
"""
function tokenfst(
    S,
    topo,
    initweights,
    finalweights,
    tokens,
    mapping = EmissionMapping(length(topo), length(tokens))
)
    emission_count = 0
    states = Set(Int[])
    arcs, init, final = [], [], []

    for (i, token) in enumerate(tokens)
        offset = length(states)
        for (j, topo_arc) in enumerate(topo)
            emission_count += 1
            src, dest, weight = offset + topo_arc[1], offset + topo_arc[2], topo_arc[3]

            arc = Arc(
                src = src,
                isym = mapping[emission_count],
                osym = j == 1 ? token : 1,
                dest = dest,
                weight = S(weight)
            )

            push!(states, src)
            push!(states, dest)
            push!(arcs, arc)
        end

        for (state, weight) in initweights
            state = offset + state
            push!(states, state)
            push!(init, state => S(weight))
        end

        for (state, weight) in finalweights
            state = offset + state
            push!(states, state)
            push!(final, state => S(weight))
        end
    end

    TensorFST(arcs, init, final)
end

end