Skip to content
Snippets Groups Projects
Verified Commit cda9696d authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

FST creation

parent 5e55afea
No related branches found
No related tags found
No related merge requests found
......@@ -17,20 +17,17 @@ end
# ╔═╡ 4684d1ff-7a80-4e8e-9f48-f3ef5b11cb5a
S = LogSemiring{Float64,1}
# ╔═╡ 150efaf6-f277-4064-84f3-d21518943743
Arc(src=1, isym=2, osym=2, dest=2, weight=one(S))
# ╔═╡ ebe54b31-daa5-4970-9a25-4146d9f101d8
fst = TensorFST(
arcs = [
[
Arc(src=1, isym=2, osym=2, weight=one(S), dest=1),
Arc(src=1, isym=2, osym=2, weight=one(S), dest=2),
Arc(src=2, isym=2, osym=2, weight=one(S), dest=2),
Arc(src=2, isym=2, osym=2, weight=one(S), dest=3),
Arc(src=3, isym=2, osym=2, weight=one(S), dest=3)
],
initweights = [1 => one(S)],
finalweights = [3 => one(S)]
[1 => one(S)],
[3 => one(S)]
)
# ╔═╡ 27d2c155-518c-49fc-b273-0853f7a04c75
......@@ -42,11 +39,14 @@ Dict(initweights(fst)...)
# ╔═╡ 41df922b-543f-433e-9e30-b6863ca00366
draw(fst; isymbols = Dict(2 => "a"), osymbols = Dict(2 => "x"))
# ╔═╡ 642af191-36ac-4068-a5ca-411758cbaa34
TensorFSTs.numarcs(fst)
# ╔═╡ Cell order:
# ╠═c893a80a-83cd-11ee-1a82-d7c85de7ab7a
# ╟─4684d1ff-7a80-4e8e-9f48-f3ef5b11cb5a
# ╠═150efaf6-f277-4064-84f3-d21518943743
# ╠═ebe54b31-daa5-4970-9a25-4146d9f101d8
# ╠═27d2c155-518c-49fc-b273-0853f7a04c75
# ╠═c353c6d9-6ff4-478e-aee8-dc46c2765847
# ╠═41df922b-543f-433e-9e30-b6863ca00366
# ╠═642af191-36ac-4068-a5ca-411758cbaa34
......@@ -5,31 +5,28 @@ module TensorFSTs
using Semirings
using SumSparseTensors
#=====================================================================#
# AbstractTensorFST interface
#=====================================================================#
export
Arc,
TensorFST,
# accessors
arcs,
draw,
initweights,
finalweights,
# utilities
numarcs,
numstates,
orientation,
reorient,
semiring,
tensorsize
tensordim,
tensorsize,
include("tensorfst.jl")
# graphic representation
draw
#export
# draw,
# graphviz
#
#include("io.jl")
include("tensorfst.jl")
end
......@@ -41,9 +41,6 @@ function Base.show(io::IO, ::MIME"text/plain", arc::Arc)
", dest = ", arc.dest, ")")
end
"""
struct TensorFST{S<:Semiring,O}
M::AbstractSumSparseTensor{S,4}
......@@ -54,10 +51,10 @@ end
Finite State Transducer using a tensor-based representation. `S` is the
semiring type of the arcs' weight and `O` is the orientation of the
tensor `M`. An orientation is a 4-tuple containing the following
symbols `:src`, `:dest`, `:ilabel` and `:olabel` in any order
symbols `:src`, `:dest`, `:isym` and `:osym` in any order
indicating which dimension is associated to the source state,
destination state, input label and output label respectively. For
instance the orientation `(:src, :ilabel, :olabel, :dest)` indicates
instance the orientation `(:src, :isym, :osym, :dest)` indicates
that the first dimension of `M` represents the source state of the
transitions, the second dimension represents input label, and so on.
The vectors `α` and `ω` represent the states' intial and final weights.
......@@ -85,13 +82,12 @@ struct TensorFST{S<:Semiring,O}
end
end
TensorFST{S}(M, α, ω) where S = TensorFST{S,(:src,:dest,:ilabel,:olabel)}(M, α, ω)
TensorFST{S}(M, α, ω) where S = TensorFST{S,(:src,:dest,:isym,:osym)}(M, α, ω)
TensorFST(M, α, ω) = TensorFST{eltype(M)}(M, α, ω)
function TensorFST(;arcs::AbstractVector{Arc{S}},
function TensorFST(arcs::AbstractVector{Arc{S}},
initweights::AbstractVector{Pair{Int,S}},
finalweights::AbstractVector{Pair{Int,S}},
) where S
finalweights::AbstractVector{Pair{Int,S}}) where S
Q = maximum(
union(
Set(map(x -> x.src, arcs)),
......@@ -160,15 +156,15 @@ Return the arcs of `fst` as a list of named tuple
arcs
function arcs(fst::TensorFST{S}) where S
src, dest, ilabel, olabel = (
src, dest, isym, osym = (
tensordim(fst, :src),
tensordim(fst, :dest),
tensordim(fst, :ilabel),
tensordim(fst, :olabel)
tensordim(fst, :isym),
tensordim(fst, :osym)
)
arcs = Arc{S}[]
for (coo, v) in eachnz(collapse(fst.M))
arc = Arc(src=coo[src], isym=coo[ilabel], osym=coo[olabel], weight=v, dest=coo[dest])
arc = Arc(src=coo[src], isym=coo[isym], osym=coo[osym], weight=v, dest=coo[dest])
push!(arcs, arc)
end
arcs
......@@ -180,8 +176,6 @@ end
Return the initial weights of `fst` as a list of pairs
`state => weight`.
"""
initweights
function initweights(fst::TensorFST{S}) where S
iws = Pair{Int,S}[]
for (i, v) in eachnz(fst.α)
......@@ -206,10 +200,9 @@ end
"""
tensordim(fst, symdim::Symbol)
Return the dimension index of the underlying tensor of `fst` in the
corresponding to `symdim`. For transducer, `symdim` can be `:src`,
`:dest`, `:ilabel` and `:olabel` and for acceptors, `symdim` cab be
`:src`, `:dest`, `:label`.
Return the dimension index of the underlying tensor of `fst`
corresponding to`symdim`. `symdim` can be one of `:src`, `:dest`,
`:isym` and `:osym`.
"""
tensordim(fst::TensorFST{S,O}, symdim::Symbol) where {S,O} =
findfirst(x -> x == symdim, orientation(fst))
......@@ -217,10 +210,9 @@ tensordim(fst::TensorFST{S,O}, symdim::Symbol) where {S,O} =
"""
tensorsize(fst, symdim::Symbol)
Return the size index of the underlying tensor of `fst` in the
symbolic dimension of `dim`. For transducer, `symdim` can be `:src`,
`:dest`, `:ilabel` and `:olabel` and for acceptors, `symdim` cab be
`:src`, `:dest`, `:label`.
Return the size of the underlying tensor of `fst` at the dimension
of `symdim`. `symdim` can be one of `:src`, `:dest`,
`:isym` and `:osym`.
"""
tensorsize(fst::TensorFST, symdim::Symbol) =
size(fst.M, tensordim(fst, symdim))
......@@ -228,14 +220,14 @@ tensorsize(fst::TensorFST, symdim::Symbol) =
"""
numstates(fst)
Return the number of states in `fst`.
Return the number of state in `fst`.
"""
numstates(fst::TensorFST) = tensorsize(fst, :src)
"""
numarcs(fst)
Return the number of arcs in `fst`.
Return the number of arc (i.e. transition) in `fst`.
"""
numarcs(fst::TensorFST) = length(arcs(fst))
......@@ -247,7 +239,7 @@ function Base.zero(::Type{TensorFST{S,O}}) where {S,O}
)
end
Base.zero(::Type{TensorFST{S}}) where S = zero(TensorFST{S,(:src,:dest,:ilabel,:olabel)})
Base.zero(::Type{TensorFST{S}}) where S = zero(TensorFST{S,(:src,:dest,:isym,:osym)})
function Base.one(::Type{TensorFST{S,O}}) where {S,O}
TensorFST{S,O}(
......@@ -257,7 +249,7 @@ function Base.one(::Type{TensorFST{S,O}}) where {S,O}
)
end
Base.one(::Type{TensorFST{S}}) where S = one(TensorFST{S,(:src,:dest,:ilabel,:olabel)})
Base.one(::Type{TensorFST{S}}) where S = one(TensorFST{S,(:src,:dest,:isym,:osym)})
#=====================================================================#
# Pretty printing/drawing of FSTs
......@@ -369,3 +361,4 @@ Draw `fst` or `fsa` using graphviz.
"""
draw(fst::TensorFST; isymbols = Dict(), osymbols = Dict(), openfst_compat = false, dpi = missing) =
DrawFST(fst, isymbols, osymbols, openfst_compat, dpi)
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Semirings = "900aad66-9ca5-44d4-b043-321c62cb7767"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@testset "Vector FST" begin
K = ProbSemiring{Float32}
symtab = Dict(
"a" => 1,
"b" => 2,
"c" => 3,
)
_arcs = [
[Arc(symtab["a"], symtab["a"], one(K), 2)],
[Arc(symtab["b"], symtab["b"], one(K), 2),
Arc(symtab["c"], symtab["c"], one(K), 3)],
[Arc(symtab["c"], symtab["c"], one(K), 3)],
@testset "TensorFST" begin
S = ProbSemiring{Float32}
arcs = [
Arc(src=1, isym=2, osym=2, weight=one(S), dest=1),
Arc(src=1, isym=2, osym=2, weight=one(S), dest=2),
Arc(src=2, isym=2, osym=2, weight=one(S), dest=2),
Arc(src=2, isym=2, osym=2, weight=one(S), dest=3),
Arc(src=3, isym=2, osym=2, weight=one(S), dest=3)
]
iweights = [1 => one(S)]
fweights = [1 => one(S)]
fst = TensorFST(arcs = arcs, initweights = iweights, finaleweights = fweights)
finalweights = [zero(K), zero(K), one(K)]
@test VectorFST{K}() isa VectorFST{K}
vfst = VectorFST(
......
using Semirings
using TensorFSTs
using TensorFSTs.Semirings
using TensorFSTs.SparseSemimodules
import LogExpFunctions: logaddexp
import LinearAlgebra: dot
using Test
@testset "SymbolTable" begin
include("symboltable.jl")
end
@testset "TensorFST" begin
S = ProbSemiring{Float32}
@testset "SparseSemimodules" begin
include("sparsesm.jl")
end
arcs = [
Arc(src=1, isym=2, osym=2, weight=one(S), dest=1),
Arc(src=1, isym=2, osym=2, weight=one(S), dest=2),
Arc(src=2, isym=2, osym=2, weight=one(S), dest=2),
Arc(src=2, isym=2, osym=2, weight=one(S), dest=3),
Arc(src=3, isym=2, osym=2, weight=one(S), dest=3)
]
iweights = [1 => one(S)]
fweights = [1 => one(S)]
fst = TensorFST(arcs, iweights, fweights)
#@testset "FST" begin
# include("fst.jl")
# include("io.jl")
#end
@test numarcs(fst) == 5
@test numstates(fst) == 3
@test orientation(fst) == (:src, :dest, :isym, :osym)
@test semiring(fst) == S
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment