Skip to content
Snippets Groups Projects
Verified Commit 5e2b0e87 authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

fsm concatenation

parent d8db1293
No related branches found
No related tags found
No related merge requests found
......@@ -48,6 +48,9 @@ shortestdistance(X)
# ╔═╡ 211012f4-551c-4364-a7a5-a3edb73f8502
summary(X)
# ╔═╡ c6595865-0db2-4c48-a43c-1cf7e493c44c
draw(cat(X, X); isymbols = symbols, osymbols = symbols)
# ╔═╡ Cell order:
# ╠═a433081c-4d58-11ee-2591-5972f009ea6d
# ╠═fcc7ff6f-7a24-4ef0-b47f-3c794558d641
......@@ -58,3 +61,4 @@ summary(X)
# ╠═542851a1-bcf4-4236-b983-f93000288a87
# ╠═e6c19def-7dea-4b50-8417-40b86c143383
# ╠═211012f4-551c-4364-a7a5-a3edb73f8502
# ╠═c6595865-0db2-4c48-a43c-1cf7e493c44c
......@@ -4,7 +4,6 @@ module SparseSemimodules
using ..Semirings
using LinearAlgebra
using CUDA
using Adapt
......@@ -48,6 +47,7 @@ export
closure,
collapse,
kron,
outerproduct,
permutedims,
resize,
sum,
......@@ -67,6 +67,7 @@ include("sptensorops/resize.jl")
include("sptensorops/sum.jl")
include("sptensorops/sumkron.jl")
include("sptensorops/vecmatmul.jl")
include("sptensorops/vcat.jl")
end
......@@ -38,8 +38,7 @@ struct OuterProduct{S,
y::Ty
end
function kron(x, yᵀ::Transpose)
y = parent(yᵀ)
function outerproduct(x, y)
OuterProduct((size(x, 1), size(y, 1)), x, y)
end
......
# SPDX-License-Identifier: CECILL-2.1
struct ConcatenatedVector{S,
Tx<:AbstractSumSparseTensor{S,1},
Ty<:AbstractSumSparseTensor{S,1}} <: AbstractSumSparseTensor{S,1}
size::Int
x::Tx
y::Ty
end
vcat(x, y) = ConcatenatedVector(size(x, 1) * size(y, 1), x, y)
function foreachnz(f, t::ConcatenatedVector)
foreachnz(t.x) do coo, val
f(coo, val)
end
foreachnz(t.y) do coo, val
f(length(t.x) + coo, val)
end
end
nnz(t::ConcatenatedVector) = nnz(t.x) + nnz(t.y)
Base.size(t::ConcatenatedVector) = (size(t.x, 1) + size(t.y, 1),)
......@@ -106,16 +106,15 @@ function Base.getindex(x::AbstractSparseTensor{S,1}, i::Integer) where S
return zero(S)
end
function vcat(x::AbstractSparseTensor{S,1}, y::AbstractSparseTensor{S,1}) where S
xI, xV = findnz(x)
yI, yV = findnz(y)
SparseTensor1DCOO(
size(x, 1) + size(y, 1),
Base.vcat(xI, size(x, 1) .+ yI),
Base.vcat(xV, yV)
)
end
#function vcat(x::AbstractSparseTensor{S,1}, y::AbstractSparseTensor{S,1}) where S
# xI, xV = findnz(x)
# yI, yV = findnz(y)
# SparseTensor1DCOO(
# size(x, 1) + size(y, 1),
# Base.vcat(xI, size(x, 1) .+ yI),
# Base.vcat(xV, yV)
# )
#end
function permutedims(x::AbstractSparseTensor{S,4}, perm) where S
perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4 && return x
......
......@@ -37,6 +37,10 @@ for S in Ss
@test all(x . dense)
@test nnz(x) == 2
@test all(Set(I) .== Set(findnz(x)[1]))
xx = SparseSemimodules.vcat(x, x)
@test nnz(xx) == 2*nnz(x)
@test all(xx . Base.vcat(dense, dense))
end
I, J, K, L = (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment