Skip to content
Snippets Groups Projects
Verified Commit 817fe900 authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

fixed test tensor sum

parent 239c93f9
No related branches found
No related tags found
No related merge requests found
......@@ -19,9 +19,6 @@ end
Base.zero(::Type{BoolSemiring}) = BoolSemiring(false)
Base.one(::Type{BoolSemiring}) = BoolSemiring(true)
"""
struct LogSemiring{T,a} <: Semiring
val::T
......@@ -42,7 +39,7 @@ end
(x::LogSemiring{T,a}, y::LogSemiring{T,a}) where {T,a} = LogSemiring{T,a}(x.val + y.val)
(x::LogSemiring{T,a}, y::LogSemiring{T,a}) where {T,a} = LogSemiring{T,a}(x.val - y.val)
Base.inv(x::LogSemiring{T,a}) where {T,a} = LogSemiring{T,a}(-x.val)
Base.zero(K::Type{<:LogSemiring{T,a}}) where {T,a} = K(a > 0 ? T(-Inf) : T(Inf))
Base.zero(K::Type{<:LogSemiring{T,a}}) where {T,a} = K(ifelse(a > 0, T(-Inf), T(Inf)))
Base.one(K::Type{<:LogSemiring{T}}) where T = K(T(0))
"""
......@@ -97,3 +94,4 @@ Base.zero(K::Type{<:ProductSemiring{A,B}}) where A where B = K(zero(A), zero(B))
Base.one(K::Type{<:ProductSemiring{A,B}}) where A where B = K(one(A), one(B))
val(x::ProductSemiring) = (x.val1, x.val2)
......@@ -10,7 +10,7 @@ using Test
end
@testset "Linear Algebra" begin
include("sparsetensor.jl")
include("sparsearrays.jl")
end
@testset "FST" begin
......
......@@ -33,7 +33,7 @@ end
end
@testset "sparse tensor" begin
for S in [BoolSemiring, LogSemiring{Float32,1}, ArcticSemiring{Float64}]
for S in [BoolSemiring, LogSemiring{Float32,-1}, ArcticSemiring{Float64}]
Y = [
zero(S) zero(S) zero(S)
one(S) zero(S) one(S) one(S)
......@@ -57,12 +57,9 @@ end
@test all(X .== Y)
@test all(sum(X, dims=(1, 2)) .== sum(Y, dims=(1, 2)))
@test all(sum(X, dims=(2, 3)) .== sum(Y, dims=(2, 3)))
display(sum(X, dims=(2, 3)))
display(sum(Y, dims=(2,3)))
display(all(sum(X, dims=(2,3)) .== sum(Y, dims=(2,3))))
@test all(sum(X, dims=(3, 4)) .== sum(Y, dims=(3, 4)))
@test all(val.(sum(X, dims=(1, 2))) . val.(sum(Y, dims=(1, 2))))
@test all(val.(sum(X, dims=(2, 3))) . val.(sum(Y, dims=(2, 3))))
@test all(val.(sum(X, dims=(3, 4))) . val.(sum(Y, dims=(3, 4))))
XI, XJ, XK, XL, XV = findnz(X)
@test Set(XI) == Set(I)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment