# SPDX-License-Identifier: CECILL-2.1

module SpeechRecognitionFSTs

using ..TensorFSTs
using ..TensorFSTs.LinearFSTs

export EmissionMapping, lexiconfst, tokenfst

"""
    struct EmissionMapping <: AbstractMatrix{Int}
        numstates::Int
        numtokens::Int
    end

Default emission mapping which assumes the same number of emissions per token.
"""
struct EmissionMapping <: AbstractMatrix{Int}
    numpdfs::Int
    numtokens::Int
end

Base.size(x::EmissionMapping) = (x.numpdfs, x.numtokens)
Base.getindex(x::EmissionMapping, i, j) = (j - 1) * x.numpdfs + i
Base.minimum(x::EmissionMapping) = 1
Base.maximum(x::EmissionMapping) = (x.numpdfs * x.numtokens)



"""
    tokenfst(S, topo, initweights, finalweights, tokens, mapping)

Create a FST composed of a set of smaller FSTs. Each token share the same FST
topology specified with `topo = [(src, dest, weight), ...]`. `S` is a
`Semiring` of output FST, `tokens` is a list of token IDs, and
`mapping[tokens[i]]` is the output symbol .
"""
function tokenfst(
    S,
    topo,
    initweights,
    finalweights,
    tokens,
    mapping = nothing
)
    states = Set(Int[])
    arcs, init, final = [], [], []
    
    if isnothing(mapping)
        st = Set(Int[]) 
        for _ in tokens
            for top in topo
                push!(st, top[1])
            end
        end
        mapping = EmissionMapping(length(st), length(tokens))   
    end

    # Extra state (infinite loop)
    init_state = length(mapping) + 1
    push!(states, init_state)
    
    for (j, token) in enumerate(tokens)
        offset = length(states) - 1
        
        for topo_arc in topo
            
            src, dest, weight = offset + topo_arc[1], offset + topo_arc[2], topo_arc[3]

            arc = Arc(
                src = src,
                isym = mapping[topo_arc[2], j],
                osym = 0,
                dest = dest,
                weight = S(weight)
            )
            
            # Init arc to sources
            if topo_arc[1] == 1
                init_arc = Arc(
                    src = init_state,
                    isym = 0,
                    osym = token,
                    dest = src,
                    weight = S(weight)
                )
                push!(arcs, init_arc)
            end

            # Final arcs to destinations
            if topo_arc[1] == length(st)
                final_arc = Arc(
                    src = src,
                    isym = 0,
                    osym = 0,
                    dest = init_state,
                    weight = S(weight)
                ) 
                push!(arcs, final_arc)
            end

            push!(states, src)
            push!(states, dest)
            push!(arcs, arc)
        end

        # for (state, weight) in initweights
        #     print(state)
        #     state = offset + state
        #     push!(states, state)
        #     push!(init, state => S(weight))
        # end

        for (state, weight) in finalweights
            state = offset + state
            push!(states, state)
            push!(final, state => S(weight))
        end
    end

    # Actually, there is just one init state
    for (_, weight) in initweights
        push!(init, init_state => S(weight))
    end

    TensorFST(arcs, init, final)
end

include("trie.jl")
include("lexicon.jl")

end