diff --git a/sandbox/GNAT.jl b/sandbox/GNAT.jl index 4f95a034532b625c409db63c08979c7d97db983d..3b7cdb5bad85a14f50c13d284c1a7ceaaba2f868 100644 --- a/sandbox/GNAT.jl +++ b/sandbox/GNAT.jl @@ -6,7 +6,7 @@ using InteractiveUtils # ╔═╡ 510f7d8a-ce11-11ed-2832-cdc20fa6f700 begin - using Pkg; Pkg.activate("..") + using Pkg; Pkg.activate("./mmi") using Revise using FiniteStateAutomata using Semirings @@ -131,11 +131,14 @@ L # ╔═╡ 5ecadd35-feea-459f-9776-3a2762b0de9f # FSA(C) -# ╔═╡ 2ea1644a-8836-443b-a1ee-8d13dbd4a534 -H[:, 2] +# ╔═╡ 410d3e85-8946-4c74-9fd5-77c799e45277 +sum(G ∩ C) ≈ sum(FSA(G ∩ C)) -# ╔═╡ 68d8f4b3-48e1-49ca-b910-c65c2ec19768 -ones(K, 3) * H[:,2]' +# ╔═╡ 895f0152-e4c6-4a1d-97e0-f4a4308fe362 +sum(FSA(G ∩ C)) + +# ╔═╡ 584c1b63-3f12-4383-8bf0-8078e518910d +≈ # ╔═╡ Cell order: # ╠═510f7d8a-ce11-11ed-2832-cdc20fa6f700 @@ -163,5 +166,6 @@ ones(K, 3) * H[:,2]' # ╠═b2290feb-24a1-4f7d-8216-7c69ea363a67 # ╠═5c3893f7-b64f-410f-9be2-5c2023fd9093 # ╠═5ecadd35-feea-459f-9776-3a2762b0de9f -# ╠═2ea1644a-8836-443b-a1ee-8d13dbd4a534 -# ╠═68d8f4b3-48e1-49ca-b910-c65c2ec19768 +# ╠═410d3e85-8946-4c74-9fd5-77c799e45277 +# ╠═895f0152-e4c6-4a1d-97e0-f4a4308fe362 +# ╠═584c1b63-3f12-4383-8bf0-8078e518910d diff --git a/src/dense_fsa.jl b/src/dense_fsa.jl index 39efba1d11a8ce9e6e10231fbba93ed8d4a26585..339480d1fae075e83bb1cace16057d01381995b3 100644 --- a/src/dense_fsa.jl +++ b/src/dense_fsa.jl @@ -5,6 +5,8 @@ struct DenseFSA{K, L} <: AbstractAcyclicFSA{K, L} DenseFSA(H, Σ, ρ) = length(Σ) == size(H, 1) ? new{eltype(H), eltype(Σ)}(H, Σ, ρ) : error("Σ not compatible") end +DenseFSA(H::AbstractMatrix{K}, Σ::AbstractVector) where K = DenseFSA(H, Σ, zero(K)) + α(G::DenseFSA) = begin L = length(G.Σ) N = size(G.H, 2) @@ -111,3 +113,12 @@ end λ(I::IntersectedDenseFSA) = repeat(λ(I.B), size(I.A.H, 2) - 1) ρ(I::IntersectedDenseFSA) = ρ(I.B) * ρ(I.A) + +function Base.sum(I::IntersectedDenseFSA, n = size(I.A.H, 2) - 1) + CH = I.C' * I.A.H + v = CH[:, 1] .* α(I.B) + for n in 2:n + v = (T(I.B)' * v) .* CH[:, n] + end + dot(v, ω(I.B) .* CH[:, end]) + ρ(I) +end diff --git a/src/ops.jl b/src/ops.jl index f44d17e08155a85453897751fb43f9b3c60f2eac..fbfd4081922641fb6a3ef986da5c984505756de2 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -31,10 +31,10 @@ end Accumulate the weight of all the paths of length less or equal to `n`. """ -function Base.sum(A::AbstractFSA; init = α(A), n = nstates(A) + 2) +function Base.sum(A::AbstractFSA; init = α(A), n = nstates(A)) v = init σ = dot(v, ω(A)) + ρ(A) - for i in 2:n-1 + for _ in 2:n v = T(A)' * v σ += dot(v, ω(A)) end diff --git a/test/runtests.jl b/test/runtests.jl index 77a6e42b224359e83539017bae66f26bd3aa49e9..16ffe3258c4354bb835919c76907dcf6eb054023 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -162,6 +162,14 @@ for K in Ks, L in Ls @test s1.tval[2] == s2.tval[2] @test typeof(AcyclicFSA(A3) |> minimize) <: AbstractAcyclicFSA end + + @testset verbose=true "intersection" begin + Σ = ["a", "b", "c"] + H = K.randn(length(Σ), 4) + dfsa = DenseFSA(H, Σ) + @test sum(dfsa ∩ A3) ≈ sum(FSA(dfsa ∩ A3)) + @test sum(A3 ∩ dfsa) ≈ sum(FSA(dfsa ∩ A3)) + end end end