Skip to content
Snippets Groups Projects
Commit d6b60316 authored by Martin Kocour's avatar Martin Kocour
Browse files

Move IntersectedDenseFSA

parent 29fc89c6
No related branches found
No related tags found
No related merge requests found
......@@ -44,7 +44,7 @@ export
include("abstractfsa.jl")
include("fsa.jl")
include("ops.jl")
include("intersect.jl")
include("dense_fsa.jl")
include("intersect.jl")
end
......@@ -39,86 +39,3 @@ end
ρ(G::DenseFSA) = G.ρ
λ(G::DenseFSA) = repeat(G.Σ, size(G.H, 2) - 1)
# Intersection with DenseFSA
struct IntersectedDenseFSA{K, L, T<:AbstractFSA{K, L}} <: AbstractAcyclicFSA{K, L}
A::DenseFSA{K, L}
B::T
C::AbstractMatrix{K} # of shape |Σ| x nstates(A)
end
_computeC(::Type{K}, Σ::AbstractVector, labels::AbstractVector) where K = begin
rows = []
cols = []
for (i,s) in enumerate(Σ)
for (j, l) in enumerate(labels)
if s == l
push!(rows, i)
push!(cols, j)
end
end
end
sparse(rows, cols, one(K), length(Σ), length(labels))
end
IntersectedDenseFSA(A::DenseFSA{K}, B::AbstractFSA{K}) where K = begin
C = _computeC(K, A.Σ, λ(B)) # TODO switch cols and rows
IntersectedDenseFSA(A, B, C)
# Hn = C' * A.H
# Hn[:, 1] .*= α(B)
# Hn[:, end] .*= ω(B)
# DenseFSA(Hn, λ(B), ρ(A) * ρ(B))
end
Base.intersect(A::DenseFSA, B::AbstractFSA) = IntersectedDenseFSA(A, B)
Base.intersect(A::AbstractFSA, B::DenseFSA) = IntersectedDenseFSA(B, A) # we assume ∩ to be commutative w.r.t any K
α(I::IntersectedDenseFSA) = begin
h1 = I.A.H[:, 1]
Q = nstates(I.B) * (size(I.A.H, 2) - 1)
vals = (I.C' * h1) .* α(I.B)
sparsevec(findnz(vals)..., Q)
end
ω(I::IntersectedDenseFSA) = begin
h = I.A.H[:, end]
Q = nstates(I.B) * (size(I.A.H, 2) - 1)
vals = (I.C' * h) .* ω(I.B)
I, V = findnz(vals)
sparsevec(I .+ (Q - length(vals)), V, Q)
end
T(I::IntersectedDenseFSA) = begin
H = I.A.H
Tb = T(I.B)
C = I.C
K = eltype(H)
Σ = length(I.A.Σ)
Λ = nstates(I.B)
Q = Λ * (size(H, 2) - 1)
rows = []
cols = []
vals = K[]
for n in 2:size(H, 2) - 1
tmp = ones(K, Λ) * (C' * H[:, n])' .* Tb
r, c, v = findnz(tmp)
append!(rows, (n - 2) * Λ .+ r)
append!(cols, (n - 1) * Λ .+ c)
append!(vals, v)
end
sparse(rows, cols, vals, Q, Q)
end
λ(I::IntersectedDenseFSA) = repeat(λ(I.B), size(I.A.H, 2) - 1)
ρ(I::IntersectedDenseFSA) = ρ(I.B) * ρ(I.A)
function Base.sum(I::IntersectedDenseFSA, n = size(I.A.H, 2) - 1)
CH = I.C' * I.A.H
v = CH[:, 1] .* α(I.B)
for n in 2:n
v = (T(I.B)' * v) .* CH[:, n]
end
dot(v, ω(I.B) .* CH[:, end]) + ρ(I)
end
# SPDX-License-Identifier: CECILL-2.1
# Generic intersection
struct IntersectedFSA{K, L, T_A<:AbstractFSA{K,L}, T_B<:AbstractFSA{K,L}} <: AbstractFSA{K, L}
A::T_A
B::T_B
......@@ -33,3 +36,86 @@ T(I::IntersectedFSA) = I.M' * kron(T(I.A), T(I.B)) * I.M
ω(I::IntersectedFSA) = I.M' * kron(ω(I.A), ω(I.B))
ρ(I::IntersectedFSA) = ρ(I.A) * ρ(I.B)
# Intersection with DenseFSA
struct IntersectedDenseFSA{K, L, T<:AbstractFSA{K, L}} <: AbstractAcyclicFSA{K, L}
A::DenseFSA{K, L}
B::T
C::AbstractMatrix{K} # of shape |Σ| x nstates(A)
end
_computeC(::Type{K}, Σ::AbstractVector, labels::AbstractVector) where K = begin
rows = []
cols = []
for (i,s) in enumerate(Σ)
for (j, l) in enumerate(labels)
if s == l
push!(rows, i)
push!(cols, j)
end
end
end
sparse(rows, cols, one(K), length(Σ), length(labels))
end
IntersectedDenseFSA(A::DenseFSA{K}, B::AbstractFSA{K}) where K = begin
C = _computeC(K, A.Σ, λ(B)) # TODO switch cols and rows
IntersectedDenseFSA(A, B, C)
# Hn = C' * A.H
# Hn[:, 1] .*= α(B)
# Hn[:, end] .*= ω(B)
# DenseFSA(Hn, λ(B), ρ(A) * ρ(B))
end
Base.intersect(A::DenseFSA, B::AbstractFSA) = IntersectedDenseFSA(A, B)
Base.intersect(A::AbstractFSA, B::DenseFSA) = IntersectedDenseFSA(B, A) # we assume ∩ to be commutative w.r.t any K
α(I::IntersectedDenseFSA) = begin
h1 = I.A.H[:, 1]
Q = nstates(I.B) * (size(I.A.H, 2) - 1)
vals = (I.C' * h1) .* α(I.B)
sparsevec(findnz(vals)..., Q)
end
ω(I::IntersectedDenseFSA) = begin
h = I.A.H[:, end]
Q = nstates(I.B) * (size(I.A.H, 2) - 1)
vals = (I.C' * h) .* ω(I.B)
I, V = findnz(vals)
sparsevec(I .+ (Q - length(vals)), V, Q)
end
T(I::IntersectedDenseFSA) = begin
H = I.A.H
Tb = T(I.B)
C = I.C
K = eltype(H)
Σ = length(I.A.Σ)
Λ = nstates(I.B)
Q = Λ * (size(H, 2) - 1)
rows = []
cols = []
vals = K[]
for n in 2:size(H, 2) - 1
tmp = ones(K, Λ) * (C' * H[:, n])' .* Tb
r, c, v = findnz(tmp)
append!(rows, (n - 2) * Λ .+ r)
append!(cols, (n - 1) * Λ .+ c)
append!(vals, v)
end
sparse(rows, cols, vals, Q, Q)
end
λ(I::IntersectedDenseFSA) = repeat(λ(I.B), size(I.A.H, 2) - 1)
ρ(I::IntersectedDenseFSA) = ρ(I.B) * ρ(I.A)
function Base.sum(I::IntersectedDenseFSA, n = size(I.A.H, 2) - 1)
CH = I.C' * I.A.H
v = CH[:, 1] .* α(I.B)
for n in 2:n
v = (T(I.B)' * v) .* CH[:, n]
end
dot(v, ω(I.B) .* CH[:, end]) + ρ(I)
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment