cavis/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp

142 lines
4.0 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 04.08.17.
//
#include "testlayers.h"
#include <memory>
#include <NDArray.h>
#include "ops/specials_sparse.h"
using namespace nd4j;
//////////////////////////////////////////////////////////////////////
class SparseUtilsTest : public testing::Test {
public:
static const Nd4jLong nnz = 40;
static const int rank = 3;
};
//////////////////////////////////////////////////////////////////////
TEST_F(SparseUtilsTest, SortCOOindices_Test) {
Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]{
0,2,7,
2,36,35,
3,30,17,
5,12,22,
5,43,45,
6,32,11,
8,8,32,
9,29,11,
5,11,22,
15,26,16,
17,48,49,
24,28,31,
26,6,23,
31,21,31,
35,46,45,
37,13,14,
6,38,18,
7,28,20,
8,29,39,
8,32,30,
9,42,43,
11,15,18,
13,18,45,
29,26,39,
30,8,25,
42,31,24,
28,33,5,
31,27,1,
35,43,26,
36,8,37,
39,22,14,
39,24,42,
42,48,2,
43,26,48,
44,23,49,
45,18,34,
46,28,5,
46,32,17,
48,34,44,
49,38,39,
};
Nd4jLong * expIndicesArr = new Nd4jLong[nnz * rank]{
0, 2, 7,
2, 36, 35,
3, 30, 17,
5, 11, 22,
5, 12, 22,
5, 43, 45,
6, 32, 11,
6, 38, 18,
7, 28, 20,
8, 8, 32,
8, 29, 39,
8, 32, 30,
9, 29, 11,
9, 42, 43,
11, 15, 18,
13, 18, 45,
15, 26, 16,
17, 48, 49,
24, 28, 31,
26, 6, 23,
28, 33, 5,
29, 26, 39,
30, 8, 25,
31, 21, 31,
31, 27, 1,
35, 43, 26,
35, 46, 45,
36, 8, 37,
37, 13, 14,
39, 22, 14,
39, 24, 42,
42, 31, 24,
42, 48, 2,
43, 26, 48,
44, 23, 49,
45, 18, 34,
46, 28, 5,
46, 32, 17,
48, 34, 44,
49, 38, 39,
};
auto values = NDArrayFactory::create<float>('c', {40}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39});
auto expValues = NDArrayFactory::create<float>('c', {40}, {0, 1, 2, 8, 3, 4, 5, 16, 17, 6, 18, 19, 7, 20, 21, 22, 9,
10, 11, 12, 26, 23, 24, 13, 27, 28, 14, 29, 15, 30, 31, 25, 32, 33,
34, 35, 36, 37, 38, 39
});
nd4j::sparse::SparseUtils<float >::sortCooIndicesGeneric(indicesArr, reinterpret_cast<float*>(values.getBuffer()), nnz, rank);
for ( int i = 0; i < rank * nnz; ++i){
ASSERT_EQ(expIndicesArr[i], indicesArr[i]);
}
ASSERT_TRUE(expValues.equalsTo(values));
delete[] indicesArr;
delete[] expIndicesArr;
}