Skip to content
Snippets Groups Projects
Commit 6da96cc2 authored by Pablo Riera's avatar Pablo Riera
Browse files
parents 2c4cbc6a c666ab46
No related branches found
No related tags found
No related merge requests found
# Releases
## 1.4.1 - 01.03.2024
## Fixed
- `project_to` function was not exported
## 1.4.0 - 01.03.2024
## Added
- Limited support of automatic differentiation
## 1.3.0 - 29.02.2024
## Added
- Weightpushing algorithm
## 1.2.2 - 26.02.2024
## Fixed
- tokenfst: is accept only sequence of 1 token.
......
name = "TensorFSTs"
uuid = "9f8e39db-0d6b-46cc-a14d-daf437028eec"
authors = ["Lucas ONDEL YANG <lucas.ondel@cnrs.fr", "Cyril ALLAUZEN <allauzen@google.com", "Lukas Burget <burget@fit.vutbr.cz>", "Martin KOCOUR <ikocour@fit.vutbr.cz>", "Tina RAISSI <raissi@i6.informatik.rwth-aachen.de>", "Pablo RIERA <pablo.riera@gmail.com>", "Michael RILEY <riley@google.com>", "Jan TRMAL <yenda@jhu.edu>"]
version = "1.2.2"
authors = ["Lucas ONDEL YANG <lucas.ondel@cnrs.fr",
"Cyril ALLAUZEN <allauzen@google.com",
"Lukas BURGET <burget@fit.vutbr.cz>",
"Simon DEVAUCHELLE <devauchelle@lisn.upsaclay.fr>",
"Martin KOCOUR <ikocour@fit.vutbr.cz>",
"Tina RAISSI <raissi@i6.informatik.rwth-aachen.de>",
"Pablo RIERA <pablo.riera@gmail.com>",
"Michael RILEY <riley@google.com>",
"Jan TRMAL <yenda@jhu.edu>"]
version = "1.4.1"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
......@@ -9,6 +17,6 @@ Semirings = "900aad66-9ca5-44d4-b043-321c62cb7767"
SumSparseTensors = "472dc678-8c5a-4778-a6eb-28a4e9c7cb58"
[compat]
SumSparseTensors = "0.13"
Semirings = "1.5"
SumSparseTensors = "0.14"
Semirings = "1.7"
julia = "1.6"
No preview for this file type
......@@ -44,6 +44,9 @@ include("draw.jl")
export
compose,
connect,
weightpush,
toinitial,
tofinal,
closure,
weight,
shortestdistance,
......@@ -54,6 +57,8 @@ export
include("filter.jl")
include("ops.jl")
export project_to
include("utils.jl")
# standalone libraries
include("../lib/LinearFSTs.jl")
......
......@@ -72,7 +72,9 @@ function _write_dot(io::IO, fst, isyms, osyms)
for arc in arcs(fst)
print(io, " ", stateid[arc.src], " -> ", stateid[arc.dest], " [label=\"")
print(io, get(isyms, arc.isym, arc.isym), ":", get(osyms, arc.osym, arc.osym))
! isone(arc.weight) ? println(io, "/", arc.weight, "\"];") : println(io, "\"];")
if !isone(arc.weight) print(io, "/", arc.weight) end
if iszero(arc.weight) print(io, "\" color=\"grey80") end
println(io, "\"];")
end
println(io, "}")
......
......@@ -151,6 +151,41 @@ function connect(A::TensorFST{S}) where S
end
@enum FSTPush toinitial tofinal
"""
weightpush(A, toinitial | tofinal)
Returns a FST ``B`` with the weights of each path pushed as much as possible towards the initial/final states.
"""
function weightpush(A::TensorFST{S}, p::FSTPush) where S
dims = dimensions(A)
N = length(dims.src)
nt = ntuple(identity, N)
# Shortest distance from q ∈ Q to F
if p == toinitial
d = mulsumstar(A.M, A.ω, dims.dest, dims.src)
oneD = ones(S, size(d)...)
M = mul(map(inv, d), A.M, d, A.dims.src, A.dims.dest)
α = mul(oneD, A.α , d, nt, nt)
ω = mul(map(inv, d), A.ω, oneD, nt, nt)
# Shortest distance from I to q ∈ Q
else
d = mulsumstar(A.α, A.M, dims.src, dims.dest)
oneD = ones(S, size(d)...)
M = mul(d, A.M, map(inv, d), A.dims.src, A.dims.dest)
α = mul(d, A.α , oneD, nt, nt)
ω = mul(oneD, A.ω, map(inv, d), nt, nt)
end
TensorFST{S}(dims, M, α, ω)
end
@enum FSTSort state inputlabel outputlabel
"""
......@@ -260,7 +295,7 @@ function compose(A::TensorFST{S}, B::TensorFST{S}; connect = true, flatten = con
AF = _compose_noeps(sort(Al, state), F; connect = true)
_compose_noeps(AF, sort(Bl, inputlabel); connect, flatten)
else
else
_compose_noeps(A, B; connect, flatten)
end
end
......
# handy functions for gradient projections
#
import SumSparseTensors: SortedSumSparseTensor, SumSparseTensor
"""
project_to([T=ProbSemiring], Ā, A)
Project the tangent `Ā` onto input `A` (i.e. original object).
The type `T` is a `Semiring` of the tangent, `ProbSemiring` by default.
"""
project_to(, A::TensorFST{S}) where S = project_to(ProbSemiring{valtype{S}}, , A)
function project_to(T, ::NamedTuple{(:dims, :M, :α, :ω)}, A::TensorFST{S}) where S
TensorFST{T}(
dimensions(A),
project_to(T, .M, A.M),
project_to(T, .α, A.α),
project_to(T, .ω, A.ω)
)
end
function project_to(T, ΣĀ::NamedTuple{(:axes, :nzcoo, :nzval)}, ΣA::SumSparseTensor{S,N,M}) where {S,N,M}
SumSparseTensor(
ΣA.axes,
ΣA.nzcoo,
map(t -> T(t.val), ΣĀ.nzval)
)
end
function project_to(T, ΣĀ::NamedTuple{(:axes, :ptr, :nzcoo, :nzval)}, ΣA::SortedSumSparseTensor{S,N,I,M}) where {S,N,I,M}
SortedSumSparseTensor{T,N,I}(
ΣA.axes,
ΣA.ptr,
ΣA.nzcoo,
map(t -> T(t.val), ΣĀ.nzval)
)
end
......@@ -14,6 +14,23 @@ function issame(A::TensorFST, B::TensorFST)
Set(finalweights(A)) == Set(finalweights(B)) || return false
end
function isapprox(A::TensorFST, B::TensorFST)
arcs_A, arcs_B = sort(arcs(A)), sort(arcs(B))
for (a_a, a_b) in zip(arcs_A, arcs_B)
# src, isym, osym, dest
if !((a_a.src == a_b.src) && (a_a.isym == a_b.isym) && (a_a.osym == a_b.osym) && (a_a.dest == a_b.dest))
return false
end
# Weights
if !(val(a_a.weight) val(a_b.weight))
return false
end
end
Set(initweights(A)) == Set(initweights(B)) || return false
Set(finalweights(A)) == Set(finalweights(B)) || return false
end
@testset "SymbolTable" begin
_syms = Dict("a" => 1, "b" => 2)
......@@ -270,6 +287,95 @@ end
@test issame(C, compose(A, sort(B, inputlabel); connect = true))
end
@testset verbose=true "Weight-pushing" begin
S = TropicalSemiring{Float64}
A = TFST = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight = one(S), dest=2),
Arc(src=1, isym=2, osym=1, weight = S(1), dest=2),
Arc(src=1, isym=3, osym=2, weight = S(5), dest=2),
Arc(src=1, isym=4, osym=1, weight = one(S), dest=3),
Arc(src=1, isym=5, osym=2, weight = S(1), dest=3),
Arc(src=2, isym=5, osym=1, weight = one(S), dest=4),
Arc(src=2, isym=6, osym=2, weight = S(1), dest=4),
Arc(src=3, isym=5, osym=1, weight = S(4), dest=4),
Arc(src=3, isym=6 ,osym=2, weight = S(5), dest=4),
],
[(1,) => one(S)],
[(4,) => one(S)]
)
B = TFST = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight = one(S), dest=2),
Arc(src=1, isym=2, osym=1, weight = S(1), dest=2),
Arc(src=1, isym=3, osym=2, weight = S(5), dest=2),
Arc(src=1, isym=4, osym=1, weight = S(4), dest=3),
Arc(src=1, isym=5, osym=2, weight = S(5), dest=3),
Arc(src=2, isym=5, osym=1, weight = one(S), dest=4),
Arc(src=2, isym=6, osym=2, weight = S(1), dest=4),
Arc(src=3, isym=5, osym=1, weight = one(S), dest=4),
Arc(src=3, isym=6 ,osym=2, weight = S(1), dest=4),
],
[(1,) => one(S)],
[(4,) => one(S)]
)
@test issame(weightpush(A, toinitial), B)
@test issame(weightpush(A, tofinal), A)
PS = ProbSemiring{Float64}
C = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight=zero(PS), dest=2),
Arc(src=1, isym=2, osym=1, weight=PS(1), dest=2),
Arc(src=1, isym=3, osym=2, weight=PS(5), dest=2),
Arc(src=1, isym=4, osym=1, weight=zero(PS), dest=3),
Arc(src=1, isym=5, osym=2, weight=PS(1), dest=3),
Arc(src=2, isym=5, osym=1, weight=zero(PS), dest=4),
Arc(src=2, isym=6, osym=2, weight=one(PS), dest=4),
Arc(src=3, isym=5, osym=1, weight=PS(4), dest=4),
Arc(src=3, isym=6 ,osym=2, weight=PS(5), dest=4),
],
[(1,) => one(PS)],
[(4,) => one(PS)]
)
D = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight=zero(PS), dest=2),
Arc(src=1, isym=2, osym=1, weight=PS(1/15), dest=2),
Arc(src=1, isym=3, osym=2, weight=PS(5/15), dest=2),
Arc(src=1, isym=4, osym=1, weight=zero(PS), dest=3),
Arc(src=1, isym=5, osym=2, weight=PS(9/15), dest=3),
Arc(src=2, isym=5, osym=1, weight=zero(PS), dest=4),
Arc(src=2, isym=6, osym=2, weight=one(PS), dest=4),
Arc(src=3, isym=5, osym=1, weight=PS(4/9), dest=4),
Arc(src=3, isym=6 ,osym=2, weight=PS(5/9), dest=4),
],
[(1,) => PS(15)],
[(4,) => PS(1)]
)
E = TensorFST(
[
Arc(src=1, isym=1, osym=2, weight=zero(PS), dest=2),
Arc(src=1, isym=2, osym=1, weight=PS(1/6), dest=2),
Arc(src=1, isym=3, osym=2, weight=PS(5/6), dest=2),
Arc(src=1, isym=4, osym=1, weight=zero(PS), dest=3),
Arc(src=1, isym=5, osym=2, weight=one(PS), dest=3),
Arc(src=2, isym=5, osym=1, weight=zero(PS), dest=4),
Arc(src=2, isym=6, osym=2, weight=PS(2/5), dest=4),
Arc(src=3, isym=5, osym=1, weight=PS(4/15), dest=4),
Arc(src=3, isym=6 ,osym=2, weight=PS(5/15), dest=4),
],
[(1,) => one(PS)],
[(4,) => PS(1/15)]
)
@test issame(weightpush(C, toinitial), D)
@test isapprox(weightpush(C, tofinal), E)
end
@testset verbose=true "Libaries" begin
include("linearfsts.jl")
include("ngramfsts.jl")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment