Skip to content
Snippets Groups Projects
Verified Commit 8c50dd3a authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

mwe gradient

parent 0dac1121
Branches
Tags
No related merge requests found
......@@ -6,4 +6,5 @@ version = "0.2.2"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Semirings = "900aad66-9ca5-44d4-b043-321c62cb7767"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
......@@ -3,7 +3,9 @@
module FiniteStateAutomata
using LinearAlgebra
using ChainRulesCore
using SparseArrays
using Semirings
export
# Abstract types
......@@ -48,6 +50,7 @@ include("fst.jl")
include("dense_fsa.jl")
include("reverse.jl")
include("sum.jl")
include("ops.jl")
include("intersect.jl")
......
......@@ -4,6 +4,6 @@ Base.iterate(A::AbstractFST) = nstates(A) == 0 ? nothing : (α(A), α(A))
function Base.iterate(A::AbstractFST, uₙ)
uₙ₊₁ = T(A)' * uₙ
nnz(uₙ) > 0 ? (uₙ₊₁, uₙ₊₁) : nothing
nnz(uₙ₊₁) > 0 ? (uₙ₊₁, uₙ₊₁) : nothing
end
......@@ -25,22 +25,6 @@ function addskipedges(M, states, weights)
FST(αₙ, Tₙ, ωₙ, ρₙ, λ(M))
end
"""
Base.sum(A[; init = α(A), n = nstates(A)])
Accumulate the weight of all the paths of length less or equal to
`n`.
"""
function Base.sum(A::AbstractFST; init = α(A), n = nstates(A))
v = init
σ = dot(v, ω(A)) + ρ(A)
for _ in 2:n
v = T(A)' * v
σ += dot(v, ω(A))
end
σ
end
"""
cat(A1[, A2, ...])
......
# SPDX-License-Identifier: CECILL-2.1
"""
Base.sum(A; n = nstates(A))
Accumulate the weights of all the paths. For cyclic FSTs sum after `n`
steps.
"""
#function Base.sum(A::AbstractFST; n = nstates(A))
# W = ρ(A)
# for (i, uₙ) in enumerate(A)
# i <= n || break
# W += dot(uₙ, ω(A))
# end
# W
#end
function ChainRulesCore.rrule(::typeof(Base.sum), A::AbstractFST{K}) where K
W = ρ(A)
U = []
for uₙ in A
push!(U, uₙ)
W += dot(uₙ, ω(A))
end
function pullback(ΔY)
V = []
for vₙ in reverse(A)
push!(V, vₙ)
end
= sum(U)
I_α, _ = findnz(α(A))
I_T, J_T, _ = findnz(T(A))
V_T = zeros(K, length(J_T))
k_1 = length(U)
for i in 1:k_1
for j in 1:(k_1 - i)
uᵢ, vⱼ = U[i], V[k_1 - j]
V_T .+= uᵢ[I_T] .* vⱼ[J_T]
end
end
= sum(V)
I_ω, _ = findnz(ω(A))
ΔA = FST(
sparsevec(I_α, [I_α], nstates(A)),
sparse(I_T, J_T, V_T, nstates(A), nstates(A)),
sparsevec(I_ω, [I_ω], nstates(A)),
iszero(ρ(A)) ? zero(K) : one(K),
λ(A)
)
(NoTangent(), ΔA)
end
W, pullback
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment