Skip to content
Snippets Groups Projects
Commit a0ac6c74 authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

Merge branch '44-wrong-emission-ids' into 'main'

Resolve "Wrong emission IDs"

Closes #44

See merge request !53
parents df98627f f449fd2b
No related branches found
No related tags found
1 merge request!53Resolve "Wrong emission IDs"
### A Pluto.jl notebook ###
# v0.19.36
# v0.19.38
using Markdown
using InteractiveUtils
......@@ -14,8 +14,12 @@ begin
# packages will be install for this notebook automatically.
using Semirings
using TensorFSTs
using TensorFSTs.LinearFSTs
end
# ╔═╡ d5ec1212-e6ee-42eb-8c9b-8f2234ef7f7f
using TensorFSTs.SpeechRecognitionFSTs
# ╔═╡ a4127267-40fd-4550-9446-6d82b0ae8373
md"""# TensorFST
*Lucas ONDEL YANG, LISN CNRS, 24/11/2023*
......@@ -55,6 +59,33 @@ md"""
The function `summary(fst)` return some basic information about your fst.
"""
# ╔═╡ 2cf7395d-4c6d-4370-aa7f-b57734c38b7b
s = ProbSemiring{Float64}
# ╔═╡ 68454444-f7ea-48b7-ac59-2aed02b2d31a
t = tokenfst(
s,
[(1, 1, 1), (1, 2, 1), (2, 2, 1), (2, 3, 1), (3, 3, 1)],
[1 => 1],
[3 => 1],
ids(SymbolTable(["a", "b", "c"]))
)
# ╔═╡ a95b0c6f-5146-4c0c-938d-1c1b7fade296
draw(t)
# ╔═╡ 5d3c18f3-d466-4c16-b54f-e21ff58b080e
begin
topo = [(1, 1, 1), (1, 2, 1), (2, 2, 1), (2, 3, 1), (3, 3, 1)]
values = [topo_arc[2] for topo_arc in topo]
end
# ╔═╡ 974cf260-4e3a-44d1-90f4-1b7115d32dc3
nstates = topo[end][2]
# ╔═╡ 20fadb1b-7883-474b-916e-c48cfb514c32
# ╔═╡ f779c388-223c-4299-9523-3f3b3ba21769
md"""
To represent the FST as an directed acyclic graph first you can use the `draw(fst)` function.
......@@ -504,6 +535,13 @@ draw(
# ╟─b4279331-1f8c-49b7-9bc4-b46f912ace46
# ╟─d62946ec-112e-40c4-8b8c-24dd6f1c5103
# ╠═27d2c155-518c-49fc-b273-0853f7a04c75
# ╠═d5ec1212-e6ee-42eb-8c9b-8f2234ef7f7f
# ╠═2cf7395d-4c6d-4370-aa7f-b57734c38b7b
# ╠═68454444-f7ea-48b7-ac59-2aed02b2d31a
# ╠═a95b0c6f-5146-4c0c-938d-1c1b7fade296
# ╠═5d3c18f3-d466-4c16-b54f-e21ff58b080e
# ╠═974cf260-4e3a-44d1-90f4-1b7115d32dc3
# ╠═20fadb1b-7883-474b-916e-c48cfb514c32
# ╟─f779c388-223c-4299-9523-3f3b3ba21769
# ╠═41df922b-543f-433e-9e30-b6863ca00366
# ╠═e9365726-92bb-40e8-bce7-9bdbd73466db
......
......@@ -16,14 +16,14 @@ export EmissionMapping, lexiconfst, tokenfst
Default emission mapping which assumes the same number of emissions per token.
"""
struct EmissionMapping <: AbstractMatrix{Int}
numstates::Int
numpdfs::Int
numtokens::Int
end
Base.size(x::EmissionMapping) = (x.numstates, x.numtokens)
Base.getindex(x::EmissionMapping, i, j) = (j - 1) * x.numstates + i
Base.size(x::EmissionMapping) = (x.numpdfs, x.numtokens)
Base.getindex(x::EmissionMapping, i, j) = (j - 1) * x.numpdfs + i
Base.minimum(x::EmissionMapping) = 1
Base.maximum(x::EmissionMapping) = (x.numstates * x.num_emission)
Base.maximum(x::EmissionMapping) = (x.numpdfs * x.numtokens)
......@@ -41,36 +41,75 @@ function tokenfst(
initweights,
finalweights,
tokens,
mapping = EmissionMapping(length(topo), length(tokens))
mapping = nothing
)
emission_count = 0
states = Set(Int[])
arcs, init, final = [], [], []
if isnothing(mapping)
st = Set(Int[])
for _ in tokens
for top in topo
push!(st, top[1])
end
end
mapping = EmissionMapping(length(st), length(tokens))
end
for (i, token) in enumerate(tokens)
offset = length(states)
for (j, topo_arc) in enumerate(topo)
emission_count += 1
# Extra state (infinite loop)
init_state = length(mapping) + 1
push!(states, init_state)
for (j, token) in enumerate(tokens)
offset = length(states) - 1
for topo_arc in topo
src, dest, weight = offset + topo_arc[1], offset + topo_arc[2], topo_arc[3]
arc = Arc(
src = src,
isym = mapping[emission_count],
osym = j == 1 ? token : 1,
isym = mapping[topo_arc[2], j],
osym = 0,
dest = dest,
weight = S(weight)
)
# Init arc to sources
if topo_arc[1] == 1
init_arc = Arc(
src = init_state,
isym = 0,
osym = token,
dest = src,
weight = S(weight)
)
push!(arcs, init_arc)
end
# Final arcs to destinations
if topo_arc[1] == length(st)
final_arc = Arc(
src = src,
isym = 0,
osym = 0,
dest = init_state,
weight = S(weight)
)
push!(arcs, final_arc)
end
push!(states, src)
push!(states, dest)
push!(arcs, arc)
end
for (state, weight) in initweights
state = offset + state
push!(states, state)
push!(init, state => S(weight))
end
# for (state, weight) in initweights
# print(state)
# state = offset + state
# push!(states, state)
# push!(init, state => S(weight))
# end
for (state, weight) in finalweights
state = offset + state
......@@ -79,6 +118,11 @@ function tokenfst(
end
end
# Actually, there is just one init state
for (_, weight) in initweights
push!(init, init_state => S(weight))
end
TensorFST(arcs, init, final)
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment