Skip to content
Snippets Groups Projects
Verified Commit 2d86a478 authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

basic sparse tensor FSA/FST

parent c666ab46
No related branches found
No related tags found
1 merge request!64Symbols
name = "TensorFSTs"
uuid = "9f8e39db-0d6b-46cc-a14d-daf437028eec"
authors = ["Lucas ONDEL YANG <lucas.ondel@cnrs.fr",
"Cyril ALLAUZEN <allauzen@google.com",
"Lukas BURGET <burget@fit.vutbr.cz>",
"Simon DEVAUCHELLE <devauchelle@lisn.upsaclay.fr>",
"Martin KOCOUR <ikocour@fit.vutbr.cz>",
"Tina RAISSI <raissi@i6.informatik.rwth-aachen.de>",
"Pablo RIERA <pablo.riera@gmail.com>",
"Michael RILEY <riley@google.com>",
"Jan TRMAL <yenda@jhu.edu>"]
authors = ["Lucas ONDEL YANG <lucas.ondel@cnrs.fr", "Cyril ALLAUZEN <allauzen@google.com", "Lukas BURGET <burget@fit.vutbr.cz>", "Simon DEVAUCHELLE <devauchelle@lisn.upsaclay.fr>", "Martin KOCOUR <ikocour@fit.vutbr.cz>", "Tina RAISSI <raissi@i6.informatik.rwth-aachen.de>", "Pablo RIERA <pablo.riera@gmail.com>", "Michael RILEY <riley@google.com>", "Jan TRMAL <yenda@jhu.edu>"]
version = "1.4.1"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Semirings = "900aad66-9ca5-44d4-b043-321c62cb7767"
SumSparseTensors = "472dc678-8c5a-4778-a6eb-28a4e9c7cb58"
SparseSemiringAlgebra = "f8e6a300-9d4f-438a-9da3-99be810debe9"
[compat]
SumSparseTensors = "0.14"
Semirings = "1.7"
Semirings = "2.0"
julia = "1.6"
......@@ -3,67 +3,86 @@
module TensorFSTs
using Semirings
using SumSparseTensors
using SparseSemiringAlgebra
export SymbolTable,
syms,
symbols,
ids
include("symboltable.jl")
export AbstractFSMArc,
FSTArc,
FSAArc,
AbstractFSM,
AbstractFST,
AbstractFSA,
arcs,
states,
initweight,
finalweight,
symboltable,
isymboltable,
osymboltable
export Arc, TensorFST
include("abstractfst.jl")
include("tensorfst.jl")
export
arcs,
dimensions,
hasinputepsilon,
hasoutputepsilon,
isymrange,
osymrange,
initweights,
finalweights,
osymrange,
numarcs,
semiring,
states,
updatesymrange
include("base.jl")
export draw
export SparseTensorFSA, SparseTensorFST
include("draw.jl")
export
compose,
connect,
weightpush,
toinitial,
tofinal,
closure,
weight,
shortestdistance,
state,
inputlabel,
outputlabel
include("filter.jl")
include("ops.jl")
export project_to
include("utils.jl")
include("tensorfst.jl")
# standalone libraries
include("../lib/LinearFSTs.jl")
include("../lib/NGramFSTs.jl")
include("../lib/SpeechRecognitionFSTs/SpeechRecognitionFSTs.jl")
#export Arc, TensorFST
#
#include("tensorfst.jl")
#
#
#export
# arcs,
# dimensions,
# hasinputepsilon,
# hasoutputepsilon,
# isymrange,
# osymrange,
# initweights,
# finalweights,
# osymrange,
# numarcs,
# semiring,
# states,
# updatesymrange
#
#include("base.jl")
#
#
#export draw
#
#include("draw.jl")
#
#
#export
# compose,
# connect,
# weightpush,
# toinitial,
# tofinal,
# closure,
# weight,
# shortestdistance,
# state,
# inputlabel,
# outputlabel
#
#include("filter.jl")
#include("ops.jl")
#
#export project_to
#include("utils.jl")
#
## standalone libraries
#include("../lib/LinearFSTs.jl")
#include("../lib/NGramFSTs.jl")
#include("../lib/SpeechRecognitionFSTs/SpeechRecognitionFSTs.jl")
end
# SPDX-License-Identifier: CECILL-2.1
"""
abstract type AbstractFSMArc{T} end
Supertype for the arc of a finite-state machine.
"""
abstract type AbstractFSMArc{T} end
"""
struct FSTArc{T} <: AbstractFSMArc{T}
src::Int
isym::Int
osym::Int
weight::T
dest::Int
end
Arc of a transducer.
"""
struct FSTArc{T} <: AbstractFSMArc{T}
src::Int
isym::Int
osym::Int
weight::T
dest::Int
end
"""
struct FSTArc{T} <: AbstractFSMArc{T}
src::Int
isym::Int
osym::Int
weight::T
dest::Int
end
Arc of an acceptor.
"""
struct FSAArc{T} <: AbstractFSMArc{T}
src::Int
sym::Int
weight::T
dest::Int
end
Base.convert(::Type{<:FSTArc{T}}, tarc::@NamedTuple{src::Int, isym::Int, osym::Int, weight::T, dest::Int}) where T =
FSTArc(tarc.src, tarc.isym, tarc.osym, tarc.weight, tarc.dest)
Base.convert(::Type{<:FSTArc{T}}, tarc::Tuple{Int,Int,Int,T,Int}) where T = FSTArc(tarc...)
Base.convert(::Type{<:FSAArc{T}}, tarc::@NamedTuple{src::Int, sym::Int, weight::T, dest::Int}) where T =
FSAArc(tarc.src, tarc.sym, tarc.weight, tarc.dest)
Base.convert(::Type{<:FSAArc{T}}, tarc::Tuple{Int,Int,T,Int}) where T = FSAArc(tarc...)
"""
abstract type AbstractFST{T} end
Supertype for finite-state machines. `T` is the type of the arcs' weight of the
machine.
"""
abstract type AbstractFSM{T} end
"""
abstract type AbstractFST{T} <:AbstractFSM end
Supertype for finite-state transducers.
"""
abstract type AbstractFST{T} <: AbstractFSM{T} end
"""
abstract type AbstractFSA{T} <:AbstractFSM end
Supertype for finite-state acceptors.
"""
abstract type AbstractFSA{T} <: AbstractFSM{T} end
"""
arcs(fsm::AbstractFST)
Return an iterable over the arcs of `fsm`.
"""
arcs
"""
states(fsm::AbstractFSM)
Return an iterable over the states of `fsm`.
"""
states
"""
initweight(fsm::AbstractFSM, state)
Return the initial weight associated to `state` in `fsm`.
"""
initweight
"""
finalweight(fsm::AbstractFSM, state)
Return the final weight associated to `state` in `fsm`.
"""
finalweight
"""
semiring(fsm::AbstractFSM)
Return the semiring type of the transitions' weight of `fsm`.
"""
semiring(fsm::AbstractFSM{T}) where T
"""
symboltable(fsa::AbstractFSA)
Return the symbol table of the acceptor `fsa`.
"""
symboltable
"""
isymboltable(fst::AbstractFST)
Return the input symbol table of transducer `fst`.
"""
isymboltable(fst::AbstractFST)
"""
osymboltable(fst::AbstractFST)
Return the output symbol table of transducer `fst`.
"""
osymboltable(fst::AbstractFST)
#==============================================================================
Drawing FST with graphviz.
==============================================================================#
function Base.show(io::IO, ::MIME"image/svg+xml", fsm::AbstractFSM)
# Generate the graph description in the DOT language.
buffer = IOBuffer()
ioctx = IOContext(buffer, :compact => true)
_write_dot(ioctx, fsm)
fst_dot = String(take!(buffer))
# Generate the SVG file from the dot script.
dir = mktempdir()
dotpath = joinpath(dir, "fst.dot")
svgpath = joinpath(dir, "fst.svg")
open(f -> write(f, fst_dot), dotpath, "w")
run(`dot -Tsvg $(dotpath) -o $(svgpath)`)
write(io, open(f -> read(f, String), svgpath))
end
function _write_dot(io::IO, fst::AbstractFST, arc::FSTArc)
label = "$(isymboltable(fst)[arc.isym]):$(osymboltable(fst)[arc.osym])"
if ! isone(arc.weight)
label *= "/$(arc.weight)"
end
color = iszero(arc.weight) ? "grey80" : "black"
println(io, "$(arc.src) -> $(arc.dest) [label=\"$(label)\" color=\"$(color)\"];")
end
function _write_dot(io::IO, fsa::AbstractFSA, arc::FSAArc)
label = "$(symboltable(fsa)[arc.sym])"
if ! isone(arc.weight)
label *= "/$(arc.weight)"
end
color = iszero(arc.weight) ? "grey80" : "black"
println(io, " $(arc.src) -> $(arc.dest) [label=\"$(label)\" color=\"$(color)\"];")
end
function _write_dot(io::IO, fsm::AbstractFSM)
println(io, "digraph {")
println(io, " rankdir=LR;")
for q in states(fsm)
iw = initweight(fsm, q)
fw = finalweight(fsm, q)
if ! iszero(iw)
style = "bold"
label = isone(iw) ? "$q" : "$(iw)/$(q)"
else
style = "solid"
label = "$q"
end
if ! iszero(fw)
shape = "doublecircle"
label *= isone(fw) ? "" : "/$(fw)"
else
shape = "circle"
end
println(io, " $q [label=\"$(label)\" shape=\"$(shape)\" style=\"$style\"];")
end
for arc in arcs(fsm)
_write_dot(io, fsm, arc)
end
println(io, "}")
end
# SPDX-License-Identifier: CECILL-2.1
const ϵ = 0
const ϵ = 1
"""
struct SymbolTable
sym2id::Dict{String,Int}
id2sym::Dict{Int,String}
end
Bidirectional map that associates an id to a key or vice versa. The mapping
reserve the id `1` to the special symbol `ϵ` for FST transition.
# Constructor
SymbolTable(syms; ϵsym = "ϵ")
`syms` is an iterable of symbls and `ϵsym` is the representation of the ``ϵ``
transition.
"""
struct SymbolTable
sym2id::Dict{String,Int}
id2sym::Dict{Int,String}
end
Base.length(symtable::SymbolTable) = length(symtable.sym2id)
Base.getindex(symtable::SymbolTable, i::Int) = symtable.id2sym[i]
Base.getindex(symtable::SymbolTable, i::Integer) = symtable.id2sym[i]
Base.getindex(symtable::SymbolTable, s::AbstractString) = symtable.sym2id[s]
Base.get(symtable::SymbolTable, i::Int, default) = get(symtable.id2sym, i, default)
Base.get(symtable::SymbolTable, i::Integer, default) = get(symtable.id2sym, i, default)
Base.get(symtable::SymbolTable, s::AbstractString, default) = get(symtable.sym2id, s, default)
"""
SymbolTable([syms]; ϵsym = "ϵ")
Create a `SymbolTable` instance, i.e. a bidirectional mapping that associates an
integer id to a string symbol and vice versa. `syms` is an iterable of symbols
and `ϵsym` is the representation of the ``ϵ``.
!!! note
The id `1` is reserved for the special symbol `ϵ` to indicate the empty
string.
"""
SymbolTable
function SymbolTable(syms)
sym2id = Dict{String,Int}("ϵ" => 0)
id2sym = Dict{Int,String}(0 => "ϵ")
function SymbolTable(syms; ϵsym = "ϵ")
sym2id = Dict{String,Int}(ϵsym => ϵ)
id2sym = Dict{Int,String}(ϵ => ϵsym)
for (i, sym) in enumerate(sort(unique(syms)))
sym2id[String(sym)] = i
id2sym[i] = String(sym)
sym2id[String(sym)] = ϵ + i
id2sym[ϵ+i] = String(sym)
end
SymbolTable(sym2id, id2sym)
end
SymbolTable(; ϵsym = "ϵ") = SymbolTable([]; ϵsym)
"""
syms(t [, exclude_ϵ=tru])
symbols(t; exclude_ϵ = true)
Return all symbols from symbol table `t`.
"""
function syms(t::SymbolTable; exclude_ϵ=true)
function symbols(t::SymbolTable; exclude_ϵ = true)
if exclude_ϵ
collect(filter(s -> cmp(s, t.id2sym[0]) != 0, keys(t.sym2id)))
sort(collect(filter(s -> t.sym2id[s] != ϵ, keys(t.sym2id))))
else
collect(keys(t.sym2id))
sort(collect(keys(t.sym2id)))
end
end
"""
ids(t [, exclude_ϵ=true])
ids(t; exclude_ϵ = true)
Return all IDs from symbol table `t`.
Return all ids from symbol table `t`.
"""
function ids(t::SymbolTable; exclude_ϵ=true)
function ids(t::SymbolTable; exclude_ϵ = true)
if exclude_ϵ
filter(id -> id != 0, collect(values(t.sym2id)))
sort(filter(id -> id != ϵ, collect(values(t.sym2id))))
else
collect(values(t.sym2id))
sort(collect(values(t.sym2id)))
end
end
# SPDX-License-Identifier: CECILL-2.1
## The type Arc is just an alias with a convenience constructor over a
# named tuple.
const Arc{S,N} = @NamedTuple{src::NTuple{N,Int}, isym::Int, osym::Int, weight::S, dest::NTuple{N,Int}} where {S<:Semiring,N}
function Arc(; src::Union{Int, NTuple{N, Int}}, isym, osym, weight, dest::Union{Int, NTuple{N, Int}}) where N
(src = tuple(src...), isym = isym, osym = osym, weight = weight, dest = tuple(dest...))
end
struct TensorFST{S<:Semiring}
dims::@NamedTuple{src::NTuple{N,Int}, isym::Int, osym::Int, dest::NTuple{N,Int}} where N
M::AbstractSumSparseTensor{S}
α::AbstractSumSparseTensor{S}
ω::AbstractSumSparseTensor{S}
"""
struct SparseTensorFSA{T} <: AbstractFSA{T}
M::SparseVector{<:SparseMatrix{T}}
α::SparseVector{T}
ω::SparseVector{T}
end
function TensorFST{S}(dims, M, α, ω) where S
if size(α) size(ω)
throw(DimensionMismatch("`size(α)` and `size(ω)` should be equal"))
end
"""
struct SparseTensorFSA{T} <: AbstractFSA{T}
M::SparseVector{<:SparseMatrix{T}}
α::SparseVector{T}
ω::SparseVector{T}
N = length(size(α))
D = length(size(M))
if D 2*(N + 1)
throw(DimensionMismatch("`length(size(M))` should be equal to `2*(length(size(α)) + 1)`"))
end
symboltable::SymbolTable
new{S}(dims, M, α, ω)
function SparseTensorFSA{T}(M, α, ω, symboltable) where T
size(α) size(ω) && throw(ArgumentError("`α` and `ω` should have the same size"))
new{T}(M, α, ω, symboltable)
end
end
"""
TensorFST(arcs, initweights, finalweights; Σrange = ..., Ωrange = ...)
function SparseTensorFSA{T}(arcs, initweights, finalweights; symboltable) where T
# Determine the number of states.
numstate = 0
for arc in arcs numstate = max(numstate, max(arc.src, arc.dest)) end
for (state, weight) in initweights numstate = max(numstate, state) end
for (state, weight) in finalweights numstate = max(numstate, state) end
Create a Finite State Transducer given a list of arcs, initial weights and
finale weights. An arc is a named tuple `(src = , isym = , osym =, weight = , dest = )`.
Initial/final weights are specified with a list of pairs `state => weight`.
`Σrange` and `Ωrange` specifies the range of input and output symbol ids
respectfully.
K = Int
V = Tuple{Vector{NTuple{2,Int}},Vector{T}}
sym_matrices = Dict{K,V}()
for arc in arcs
coordinates, values = get(sym_matrices, arc.sym, (NTuple{2,Int}[], T[]))
push!(coordinates, (arc.src, arc.dest))
push!(values, arc.weight)
sym_matrices[arc.sym] = (coordinates, values)
end
!!! warn "
The specifal symbol ``\\epsilon`` has a reserved id of 0.
"""
TensorFST
indices = collect(keys(sym_matrices))
matrices = SparseMatrix{T}[]
for (coordinates, values) in Base.values(sym_matrices)
Mi = sparse(coordinates, values; size = (numstate, numstate))
push!(matrices, Mi)
end
M = sparsevec(indices, matrices; size = length(symboltable), zeroval = spzeros(T, numstate, numstate))
α = sparsevec(map(first, initweights), map(last, initweights); size = numstate)
ω = sparsevec(map(first, finalweights), map(last, finalweights); size = numstate)
function _range_sym(f, arcs, init)
range_union(r1, r2) = min(first(r1), first(r2)):max(last(r1), last(r2))
mapreduce(f, range_union, arcs; init)
SparseTensorFSA{T}(M, α, ω, symboltable)
end
function TensorFST(arcs::AbstractVector{<:Arc{S,N}},
initweights::AbstractVector{<:Pair},
finalweights::AbstractVector{<:Pair};
Σrange = 1:0,
Ωrange = 1:0) where {S,N}
# We take the largest range possible for the symbol
Σrange = _range_sym(a -> min(1, a.isym):a.isym, arcs, Σrange)
Ωrange = _range_sym(a -> min(1, a.osym):a.osym, arcs, Ωrange)
# Find the state size by finding the maximum state index in each dimension.
maxstate = zeros(Int, N)
for arc in arcs
for n in 1:N
maxstate[n] = max(maxstate[n], arc.src[n])
maxstate[n] = max(maxstate[n], arc.dest[n])
function arcs(fsa::SparseTensorFSA{T}) where T
arcs = FSAArc{T}[]
for nzi in 1:nnz(fsa.M)
sym = fsa.M.indices[nzi]
Mi = fsa.M.values[nzi]
for nzi in 1:nnz(Mi)
(src, dest) = Mi.coordinates[nzi]
weight = Mi.values[nzi]
push!(arcs, (src, sym, weight, dest))
end
end
arcs
end
for (q, iw) in initweights
for n in 1:N
maxstate[n] = max(maxstate[n], q[n])
end
states(fsa::SparseTensorFSA) = 1:size(fsa.α, 1)
initweight(fsa::SparseTensorFSA, state) = fsa.α[state]
finalweight(fsa::SparseTensorFSA, state) = fsa.ω[state]
symboltable(fsa::SparseTensorFSA) = fsa.symboltable
"""
struct SparseTensorFST{T} <: AbstractFST{T}
M::SparseMatrix{<:SparseMatrix{T}}
α::SparseVector{T}
ω::SparseVector{T}
end
for (q, fw) in finalweights
for n in 1:N
maxstate[n] = max(maxstate[n], q[n])
end
"""
struct SparseTensorFST{T} <: AbstractFST{T}
M::SparseMatrix{<:SparseMatrix{T}}
α::SparseVector{T}
ω::SparseVector{T}
isymboltable::SymbolTable
osymboltable::SymbolTable
function SparseTensorFST{T}(M, α, ω, isymboltable, osymboltable) where T
size(α) size(ω) && throw(ArgumentError("`α` and `ω` should have the same size"))
new{T}(M, α, ω, isymboltable, osymboltable)
end
stateranges = ntuple(n -> 1:maxstate[n], N)
end
# Build the sparse tensor encoding the arcs.
nzcoo = ! isempty(arcs) ? map(a -> (a.src..., a.dest..., a.isym, a.osym), arcs) : NTuple{2N+2,Int}[]
nzval = ! isempty(arcs) ? map(a -> a.weight, arcs) : S[]
M = sort(sparsesum(nzcoo, nzval; axes = (stateranges..., stateranges..., Σrange, Ωrange)), 1)
# Build the sparse vector encoding the initial weights.
nzcoo = ! isempty(initweights) ? map(x -> tuple(first(x)...), initweights) : NTuple{N,Int}[]
nzval = ! isempty(initweights) ? map(x -> last(x), initweights) : S[]
α = sparsesum(nzcoo, nzval; axes = stateranges)
function SparseTensorFST{T}(arcs, initweights, finalweights; isymboltable, osymboltable) where T
numstate = 0
for arc in arcs numstate = max(numstate, max(arc.src, arc.dest)) end
for (state, weight) in initweights numstate = max(numstate, state) end
for (state, weight) in finalweights numstate = max(numstate, state) end
# Build the sparse vector encoding the final weights.
nzcoo = ! isempty(initweights) ? map(x -> tuple(first(x)...), finalweights) : NTuple{N,Int}[]
nzval = ! isempty(initweights) ? map(x -> last(x), finalweights) : S[]
ω = sparsesum(nzcoo, nzval; axes = stateranges)
K = Tuple{Int,Int}
V = Tuple{Vector{NTuple{2,Int}},Vector{T}}
sym_matrices = Dict{K,V}()
for arc in arcs
sym = (arc.isym, arc.osym)
coordinates, values = get(sym_matrices, sym, (NTuple{2,Int}[], T[]))
push!(coordinates, (arc.src, arc.dest))
push!(values, arc.weight)
sym_matrices[sym] = (coordinates, values)
end
coordinates = collect(keys(sym_matrices))
matrices = SparseMatrix{T}[]
for (coordinates, values) in Base.values(sym_matrices)
Mi = sparse(coordinates, values; size = (numstate, numstate))
push!(matrices, Mi)
end
dims = (src = tuple(1:N...), isym = 2N + 1, osym = 2N + 2, dest = tuple(N+1:2N...))
TensorFST{S}(dims, M, α, ω)
size = (length(isymboltable), length(osymboltable))
M = sparse(coordinates, matrices; size, zeroval = spzeros(T, numstate, numstate))
α = sparsevec(map(first, initweights), map(last, initweights); size = numstate)
ω = sparsevec(map(first, finalweights), map(last, finalweights); size = numstate)
SparseTensorFST{T}(M, α, ω, isymboltable, osymboltable)
end
function TensorFST(arcs, initweights, finalweights; Σrange = 1:0, Ωrange = 1:0)
TensorFST(
map(a -> Arc(src = a.src, isym = a.isym, osym = a.osym, weight = a.weight, dest = a.dest), arcs),
map(iw -> tuple(first(iw)...) => last(iw), initweights),
map(fw -> tuple(first(fw)...) => last(fw), finalweights);
Σrange,
Ωrange
)
function arcs(fst::SparseTensorFST{T}) where T
arcs = FSTArc{T}[]
for nzi in 1:nnz(fst.M)
(isym, osym) = fst.M.coordinates[nzi]
Mij = fst.M.values[nzi]
for nzi in 1:nnz(Mij)
(src, dest) = Mij.coordinates[nzi]
weight = Mij.values[nzi]
push!(arcs, (src, isym, osym, weight, dest))
end
end
arcs
end
states(fst::SparseTensorFST) = 1:size(fst.α, 1)
initweight(fst::SparseTensorFST, state) = fst.α[state]
finalweight(fst::SparseTensorFST, state) = fst.ω[state]
isymboltable(fst::SparseTensorFST) = fst.isymboltable
osymboltable(fst::SparseTensorFST) = fst.osymboltable
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment