Skip to content
Snippets Groups Projects
Commit 19c06919 authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

union for any tensor orientation

parent a5ff9d22
No related branches found
No related tags found
No related merge requests found
......@@ -67,7 +67,7 @@ function reorient(fst::TensorFST{S,O}, order::NTuple{4,Symbol}) where {S,O}
findfirst(x -> x == order[3], O),
findfirst(x -> x == order[4], O),
)
TensorFST{S,order}(permutedims(fst.M, perm), fst.α, fst.ω)
TensorFST{S,order}(SparseSemimodules.permutedims(fst.M, perm), fst.α, fst.ω)
end
"""
......@@ -107,12 +107,15 @@ end
Take the union of `A` and `B`.
"""
function Base.union(A::AbstractTensorFST{S,(:src,:dest,:ilabel,:olabel)},
B::AbstractTensorFST{S,(:src,:dest,:ilabel,:olabel)}) where S
TensorFST{S,(:src,:dest,:ilabel,:olabel)}(
blockdiag(getM(A), getM(B)),
vcat(getα(A), getα(B)),
vcat(getω(A), getω(B))
function Base.union(A::AbstractTensorFST{S,O}, B::AbstractTensorFST{S,O}) where {S,O}
dims = (
findfirst(x -> x == :src, O),
findfirst(x -> x == :dest, O),
)
TensorFST{S,O}(
blockdiag(getM(A), getM(B); dims),
SparseSemimodules.vcat(getα(A), getα(B)),
SparseSemimodules.vcat(getω(A), getω(B))
)
end
......
......@@ -72,7 +72,9 @@ include("sparsetensorcoo.jl")
#=====================================================================#
export
blockdiag
blockdiag,
permutedims,
vcat
include("linalg.jl")
......
......@@ -4,13 +4,13 @@
# Vector concatenation
#=====================================================================#
function Base.vcat(x::AbstractSparseTensor{S,1}, y::AbstractSparseTensor{S,1}) where S
function vcat(x::AbstractSparseTensor{S,1}, y::AbstractSparseTensor{S,1}) where S
xI, xV = findnz(x)
yI, yV = findnz(y)
SparseTensor1DCOO(
size(x, 1) + size(y, 1),
vcat(xI, size(x, 1) .+ yI),
vcat(xV, yV)
Base.vcat(xI, size(x, 1) .+ yI),
Base.vcat(xV, yV)
)
end
......@@ -18,7 +18,7 @@ end
# Permute dimensions
#=====================================================================#
function Base.permutedims(x::AbstractSparseTensor{S,4}, perm) where S
function permutedims(x::AbstractSparseTensor{S,4}, perm) where S
perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4 && return x
I, J, K, L, V = findnz(x)
coo = (I, J, K, L)
......@@ -42,28 +42,33 @@ end
#=====================================================================#
"""
blockdiag(x, y)
blockdiag(x, y; dims)
Block diagonal.
"""
function blockdiag(x::AbstractSparseTensor{S,4}, y::AbstractSparseTensor{S,4}) where S
(size(x, 3) != size(y, 3) || size(x, 4) != size(y, 4)) && throw(DimensionMismatch())
function blockdiag(x::AbstractSparseTensor{S,4}, y::AbstractSparseTensor{S,4}; dims::Tuple) where S
length(dims) == 0 && throw(ArgumentError("should provide at least 1 dimension"))
notdims = setdiff((1, 2, 3, 4), dims)
for d in notdims
size(x, d) != size(y, d) && throw(DimensionMismatch())
end
xI, xJ, xK, xL, xV = findnz(x)
yI, yJ, yK, yL, yV = findnz(y)
newsize = (
size(x, 1) + size(y, 1),
size(x, 2) + size(y, 2),
size(x, 3),
size(x, 4)
size(x, 1) + (1 dims ? size(y, 1) : 0 ),
size(x, 2) + (2 dims ? size(y, 2) : 0),
size(x, 3) + (3 dims ? size(y, 3) : 0),
size(x, 4) + (4 dims ? size(y, 4) : 0)
)
SparseTensor4DCOO(
newsize,
vcat(xI, size(x, 1) .+ yI),
vcat(xJ, size(x, 2) .+ yJ),
vcat(xK, yK),
vcat(xL, yL),
vcat(xV, yV)
Base.vcat(xI, 1 dims ? size(x, 1) .+ yI : yI),
Base.vcat(xJ, 2 dims ? size(x, 2) .+ yJ : yJ),
Base.vcat(xK, 3 dims ? size(x, 3) .+ yK : yK),
Base.vcat(xL, 4 dims ? size(x, 4) .+ yL : yL),
Base.vcat(xV, yV)
)
end
......
......@@ -77,12 +77,16 @@ for S in Ss
end
perm = (2, 3, 1, 4)
@test permutedims(x, perm) isa AbstractSparseTensor
@test all(permutedims(x, perm) .== permutedims(dense, perm))
@test SM.permutedims(x, perm) isa AbstractSparseTensor
@test all(SM.permutedims(x, perm) .== Base.permutedims(dense, perm))
y = blockdiag(x, x)
y = blockdiag(x, x; dims = (1, 2))
d1, d2 = size(x, 1), size(x, 2)
@test all(x[1:d1,1:d2,:,:] .== y[d1 .+ (1:d1), d2 .+ (1:d2),:,:])
y = blockdiag(x, x; dims = (1, 2, 3, 4))
d1, d2, d3, d4 = size(x)
@test all(x[1:d1,1:d2,1:d3,1:d4] .== y[d1 .+ (1:d1), d2 .+ (1:d2), d3 .+ (1:d3), d4 .+ (1:d4)])
end
end
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment