Skip to content
Snippets Groups Projects
Verified Commit 311a0c96 authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

src/TensorFSTs.jl

parent f31c7779
No related branches found
No related tags found
1 merge request!64Symbols
......@@ -5,44 +5,31 @@ module TensorFSTs
using Semirings
using SparseSemiringAlgebra
export SymbolTable,
symbols,
ids,
ϵ
include("symboltable.jl")
# Generic FSM API
export AbstractFSMArc,
FSTArc,
FSAArc,
FSMArc,
AbstractFSM,
AbstractFST,
AbstractFSA,
arcs,
states,
initweight,
initweights,
finalweight,
finalweights,
symboltable,
isymboltable,
osymboltable,
draw
draw,
resize
include("abstractfst.jl")
export SparseTensorFSA, SparseTensorFST
include("sparsetensorfst.jl")
export DenseTensorFSA, DenseTensorFST
include("densetensorfst.jl")
export SparseTensorFSM
export closure
include("sparsetensorfsm.jl")
include("ops.jl")
#export DenseTensorFSA, DenseTensorFST
#
#include("densetensorfst.jl")
#
#export closure
#
#include("ops.jl")
#export Arc, TensorFST
#
......
# SPDX-License-Identifier: CECILL-2.1
"""
abstract type AbstractFSMArc{T} end
Supertype for the arc of a finite-state machine.
"""
abstract type AbstractFSMArc{T} end
"""
struct FSTArc{T} <: AbstractFSMArc{T}
struct FSMArc{T,N} <: AbstractFSMArc{T,N}
src::Int
isym::Int
sym::NTuple{N,Int}
osym::Int
weight::T
dest::Int
end
Arc of a transducer.
Arc of a `N`-order relation finite-state machine over the semiring `T`.
"""
struct FSTArc{T} <: AbstractFSMArc{T}
struct FSMArc{T,N}
src::Int
isym::Int
osym::Int
sym::NTuple{N,Int}
weight::T
dest::Int
end
"""
struct FSTArc{T} <: AbstractFSMArc{T}
src::Int
isym::Int
osym::Int
weight::T
dest::Int
end
Arc of an acceptor.
"""
struct FSAArc{T} <: AbstractFSMArc{T}
src::Int
sym::Int
weight::T
dest::Int
end
Base.convert(::Type{<:FSMArc{T,N}}, tarc::@NamedTuple{src::Int, sym::NTuple{N,Int}, weight::T, dest::Int}) where {T,N} =
FSMArc(tarc.src, tarc.sym, tarc.weight, tarc.dest)
Base.convert(::Type{<:FSMArc{T,N}}, tarc::Tuple{Int,NTuple{N,Int},T,Int}) where {T,N} = FSMArc(tarc...)
Base.convert(::Type{<:FSTArc{T}}, tarc::@NamedTuple{src::Int, isym::Int, osym::Int, weight::T, dest::Int}) where T =
FSTArc(tarc.src, tarc.isym, tarc.osym, tarc.weight, tarc.dest)
Base.convert(::Type{<:FSTArc{T}}, tarc::Tuple{Int,Int,Int,T,Int}) where T = FSTArc(tarc...)
Base.convert(::Type{<:FSTArc{T1}}, arc::FSTArc{T2}) where {T1,T2} =
FSTArc{T1}(arc.src, arc.isym, arc.osym, convert(T1, arc.weight), arc.dest)
Base.convert(::Type{<:FSAArc{T}}, tarc::@NamedTuple{src::Int, sym::Int, weight::T, dest::Int}) where T =
FSAArc(tarc.src, tarc.sym, tarc.weight, tarc.dest)
Base.convert(::Type{<:FSAArc{T}}, tarc::Tuple{Int,Int,T,Int}) where T = FSAArc(tarc...)
Base.convert(::Type{<:FSAArc{T1}}, arc::FSAArc{T2}) where {T1,T2} =
FSAArc{T1}(arc.src, arc.sym, convert(T1, arc.weight), arc.dest)
Base.convert(::Type{<:FSMArc{T1,N}}, arc::FSMArc{T2,N}) where {T1,T2,N} =
FSMArc{T1,N}(arc.src, arc.sym, convert(T1, arc.weight), arc.dest)
"""
abstract type AbstractFST{T} end
Supertype for finite-state machines. `T` is the type of the arcs' weight of the
machine.
"""
abstract type AbstractFSM{T} end
"""
abstract type AbstractFST{T} <:AbstractFSM end
Supertype for finite-state transducers.
Supertype for finite-state machines. `T` is the semiring over which the machine
is defined and `N` is the order of the relation represented by the machine.
"""
abstract type AbstractFST{T} <: AbstractFSM{T} end
abstract type AbstractFSM{T,N} end
"""
abstract type AbstractFSA{T} <:AbstractFSM end
Supertype for finite-state acceptors.
"""
abstract type AbstractFSA{T} <: AbstractFSM{T} end
"""
arcs(fsm::AbstractFST)
arcs(fsm::AbstractFSM)
Return an iterable over the arcs of `fsm`.
"""
......@@ -110,14 +62,6 @@ Return the initial weight associated to `state` in `fsm`.
initweight
"""
initweights(fsm::AbstractFSM)
Return an iterable of pairs `state => initweight`.
"""
initweights
"""
finalweight(fsm::AbstractFSM, state)
......@@ -126,14 +70,6 @@ Return the final weight associated to `state` in `fsm`.
finalweight
"""
finalweights(fsm::AbstractFSM)
Return an iterable of pairs `state => finalweight`.
"""
finalweights
"""
semiring(fsm::AbstractFSM)
semiring(F::Type{<:AbstractFSM})
......@@ -146,52 +82,19 @@ semiring(::Type{<:AbstractFSM{T}}) where T = T
semiring(fsm::AbstractFSM) = semiring(typeof(fsm))
"""
symboltable(fsa::AbstractFSA)
Return the symbol table of the acceptor `fsa`.
"""
symboltable
"""
isymboltable(fst::AbstractFST)
Return the input symbol table of transducer `fst`.
"""
isymboltable(fst::AbstractFST)
"""
osymboltable(fst::AbstractFST)
Return the output symbol table of transducer `fst`.
"""
osymboltable(fst::AbstractFST)
_symboltables(fsa::AbstractFSA) = (symboltable(fsa),)
_symboltables(fst::AbstractFST) = (isymboltable(fst), osymboltable(fst))
#function (F::Type{<:AbstractFSM})(fsm::AbstractFSM)
# T1 = semiring(F)
# T2 = semiring(fsm)
# F(
# convert(Vector{_arctype(F)}, arcs(fsm)),
# convert(Vector{Pair{Int,T1}}, initweights(fsm)),
# convert(Vector{Pair{Int,T1}}, finalweights(fsm))
# )
#end
_arctype(fsa::Type{<:AbstractFSA{T}}) where T = FSAArc{T}
_arctype(fsa::AbstractFSA) = _arctype(typeof(fsa))
_arctype(fsa::Type{<:AbstractFST{T}}) where T = FSTArc{T}
_arctype(fst::AbstractFST) = _arctype(typeof(fst))
function (F::Type{<:AbstractFSM})(fsm::AbstractFSM, symboltables... = _symboltables(fsm)...)
T1 = semiring(F)
T2 = semiring(fsm)
F(
convert(Vector{_arctype(F)}, arcs(fsm)),
convert(Vector{Pair{Int,T1}}, initweights(fsm)),
convert(Vector{Pair{Int,T1}}, finalweights(fsm)),
symboltables...
)
end
Base.convert(F::Type{<:AbstractFSM}, fsm::AbstractFSM) = F(fsm)
#Base.convert(F::Type{<:AbstractFSM}, fsm::AbstractFSM) = F(fsm)
#==============================================================================
......@@ -200,6 +103,7 @@ Drawing FST with graphviz.
struct DrawFSM
fsm
symtables
show_zero_arcs
end
......@@ -210,14 +114,15 @@ end
Return a drawable representation of the finite-state machine `fsm` using
graphviz.
"""
draw(fsm; show_zero_arcs = false) = DrawFSM(fsm, show_zero_arcs)
draw(fsm::AbstractFSM{T,N}; symtables = ntuple(n -> Dict(), N), show_zero_arcs = false) where {T,N} =
DrawFSM(fsm, symtables, show_zero_arcs)
function Base.show(io::IO, ::MIME"image/svg+xml", dfsm::DrawFSM)
# Generate the graph description in the DOT language.
buffer = IOBuffer()
ioctx = IOContext(buffer, :compact => true)
_write_dot(ioctx, dfsm.fsm, dfsm.show_zero_arcs)
_write_dot(ioctx, dfsm.fsm, dfsm.symtables, dfsm.show_zero_arcs)
fst_dot = String(take!(buffer))
# Generate the SVG file from the dot script.
......@@ -229,27 +134,8 @@ function Base.show(io::IO, ::MIME"image/svg+xml", dfsm::DrawFSM)
write(io, open(f -> read(f, String), svgpath))
end
function _write_dot(io::IO, fst::AbstractFST, arc::FSTArc)
label = "$(isymboltable(fst)[arc.isym]):$(osymboltable(fst)[arc.osym])"
if ! iszero(arc.weight) && ! isone(arc.weight)
label *= "/$(arc.weight)"
end
color = iszero(arc.weight) ? "#8080804D" : "black"
style = iszero(arc.weight) ? "dashed" : "solid"
println(io, " $(arc.src) -> $(arc.dest) [label=\"$(label)\" color=\"$(color)\" style=\"$(style)\"];")
end
function _write_dot(io::IO, fsa::AbstractFSA, arc::FSAArc)
label = "$(symboltable(fsa)[arc.sym])"
if ! iszero(arc.weight) && ! isone(arc.weight)
label *= "/$(arc.weight)"
end
color = iszero(arc.weight) ? "#8080804D" : "black"
style = iszero(arc.weight) ? "dashed" : "solid"
println(io, " $(arc.src) -> $(arc.dest) [label=\"$(label)\" color=\"$(color)\" style=\"$(style)\"];")
end
function _write_dot(io::IO, fsm::AbstractFSM, show_zero_arcs)
function _write_dot(io::IO, fsm::AbstractFSM{T,N}, symtables, show_zero_arcs) where {T,N}
println(io, "digraph {")
println(io, " rankdir=LR;")
......@@ -277,7 +163,13 @@ function _write_dot(io::IO, fsm::AbstractFSM, show_zero_arcs)
for arc in arcs(fsm)
if ! iszero(arc.weight) || show_zero_arcs
_write_dot(io, fsm, arc)
label = join([get(symtables[n], arc.sym[n], arc.sym[n]) for n in 1:N], ":")
if ! iszero(arc.weight) && ! isone(arc.weight)
label *= "/$(arc.weight)"
end
color = iszero(arc.weight) ? "#8080804D" : "black"
style = iszero(arc.weight) ? "dashed" : "solid"
println(io, " $(arc.src) -> $(arc.dest) [label=\"$(label)\" color=\"$(color)\" style=\"$(style)\"];")
end
end
......
......@@ -12,7 +12,10 @@ Union of two state-machine ``A + B`` defined as
"""
Base.union
function Base.union(A::Tm, B::Tm) where Tm<:AbstractFSM
function Base.union(A::AbstractFSM{T,N}, B::AbstractFSM{T,N}) where {T,N}
newsize = ntuple(max(size(A, n), size(B, n)), N)
A, B = resize(A, newsize), resize(B, newsize)
M = blockdiag.(A.M, B.M)
α = vcat(A.α, B.α)
ω = vcat(A.ω, B.ω)
......@@ -32,8 +35,9 @@ Concatenation of two transducers ``A \\cdot B`` defined as:
"""
Base.cat
function Base.cat(A::Tm, B::Tm) where Tm<:AbstractFSM
T = semiring(A)
function Base.cat(A::AbstractFSM{T,N}, B::AbstractFSM{T,N}) where {T,N}
newsize = ntuple(max(size(A, n), size(B, n)), N)
A, B = resize(A, newsize), resize(B, newsize)
M = blockdiag.(A.M, B.M)
M[ϵ] += vcat(A.ω, zero(T) * B.ω) * transpose(vcat(zero(T) * A.α, B.α))
......@@ -55,11 +59,10 @@ Kleene plus closure of a transducer ``A^+`` defined as:
```
where ``A^n`` is the FST ``A`` concatenated ``n`` times with itself.
"""
function closure_plus(A::Tm) where Tm<:AbstractFSM
M = spzeros(eltype(A.M), size(A.M)...)
M += A.M
function closure_plus(A::AbstractFSM)
M = copy(A.M)
M[ϵ] = A.ω * transpose(A.α)
typeof(A)(M, A.α, A.ω, _symboltables(A)...)
typeof(A)(M, A.α, A.ω)
end
Base.:^(A::AbstractFSM, ::typeof(+)) = closure_plus(A)
......@@ -74,12 +77,7 @@ Kleene closure of a transducer ``A^*`` defined as:
```
where ``A^n`` is the FST ``A`` concatenated ``n`` times with itself.
"""
function closure(A::Tm) where Tm<:AbstractFSM
T = semiring(Tm)
Tarc = _arctype(A)
𝓔 = Tm(Tarc[], [1 => one(T)], [1 => one(T)], _symboltables(A)...)
𝓔 * A^+
end
closure(A::Tm) where Tm<:AbstractFSM = one(Tm) * A^+
Base.:^(A::AbstractFSM, ::typeof(*)) = closure(A)
......
# SPDX-License-Identifier: CECILL-2.1
"""
struct SparseTensorFSM{T,N} <: AbstractFSM{T,N}
M::AbstractSparseArray{<:DirectMatrix,N}
α::AbstractSparseVector{T}
ω::AbstractSparseVector{T}
end
Finite-State machine encoded as a sparse tensor.
"""
struct SparseTensorFSM{T,N} <: AbstractFSM{T,N}
M::AbstractSparseArray{<:DirectMatrixSemiring,N}
α::AbstractSparseVector{T}
ω::AbstractSparseVector{T}
function SparseTensorFSM{T,N}(M::AbstractSparseArray{<:DirectMatrixSemiring,N},
α::AbstractSparseVector{T},
ω::AbstractSparseVector{T}) where {T,N}
size(α) size(ω) && throw(ArgumentError("`α` and `ω` should have the same size"))
new{T,N}(M, α, ω)
end
end
function SparseTensorFSM{T,N}(arcs, initweights, finalweights) where {T,N}
numstate = 0
for arc in arcs numstate = max(numstate, max(arc.src, arc.dest)) end
for (state, weight) in initweights numstate = max(numstate, state) end
for (state, weight) in finalweights numstate = max(numstate, state) end
K = Int
V = Tuple{Vector{NTuple{N,Int}},Vector{T}}
sym_matrices = Dict{K,V}()
for arc in arcs
coordinates, values = get(sym_matrices, arc.sym, (NTuple{N,Int}[], T[]))
push!(coordinates, (arc.src, arc.dest))
push!(values, arc.weight)
sym_matrices[arc.sym] = (coordinates, values)
end
indices = collect(keys(sym_matrices))
matrices = SparseMatrix{T,numstate,numstate}[]
for (coordinates, values) in Base.values(sym_matrices)
Mi = sparse(coordinates, values, numstate, numstate)
push!(matrices, Mi)
end
M = sparsevec(indices, matrices, length(symboltable))
α = sparsevec(map(first, initweights), map(last, initweights), numstate)
ω = sparsevec(map(first, finalweights), map(last, finalweights), numstate)
SparseTensorFSM{T,N}(M, α, ω)
end
function arcs(fsa::SparseTensorFSM{T,N}) where {T,N}
arcs = FSMArc{T,N}[]
for nzi in 1:nnz(fsa.M)
sym = (fsa.M.indices[nzi],)
Mi = fsa.M.values[nzi]
for nzi in 1:nnz(Mi)
(src, dest) = Mi.coordinates[nzi]
weight = Mi.values[nzi]
push!(arcs, (src, sym, weight, dest))
end
end
arcs
end
states(fsa::SparseTensorFSM) = 1:size(fsa.α, 1)
initweight(fsa::SparseTensorFSM, state) = fsa.α[state]
finalweight(fsa::SparseTensorFSM, state) = fsa.ω[state]
initweights(fsa::SparseTensorFSM) = [i => w for (i, w) in zip(fsa.α.indices, fsa.α.values)]
finalweights(fsa::SparseTensorFSM) = [i => w for (i, w) in zip(fsa.ω.indices, fsa.ω.values)]
function resize(fsm::SparseTensorFSM{T,N}, size::Dims{N}) where {T,N}
M = sparsevec(fsm.M.indices, fsm.M.values, size...)
SparseTensorFSM{T,N}(M, fsm.α, fsm.ω)
end
# SPDX-License-Identifier: CECILL-2.1
"""
struct SparseTensorFSA{T} <: AbstractFSA{T}
M::SparseVector{<:SparseMatrix{T}}
α::SparseVector{T}
ω::SparseVector{T}
end
"""
struct SparseTensorFSA{T} <: AbstractFSA{T}
M::SparseVector{<:SparseMatrix{T}}
α::SparseVector{T}
ω::SparseVector{T}
symboltable::SymbolTable
function SparseTensorFSA{T}(M::SparseVector{<:SparseMatrix{T}}, α::SparseVector{T}, ω::SparseVector{T}, symboltable::SymbolTable) where T
size(α) size(ω) && throw(ArgumentError("`α` and `ω` should have the same size"))
new{T}(M, α, ω, symboltable)
end
end
function SparseTensorFSA{T}(arcs, initweights, finalweights, symboltable) where T
# Determine the number of states.
numstate = 0
for arc in arcs numstate = max(numstate, max(arc.src, arc.dest)) end
for (state, weight) in initweights numstate = max(numstate, state) end
for (state, weight) in finalweights numstate = max(numstate, state) end
K = Int
V = Tuple{Vector{NTuple{2,Int}},Vector{T}}
sym_matrices = Dict{K,V}()
for arc in arcs
coordinates, values = get(sym_matrices, arc.sym, (NTuple{2,Int}[], T[]))
push!(coordinates, (arc.src, arc.dest))
push!(values, arc.weight)
sym_matrices[arc.sym] = (coordinates, values)
end
indices = collect(keys(sym_matrices))
matrices = SparseMatrix{T,numstate,numstate}[]
for (coordinates, values) in Base.values(sym_matrices)
Mi = sparse(coordinates, values, numstate, numstate)
push!(matrices, Mi)
end
M = sparsevec(indices, matrices, length(symboltable))
α = sparsevec(map(first, initweights), map(last, initweights), numstate)
ω = sparsevec(map(first, finalweights), map(last, finalweights), numstate)
SparseTensorFSA{T}(M, α, ω, symboltable)
end
function arcs(fsa::SparseTensorFSA{T}) where T
arcs = FSAArc{T}[]
for nzi in 1:nnz(fsa.M)
sym = fsa.M.indices[nzi]
Mi = fsa.M.values[nzi]
for nzi in 1:nnz(Mi)
(src, dest) = Mi.coordinates[nzi]
weight = Mi.values[nzi]
push!(arcs, (src, sym, weight, dest))
end
end
arcs
end
states(fsa::SparseTensorFSA) = 1:size(fsa.α, 1)
initweight(fsa::SparseTensorFSA, state) = fsa.α[state]
initweights(fsa::SparseTensorFSA) = [i => w for (i, w) in zip(fsa.α.indices, fsa.α.values)]
finalweight(fsa::SparseTensorFSA, state) = fsa.ω[state]
finalweights(fsa::SparseTensorFSA) = [i => w for (i, w) in zip(fsa.ω.indices, fsa.ω.values)]
symboltable(fsa::SparseTensorFSA) = fsa.symboltable
"""
struct SparseTensorFST{T} <: AbstractFST{T}
M::SparseMatrix{<:SparseMatrix{T}}
α::SparseVector{T}
ω::SparseVector{T}
end
"""
struct SparseTensorFST{T} <: AbstractFST{T}
M::SparseMatrix{<:SparseMatrix{T}}
α::SparseVector{T}
ω::SparseVector{T}
isymboltable::SymbolTable
osymboltable::SymbolTable
function SparseTensorFST{T}(M::SparseMatrix{<:SparseMatrix{T}}, α::SparseVector{T}, ω::SparseVector{T}, isymboltable::SymbolTable, osymboltable::SymbolTable) where T
size(α) size(ω) && throw(ArgumentError("`α` and `ω` should have the same size"))
new{T}(M, α, ω, isymboltable, osymboltable)
end
end
function SparseTensorFST{T}(arcs, initweights, finalweights, isymboltable, osymboltable) where T
numstate = 0
for arc in arcs numstate = max(numstate, max(arc.src, arc.dest)) end
for (state, weight) in initweights numstate = max(numstate, state) end
for (state, weight) in finalweights numstate = max(numstate, state) end
K = Dims{2}
V = Tuple{Vector{Dims{2}},Vector{T}}
sym_matrices = Dict{K,V}()
for arc in arcs
sym = (arc.isym, arc.osym)
coordinates, values = get(sym_matrices, sym, (Dims{2}[], T[]))
push!(coordinates, (arc.src, arc.dest))
push!(values, arc.weight)
sym_matrices[sym] = (coordinates, values)
end
coordinates = collect(keys(sym_matrices))
matrices = SparseMatrix{T,numstate,numstate}[]
for (coordinates, values) in Base.values(sym_matrices)
Mi = sparse(coordinates, values, numstate, numstate)
push!(matrices, Mi)
end
size = (length(isymboltable), length(osymboltable))
M = sparse(coordinates, matrices, length(isymboltable), length(osymboltable))
α = sparsevec(map(first, initweights), map(last, initweights), numstate)
ω = sparsevec(map(first, finalweights), map(last, finalweights), numstate)
SparseTensorFST{T}(M, α, ω, isymboltable, osymboltable)
end
function arcs(fst::SparseTensorFST{T}) where T
arcs = FSTArc{T}[]
for nzi in 1:nnz(fst.M)
(isym, osym) = fst.M.coordinates[nzi]
Mij = fst.M.values[nzi]
for nzi in 1:nnz(Mij)
(src, dest) = Mij.coordinates[nzi]
weight = Mij.values[nzi]
push!(arcs, (src, isym, osym, weight, dest))
end
end
arcs
end
states(fst::SparseTensorFST) = 1:size(fst.α, 1)
initweight(fst::SparseTensorFST, state) = fst.α[state]
initweights(fst::SparseTensorFST) = [i => w for (i, w) in zip(fst.α.indices, fst.α.values)]
finalweight(fst::SparseTensorFST, state) = fst.ω[state]
finalweights(fst::SparseTensorFST) = [i => w for (i, w) in zip(fst.ω.indices, fst.ω.values)]
isymboltable(fst::SparseTensorFST) = fst.isymboltable
osymboltable(fst::SparseTensorFST) = fst.osymboltable
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment