Skip to content
Snippets Groups Projects
Verified Commit 239c93f9 authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

added skeleton example

parent 80390a4c
No related branches found
No related tags found
No related merge requests found
1 2 1 1 -5
1 3 2 2 -9
2 3 2 2 -8
2 4 2 2 -12
3 4 3 3 -6
4 -5
# SPDX-License-Identifier: CECILL-2.1
# These packages are just for plotting.
using ImageInTerminal
using Sixel
# We just use the "TensorFSTs" without installing it.
push!(LOAD_PATH, "../..")
using TensorFSTs
using TensorFSTs.Semirings
using TensorFSTs.SparseSemimodules
using BenchmarkTools
include("utils.jl")
# Semiring to use. Possible choices are:
# - LogSemiring{Float32|Float64,a} # a is the temperature parameter
# - TropicalSemiring{Float32|Float64} # equivalent to LogSemiring{Float32|Float64,Inf}
S = LogSemiring{Float32,-1}
# Load the load symbol tables.
symbols = open(loadsymbols, "../data/symboltables/latin.syms")
isymbols = symbols
osymbols = symbols
# Load the FSTs to compose.
A = open("../data/fsts/push.fst.txt", "r") do f
convert(TensorFST{S}, compile(f; semiring=S))
end
showimage(A; isymbols, osymbols)
......@@ -5,29 +5,21 @@
Base class for all sparse arrays.
"""
abstract type AbstractSparseTensor{S,N} <: AbstractArray{S,N} end
"""
slicevals(X, row)
Return the cartesian indices of the non-zero values of the row-slice of
`row`.
"""
slicevals(X::AbstractSparseTensor, row)
abstract type AbstractSparseArray{S,N} <: AbstractArray{S,N} end
"""
nonzeros(X)
Return the non-zero values of the sparse tensor `X`.
"""
nonzeros(::AbstractSparseTensor)
nonzeros(::AbstractSparseArray)
"""
nnz(X)
Return the number of non-zero values in `X`.
"""
nnz(::AbstractSparseTensor)
nnz(::AbstractSparseArray)
"""
C, V = findnz(X)
......@@ -35,59 +27,12 @@ nnz(::AbstractSparseTensor)
Returns the Cartesian indices `C` and the corresponding non-zero
values of `X`.
"""
findnz(X::AbstractSparseTensor)
#=====================================================================#
# Compressed Sparse Row (CSR) format.
#=====================================================================#
"""
abstract type AbstractTensorCSR{S,N} <: AbstractSparseTensor{S,N} end
Base class for row-compressed sparse tensors.
"""
abstract type AbstractSparseTensorCSR{S,N} <: AbstractSparseTensor{S,N} end
"""
getrowptr(X)
Return the row pointer of the tensor `X`.
"""
getrowptr(::AbstractSparseTensorCSR)
"""
nzrange(M, r)
Return an iterator over the indices of non-zero values of `X`.
"""
nzrange(::AbstractSparseTensorCSR, r)
#=====================================================================#
# COOrdinate (COO) format.
#=====================================================================#
"""
abstract type AbstractTensorCOO{S,N} <: AbstractSparseTensor{S,N} end
Base class for coordinate (COO) format sparse tensors.
"""
abstract type AbstractSparseTensorCOO{S,N} <: AbstractSparseTensor{S,N} end
#=====================================================================#
# Dictionary of Keys (DOK) format.
#=====================================================================#
"""
abstract type AbstractTensorDOK{S,N} <: AbstractSparseTensor{S,N} end
Base class sparse tensors stored as dicionary of keys.
"""
abstract type AbstractSparseTensorDOK{S,N} <: AbstractSparseTensor{S,N} end
findnz(X::AbstractSparseArray)
"""
sparse(I, J, ..., V, nrows, ncols)
sparse(...)
Create a sparse tensor in COO format.
Create a sparse array.
"""
sparse
......@@ -17,7 +17,46 @@ function Base.permutedims(X::SparseTensorCSR{S}, perm) where S
end
#=====================================================================#
# Sum of tensor products
# Sum
#=====================================================================#
function Base.sum(X::SparseTensorCSR{S}; dims::NTuple{2,Int}) where S
newI = Vector{Int}(undef, nnz(X))
newJ = Vector{Int}(undef, nnz(X))
newV = zeros(S, nnz(X))
d1, d2 = setdiff((1, 2, 3, 4), dims)
J, K, L = fibervals(X)
V = nonzeros(X)
for i in 1:size(X, 1)
for nzi = nzrange(X, i)
j, k, l, v = J[nzi], K[nzi], L[nzi], V[nzi]
newI[nzi] = (i, j, k, l)[d1]
newJ[nzi] = (i, j, k, l)[d2]
newV[nzi] = V[nzi]
end
end
nzflatdim = ones(Int, nnz(X))
fibers = (
d1 == 1 ? newI : (d2 == 1 ? newJ : nzflatdim),
d1 == 2 ? newI : (d2 == 2 ? newJ : nzflatdim),
d1 == 3 ? newI : (d2 == 3 ? newJ : nzflatdim),
d1 == 4 ? newI : (d2 == 4 ? newJ : nzflatdim),
)
dims = (
d1 == 1 || d2 == 1 ? size(X, 1) : 1,
d1 == 2 || d2 == 2 ? size(X, 2) : 1,
d1 == 3 || d2 == 3 ? size(X, 3) : 1,
d1 == 4 || d2 == 4 ? size(X, 4) : 1,
)
sparse(fibers..., newV; dims)
end
#=====================================================================#
# Sum - tensor products
#=====================================================================#
function sumtensorproduct(X::AbstractArray{S,4}, Y::AbstractArray{S,4}) where S
......
# SPDX-License-Identifier: CECILL-2.1
struct SparseMatrixCSR{S} <: AbstractSparseTensorCSR{S,2}
struct SparseMatrixCSR{S} <: AbstractSparseArray{S,2}
dims::NTuple{2,Int}
rowptr::Vector{Int}
nzcol::Vector{Int}
......
# SPDX-License-Identifier: CECILL-2.1
struct SparseTensorCSR{S} <: AbstractSparseTensorCSR{S,4}
struct SparseTensorCSR{S} <: AbstractSparseArray{S,4}
dims::NTuple{4,Int}
rowptr::Vector{Int}
nzfib2::Vector{Int}
......
# SPDX-License-Identifier: CECILL-2.1
"""
const AbstractSparseVector{S} = AbstractSparseTensor{S,1}
const AbstractSparseVector{S} = AbstractSparseArray{S,1}
Base class for sparse vectors (1-dimensional sparse tensors).
"""
const AbstractSparseVector{S} = AbstractSparseTensor{S,1}
const AbstractSparseVector{S} = AbstractSparseArray{S,1}
struct SparseVector{S} <: AbstractSparseVector{S}
n::Int
......
......@@ -17,7 +17,7 @@ end
@testset "Logarithmic semiring" begin
for T in [Float64, Float32]
a = 2.4
for a in [1, 4.5]
K = LogSemiring{T,a}
x, y = K(2), K(3)
@test val(x y) logaddexp(a*val(x), a*val(y)) / a
......@@ -26,8 +26,9 @@ end
@test val(inv(x)) -val(x)
@test val(zero(x)) == T(-Inf)
@test val(one(x)) == T(0)
end
a = - 2.3
for a in [-1, -4.5]
K = LogSemiring{T,a}
x, y = K(2), K(3)
@test val(x y) logaddexp(a*val(x), a*val(y)) / a
......@@ -38,6 +39,7 @@ end
@test val(one(x)) == T(0)
end
end
end
@testset "Probability semiring" begin
for T in [Float32, Float64]
......
......@@ -33,7 +33,7 @@ end
end
@testset "sparse tensor" begin
for S in [BoolSemiring, LogSemiring{Float32,-1}, ArcticSemiring{Float64}]
for S in [BoolSemiring, LogSemiring{Float32,1}, ArcticSemiring{Float64}]
Y = [
zero(S) zero(S) zero(S)
one(S) zero(S) one(S) one(S)
......@@ -57,6 +57,13 @@ end
@test all(X .== Y)
@test all(sum(X, dims=(1, 2)) .== sum(Y, dims=(1, 2)))
@test all(sum(X, dims=(2, 3)) .== sum(Y, dims=(2, 3)))
display(sum(X, dims=(2, 3)))
display(sum(Y, dims=(2,3)))
display(all(sum(X, dims=(2,3)) .== sum(Y, dims=(2,3))))
@test all(sum(X, dims=(3, 4)) .== sum(Y, dims=(3, 4)))
XI, XJ, XK, XL, XV = findnz(X)
@test Set(XI) == Set(I)
@test Set(XJ) == Set(J)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment