Skip to content
Snippets Groups Projects
Commit 56abc1ae authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

reshaped tensor

parent 19c06919
No related branches found
No related tags found
No related merge requests found
......@@ -108,12 +108,32 @@ end
Take the union of `A` and `B`.
"""
function Base.union(A::AbstractTensorFST{S,O}, B::AbstractTensorFST{S,O}) where {S,O}
dims = (
findfirst(x -> x == :src, O),
findfirst(x -> x == :dest, O),
MA = getM(A)
MB = getM(B)
# Make sure that MA and MB have the same input and output label dimension.
MA = resize(
getM(A),
(
(O[1] === :ilabel || O[1] === :olabel) ? max(size(MA, 1), size(MB, 1)) : size(getM(A), 1),
(O[1] === :ilabel || O[1] === :olabel) ? max(size(MA, 2), size(MB, 2)) : size(getM(A), 2),
(O[1] === :ilabel || O[1] === :olabel) ? max(size(MA, 3), size(MB, 3)) : size(getM(A), 3),
(O[1] === :ilabel || O[1] === :olabel) ? max(size(MA, 4), size(MB, 4)) : size(getM(A), 4)
)
)
MB = resize(
getM(A),
(
(O[1] === :ilabel || O[1] === :olabel) ? max(size(MA, 1), size(MB, 1)) : size(getM(B), 1),
(O[1] === :ilabel || O[1] === :olabel) ? max(size(MA, 2), size(MB, 2)) : size(getM(B), 2),
(O[1] === :ilabel || O[1] === :olabel) ? max(size(MA, 3), size(MB, 3)) : size(getM(B), 3),
(O[1] === :ilabel || O[1] === :olabel) ? max(size(MA, 4), size(MB, 4)) : size(getM(B), 4)
)
)
dims = (findfirst(x -> x == :src, O), findfirst(x -> x == :dest, O))
TensorFST{S,O}(
blockdiag(getM(A), getM(B); dims),
blockdiag(MA, MB; dims),
SparseSemimodules.vcat(getα(A), getα(B)),
SparseSemimodules.vcat(getω(A), getω(B))
)
......@@ -138,9 +158,17 @@ end
Closure of `A`.
"""
function Base.cat(A::AbstractTensorFST{S,(:src,:dest,:ilabel,:olabel)}) where S
M = mutable(getM(A))
M = add!(M[:,:,1,1], getω(A) * transpose(getα(A)))
TensorFST{S,(:src,:dest,:ilabel,:olabel)}(M, getα(A), getω(A))
function closure(A::AbstractTensorFST{S,O}) where {S,O}
= (getω(A) * transpose(getα(A)))[:,:,1:1,1:1]
perm = (
findfirst(x -> x == order[1], O),
findfirst(x -> x == order[2], O),
findfirst(x -> x == order[3], O),
findfirst(x -> x == order[4], O),
)
= SparseSemimodules.permutedims(, perm)
TensorFSTS{S,O}(getM(A) + , getα(A), getω(A))
end
......@@ -16,20 +16,13 @@ export
AbstractSparseTensor,
eachnz,
findnz,
nnz
nnz,
resize
include("abstractsparsetensor.jl")
#=====================================================================#
# Slice
#=====================================================================#
#export
# SparseTensor3DSlice,
# SparseMatrixSlice,
# SparseVectorSlice
#include("slice.jl")
export addaxes
include("reshapedtensor.jl")
#=====================================================================#
# COOrdinates (COO) Format
......
......@@ -29,7 +29,33 @@ Return the number of non-zero elements in `X`.
"""
nnz
function Base.getindex(x::AbstractSparseTensor{S,4}, i::Integer, j::Integer, k::Integer, l::Integer) where S
"""
reshape(x, newshape)
Return a new tensor with shape `newshape`
"""
reshape
"""
resize(x, newsize)
Return a new tensor with size `newsize`. For each dimension `i`,
`newsize[i]` should be greater or equal to `size(x, i)`.
"""
resize
"""
abstract type AbstractNonZeroIterator end
Base class for all iterators over non-zero elements.
"""
abstract type AbstractNonZeroIterator end
#=====================================================================#
# Scalar indexing
#=====================================================================#
function Base.getindex(x::AbstractSparseTensor{S}, i::Integer, j::Integer, k::Integer, l::Integer) where S
for ((xi, xj, xk, xl), v) in eachnz(x)
if i == xi && j == xj && k == xk && l == xl
return v
......@@ -38,7 +64,25 @@ function Base.getindex(x::AbstractSparseTensor{S,4}, i::Integer, j::Integer, k::
return zero(S)
end
function Base.getindex(x::AbstractSparseTensor{S,1}, i::Integer) where S
function Base.getindex(x::AbstractSparseTensor{S}, i::Integer, j::Integer, k::Integer) where S
for ((xi, xj, xk), v) in eachnz(x)
if i == xi && j == xj && k == xk
return v
end
end
return zero(S)
end
function Base.getindex(x::AbstractSparseTensor{S}, i::Integer, j::Integer) where S
for ((xi, xj), v) in eachnz(x)
if i == xi && j == xj
return v
end
end
return zero(S)
end
function Base.getindex(x::AbstractSparseTensor{S}, i::Integer) where S
for (xi, v) in eachnz(x)
if i == xi
return v
......
# SPDX-License-Identifier: CECILL-2.1
struct ReshapedSparseTensor{S,N,T<:AbstractSparseTensor{S}} <: AbstractSparseTensor{S,N}
x::T
num_extra_axes::Int
end
function addaxes(x::AbstractSparseTensor{S,N}, numaxes) where {S,N}
ReshapedSparseTensor{S,N+numaxes,typeof(x)}(x, numaxes)
end
nnz(rx::ReshapedSparseTensor) = nnz(rx.x)
Base.size(rx::ReshapedSparseTensor) = (size(rx.x)..., ntuple(_ -> 1, rx.num_extra_axes)...)
function findnz(rx::ReshapedSparseTensor)
coo..., V = findnz(rx.x)
extra_coo = ntuple(_ -> ones(Int, nnz(rx.x)), rx.num_extra_axes)
coo..., extra_coo..., V
end
struct IteratorReshapedSparseTensor{T<:AbstractNonZeroIterator}
it::T
num_extra_axes::Int
end
eachnz(rx::ReshapedSparseTensor) = IteratorReshapedSparseTensor(eachnz(rx.x), rx.num_extra_axes)
Base.length(rit::IteratorReshapedSparseTensor) = length(rit.it)
function Base.iterate(rit::IteratorReshapedSparseTensor, state = nothing)
retval = isnothing(state) ? iterate(rit.it) : iterate(rit.it, state)
if isnothing(retval)
return nothing
else
cooval, state = retval
coo, val = cooval
return (((coo..., ntuple(_ -> 1, rit.num_extra_axes)...), val), state)
end
end
......@@ -42,10 +42,10 @@ end
Base.size(x::SparseTensor1DCOO) = (x.size,)
struct Iterator4DCOO{T<:SparseTensor4DCOO} x::T end
struct Iterator3DCOO{T<:SparseTensor3DCOO} x::T end
struct Iterator2DCOO{T<:SparseTensor2DCOO} x::T end
struct Iterator1DCOO{T<:SparseTensor1DCOO} x::T end
struct Iterator4DCOO{T<:SparseTensor4DCOO} <: AbstractNonZeroIterator x::T end
struct Iterator3DCOO{T<:SparseTensor3DCOO} <: AbstractNonZeroIterator x::T end
struct Iterator2DCOO{T<:SparseTensor2DCOO} <: AbstractNonZeroIterator x::T end
struct Iterator1DCOO{T<:SparseTensor1DCOO} <: AbstractNonZeroIterator x::T end
Base.length(it::Iterator4DCOO) = nnz(it.x)
Base.length(it::Iterator3DCOO) = nnz(it.x)
......@@ -74,42 +74,30 @@ findnz(x::SparseTensor3DCOO) = x.I, x.J, x.K, x.V
findnz(x::SparseTensor2DCOO) = x.I, x.J, x.V
findnz(x::SparseTensor1DCOO) = x.I, x.V
function Base.getindex(x::SparseTensor4DCOO{S}, i::Integer, j::Integer, k::Integer, l::Integer) where S
for ((xi, xj, xk, xl), v) in eachnz(x)
if i == xi && j == xj && k == xk && l == xl
return v
end
end
return zero(S)
function resize(x::SparseTensor4DCOO, newsize::NTuple{4,Int})
all(size(x) .<= newsize) || throw(ArgumentError("invalid new size: $newsize"))
SparseTensor4DCOO(newsize, x.I, x.J, x.K, x.L, x.V)
end
function Base.getindex(x::SparseTensor3DCOO{S}, i::Integer, j::Integer, k::Integer) where S
for ((xi, xj, xk), v) in eachnz(x)
if i == xi && j == xj && k == xk
return v
end
end
return zero(S)
function resize(x::SparseTensor3DCOO, newsize::NTuple{3,Int})
all(size(x) .<= newsize) || throw(ArgumentError("invalid new size: $newsize"))
SparseTensor3DCOO(newsize, x.I, x.J, x.K, x.V)
end
function Base.getindex(x::SparseTensor2DCOO{S}, i::Integer, j::Integer) where S
for ((xi, xj), v) in eachnz(x)
if i == xi && j == xj
return v
end
end
return zero(S)
function resize(x::SparseTensor2DCOO, newsize::NTuple{2,Int})
all(size(x) .<= newsize) || throw(ArgumentError("invalid new size: $newsize"))
SparseTensor2DCOO(newsize, x.I, x.J, x.V)
end
function Base.getindex(x::SparseTensor1DCOO{S}, i::Integer) where S
for (xi, v) in eachnz(x)
if i == xi
return v
end
end
return zero(S)
function resize(x::SparseTensor3DCOO, newsize::Integer)
size(x) .<= newsize || throw(ArgumentError("invalid new size: $newsize"))
SparseTensor1DCOO(newsize, x.I, x.V)
end
#=====================================================================#
# Slicing
#=====================================================================#
function Base.getindex(x::SparseTensor4DCOO{S}, i::Integer, ::Colon...) where S
ii = findall(coo -> coo == i, x.I)
SparseTensor3DCOO(x.size[2:end], x.J[ii], x.K[ii], x.L[ii], x.V[ii])
......
......@@ -87,6 +87,10 @@ for S in Ss
y = blockdiag(x, x; dims = (1, 2, 3, 4))
d1, d2, d3, d4 = size(x)
@test all(x[1:d1,1:d2,1:d3,1:d4] .== y[d1 .+ (1:d1), d2 .+ (1:d2), d3 .+ (1:d3), d4 .+ (1:d4)])
newsize = (4, 5, 3, 7)
y = resize(x, newsize)
@test all(size(y) .== newsize)
end
end
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment