294 lines
9.0 KiB
C++
294 lines
9.0 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// Created by raver119 on 01.11.2017.
|
|
//
|
|
|
|
#include "testlayers.h"
|
|
#include <helpers/ShapeUtils.h>
|
|
#include <array/NDArray.h>
|
|
|
|
|
|
using namespace sd;
|
|
using namespace sd::graph;
|
|
|
|
class ShapeUtilsTests : public testing::Test {
|
|
public:
|
|
|
|
};
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalDimsToExclude_1) {
|
|
std::vector<int> res = ShapeUtils::evalDimsToExclude(3, {0});
|
|
|
|
ASSERT_EQ(2, res.size());
|
|
ASSERT_EQ(1, res.at(0));
|
|
ASSERT_EQ(2, res.at(1));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalDimsToExclude_2) {
|
|
std::vector<int> res = ShapeUtils::evalDimsToExclude(4, {2, 3});
|
|
|
|
ASSERT_EQ(2, res.size());
|
|
ASSERT_EQ(0, res.at(0));
|
|
ASSERT_EQ(1, res.at(1));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_1)
|
|
{
|
|
|
|
Nd4jLong xShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99};
|
|
Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99};
|
|
Nd4jLong expShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99};
|
|
|
|
NDArray x(xShapeInfo);
|
|
NDArray y(yShapeInfo);
|
|
|
|
Nd4jLong *newShapeInfo = nullptr;
|
|
ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr);
|
|
|
|
ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_2)
|
|
{
|
|
|
|
Nd4jLong xShapeInfo[] = {4, 8, 1, 6, 1, 6, 6, 1, 1, 8192, 1, 99};
|
|
Nd4jLong yShapeInfo[] = {3, 7, 1, 5, 5, 5, 1, 8192, 1, 99};
|
|
Nd4jLong expShapeInfo[] = {4, 8, 7, 6, 5, 210, 30, 5, 1, 8192, 1, 99};
|
|
|
|
NDArray x(xShapeInfo);
|
|
NDArray y(yShapeInfo);
|
|
|
|
Nd4jLong *newShapeInfo = nullptr;
|
|
ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr);
|
|
|
|
ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_3)
|
|
{
|
|
|
|
Nd4jLong xShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99};
|
|
Nd4jLong yShapeInfo[] = {3, 15, 1, 5, 5, 5, 1, 8192, 1, 99};
|
|
Nd4jLong expShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99};
|
|
|
|
NDArray x(xShapeInfo);
|
|
NDArray y(yShapeInfo);
|
|
|
|
Nd4jLong *newShapeInfo = nullptr;
|
|
ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr);
|
|
|
|
ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_4)
|
|
{
|
|
|
|
Nd4jLong xShapeInfo[] = {3, 8, 1, 3, 3, 3, 1, 8192, 1, 99};
|
|
Nd4jLong yShapeInfo[] = {2, 4, 3, 3, 1, 8192, 1, 99};
|
|
Nd4jLong expShapeInfo[] = {3, 8, 4, 3, 12, 3, 1, 8192, 1, 99};
|
|
|
|
NDArray x(xShapeInfo);
|
|
NDArray y(yShapeInfo);
|
|
|
|
Nd4jLong *newShapeInfo = nullptr;
|
|
ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr);
|
|
//for(int i=0; i<2*newShapeInfo[0]+4; ++i)
|
|
// std::cout<<newShapeInfo[i]<<" ";
|
|
// std::cout<<std::endl;
|
|
|
|
ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test1)
|
|
{
|
|
|
|
auto x = NDArrayFactory::create<float>('c',{2,3,4,5});
|
|
auto expected = NDArrayFactory::create<float>('c', {2,4,5});
|
|
std::vector<int> dimensions = {1};
|
|
|
|
auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.getShapeInfo());
|
|
|
|
ASSERT_TRUE(shape::shapeEquals(expected.getShapeInfo(), newShapeInfo));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test2)
|
|
{
|
|
|
|
auto x = NDArrayFactory::create<float>('c',{2,3,4,5});
|
|
auto expected = NDArrayFactory::create<float>('c', {2,1,4,5});
|
|
std::vector<int> dimensions = {1};
|
|
|
|
auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.getShapeInfo(), true);
|
|
|
|
ASSERT_TRUE(shape::shapeEquals(expected.getShapeInfo(), newShapeInfo));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test3)
|
|
{
|
|
|
|
auto x = NDArrayFactory::create<float>('c',{2,3,4,5});
|
|
auto expected = NDArrayFactory::create<float>('c', {1,1,1,5});
|
|
std::vector<int> dimensions = {0,1,2};
|
|
|
|
auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.getShapeInfo(), true);
|
|
|
|
ASSERT_TRUE(shape::shapeEquals(expected.getShapeInfo(), newShapeInfo));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test4)
|
|
{
|
|
|
|
auto x = NDArrayFactory::create<float>('c',{2,3,4,5});
|
|
auto expected = NDArrayFactory::create<float>('c', {1,1,1,1});
|
|
std::vector<int> dimensions = {0,1,2,3};
|
|
|
|
auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.getShapeInfo(), true);
|
|
|
|
ASSERT_TRUE(shape::shapeEquals(expected.getShapeInfo(), newShapeInfo));
|
|
}
|
|
|
|
TEST_F(ShapeUtilsTests, Test_Strings_1) {
|
|
auto x = NDArrayFactory::create<float>('c', {2, 3, 4, 5});
|
|
std::string exp("[2, 3, 4, 5]");
|
|
|
|
auto s = ShapeUtils::shapeAsString(&x);
|
|
|
|
ASSERT_EQ(exp, s);
|
|
}
|
|
|
|
TEST_F(ShapeUtilsTests, Test_Backward_Axis_1) {
|
|
auto x = NDArrayFactory::create<float>('c', {2, 4, 3});
|
|
auto y = NDArrayFactory::create<float>('c', {4, 3});
|
|
std::vector<int> exp({0});
|
|
|
|
auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo());
|
|
|
|
ASSERT_EQ(exp, z);
|
|
}
|
|
|
|
TEST_F(ShapeUtilsTests, Test_Backward_Axis_2) {
|
|
auto x = NDArrayFactory::create<float>('c', {2, 4, 4, 3});
|
|
auto y = NDArrayFactory::create<float>('c', {4, 1, 3});
|
|
std::vector<int> exp({0, 2});
|
|
|
|
auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo());
|
|
|
|
ASSERT_EQ(exp, z);
|
|
}
|
|
|
|
|
|
TEST_F(ShapeUtilsTests, Test_Backward_Axis_3) {
|
|
auto x = NDArrayFactory::create<float>('c', {2, 4, 4, 3});
|
|
auto y = NDArrayFactory::create<float>('c', {2, 1, 1, 3});
|
|
std::vector<int> exp({1, 2});
|
|
|
|
auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo());
|
|
|
|
ASSERT_EQ(exp, z);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalPermutFromTo_test1) {
|
|
|
|
int a=1, b=2, c=3, d=4;
|
|
std::vector<int> expected = {2, 3, 0, 1};
|
|
|
|
std::vector<int> result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,b});
|
|
|
|
ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result)));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalPermutFromTo_test2) {
|
|
|
|
int a=1, b=2, c=3, d=4;
|
|
std::vector<int> expected = {0, 1, 3, 2};
|
|
|
|
std::vector<int> result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c});
|
|
|
|
ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result)));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalPermutFromTo_test3) {
|
|
|
|
int a=2, b=2, c=3, d=2;
|
|
std::vector<int> expected = {0, 1, 3, 2};
|
|
|
|
std::vector<int> result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c});
|
|
|
|
ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result)));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalPermutFromTo_test4) {
|
|
|
|
int a=2, b=3, c=4, d=5;
|
|
|
|
std::vector<int> result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d});
|
|
|
|
ASSERT_TRUE(result.empty());
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalPermutFromTo_test5) {
|
|
|
|
int a=1, b=2, c=3, d=4;
|
|
|
|
// EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,8}), const char*);
|
|
ASSERT_TRUE(1);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, evalPermutFromTo_test6) {
|
|
|
|
int a=1, b=2, c=3, d=4;
|
|
|
|
// EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d,d}), const char*);
|
|
ASSERT_TRUE(1);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, isPermutNecessary_test1) {
|
|
|
|
ASSERT_TRUE(ShapeUtils::isPermutNecessary({1,0,2,3}));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
TEST_F(ShapeUtilsTests, isPermutNecessary_test2) {
|
|
|
|
ASSERT_TRUE(!ShapeUtils::isPermutNecessary({0,1,2,3}));
|
|
}
|
|
|
|
|