diff --git a/sandbox/example.jl b/sandbox/example.jl index e66addd6d5f271472ac12e4944c25fa4704ed4dc..f64685e139820c861908ea338768d73695d7a6cd 100644 --- a/sandbox/example.jl +++ b/sandbox/example.jl @@ -576,16 +576,16 @@ end edges # â•”â•â•¡ c8d7966a-61c1-4ab2-856c-6ca5d10648cd -fsadet(M3 |> renorm, edges, states) |> renorm +fsadet(M3, edges, states) |> renorm # â•”â•â•¡ 358d297d-4def-44dc-b79d-6d23aa67d179 -sum(M3 |> renorm) +M3 # â•”â•â•¡ c31ccfdc-3687-46cb-9eed-03791bde56d4 sparsevec([2, 3], 1, 5) # â•”â•â•¡ 75c4a277-9edb-4f47-bd5e-13dcb498bd96 -M3 +sum(union(AcyclicFSA(M3), AcyclicFSA(M3))) # â•”â•â•¡ 68b7380b-2514-4362-9504-552fb4e467f7 Z1 = spdiagm(M3.α) * C diff --git a/src/FSALinalg.jl b/src/FSALinalg.jl index 102a2b628a0ceadb349842552449fc6bf073cb0a..d34225d9acbc98d8b36810fe5ed3b2c22d60e7cb 100644 --- a/src/FSALinalg.jl +++ b/src/FSALinalg.jl @@ -32,6 +32,7 @@ export addskipedges, closure, globalrenorm, + propagate, renorm include("abstractfsa.jl") diff --git a/src/fsa.jl b/src/fsa.jl index 2b703ee458d76856e2dac330ef11725d21acda2e..023df25d24cdd08e28e76b988a0d42b8f1458e96 100644 --- a/src/fsa.jl +++ b/src/fsa.jl @@ -50,12 +50,13 @@ function Base.convert(f::Function, A::AbstractFSA{K,L}) where {K,L} end struct AcyclicFSA{K,L} <: AbstractAcyclicFSA{K,L} - fsa::FSA{K,L} + fsa::AbstractFSA{K,L} end AcyclicFSA(α, T, ω, Ï, λ) = AcyclicFSA(FSA(α, T, ω, Ï, λ)) -Base.parent(A::AcyclicFSA) = A.fsa +Base.parent(A::AcyclicFSA) = parent(A.fsa) -Base.convert(f, A::AbstractAcyclicFSA) = AcyclicFSA(convert(f, A.fsa)) +Base.convert(f::Function, A::AbstractAcyclicFSA{K,L}) where {K,L} = + AcyclicFSA(convert(f, parent(A))) diff --git a/src/ops.jl b/src/ops.jl index 1d3c560e3b8d8dfa99b5c919f4334a496c5da47f..dc90cca604ba657b73cac6f3093e777a973112b5 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -87,9 +87,9 @@ function Base.filter(f::Function, A::AbstractFSA) end """ - notaccessible(A::AbstractFSA{K}) where K + accessible(A::AbstractFSA{K}) where K -Returns a vector `x` of dimension `nstates(A)` where `x[i]` is `one(K)` +Return a vector `x` of dimension `nstates(A)` where `x[i]` is `one(K)` if the state `i` is not accessible, `zero(K)` otherwise. """ function accessible(A::AbstractFSA{K}) where K @@ -125,7 +125,7 @@ one. function renorm(A::AbstractFSA) Z = inv.((sum(T(A), dims=2) .+ ω(A))) Zâ‚ = inv(sum(α(A)) + Ï(A)) - FSA( + typeof(A)( Zâ‚ * α(A), Z .* T(A), Z[:, 1] .* ω(A), @@ -182,22 +182,38 @@ Base.union(A1::AbstractFSA{K}, AN::AbstractFSA{K}...) where K = #= Functions for acyclic FSA =# +renorm(A::AbstractAcyclicFSA) = AcyclicFSA(parent(A) |> renorm) +Base.reverse(A::AbstractAcyclicFSA) = AcyclicFSA(parent(A) |> reverse) +Base.cat(A1::AbstractAcyclicFSA, A2::AbstractAcyclicFSA) = + AcyclicFSA(cat(parent(A1), parent(A2))) +Base.union(A1::AbstractAcyclicFSA, A2::AbstractAcyclicFSA) = + AcyclicFSA(union(parent(A1), parent(A2))) + """ - globalrenorm(A::AbstractAcyclicFSA) + propagate(A::AbstractAcyclicFSA[; forward = true]) -Global renormalization of the weights of `A`. +Multiply the weight of each arc by the sum-product of all the prefix +paths. If `forward` is `false` the arc's weight is multiply by the +sum-product of all the suffix paths of this arc. """ -function globalrenorm(A::AbstractAcyclicFSA) - # Accumulate the weight backward starting from the end state. - v = ω(A) +function propagate(A::AbstractAcyclicFSA; forward = true) + v = forward ? α(A) : ω(A) + M = forward ? T(A)' : T(A) σ = v while nnz(v) > 0 - v = T(A) * v + v = M * v σ += v end σ D = spdiagm(σ) - FSA(σ .* α(A), T(A) * D, ω(A), Ï(A), λ(A)) |> renorm + AcyclicFSA(FSA(σ .* α(A), T(A) * D, ω(A), Ï(A), λ(A))) end +""" + globalrenorm(A::AbstractAcyclicFSA) + +Global renormalization of the weights of `A`. +""" +globalrenorm(A::AbstractAcyclicFSA) = propagate(A; forward = false) |> renorm + diff --git a/test/runtests.jl b/test/runtests.jl index df0e9fcd64ec6b9102542a344b9b2460b1f5c17f..f3f2cfb45eccc9c22068cb391718d8e5a6030bfd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,7 +41,7 @@ for K in Ks, L in Ls ω2 = sparsevec([2, 4], K[5, 6], 4) Ï2 = zero(K) λ2 = [L("a"), L("b"), L("c"), L("d")] - A2 = FSA(α2, T2, ω2, Ï2, λ2) + A2 = AcyclicFSA(α2, T2, ω2, Ï2, λ2) B2 = convert((w, l) -> f(L, A2, w, l), A2) Aϵ = FSA(spzeros(K, 0), spzeros(K, 0, 0), spzeros(K, 0), one(K), L[]) @@ -71,6 +71,7 @@ for K in Ks, L in Ls cs_B1_B2 = sum(B1; n) + sum(B2; n) @test cs_B12.tval[1] ≈ cs_B1_B2.tval[1] @test cs_B12.tval[2] == cs_B1_B2.tval[2] + @test typeof(union(A2, A2)) <: AbstractAcyclicFSA end @testset verbose=true "concatenation" begin @@ -79,6 +80,7 @@ for K in Ks, L in Ls cs_B12 = sum(B12; n = 2n) cs_B1_B2 = sum(B2; n) * sum(B1; n) @test cs_B12.tval[2] ⊇ cs_B1_B2.tval[2] + @test typeof(cat(A2, A2)) <: AbstractAcyclicFSA end # The number of iteration for the cumulative sum depends on the @@ -107,6 +109,7 @@ for K in Ks, L in Ls s1 = val(sum(B2; n).tval[2]) s2 = Set((StringMonoid ∘ reverse ∘ val).((val ∘ sum)(rB2)[2])) @test s1 == s2 + @test typeof(rA2) <: AbstractAcyclicFSA end @testset verbose=true "renorm" begin @@ -127,7 +130,9 @@ for K in Ks, L in Ls end @testset verbose=true "globalrenorm" begin - @test Base.isapprox(val(sum(AcyclicFSA(A2) |> globalrenorm; n = 100)), val(one(K)), atol=1e-6) + A = AcyclicFSA(A2) |> globalrenorm + @test Base.isapprox(val(sum(A; n = 100)), val(one(K)), atol=1e-6) + @test typeof(A) <: AbstractAcyclicFSA end end end