Skip to content
Snippets Groups Projects
Verified Commit 7e0baba5 authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

merge

parents 2b02063f 093c3c2f
No related branches found
No related tags found
1 merge request!60[WIP] Resolve "Read IARPA format LM"
# Releases
## 1.4.1 - 01.03.2024
## Fixed
- `project_to` function was not exported
## 1.4.0 - 01.03.2024
## Added
- Limited support of automatic differentiation
## 1.3.0 - 29.02.2024
## Added
- Weightpushing algorithm
## 1.2.2 - 26.02.2024
## Fixed
- tokenfst: is accept only sequence of 1 token.
......
name = "TensorFSTs"
uuid = "9f8e39db-0d6b-46cc-a14d-daf437028eec"
authors = ["Lucas ONDEL YANG <lucas.ondel@cnrs.fr", "Cyril ALLAUZEN <allauzen@google.com", "Lukas Burget <burget@fit.vutbr.cz>", "Martin KOCOUR <ikocour@fit.vutbr.cz>", "Tina RAISSI <raissi@i6.informatik.rwth-aachen.de>", "Pablo RIERA <pablo.riera@gmail.com>", "Michael RILEY <riley@google.com>", "Jan TRMAL <yenda@jhu.edu>"]
version = "1.2.2"
authors = ["Lucas ONDEL YANG <lucas.ondel@cnrs.fr", "Cyril ALLAUZEN <allauzen@google.com", "Lukas BURGET <burget@fit.vutbr.cz>", "Simon DEVAUCHELLE <devauchelle@lisn.upsaclay.fr>", "Martin KOCOUR <ikocour@fit.vutbr.cz>", "Tina RAISSI <raissi@i6.informatik.rwth-aachen.de>", "Pablo RIERA <pablo.riera@gmail.com>", "Michael RILEY <riley@google.com>", "Jan TRMAL <yenda@jhu.edu>"]
version = "1.4.1"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63"
Semirings = "900aad66-9ca5-44d4-b043-321c62cb7767"
SumSparseTensors = "472dc678-8c5a-4778-a6eb-28a4e9c7cb58"
[compat]
SumSparseTensors = "0.13"
Semirings = "1.5"
julia = "1.6"
Semirings = "1.7"
SumSparseTensors = "0.14"
......@@ -4,7 +4,14 @@ module NGramFSTs
using ..TensorFSTs
export countngrams, ngramfst
export countngrams,
ngramfst,
read_arpafile
read_arpafile,
ngram_get_score,
ngram_score_sentence
include("ngram_arpa.jl")
const BOS = 0
......
using GZip
"""
::List[::Dict] = read_arpafile_as_dicts(arpafile::String)
read_arpafile(fd)
Read srilm arpa lm in plain format or .gz
It produces Vector of Dict(Tuple, Tuple) holding n-gram tables:
[Dict{(w1,) => (logprob, backoff)},
Dict{(w1,w2) => (logprob, backoff)},
Dict{(w1,w2,w2) => (logprob, backoff)},
...
]
Read srilm arpa lm from a file descripto `fd`.
It produces dictionary of tuples like:
```
(0,0,1gram) => (logprob, backoff)
(0,1gram,2gram) => (logprob, backoff)
(1gram,2gram,3gram) => (logprob, backoff)
(1gram,2gram,3gram) => (logprob, backoff)
```
In case backoff value is not present in ARPA file, its value is `nothing`.
"""
function ngram_read_arpafile(arpafile::String)
if occursin(r".*\.gz$", arpafile)
fd = GZip.open(arpafile, "r")
else
fd = open(arpafile)
end
function ngram_read_arpafile(fd)
ngram_count_in_header = Dict()
# header reading
......@@ -102,9 +94,7 @@ function ngram_read_arpafile(arpafile::String)
end
end
close(fd)
return ngram_dicts
ngram_dicts
end
......
No preview for this file type
......@@ -44,6 +44,9 @@ include("draw.jl")
export
compose,
connect,
weightpush,
toinitial,
tofinal,
closure,
weight,
shortestdistance,
......@@ -54,11 +57,10 @@ export
include("filter.jl")
include("ops.jl")
# arpa format
export ngram_read_arpafile,
ngram_get_score,
ngram_score_sentence
include("ngram_arpa.jl")
export project_to
include("utils.jl")
# standalone libraries
include("../lib/LinearFSTs.jl")
......
......@@ -72,7 +72,9 @@ function _write_dot(io::IO, fst, isyms, osyms)
for arc in arcs(fst)
print(io, " ", stateid[arc.src], " -> ", stateid[arc.dest], " [label=\"")
print(io, get(isyms, arc.isym, arc.isym), ":", get(osyms, arc.osym, arc.osym))
! isone(arc.weight) ? println(io, "/", arc.weight, "\"];") : println(io, "\"];")
if !isone(arc.weight) print(io, "/", arc.weight) end
if iszero(arc.weight) print(io, "\" color=\"grey80") end
println(io, "\"];")
end
println(io, "}")
......
......@@ -151,6 +151,41 @@ function connect(A::TensorFST{S}) where S
end
@enum FSTPush toinitial tofinal
"""
weightpush(A, toinitial | tofinal)
Returns a FST ``B`` with the weights of each path pushed as much as possible towards the initial/final states.
"""
function weightpush(A::TensorFST{S}, p::FSTPush) where S
dims = dimensions(A)
N = length(dims.src)
nt = ntuple(identity, N)
# Shortest distance from q ∈ Q to F
if p == toinitial
d = mulsumstar(A.M, A.ω, dims.dest, dims.src)
oneD = ones(S, size(d)...)
M = mul(map(inv, d), A.M, d, A.dims.src, A.dims.dest)
α = mul(oneD, A.α , d, nt, nt)
ω = mul(map(inv, d), A.ω, oneD, nt, nt)
# Shortest distance from I to q ∈ Q
else
d = mulsumstar(A.α, A.M, dims.src, dims.dest)
oneD = ones(S, size(d)...)
M = mul(d, A.M, map(inv, d), A.dims.src, A.dims.dest)
α = mul(d, A.α , oneD, nt, nt)
ω = mul(oneD, A.ω, map(inv, d), nt, nt)
end
TensorFST{S}(dims, M, α, ω)
end
@enum FSTSort state inputlabel outputlabel
"""
......@@ -260,7 +295,7 @@ function compose(A::TensorFST{S}, B::TensorFST{S}; connect = true, flatten = con
AF = _compose_noeps(sort(Al, state), F; connect = true)
_compose_noeps(AF, sort(Bl, inputlabel); connect, flatten)
else
else
_compose_noeps(A, B; connect, flatten)
end
end
......
# handy functions for gradient projections
#
import SumSparseTensors: SortedSumSparseTensor, SumSparseTensor
"""
project_to([T=ProbSemiring], Ā, A)
Project the tangent `Ā` onto input `A` (i.e. original object).
The type `T` is a `Semiring` of the tangent, `ProbSemiring` by default.
"""
project_to(, A::TensorFST{S}) where S = project_to(ProbSemiring{valtype{S}}, , A)
function project_to(T, ::NamedTuple{(:dims, :M, :α, :ω)}, A::TensorFST{S}) where S
TensorFST{T}(
dimensions(A),
project_to(T, .M, A.M),
project_to(T, .α, A.α),
project_to(T, .ω, A.ω)
)
end
function project_to(T, ΣĀ::NamedTuple{(:axes, :nzcoo, :nzval)}, ΣA::SumSparseTensor{S,N,M}) where {S,N,M}
SumSparseTensor(
ΣA.axes,
ΣA.nzcoo,
map(t -> T(t.val), ΣĀ.nzval)
)
end
function project_to(T, ΣĀ::NamedTuple{(:axes, :ptr, :nzcoo, :nzval)}, ΣA::SortedSumSparseTensor{S,N,I,M}) where {S,N,I,M}
SortedSumSparseTensor{T,N,I}(
ΣA.axes,
ΣA.ptr,
ΣA.nzcoo,
map(t -> T(t.val), ΣĀ.nzval)
)
end
......@@ -14,6 +14,23 @@ function issame(A::TensorFST, B::TensorFST)
Set(finalweights(A)) == Set(finalweights(B)) || return false
end
function isapprox(A::TensorFST, B::TensorFST)
arcs_A, arcs_B = sort(arcs(A)), sort(arcs(B))
for (a_a, a_b) in zip(arcs_A, arcs_B)
# src, isym, osym, dest
if !((a_a.src == a_b.src) && (a_a.isym == a_b.isym) && (a_a.osym == a_b.osym) && (a_a.dest == a_b.dest))
return false
end
# Weights
if !(val(a_a.weight) val(a_b.weight))
return false
end
end
Set(initweights(A)) == Set(initweights(B)) || return false
Set(finalweights(A)) == Set(finalweights(B)) || return false
end
@testset "SymbolTable" begin
_syms = Dict("a" => 1, "b" => 2)
......@@ -270,6 +287,95 @@ end
@test issame(C, compose(A, sort(B, inputlabel); connect = true))
end
@testset verbose=true "Weight-pushing" begin
S = TropicalSemiring{Float64}
A = TFST = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight = one(S), dest=2),
Arc(src=1, isym=2, osym=1, weight = S(1), dest=2),
Arc(src=1, isym=3, osym=2, weight = S(5), dest=2),
Arc(src=1, isym=4, osym=1, weight = one(S), dest=3),
Arc(src=1, isym=5, osym=2, weight = S(1), dest=3),
Arc(src=2, isym=5, osym=1, weight = one(S), dest=4),
Arc(src=2, isym=6, osym=2, weight = S(1), dest=4),
Arc(src=3, isym=5, osym=1, weight = S(4), dest=4),
Arc(src=3, isym=6 ,osym=2, weight = S(5), dest=4),
],
[(1,) => one(S)],
[(4,) => one(S)]
)
B = TFST = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight = one(S), dest=2),
Arc(src=1, isym=2, osym=1, weight = S(1), dest=2),
Arc(src=1, isym=3, osym=2, weight = S(5), dest=2),
Arc(src=1, isym=4, osym=1, weight = S(4), dest=3),
Arc(src=1, isym=5, osym=2, weight = S(5), dest=3),
Arc(src=2, isym=5, osym=1, weight = one(S), dest=4),
Arc(src=2, isym=6, osym=2, weight = S(1), dest=4),
Arc(src=3, isym=5, osym=1, weight = one(S), dest=4),
Arc(src=3, isym=6 ,osym=2, weight = S(1), dest=4),
],
[(1,) => one(S)],
[(4,) => one(S)]
)
@test issame(weightpush(A, toinitial), B)
@test issame(weightpush(A, tofinal), A)
PS = ProbSemiring{Float64}
C = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight=zero(PS), dest=2),
Arc(src=1, isym=2, osym=1, weight=PS(1), dest=2),
Arc(src=1, isym=3, osym=2, weight=PS(5), dest=2),
Arc(src=1, isym=4, osym=1, weight=zero(PS), dest=3),
Arc(src=1, isym=5, osym=2, weight=PS(1), dest=3),
Arc(src=2, isym=5, osym=1, weight=zero(PS), dest=4),
Arc(src=2, isym=6, osym=2, weight=one(PS), dest=4),
Arc(src=3, isym=5, osym=1, weight=PS(4), dest=4),
Arc(src=3, isym=6 ,osym=2, weight=PS(5), dest=4),
],
[(1,) => one(PS)],
[(4,) => one(PS)]
)
D = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight=zero(PS), dest=2),
Arc(src=1, isym=2, osym=1, weight=PS(1/15), dest=2),
Arc(src=1, isym=3, osym=2, weight=PS(5/15), dest=2),
Arc(src=1, isym=4, osym=1, weight=zero(PS), dest=3),
Arc(src=1, isym=5, osym=2, weight=PS(9/15), dest=3),
Arc(src=2, isym=5, osym=1, weight=zero(PS), dest=4),
Arc(src=2, isym=6, osym=2, weight=one(PS), dest=4),
Arc(src=3, isym=5, osym=1, weight=PS(4/9), dest=4),
Arc(src=3, isym=6 ,osym=2, weight=PS(5/9), dest=4),
],
[(1,) => PS(15)],
[(4,) => PS(1)]
)
E = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight=zero(PS), dest=2),
Arc(src=1, isym=2, osym=1, weight=PS(1/6), dest=2),
Arc(src=1, isym=3, osym=2, weight=PS(5/6), dest=2),
Arc(src=1, isym=4, osym=1, weight=zero(PS), dest=3),
Arc(src=1, isym=5, osym=2, weight=one(PS), dest=3),
Arc(src=2, isym=5, osym=1, weight=zero(PS), dest=4),
Arc(src=2, isym=6, osym=2, weight=PS(2/5), dest=4),
Arc(src=3, isym=5, osym=1, weight=PS(4/15), dest=4),
Arc(src=3, isym=6 ,osym=2, weight=PS(5/15), dest=4),
],
[(1,) => one(PS)],
[(4,) => PS(1/15)]
)
@test issame(weightpush(C, toinitial), D)
@test isapprox(weightpush(C, tofinal), E)
end
@testset verbose=true "Libaries" begin
include("linearfsts.jl")
include("ngramfsts.jl")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment