2586 lines
110 KiB
C++
2586 lines
110 KiB
C++
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
//
|
||
|
// @author raver119@gmail.com
|
||
|
//
|
||
|
|
||
|
#include "testlayers.h"
|
||
|
#include <ops/declarable/CustomOperations.h>
|
||
|
#include <helpers/helper_hash.h>
|
||
|
#include <NDArray.h>
|
||
|
#include <array/NDArrayList.h>
|
||
|
|
||
|
|
||
|
using namespace nd4j;
|
||
|
using namespace nd4j::graph;
|
||
|
|
||
|
class DeclarableOpsTests4 : public testing::Test {
|
||
|
public:
|
||
|
|
||
|
DeclarableOpsTests4() {
|
||
|
printf("\n");
|
||
|
fflush(stdout);
|
||
|
|
||
|
nd4j::ops::adjust_hue op0;
|
||
|
nd4j::ops::adjust_saturation op1;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <typename T>
|
||
|
class TypedDeclarableOpsTests4 : public testing::Test {
|
||
|
public:
|
||
|
|
||
|
TypedDeclarableOpsTests4() {
|
||
|
printf("\n");
|
||
|
fflush(stdout);
|
||
|
|
||
|
nd4j::ops::adjust_hue op0;
|
||
|
nd4j::ops::adjust_saturation op1;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
typedef ::testing::Types<double, float> TestingTypes;
|
||
|
TYPED_TEST_CASE(TypedDeclarableOpsTests4, TestingTypes);
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_1) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 4, 4, 2});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
nd4j::ops::avgpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_2) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 4, 4, 2});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::avgpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_3) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 4, 4, 2});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::maxpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_4) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 4, 4, 2});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::maxpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_5) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 5, 5, 2});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 3, 3, 2}, {7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f,});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::avgpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 1});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_6) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 5, 5, 2});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {7.f, 8.f, 11.f, 12.f, 27.f, 28.f, 31.f, 32.f, 57.f, 58.f, 61.f, 62.f, 77.f, 78.f, 81.f, 82.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
nd4j::ops::avgpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_7) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 5, 5});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {7.f, 9.f, 17.f, 19.f, 32.f, 34.f, 42.f, 44.f, 57.f, 59.f, 67.f, 69.f, 82.f, 84.f, 92.f, 94.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::maxpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_8) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 5, 5});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 3, 3}, {1.f, 2.5f, 4.5f, 8.5f, 10.f, 12.f, 18.5f, 20.f, 22.f, 26.f, 27.5f, 29.5f, 33.5f, 35.f, 37.f, 43.5f, 45.f, 47.f, 51.f, 52.5f, 54.5f, 58.5f, 60.f, 62.f, 68.5f, 70.f, 72.f, 76.f, 77.5f, 79.5f, 83.5f, 85.f, 87.f, 93.5f, 95.f, 97.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::avgpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_9) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 5, 5});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 3, 3}, {0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, 6.5f, 13.75f, 14.75, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::avgpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 0});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_10) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 5, 5});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 3, 3}, {4.f, 6.f, 7.5f, 14.f, 16.f, 17.5f, 21.5f, 23.5f, 25.f, 29.f, 31.f, 32.5f, 39.f, 41.f, 42.5f, 46.5f, 48.5f, 50.f, 54.f, 56.f, 57.5f, 64.f, 66.f, 67.5f, 71.5f, 73.5f, 75.f, 79.f, 81.f, 82.5f, 89.f, 91.f, 92.5f, 96.5f, 98.5f, 100.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::avgpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 0});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_11) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 3, 3});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 2}, {3, 4, 6, 7});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::avgpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_12) {
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 3, 3});
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 3, 3}, {3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f});
|
||
|
|
||
|
x.linspace(1);
|
||
|
|
||
|
|
||
|
nd4j::ops::avgpool2d op;
|
||
|
auto result = op.execute({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0, 0});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
//z->printShapeInfo("z shape:");
|
||
|
//z->printBuffer("z buffer:");
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_BiasAdd_NHWC_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {2, 3, 3, 2});
|
||
|
auto bias = NDArrayFactory::create<double>('c', {1, 2}, {1, 2});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f});
|
||
|
|
||
|
nd4j::ops::biasadd op;
|
||
|
auto result = op.execute({&x, &bias}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_BiasAdd_NCHW_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 3, 3});
|
||
|
auto bias = NDArrayFactory::create<double>('c', {1, 2}, {1, 2});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {2, 2, 3, 3}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f});
|
||
|
|
||
|
nd4j::ops::biasadd op;
|
||
|
auto result = op.execute({&x, &bias}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Fill_1) {
|
||
|
auto x = NDArrayFactory::create<int>('c', {1, 3}, {3, 2, 4});
|
||
|
auto v = NDArrayFactory::create<double>(2.);
|
||
|
auto exp = NDArrayFactory::create<double>('c', {3, 2, 4});
|
||
|
exp.assign(2.0f);
|
||
|
|
||
|
nd4j::ops::fill op;
|
||
|
auto result = op.execute({&x, &v}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {4, 3});
|
||
|
|
||
|
x.linspace(1);
|
||
|
exp.linspace(1);
|
||
|
|
||
|
nd4j::ops::reshape op;
|
||
|
auto result = op.execute({&x}, {}, {-99, 4, 3});
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Gemv_Transpose_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||
|
auto y = NDArrayFactory::create<double>('c', {4, 1});
|
||
|
auto exp = NDArrayFactory::create<double>('c',{ 3, 1}, {70, 80, 90});
|
||
|
|
||
|
x.linspace(1);
|
||
|
y.linspace(1);
|
||
|
|
||
|
nd4j::ops::matmul op;
|
||
|
auto result = op.execute({&x, &y}, {}, {1, 0});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Split_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {5, 30});
|
||
|
auto sizes = NDArrayFactory::create<int>('c', {1, 3}, {4, 15, 11});
|
||
|
|
||
|
std::vector<Nd4jLong> list0({0,0, 0,4});
|
||
|
std::vector<Nd4jLong> list1({0,0, 4,19});
|
||
|
std::vector<Nd4jLong> list2({0,0, 19,30});
|
||
|
|
||
|
auto sub0 = x(list0, true);
|
||
|
auto sub1 = x(list1, true);
|
||
|
auto sub2 = x(list2, true);
|
||
|
|
||
|
sub0.assign(0.0);
|
||
|
sub1.assign(1.0);
|
||
|
sub2.assign(2.0);
|
||
|
|
||
|
|
||
|
nd4j::ops::split_v op;
|
||
|
auto result = op.execute({&x, &sizes}, {}, {1});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
ASSERT_EQ(3, result->size());
|
||
|
|
||
|
auto z0 = result->at(0);
|
||
|
auto z1 = result->at(1);
|
||
|
auto z2 = result->at(2);
|
||
|
|
||
|
ASSERT_TRUE(sub0.isSameShape(z0));
|
||
|
ASSERT_TRUE(sub1.isSameShape(z1));
|
||
|
ASSERT_TRUE(sub2.isSameShape(z2));
|
||
|
|
||
|
ASSERT_TRUE(sub0.equalsTo(z0));
|
||
|
ASSERT_TRUE(sub1.equalsTo(z1));
|
||
|
ASSERT_TRUE(sub2.equalsTo(z2));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
// special test for TF mode, when axis goes first
|
||
|
TEST_F(DeclarableOpsTests4, Test_Split_2) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {5, 12});
|
||
|
auto axis = NDArrayFactory::create<double>('c', {1, 1}, {1.f});
|
||
|
|
||
|
std::vector<Nd4jLong> list0 = {0,0, 0,3};
|
||
|
std::vector<Nd4jLong> list1 = {0,0, 3,6};
|
||
|
std::vector<Nd4jLong> list2 = {0,0, 6,9};
|
||
|
std::vector<Nd4jLong> list3 = {0,0, 9,12};
|
||
|
|
||
|
auto sub0 = x(list0, true);
|
||
|
auto sub1 = x(list1, true);
|
||
|
auto sub2 = x(list2, true);
|
||
|
auto sub3 = x(list3, true);
|
||
|
|
||
|
sub0.assign(0.0f);
|
||
|
sub1.assign(1.0f);
|
||
|
sub2.assign(2.0f);
|
||
|
sub3.assign(3.0f);
|
||
|
|
||
|
|
||
|
nd4j::ops::split op;
|
||
|
auto result = op.execute({&axis, &x}, {}, {4});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z0 = result->at(0);
|
||
|
auto z1 = result->at(1);
|
||
|
auto z2 = result->at(2);
|
||
|
auto z3 = result->at(3);
|
||
|
|
||
|
ASSERT_TRUE(sub0.isSameShape(z0));
|
||
|
ASSERT_TRUE(sub1.isSameShape(z1));
|
||
|
ASSERT_TRUE(sub2.isSameShape(z2));
|
||
|
ASSERT_TRUE(sub3.isSameShape(z3));
|
||
|
|
||
|
ASSERT_TRUE(sub0.equalsTo(z0));
|
||
|
ASSERT_TRUE(sub1.equalsTo(z1));
|
||
|
ASSERT_TRUE(sub2.equalsTo(z2));
|
||
|
ASSERT_TRUE(sub3.equalsTo(z3));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
// special test for TF mode, when axis goes first
|
||
|
TEST_F(DeclarableOpsTests4, Test_Split_3) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {6, 12});
|
||
|
auto axis = NDArrayFactory::create<double>('c', {1, 1}, {0.f});
|
||
|
|
||
|
std::vector<Nd4jLong> list0 = {0,2, 0,0};
|
||
|
std::vector<Nd4jLong> list1 = {2,4, 0,0};
|
||
|
std::vector<Nd4jLong> list2 = {4,6, 0,0};
|
||
|
|
||
|
auto sub0 = x(list0, true);
|
||
|
auto sub1 = x(list1, true);
|
||
|
auto sub2 = x(list2, true);
|
||
|
|
||
|
sub0.assign(0.0f);
|
||
|
sub1.assign(1.0f);
|
||
|
sub2.assign(2.0f);
|
||
|
|
||
|
nd4j::ops::split op;
|
||
|
auto result = op.execute({&axis, &x}, {}, {3});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z0 = result->at(0);
|
||
|
auto z1 = result->at(1);
|
||
|
auto z2 = result->at(2);
|
||
|
|
||
|
ASSERT_TRUE(sub0.isSameShape(z0));
|
||
|
ASSERT_TRUE(sub1.isSameShape(z1));
|
||
|
ASSERT_TRUE(sub2.isSameShape(z2));
|
||
|
|
||
|
ASSERT_TRUE(sub0.equalsTo(z0));
|
||
|
ASSERT_TRUE(sub1.equalsTo(z1));
|
||
|
ASSERT_TRUE(sub2.equalsTo(z2));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Stack_4) {
|
||
|
auto t = NDArrayFactory::create<double>('c', {2, 3, 5});
|
||
|
auto u = NDArrayFactory::create<double>('c', {2, 3, 5});
|
||
|
auto v = NDArrayFactory::create<double>('c', {2, 3, 5});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {3, 2, 3, 5});
|
||
|
|
||
|
nd4j::ops::stack op;
|
||
|
auto result = op.execute({&t, &u, &v}, {}, {-4});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});
|
||
|
|
||
|
nd4j::ops::squeeze op;
|
||
|
auto result = op.execute({&x}, {}, {1, 3});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
|
||
|
auto y = NDArrayFactory::create<double>('c', {2}, {1.f, 3.f});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});
|
||
|
|
||
|
nd4j::ops::squeeze op;
|
||
|
auto result = op.execute({&x, &y}, {}, {});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});
|
||
|
|
||
|
nd4j::ops::squeeze op;
|
||
|
auto result = op.execute({&x}, {}, {-2, -3});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_BiasAdd_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
||
|
auto row = NDArrayFactory::create<double>('c', {3}, {1, 2, 3});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 1, 2, 3});
|
||
|
|
||
|
nd4j::ops::biasadd op;
|
||
|
auto result = op.execute({&x, &row}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_1D_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
||
|
|
||
|
nd4j::ops::unstack op;
|
||
|
auto result = op.execute({&x}, {}, {1});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
ASSERT_EQ(3, result->size());
|
||
|
|
||
|
for (int e = 0; e < 3; e++)
|
||
|
ASSERT_EQ(1, result->at(e)->rankOf());
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
|
||
|
nd4j::ops::space_to_depth op;
|
||
|
auto result = op.execute({&x}, {}, {2, 1});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12});
|
||
|
|
||
|
nd4j::ops::space_to_depth op;
|
||
|
auto result = op.execute({&x}, {}, {2, 0});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_DepthToSpace_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
|
||
|
nd4j::ops::depth_to_space op;
|
||
|
auto result = op.execute({&x}, {}, {2, 1});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_DepthToSpace_2) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
|
||
|
nd4j::ops::depth_to_space op;
|
||
|
auto result = op.execute({&x}, {}, {2, 0});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {4, 4, 16, 16});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {4, 16, 64, 1});
|
||
|
|
||
|
nd4j::ops::depth_to_space op;
|
||
|
auto result = op.execute({&x}, {}, {4, 1});
|
||
|
ASSERT_EQ(Status::OK(), result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Cross_1) {
|
||
|
auto a = NDArrayFactory::create<double>('c', {3}, {1, 2, 3});
|
||
|
auto b = NDArrayFactory::create<double>('c', {3}, {6, 7, 8});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {3}, {-5, 10, -5});
|
||
|
|
||
|
nd4j::ops::cross op;
|
||
|
auto result = op.execute({&a, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Cross_2) {
|
||
|
auto a = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 1, 2, 3});
|
||
|
auto b = NDArrayFactory::create<double>('c', {2, 3}, {6, 7, 8, 6, 7, 8});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-5, 10, -5, -5, 10, -5});
|
||
|
|
||
|
nd4j::ops::cross op;
|
||
|
auto result = op.execute({&a, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Cross_3) {
|
||
|
auto a = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||
|
auto b = NDArrayFactory::create<double>('c', {3, 3}, {2, 3, 4, 7, 6, 5, 6, 3, 2});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, { -1, 2, -1, -11, 22, -11, -11, 40, -27});
|
||
|
|
||
|
nd4j::ops::cross op;
|
||
|
auto result = op.execute({&a, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_1) {
|
||
|
auto a = NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto b = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {3}, {30, 70, 110});
|
||
|
|
||
|
nd4j::ops::matmul op;
|
||
|
auto result = op.execute({&a, &b}, {}, {});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_2) {
|
||
|
auto a = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
|
||
|
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {3}, {70, 80, 90});
|
||
|
|
||
|
nd4j::ops::matmul op;
|
||
|
auto result = op.execute({&a, &b}, {}, {});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_3) {
|
||
|
auto a = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||
|
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {1, 3}, {70, 80, 90});
|
||
|
|
||
|
nd4j::ops::matmul op;
|
||
|
auto result = op.execute({&a, &b}, {}, {});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Add_119) {
|
||
|
auto a = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||
|
auto b = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {1, 4}, {2, 4, 6, 8});
|
||
|
|
||
|
nd4j::ops::add op;
|
||
|
auto result = op.execute({&a, &b}, {}, {});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_EQ(2, z->rankOf());
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||
|
auto shape = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 2});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||
|
|
||
|
nd4j::ops::reshape op;
|
||
|
auto result = op.execute({&x, &shape}, {}, {});
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {2, 1, 3});
|
||
|
auto exp = NDArrayFactory::create<double>('c', {2, 4, 3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f,
|
||
|
4.f, 5.f, 6.f,4.f, 5.f, 6.f,4.f, 5.f, 6.f,4.f, 5.f, 6.f});
|
||
|
x.linspace(1.f);
|
||
|
|
||
|
nd4j::ops::tile_to_shape op;
|
||
|
auto result = op.execute({&x},{}, {2, 4, 3}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {3, 4, 5});
|
||
|
x.linspace(1);
|
||
|
auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});
|
||
|
exp.linspace(1);
|
||
|
|
||
|
nd4j::ops::strided_slice op;
|
||
|
auto result = op.execute({&x}, {}, {0,0,0,1,0, -999,0,0,0, -999,3,4,5, -999,1,1,1});
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {3, 4, 5});
|
||
|
auto begin = NDArrayFactory::create<double>('c', {4}, {-999,0,0,0});
|
||
|
auto end = NDArrayFactory::create<double>('c', {4}, {-999,3,4,5});
|
||
|
auto stride = NDArrayFactory::create<double>('c', {4}, {-999,1,1,1});
|
||
|
x.linspace(1);
|
||
|
auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});
|
||
|
exp.linspace(1);
|
||
|
|
||
|
nd4j::ops::strided_slice op;
|
||
|
auto result = op.execute({&x, &begin, &end, &stride}, {}, {0,0,0,1,0});
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), result->status());
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {1}, {10});
|
||
|
auto begin = NDArrayFactory::create<double>('c', {1}, {0.});
|
||
|
auto end = NDArrayFactory::create<double>('c', {1}, {0.});
|
||
|
auto stride = NDArrayFactory::create<double>('c', {1}, {1});
|
||
|
//x.linspace(1);
|
||
|
//auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});
|
||
|
//exp.linspace(1);
|
||
|
|
||
|
nd4j::ops::strided_slice op;
|
||
|
auto result = op.execute({&x, &begin, &end, &stride}, {}, {1,0,0,0,0});
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
z->printShapeInfo("Emply shape expected");
|
||
|
ASSERT_TRUE(z->isEmpty());
|
||
|
|
||
|
delete result;
|
||
|
}
|
||
|
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) {
|
||
|
auto x = NDArrayFactory::create<double>('c', {1,3}, {1, 2, 3});
|
||
|
auto begin = NDArrayFactory::create<double>('c', {2}, {0, 0});
|
||
|
auto end = NDArrayFactory::create<double>('c', {2}, {0,1});
|
||
|
auto stride = NDArrayFactory::create<double>('c', {2}, {1,1});
|
||
|
// x.linspace(1);
|
||
|
auto exp = NDArrayFactory::create<double>('c', {1}, {1});
|
||
|
//exp.linspace(1);
|
||
|
|
||
|
nd4j::ops::strided_slice op;
|
||
|
auto result = op.execute({&x, &begin, &end, &stride}, {}, {1,0,1,0,2});
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), result->status());
|
||
|
|
||
|
auto z = result->at(0);
|
||
|
z->printBuffer("Strided Slice");
|
||
|
z->printShapeInfo("Vector size 1 shape expected");
|
||
|
exp.printShapeInfo("Expected shape");
|
||
|
ASSERT_TRUE(z->lengthOf() == 1);
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
delete result;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, parallel_stack_test1) {
|
||
|
|
||
|
auto x1 = NDArrayFactory::create<double>('c', {2,2,2});
|
||
|
auto x2 = NDArrayFactory::create<double>('c', {2,2,2});
|
||
|
auto x3 = NDArrayFactory::create<double>('c', {2,2,2});
|
||
|
x1.linspace(1);
|
||
|
x2.linspace(9);
|
||
|
x3.linspace(17);
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {3,2,2,2});
|
||
|
expected.linspace(1);
|
||
|
|
||
|
nd4j::ops::parallel_stack op;
|
||
|
auto results = op.execute({&x1, &x2, &x3}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, parallel_stack_test2) {
|
||
|
|
||
|
auto x1 = NDArrayFactory::create<double>('c', {1,2}, {1,2});
|
||
|
auto x2 = NDArrayFactory::create<double>('c', {1,2}, {3,4});
|
||
|
auto x3 = NDArrayFactory::create<double>('c', {1,2}, {5,6});
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {3,1,2}, {1,2,3,4,5,6});
|
||
|
|
||
|
nd4j::ops::parallel_stack op;
|
||
|
auto results = op.execute({&x1, &x2, &x3}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, parallel_stack_test3) {
|
||
|
|
||
|
auto x1 = NDArrayFactory::create<double>('c', {2,1}, {1,2});
|
||
|
auto x2 = NDArrayFactory::create<double>('c', {2,1}, {3,4});
|
||
|
auto x3 = NDArrayFactory::create<double>('c', {2,1}, {5,6});
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {3,2,1}, {1,2,3,4,5,6});
|
||
|
|
||
|
nd4j::ops::parallel_stack op;
|
||
|
auto results = op.execute({&x1, &x2, &x3}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
\
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, parallel_stack_test4) {
|
||
|
|
||
|
auto x1 = NDArrayFactory::create<double>('c', {2}, {1,2});
|
||
|
auto x2 = NDArrayFactory::create<double>('c', {2}, {3,4});
|
||
|
auto x3 = NDArrayFactory::create<double>('c', {2}, {5,6});
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {3,2}, {1,2,3,4,5,6});
|
||
|
|
||
|
nd4j::ops::parallel_stack op;
|
||
|
auto results = op.execute({&x1, &x2, &x3}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, parallel_stack_test5) {
|
||
|
|
||
|
auto x1 = NDArrayFactory::create<double>('c', {1}, {1});
|
||
|
auto x2 = NDArrayFactory::create<double>('c', {1}, {3});
|
||
|
auto x3 = NDArrayFactory::create<double>('c', {1}, {5});
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {3,1}, {1,3,5});
|
||
|
|
||
|
nd4j::ops::parallel_stack op;
|
||
|
auto results = op.execute({&x1, &x2, &x3}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, parallel_stack_test6) {
|
||
|
|
||
|
auto x1 = NDArrayFactory::create<double>(1.);
|
||
|
auto x2 = NDArrayFactory::create<double>(3.);
|
||
|
auto x3 = NDArrayFactory::create<double>(5.);
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {3}, {1,3,5});
|
||
|
|
||
|
nd4j::ops::parallel_stack op;
|
||
|
auto results = op.execute({&x1, &x2, &x3}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, parallel_stack_test7) {
|
||
|
|
||
|
auto x1 = NDArrayFactory::create<double>(1.);
|
||
|
auto expected = NDArrayFactory::create<double>('c', {1}, {1.});
|
||
|
|
||
|
nd4j::ops::parallel_stack op;
|
||
|
auto results = op.execute({&x1}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, meshgrid_test1) {
|
||
|
|
||
|
auto in0 = NDArrayFactory::create<double>('c', {2}, {1, 2});
|
||
|
auto in1 = NDArrayFactory::create<double>('c', {3}, {10, 20, 30});
|
||
|
auto in2 = NDArrayFactory::create<double>('c', {4}, {100, 200, 300, 400});
|
||
|
auto exp0 = NDArrayFactory::create<double>('c', {2,3,4}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2});
|
||
|
auto exp1 = NDArrayFactory::create<double>('c', {2,3,4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30});
|
||
|
auto exp2 = NDArrayFactory::create<double>('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400});
|
||
|
|
||
|
nd4j::ops::meshgrid op;
|
||
|
auto results = op.execute({&in0, &in1, &in2}, {}, {0});
|
||
|
auto out0 = results->at(0);
|
||
|
auto out1 = results->at(1);
|
||
|
auto out2 = results->at(2);
|
||
|
|
||
|
// out0->printIndexedBuffer();
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||
|
ASSERT_TRUE(exp1.isSameShape(out1));
|
||
|
ASSERT_TRUE(exp1.equalsTo(out1));
|
||
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, meshgrid_test2) {
|
||
|
|
||
|
auto in0 = NDArrayFactory::create<double>('c', {2}, {1, 2});
|
||
|
auto in1 = NDArrayFactory::create<double>('c', {3}, {10, 20, 30});
|
||
|
auto in2 = NDArrayFactory::create<double>('c', {4}, {100, 200, 300, 400});
|
||
|
auto exp0 = NDArrayFactory::create<double>('c', {3,2,4}, {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2});
|
||
|
auto exp1 = NDArrayFactory::create<double>('c', {3,2,4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30});
|
||
|
auto exp2 = NDArrayFactory::create<double>('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400});
|
||
|
|
||
|
nd4j::ops::meshgrid op;
|
||
|
auto results = op.execute({&in0, &in1, &in2}, {}, {});
|
||
|
auto out0 = results->at(0);
|
||
|
auto out1 = results->at(1);
|
||
|
auto out2 = results->at(2);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||
|
ASSERT_TRUE(exp1.isSameShape(out1));
|
||
|
ASSERT_TRUE(exp1.equalsTo(out1));
|
||
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, meshgrid_test3) {
|
||
|
|
||
|
auto in0 = NDArrayFactory::create<double>('c', {2}, {1, 2});
|
||
|
auto in1 = NDArrayFactory::create<double>('c', {1,3}, {10, 20, 30});
|
||
|
auto in2 = NDArrayFactory::create<double>('c', {2,2}, {100, 200, 300, 400});
|
||
|
auto exp0 = NDArrayFactory::create<double>('c', {3,2,4}, {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2});
|
||
|
auto exp1 = NDArrayFactory::create<double>('c', {3,2,4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30});
|
||
|
auto exp2 = NDArrayFactory::create<double>('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400});
|
||
|
|
||
|
nd4j::ops::meshgrid op;
|
||
|
auto results = op.execute({&in0, &in1, &in2}, {}, {});
|
||
|
auto out0 = results->at(0);
|
||
|
auto out1 = results->at(1);
|
||
|
auto out2 = results->at(2);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||
|
ASSERT_TRUE(exp1.isSameShape(out1));
|
||
|
ASSERT_TRUE(exp1.equalsTo(out1));
|
||
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, meshgrid_test4) {
|
||
|
|
||
|
auto in0 = NDArrayFactory::create<double>('c', {1,2}, {1, 2});
|
||
|
auto in1 = NDArrayFactory::create<double>('c', {3,1}, {10, 20, 30});
|
||
|
auto in2 = NDArrayFactory::create<double>('c', {1,4,1}, {100, 200, 300, 400});
|
||
|
auto exp0 = NDArrayFactory::create<double>('c', {2,3,4}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2});
|
||
|
auto exp1 = NDArrayFactory::create<double>('c', {2,3,4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30});
|
||
|
auto exp2 = NDArrayFactory::create<double>('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400});
|
||
|
|
||
|
nd4j::ops::meshgrid op;
|
||
|
auto results = op.execute({&in0, &in1, &in2}, {}, {0});
|
||
|
auto out0 = results->at(0);
|
||
|
auto out1 = results->at(1);
|
||
|
auto out2 = results->at(2);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||
|
ASSERT_TRUE(exp1.isSameShape(out1));
|
||
|
ASSERT_TRUE(exp1.equalsTo(out1));
|
||
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, meshgrid_test5) {
|
||
|
|
||
|
auto in0 = NDArrayFactory::create<double>(1);
|
||
|
auto in1 = NDArrayFactory::create<double>(2);
|
||
|
auto in2 = NDArrayFactory::create<double>(3);
|
||
|
auto exp0 = NDArrayFactory::create<double>('c', {1,1,1}, {1});
|
||
|
auto exp1 = NDArrayFactory::create<double>('c', {1,1,1}, {2});
|
||
|
auto exp2 = NDArrayFactory::create<double>('c', {1,1,1}, {3});
|
||
|
|
||
|
nd4j::ops::meshgrid op;
|
||
|
auto results = op.execute({&in0, &in1, &in2}, {}, {0});
|
||
|
auto out0 = results->at(0);
|
||
|
auto out1 = results->at(1);
|
||
|
auto out2 = results->at(2);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||
|
ASSERT_TRUE(exp1.isSameShape(out1));
|
||
|
ASSERT_TRUE(exp1.equalsTo(out1));
|
||
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, meshgrid_test6) {
|
||
|
|
||
|
auto in0 = NDArrayFactory::create<double>('c', {2,2},{1,2,3,4});
|
||
|
auto in1 = NDArrayFactory::create<double>(5);
|
||
|
auto in2 = NDArrayFactory::create<double>(6);
|
||
|
auto exp0 = NDArrayFactory::create<double>('c', {4,1,1}, {1,2,3,4});
|
||
|
auto exp1 = NDArrayFactory::create<double>('c', {4,1,1}, {5,5,5,5});
|
||
|
auto exp2 = NDArrayFactory::create<double>('c', {4,1,1}, {6,6,6,6});
|
||
|
|
||
|
nd4j::ops::meshgrid op;
|
||
|
auto results = op.execute({&in0, &in1, &in2}, {}, {0});
|
||
|
auto out0 = results->at(0);
|
||
|
auto out1 = results->at(1);
|
||
|
auto out2 = results->at(2);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||
|
ASSERT_TRUE(exp1.isSameShape(out1));
|
||
|
ASSERT_TRUE(exp1.equalsTo(out1));
|
||
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, meshgrid_test7) {
|
||
|
|
||
|
auto in0 = NDArrayFactory::create<double>('c', {2,2},{1,2,3,4});
|
||
|
auto in1 = NDArrayFactory::create<double>(5);
|
||
|
auto in2 = NDArrayFactory::create<double>(6);
|
||
|
auto exp0 = NDArrayFactory::create<double>('c', {1,4,1}, {1,2,3,4});
|
||
|
auto exp1 = NDArrayFactory::create<double>('c', {1,4,1}, {5,5,5,5});
|
||
|
auto exp2 = NDArrayFactory::create<double>('c', {1,4,1}, {6,6,6,6});
|
||
|
|
||
|
nd4j::ops::meshgrid op;
|
||
|
auto results = op.execute({&in0, &in1, &in2}, {}, {1});
|
||
|
auto out0 = results->at(0);
|
||
|
auto out1 = results->at(1);
|
||
|
auto out2 = results->at(2);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||
|
ASSERT_TRUE(exp1.isSameShape(out1));
|
||
|
ASSERT_TRUE(exp1.equalsTo(out1));
|
||
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, meshgrid_test8) {
|
||
|
|
||
|
auto in0 = NDArrayFactory::create<double>(5);
|
||
|
auto exp0 = NDArrayFactory::create<double>('c', {1}, {5});
|
||
|
|
||
|
nd4j::ops::meshgrid op;
|
||
|
auto results = op.execute({&in0}, {}, {0});
|
||
|
auto out0 = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, meshgrid_test9) {
|
||
|
|
||
|
auto in0 = NDArrayFactory::create<double>(5);
|
||
|
auto exp0 = NDArrayFactory::create<double>('c', {1}, {5});
|
||
|
|
||
|
nd4j::ops::meshgrid op;
|
||
|
auto results = op.execute({&in0}, {}, {1});
|
||
|
auto out0 = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_1) {
|
||
|
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f});
|
||
|
auto targets = NDArrayFactory::create<double>('c', {2, 3}, {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f});
|
||
|
auto weight = NDArrayFactory::create<double>(0.7f);
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2, 3}, {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887});
|
||
|
|
||
|
//Targets {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f};
|
||
|
//----------
|
||
|
//Inputs {11.f, 13.f, 4.f, 15.f, 6.f, 3.f};
|
||
|
//----------
|
||
|
//Weights [0.7]
|
||
|
//Result {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887}
|
||
|
|
||
|
nd4j::ops::weighted_cross_entropy_with_logits op;
|
||
|
auto results = op.execute({&targets, &input, &weight}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
// output->printIndexedBuffer();
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_2) {
|
||
|
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f});
|
||
|
auto targets = NDArrayFactory::create<double>('c', {2, 3}, {15.5f, 15.7f, 5.f, 15.f, 5.f, 6.f});
|
||
|
auto weights = NDArrayFactory::create<double>({0.5f, 0.7f, 1.0f}) ;
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2, 3}, {-159.5001f, -191.1f, -15.98185f, -210.f, -24.001238f, -14.951412f});
|
||
|
|
||
|
nd4j::ops::weighted_cross_entropy_with_logits op;
|
||
|
auto results = op.execute({&targets, &input, &weights}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
// output->printIndexedBuffer("Result is ");
|
||
|
// expected.printIndexedBuffer("Expected is ");
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
///////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, lstm_test1) {
|
||
|
|
||
|
const int time = 5;
|
||
|
const int batchSize = 3;
|
||
|
const int inSize = 3;
|
||
|
const int numProj = 3;
|
||
|
const int numUnits = 3;
|
||
|
|
||
|
auto x = NDArrayFactory::create<double>('c', {time, batchSize, inSize});
|
||
|
auto h0 = NDArrayFactory::create<double>('c', {batchSize, numProj});
|
||
|
auto c0 = NDArrayFactory::create<double>('c', {batchSize, numUnits});
|
||
|
auto Wx = NDArrayFactory::create<double>('c', {inSize, 4*numUnits});
|
||
|
auto Wh = NDArrayFactory::create<double>('c', {numProj, 4*numUnits});
|
||
|
auto Wc = NDArrayFactory::create<double>('c', {3*numUnits});
|
||
|
auto Wp = NDArrayFactory::create<double>('c', {numUnits, numProj});
|
||
|
auto b = NDArrayFactory::create<double>('c', {4*numUnits});
|
||
|
|
||
|
x.linspace(0.5, 0.5);
|
||
|
h0 = 1.;
|
||
|
c0 = 2.;
|
||
|
Wx = 0.003;
|
||
|
Wh = 0.006;
|
||
|
Wc = 0.;
|
||
|
Wp = 0.;
|
||
|
b = 0.5;
|
||
|
|
||
|
auto expH = NDArrayFactory::create<double>('c', {time, batchSize, numProj}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434,
|
||
|
0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338,
|
||
|
0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287,
|
||
|
0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327,
|
||
|
0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625});
|
||
|
|
||
|
auto expClast = NDArrayFactory::create<double>('c', {1, batchSize, numProj}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861});
|
||
|
|
||
|
nd4j::ops::lstm op;
|
||
|
auto results = op.execute({&x, &h0, &c0, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 0.}, {0, 0});
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||
|
|
||
|
auto *h = results->at(0);
|
||
|
auto *c = results->at(1);
|
||
|
auto cLast = (*c)({4,5,0,0,0,0},true);
|
||
|
|
||
|
ASSERT_TRUE(expH.isSameShape(h));
|
||
|
ASSERT_TRUE(expH.equalsTo(h));
|
||
|
|
||
|
ASSERT_TRUE(expClast.isSameShape(&cLast));
|
||
|
ASSERT_TRUE(expClast.equalsTo(&cLast));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
///////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, relu6_test1) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2,4}, {-13.,10,-5,0,2,7,6,12});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2,4}, {0., 6., 0., 0.,2., 6., 6., 6.});
|
||
|
|
||
|
nd4j::ops::relu6 op;
|
||
|
auto results = op.execute({&input}, {0.}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||
|
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
///////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, relu6_bp_test1) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2,4}, {-13.,10, -5, 0, 2, 7, 6, 5});
|
||
|
auto gradO = NDArrayFactory::create<double>('c', {2,4}, {-1., -2., 0., 4., 5., 6., 7., 8.});
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2,4}, {0., 0., 0., 0., 5., 0., 0., 8.});
|
||
|
|
||
|
nd4j::ops::relu6_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {0.}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||
|
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) {
|
||
|
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5, 0., 0.3, 5.5,
|
||
|
8.6, 0., 0., 0.4,
|
||
|
1.5, 1., 1.3, 1.5,
|
||
|
2.6, 2., 3., 1.4}
|
||
|
);
|
||
|
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {
|
||
|
0.98386997, 0., 0.05358852, 0.9824562,
|
||
|
0.99330735, 0., 0., 0.37139067,
|
||
|
0.72760683, 0.4850712, 0.5848977, 0.67488194,
|
||
|
0.7581754, 0.58321184, 0.86747235, 0.4048204}
|
||
|
);
|
||
|
|
||
|
nd4j::ops::lrn op;
|
||
|
auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {5}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
auto out = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp.isSameShape(out));
|
||
|
// out->printIndexedBuffer("LRN out");
|
||
|
// exp.printIndexedBuffer("LRN exp");
|
||
|
ASSERT_TRUE(exp.equalsTo(out));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) {
|
||
|
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5, 0., 0.3, 5.5,
|
||
|
8.6, 0., 0., 0.4,
|
||
|
1.5, 1., 1.3, 1.5,
|
||
|
2.6, 2., 3., 1.4});
|
||
|
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {
|
||
|
0.98386997, 0., 0.05358852, 0.9824562,
|
||
|
0.99330735, 0., 0., 0.37139067,
|
||
|
0.72760683, 0.4850712, 0.5848977, 0.67488194,
|
||
|
0.7581754, 0.58321184, 0.86747235, 0.4048204});
|
||
|
|
||
|
nd4j::ops::lrn op;
|
||
|
auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
auto out = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp.isSameShape(out));
|
||
|
// out->printIndexedBuffer("LRN out");
|
||
|
// exp.printIndexedBuffer("LRN exp");
|
||
|
ASSERT_TRUE(exp.equalsTo(out));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) {
|
||
|
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
|
||
|
|
||
|
5.5, 0., 0.3, 5.5,
|
||
|
1.5, 0., 1.3, 6.5,
|
||
|
8.6, 0., 0., 0.4,
|
||
|
2.5, 1., 0.3, 4.5,
|
||
|
1.5, 1., 1.3, 1.5,
|
||
|
3.5, 0., 1.3, 2.5,
|
||
|
2.6, 2., 3., 1.4,
|
||
|
4.5, 1., 0.3, 0.5}
|
||
|
);
|
||
|
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
|
||
|
0.9824562, 0., 0.03822664, 0.9824562,
|
||
|
0.67488194, 0., 0.18924236, 0.96960944,
|
||
|
0.99330735, 0., 0., 0.37139067,
|
||
|
0.86567914, 0.18702209, 0.05610663, 0.9520745,
|
||
|
0.6154575, 0.34942827, 0.45425674, 0.6154575,
|
||
|
0.905509, 0. , 0.2824086, 0.8361251,
|
||
|
0.57063663, 0.41959068, 0.629386, 0.3504383,
|
||
|
0.9520745, 0.21039814, 0.06311944, 0.3268602 }
|
||
|
);
|
||
|
|
||
|
nd4j::ops::lrn op;
|
||
|
auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
auto out = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp.isSameShape(out));
|
||
|
// out->printIndexedBuffer("LRN out");
|
||
|
// exp.printIndexedBuffer("LRN exp");
|
||
|
ASSERT_TRUE(exp.equalsTo(out));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) {
|
||
|
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
|
||
|
|
||
|
5.5, 0., 0.3, 5.5,
|
||
|
1.5, 0., 1.3, 6.5,
|
||
|
8.6, 0., 0., 0.4,
|
||
|
2.5, 1., 0.3, 4.5,
|
||
|
1.5, 1., 1.3, 1.5,
|
||
|
3.5, 0., 1.3, 2.5,
|
||
|
2.6, 2., 3., 1.4,
|
||
|
4.5, 1., 0.3, 0.5}
|
||
|
);
|
||
|
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
|
||
|
0.70082176, 0., 0.03822664, 0.70082176,
|
||
|
0.21835658, 0., 0.18924236, 0.9462118,
|
||
|
0.9922489, 0., 0., 0.04615111,
|
||
|
0.46755522, 0.18702209, 0.05610663, 0.8415994,
|
||
|
0.5241424, 0.34942827, 0.45425674, 0.5241424,
|
||
|
0.76033086, 0., 0.2824086, 0.54309344,
|
||
|
0.54546785, 0.41959068, 0.629386, 0.29371348,
|
||
|
0.94679165, 0.21039814, 0.06311944, 0.10519907}
|
||
|
);
|
||
|
|
||
|
nd4j::ops::lrn op;
|
||
|
auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {5}, {}, false, nd4j::DataType::DOUBLE);
|
||
|
auto out = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp.isSameShape(out));
|
||
|
// out->printIndexedBuffer("LRN out");
|
||
|
// exp.printIndexedBuffer("LRN exp");
|
||
|
ASSERT_TRUE(exp.equalsTo(out));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) {
|
||
|
|
||
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
|
||
|
|
||
|
5.5,0., 0.3, 5.5,
|
||
|
1.5,0., 1.3, 6.5,
|
||
|
8.6,0., 0., 0.4,
|
||
|
2.5,1., 0.3, 4.5,
|
||
|
1.5,1., 1.3, 1.5,
|
||
|
3.5,0., 1.3, 2.5,
|
||
|
2.6,2., 3., 1.4,
|
||
|
4.5,1., 0.3, 0.5}
|
||
|
);
|
||
|
|
||
|
auto eps = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
|
||
|
0.70082176, 0., 0.03822664, 0.70082176,
|
||
|
0.21835658, 0., 0.18924236, 0.9462118,
|
||
|
|
||
|
0.9922489, 0., 0. , 0.04615111,
|
||
|
0.46755522, 0.18702209, 0.05610663, 0.8415994,
|
||
|
|
||
|
|
||
|
0.5241424, 0.34942827, 0.45425674, 0.5241424,
|
||
|
0.76033086, 0., 0.2824086 , 0.54309344,
|
||
|
|
||
|
0.54546785, 0.41959068, 0.629386 , 0.29371348,
|
||
|
0.94679165, 0.21039814, 0.06311944, 0.10519907}
|
||
|
);
|
||
|
|
||
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4});
|
||
|
|
||
|
nd4j::ops::lrn_bp op;
|
||
|
auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, false, typeid(TypeParam) == typeid(float) ? nd4j::DataType::FLOAT32 : nd4j::DataType::DOUBLE);
|
||
|
auto out = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(exp.isSameShape(out));
|
||
|
// out->printIndexedBuffer("LRN out");
|
||
|
// exp.printIndexedBuffer("LRN exp");
|
||
|
// ASSERT_TRUE(exp.equalsTo(out));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, avgpool3d_test1) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=2,oH=2,oW=2;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}, {10.5, 11.5, 13.5, 14.5, 22.5, 23.5, 25.5, 26.5, 46.5, 47.5, 49.5, 50.5, 58.5, 59.5, 61.5, 62.5,
|
||
|
82.5, 83.5, 85.5, 86.5, 94.5, 95.5, 97.5, 98.5,118.5,119.5,121.5,122.5,130.5,131.5,133.5,134.5,
|
||
|
154.5,155.5,157.5,158.5,166.5,167.5,169.5,170.5,190.5,191.5,193.5,194.5,202.5,203.5,205.5,206.5});
|
||
|
input.linspace(1.);
|
||
|
|
||
|
nd4j::ops::avgpool3dnew op;
|
||
|
auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, avgpool3d_test2) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=3,oH=4,oW=3;
|
||
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 25. , 26. , 27. , 28. , 29. , 30. , 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 32.5, 33.5, 34.5, 34. , 35. , 36. , 38.5, 39.5, 40.5, 41.5, 42.5, 43.5, 43. , 44. , 45. , 43. , 44. , 45. , 46. , 47. , 48. , 47.5, 48.5, 49.5,
|
||
|
61. , 62. , 63. , 64. , 65. , 66. , 65.5, 66.5, 67.5, 65.5, 66.5, 67.5, 68.5, 69.5, 70.5, 70. , 71. , 72. , 74.5, 75.5, 76.5, 77.5, 78.5, 79.5, 79. , 80. , 81. , 79. , 80. , 81. , 82. , 83. , 84. , 83.5, 84.5, 85.5,
|
||
|
79. , 80. , 81. , 82. , 83. , 84. , 83.5, 84.5, 85.5, 83.5, 84.5, 85.5, 86.5, 87.5, 88.5, 88. , 89. , 90. , 92.5, 93.5, 94.5, 95.5, 96.5, 97.5, 97. , 98. , 99. , 97. , 98. , 99. ,100. ,101. ,102. ,101.5,102.5,103.5,
|
||
|
133. ,134. ,135. ,136. ,137. ,138. ,137.5,138.5,139.5,137.5,138.5,139.5,140.5,141.5,142.5,142. ,143. ,144. ,146.5,147.5,148.5,149.5,150.5,151.5,151. ,152. ,153. ,151. ,152. ,153. ,154. ,155. ,156. ,155.5,156.5,157.5,
|
||
|
169. ,170. ,171. ,172. ,173. ,174. ,173.5,174.5,175.5,173.5,174.5,175.5,176.5,177.5,178.5,178. ,179. ,180. ,182.5,183.5,184.5,185.5,186.5,187.5,187. ,188. ,189. ,187. ,188. ,189. ,190. ,191. ,192. ,191.5,192.5,193.5,
|
||
|
187. ,188. ,189. ,190. ,191. ,192. ,191.5,192.5,193.5,191.5,192.5,193.5,194.5,195.5,196.5,196. ,197. ,198. ,200.5,201.5,202.5,203.5,204.5,205.5,205. ,206. ,207. ,205. ,206. ,207. ,208. ,209. ,210. ,209.5,210.5,211.5});
|
||
|
input.linspace(1.);
|
||
|
|
||
|
nd4j::ops::avgpool3dnew op;
|
||
|
auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, avgpool3d_test3) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=2,oH=2,oW=2;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 29.5, 30.5, 31.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 41.5, 42.5, 43.5, 65.5, 66.5, 67.5, 68.5, 69.5, 70.5,
|
||
|
74.5, 75.5, 76.5, 77.5, 78.5, 79.5,137.5,138.5,139.5,140.5,141.5,142.5,146.5,147.5,148.5,149.5,150.5,151.5,
|
||
|
173.5,174.5,175.5,176.5,177.5,178.5,182.5,183.5,184.5,185.5,186.5,187.5});
|
||
|
input.linspace(1.);
|
||
|
|
||
|
nd4j::ops::avgpool3dnew op;
|
||
|
auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, avgpool3d_test4) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
||
|
int oD=4,oH=4,oW=4;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW},{0.416667, 1.00, 1.333333, 0.75, 1.00, 2.25, 2.75, 1.50, 1.75, 3.75, 4.25, 2.25, 1.416667, 3.00, 3.333333, 1.75, 2.833333, 6.00, 6.666667, 3.50, 5.00, 10.50, 11.50, 6.00, 6.50,
|
||
|
13.50, 14.50, 7.50, 4.833333, 10.00, 10.666667, 5.50, 6.833333, 14.00, 14.666667, 7.50, 11.00, 22.50, 23.50, 12.00, 12.50, 25.50, 26.50, 13.50, 8.833333, 18.00, 18.666666, 9.50,
|
||
|
4.416667, 9.00, 9.333333, 4.75, 7.00, 14.25, 14.75, 7.50, 7.75, 15.75, 16.25, 8.25, 5.416667, 11.00, 11.333333, 5.75, 6.416667, 13.00, 13.333333, 6.75, 10.00, 20.25, 20.75,
|
||
|
10.50, 10.75, 21.75, 22.25, 11.25, 7.416667, 15.00, 15.333333, 7.75, 14.833333, 30.00, 30.666666, 15.50, 23.00, 46.50, 47.50, 24.00, 24.50, 49.50, 50.50, 25.50, 16.833334,
|
||
|
34.00, 34.666668, 17.50, 18.833334, 38.00, 38.666668, 19.50, 29.00, 58.50, 59.50, 30.00, 30.50, 61.50, 62.50, 31.50, 20.833334, 42.00, 42.666668, 21.50, 10.416667, 21.00,
|
||
|
21.333334, 10.75, 16.00, 32.25, 32.75, 16.50, 16.75, 33.75, 34.25, 17.25, 11.416667, 23.00, 23.333334, 11.75, 12.416667, 25.00, 25.333334, 12.75, 19.00, 38.25, 38.75, 19.50,
|
||
|
19.75, 39.75, 40.25, 20.25, 13.416667, 27.00, 27.333334, 13.75, 26.833334, 54.00, 54.666668, 27.50, 41.00, 82.50, 83.50, 42.00, 42.50, 85.50, 86.50, 43.50, 28.833334, 58.00,
|
||
|
58.666668, 29.50, 30.833334, 62.00, 62.666668, 31.50, 47.00, 94.50, 95.50, 48.00, 48.50, 97.50, 98.50, 49.50, 32.833332, 66.00, 66.666664, 33.50, 16.416666, 33.00, 33.333332,
|
||
|
16.75, 25.00, 50.25, 50.75, 25.50, 25.75, 51.75, 52.25, 26.25, 17.416666, 35.00, 35.333332, 17.75, 18.416666, 37.00, 37.333332, 18.75, 28.00, 56.25, 56.75, 28.50, 28.75,
|
||
|
57.75, 58.25, 29.25, 19.416666, 39.00, 39.333332, 19.75, 38.833332, 78.00, 78.666664, 39.50, 59.00, 118.50, 119.50, 60.00, 60.50, 121.50, 122.50, 61.50, 40.833332, 82.00,
|
||
|
82.666664, 41.50, 42.833332, 86.00, 86.666664, 43.50, 65.00, 130.50, 131.50, 66.00, 66.50, 133.50, 134.50, 67.50, 44.833332, 90.00, 90.666664, 45.50, 22.416666, 45.00,
|
||
|
45.333332, 22.75, 34.00, 68.25, 68.75, 34.50, 34.75, 69.75, 70.25, 35.25, 23.416666, 47.00, 47.333332, 23.75, 24.416666, 49.00, 49.333332, 24.75, 37.00, 74.25, 74.75,
|
||
|
37.50, 37.75, 75.75, 76.25, 38.25, 25.416666, 51.00, 51.333332, 25.75, 50.833332, 102.00, 102.666664, 51.50, 77.00, 154.50, 155.50, 78.00, 78.50, 157.50, 158.50, 79.50,
|
||
|
52.833332, 106.00, 106.666664, 53.50, 54.833332, 110.00, 110.666664, 55.50, 83.00, 166.50, 167.50, 84.00, 84.50, 169.50, 170.50, 85.50, 56.833332, 114.00, 114.666664,
|
||
|
57.50, 28.416666, 57.00, 57.333332, 28.75, 43.00, 86.25, 86.75, 43.50, 43.75, 87.75, 88.25, 44.25, 29.416666, 59.00, 59.333332, 29.75, 30.416666, 61.00, 61.333332, 30.75,
|
||
|
46.00, 92.25, 92.75, 46.50, 46.75, 93.75, 94.25, 47.25, 31.416666, 63.00, 63.333332, 31.75, 62.833332, 126.00, 126.666664, 63.50, 95.00, 190.50, 191.50, 96.00, 96.50,
|
||
|
193.50, 194.50, 97.50, 64.833336, 130.00, 130.666672, 65.50, 66.833336, 134.00, 134.666672, 67.50, 101.00, 202.50, 203.50, 102.00, 102.50, 205.50, 206.50, 103.50,
|
||
|
68.833336, 138.00, 138.666672, 69.50, 34.416668, 69.00, 69.333336, 34.75, 52.00, 104.25, 104.75, 52.50, 52.75, 105.75, 106.25, 53.25, 35.416668, 71.00, 71.333336, 35.75});
|
||
|
input.linspace(1.);
|
||
|
|
||
|
nd4j::ops::avgpool3dnew op;
|
||
|
auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, avgpool3d_bp_test1) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=2,oH=2,oW=2;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,
|
||
|
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,
|
||
|
0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,
|
||
|
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,
|
||
|
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,
|
||
|
0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,
|
||
|
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,
|
||
|
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,
|
||
|
0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667});
|
||
|
input.linspace(1.);
|
||
|
gradO = 2.;
|
||
|
|
||
|
nd4j::ops::avgpool3dnew_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, avgpool3d_bp_test2) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
||
|
int oD=4,oH=4,oW=4;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,
|
||
|
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,
|
||
|
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,
|
||
|
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,
|
||
|
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,
|
||
|
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,
|
||
|
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,
|
||
|
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,
|
||
|
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333});
|
||
|
input.linspace(1.);
|
||
|
gradO = 2.;
|
||
|
|
||
|
nd4j::ops::avgpool3dnew_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, avgpool3d_bp_test3) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=3,oH=4,oW=3;
|
||
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 ,
|
||
|
0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,
|
||
|
1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 ,
|
||
|
1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 ,
|
||
|
0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 ,
|
||
|
0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,
|
||
|
1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 ,
|
||
|
1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 });
|
||
|
input.linspace(1.);
|
||
|
gradO = 2.;
|
||
|
|
||
|
nd4j::ops::avgpool3dnew_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, avgpool3d_bp_test4) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=3,oH=4,oW=3;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.16667, 0.16667, 0.16667,0.33333, 0.33333, 0.33333,0.5 , 0.5 , 0.5 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75 , 1.75 ,
|
||
|
0.91667, 0.91667, 0.91667,1.83333, 1.83333, 1.83333,2.75, 2.75 , 2.75 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.66667, 0.66667, 0.66667,1.33333, 1.33333, 1.33333,2. , 2. , 2. ,
|
||
|
1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,1.83333, 1.83333, 1.83333,3.66667, 3.66667, 3.66667,5.5 , 5.5 , 5.5 ,0.5 , 0.5 , 0.5 ,1. , 1. , 1. ,1.5 , 1.5 , 1.5 ,
|
||
|
1. , 1. , 1. ,2. , 2. , 2. ,3. , 3. , 3. ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25 , 5.25 ,2.75 , 2.75 , 2.75 ,5.5 , 5.5 , 5.5 ,8.25, 8.25 , 8.25 ,
|
||
|
0.16667, 0.16667, 0.16667,0.33333, 0.33333, 0.33333,0.5 , 0.5 , 0.5 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75 , 1.75 ,
|
||
|
0.91667, 0.91667, 0.91667,1.83333, 1.83333, 1.83333,2.75, 2.75 , 2.75 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.66667, 0.66667, 0.66667,1.33333, 1.33333, 1.33333,2. , 2. , 2. ,
|
||
|
1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,1.83333, 1.83333, 1.83333,3.66667, 3.66667, 3.66667,5.5 , 5.5 , 5.5 ,0.5 , 0.5 , 0.5 ,1. , 1. , 1. ,1.5 , 1.5 , 1.5 ,
|
||
|
1. , 1. , 1. ,2. , 2. , 2. ,3. , 3. , 3. ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25 , 5.25 ,2.75 , 2.75 , 2.75 ,5.5 , 5.5 , 5.5 ,8.25, 8.25 , 8.25 });
|
||
|
input.linspace(1.);
|
||
|
gradO = 2.;
|
||
|
|
||
|
nd4j::ops::avgpool3dnew_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, maxpool3d_test1) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=2,oH=2,oW=2;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}, {20., 21., 23., 24., 32., 33., 35., 36., 56., 57., 59., 60., 68., 69., 71., 72., 92., 93., 95., 96.,104.,105.,107.,108.,
|
||
|
128.,129.,131.,132.,140.,141.,143.,144.,164.,165.,167.,168.,176.,177.,179.,180.,200.,201.,203.,204.,212.,213.,215.,216.});
|
||
|
input.linspace(1.);
|
||
|
|
||
|
nd4j::ops::maxpool3dnew op;
|
||
|
auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, maxpool3d_test2) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=3,oH=4,oW=3;
|
||
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 49., 50., 51., 52., 53., 54., 52., 53., 54., 58., 59., 60., 61., 62., 63., 61., 62., 63., 67., 68., 69., 70., 71., 72., 70., 71., 72., 67., 68., 69., 70., 71., 72., 70., 71., 72.,
|
||
|
85., 86., 87., 88., 89., 90., 88., 89., 90., 94., 95., 96., 97., 98., 99., 97., 98., 99.,103., 104., 105.,106., 107., 108.,106., 107., 108.,103., 104., 105.,106., 107., 108.,106., 107., 108.,
|
||
|
85., 86., 87., 88., 89., 90., 88., 89., 90., 94., 95., 96., 97., 98., 99., 97., 98., 99.,103., 104., 105.,106., 107., 108.,106., 107., 108.,103., 104., 105.,106., 107., 108.,106., 107., 108.,
|
||
|
157., 158., 159.,160., 161., 162.,160., 161., 162.,166., 167., 168.,169., 170., 171.,169., 170., 171.,175., 176., 177.,178., 179., 180.,178., 179., 180.,175., 176., 177.,178., 179., 180.,178., 179., 180.,
|
||
|
193., 194., 195.,196., 197., 198.,196., 197., 198.,202., 203., 204.,205., 206., 207.,205., 206., 207.,211., 212., 213.,214., 215., 216.,214., 215., 216.,211., 212., 213.,214., 215., 216.,214., 215., 216.,
|
||
|
193., 194., 195.,196., 197., 198.,196., 197., 198.,202., 203., 204.,205., 206., 207.,205., 206., 207.,211., 212., 213.,214., 215., 216.,214., 215., 216.,211., 212., 213.,214., 215., 216.,214., 215., 216.});
|
||
|
input.linspace(1.);
|
||
|
|
||
|
nd4j::ops::maxpool3dnew op;
|
||
|
auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, maxpool3d_test3) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=2,oH=2,oW=2;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, {58., 59., 60., 61., 62., 63., 67., 68., 69., 70., 71., 72., 94., 95., 96., 97., 98., 99.,103., 104., 105.,106., 107., 108.,
|
||
|
166., 167., 168.,169., 170., 171.,175., 176., 177.,178., 179., 180.,202., 203., 204.,205., 206., 207.,211., 212., 213.,214., 215., 216.});
|
||
|
input.linspace(1.);
|
||
|
|
||
|
nd4j::ops::maxpool3dnew op;
|
||
|
auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, maxpool3d_test4) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
||
|
int oD=4,oH=4,oW=4;
|
||
|
int paddingMode = 0; // -SAME, 0-VALID
|
||
|
int dataFormat = 0; // -NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW},{ 4., 5., 6., 6., 7., 8., 9., 9., 10., 11., 12., 12., 10., 11., 12., 12., 16., 17., 18., 18., 19., 20., 21., 21., 22., 23., 24., 24., 22., 23., 24., 24., 28., 29., 30., 30., 31., 32., 33., 33., 34., 35., 36., 36., 34., 35., 36., 36.,
|
||
|
28., 29., 30., 30., 31., 32., 33., 33., 34., 35., 36., 36., 34., 35., 36., 36., 40., 41., 42., 42., 43., 44., 45., 45., 46., 47., 48., 48., 46., 47., 48., 48., 52., 53., 54., 54., 55., 56., 57., 57., 58., 59., 60., 60., 58., 59., 60., 60.,
|
||
|
64., 65., 66., 66., 67., 68., 69., 69., 70., 71., 72., 72., 70., 71., 72., 72., 64., 65., 66., 66., 67., 68., 69., 69., 70., 71., 72., 72., 70., 71., 72., 72., 76., 77., 78., 78., 79., 80., 81., 81., 82., 83., 84., 84., 82., 83., 84., 84.,
|
||
|
88., 89., 90., 90., 91., 92., 93., 93., 94., 95., 96., 96., 94., 95., 96., 96.,100., 101., 102., 102.,103., 104., 105., 105.,106., 107., 108., 108.,106., 107., 108., 108.,100., 101., 102., 102.,103., 104., 105., 105.,106., 107., 108., 108.,106., 107., 108., 108.,
|
||
|
112., 113., 114., 114.,115., 116., 117., 117.,118., 119., 120., 120.,118., 119., 120., 120.,124., 125., 126., 126.,127., 128., 129., 129.,130., 131., 132., 132.,130., 131., 132., 132.,136., 137., 138., 138.,139., 140., 141., 141.,142., 143., 144., 144.,142., 143., 144., 144.,
|
||
|
136., 137., 138., 138.,139., 140., 141., 141.,142., 143., 144., 144.,142., 143., 144., 144.,148., 149., 150., 150.,151., 152., 153., 153.,154., 155., 156., 156.,154., 155., 156., 156.,160., 161., 162., 162.,163., 164., 165., 165.,166., 167., 168., 168.,166., 167., 168., 168.,
|
||
|
172., 173., 174., 174.,175., 176., 177., 177.,178., 179., 180., 180.,178., 179., 180., 180.,172., 173., 174., 174.,175., 176., 177., 177.,178., 179., 180., 180.,178., 179., 180., 180.,184., 185., 186., 186.,187., 188., 189., 189.,190., 191., 192., 192.,190., 191., 192., 192.,
|
||
|
196., 197., 198., 198.,199., 200., 201., 201.,202., 203., 204., 204.,202., 203., 204., 204.,208., 209., 210., 210.,211., 212., 213., 213.,214., 215., 216., 216.,214., 215., 216., 216.,208., 209., 210., 210.,211., 212., 213., 213.,214., 215., 216., 216.,214., 215., 216., 216.});
|
||
|
input.linspace(1.);
|
||
|
|
||
|
nd4j::ops::maxpool3dnew op;
|
||
|
auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, maxpool3d_bp_test1) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=2,oH=2,oW=2;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.1, 0.2,0. , 0.3, 0.4,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.5, 0.6,0. , 0.7, 0.8,
|
||
|
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.9, 1. ,0. , 1.1, 1.2,0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.3, 1.4,0. , 1.5, 1.6,
|
||
|
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.7, 1.8,0. , 1.9, 2. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.1, 2.2,0. , 2.3, 2.4,
|
||
|
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.5, 2.6,0. , 2.7, 2.8,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.9, 3. ,0. , 3.1, 3.2,
|
||
|
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 3.3, 3.4,0. , 3.5, 3.6,0. , 0. , 0. ,0. , 0. , 0. ,0. , 3.7, 3.8,0. , 3.9, 4. ,
|
||
|
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 4.1, 4.2,0. , 4.3, 4.4,0. , 0. , 0. ,0. , 0. , 0. ,0. , 4.5, 4.6,0. , 4.7, 4.8});
|
||
|
input.linspace(1.);
|
||
|
gradO.linspace(0.1, 0.1);
|
||
|
|
||
|
nd4j::ops::maxpool3dnew_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, maxpool3d_bp_test2) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
||
|
int oD=4,oH=4,oW=4;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.000e+00, 0.000e+00, 0.000e+00,1.000e-01, 2.000e-01, 7.000e-01,5.000e-01, 6.000e-01, 1.500e+00,2.200e+00, 2.400e+00, 5.400e+00,0.000e+00, 0.000e+00, 0.000e+00,1.700e+00, 1.800e+00, 3.900e+00,2.100e+00, 2.200e+00, 4.700e+00,5.400e+00, 5.600e+00, 1.180e+01,
|
||
|
0.000e+00, 0.000e+00, 0.000e+00,8.200e+00, 8.400e+00, 1.740e+01,9.000e+00, 9.200e+00, 1.900e+01,2.040e+01, 2.080e+01, 4.280e+01,0.000e+00, 0.000e+00, 0.000e+00,6.500e+00, 6.600e+00, 1.350e+01,6.900e+00, 7.000e+00, 1.430e+01,1.500e+01, 1.520e+01, 3.100e+01,
|
||
|
0.000e+00, 0.000e+00, 0.000e+00,8.100e+00, 8.200e+00, 1.670e+01,8.500e+00, 8.600e+00, 1.750e+01,1.820e+01, 1.840e+01, 3.740e+01,0.000e+00, 0.000e+00, 0.000e+00,2.100e+01, 2.120e+01, 4.300e+01,2.180e+01, 2.200e+01, 4.460e+01,4.600e+01, 4.640e+01, 9.400e+01,
|
||
|
0.000e+00, 0.000e+00, 0.000e+00,1.290e+01, 1.300e+01, 2.630e+01,1.330e+01, 1.340e+01, 2.710e+01,2.780e+01, 2.800e+01, 5.660e+01,0.000e+00, 0.000e+00, 0.000e+00,1.450e+01, 1.460e+01, 2.950e+01,1.490e+01, 1.500e+01, 3.030e+01,3.100e+01, 3.120e+01, 6.300e+01,
|
||
|
0.000e+00, 0.000e+00, 0.000e+00,3.380e+01, 3.400e+01, 6.860e+01,3.460e+01, 3.480e+01, 7.020e+01,7.160e+01, 7.200e+01, 1.452e+02,0.000e+00, 0.000e+00, 0.000e+00,1.930e+01, 1.940e+01, 3.910e+01,1.970e+01, 1.980e+01, 3.990e+01,4.060e+01, 4.080e+01, 8.220e+01,
|
||
|
0.000e+00, 0.000e+00, 0.000e+00,2.090e+01, 2.100e+01, 4.230e+01,2.130e+01, 2.140e+01, 4.310e+01,4.380e+01, 4.400e+01, 8.860e+01,0.000e+00, 0.000e+00, 0.000e+00,4.660e+01, 4.680e+01, 9.420e+01,4.740e+01, 4.760e+01, 9.580e+01,9.720e+01, 9.760e+01, 1.964e+02,
|
||
|
0.000e+00, 0.000e+00, 0.000e+00,2.570e+01, 2.580e+01, 5.190e+01,2.610e+01, 2.620e+01, 5.270e+01,5.340e+01, 5.360e+01, 1.078e+02,0.000e+00, 0.000e+00, 0.000e+00,2.730e+01, 2.740e+01, 5.510e+01,2.770e+01, 2.780e+01, 5.590e+01,5.660e+01, 5.680e+01, 1.142e+02,
|
||
|
0.000e+00, 0.000e+00, 0.000e+00,5.940e+01, 5.960e+01, 1.198e+02,6.020e+01, 6.040e+01, 1.214e+02,1.228e+02, 1.232e+02, 2.476e+02,0.000e+00, 0.000e+00, 0.000e+00,3.210e+01, 3.220e+01, 6.470e+01,3.250e+01, 3.260e+01, 6.550e+01,6.620e+01, 6.640e+01, 1.334e+02,
|
||
|
0.000e+00, 0.000e+00, 0.000e+00,3.370e+01, 3.380e+01, 6.790e+01,3.410e+01, 3.420e+01, 6.870e+01,6.940e+01, 6.960e+01, 1.398e+02,0.000e+00, 0.000e+00, 0.000e+00,7.220e+01, 7.240e+01, 1.454e+02,7.300e+01, 7.320e+01, 1.470e+02,1.484e+02, 1.488e+02, 2.988e+02});
|
||
|
input.linspace(1.);
|
||
|
gradO.linspace(0.1, 0.1);
|
||
|
|
||
|
nd4j::ops::maxpool3dnew_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, maxpool3d_bp_test3) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=3,oH=4,oW=3;
|
||
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, { 0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
||
|
0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.1, 0.2 , 0.3, 1.1, 1.3 , 1.5,
|
||
|
0., 0., 0., 1. , 1.1, 1.2, 2.9, 3.1 , 3.3, 0. , 0. , 0. , 4.7, 4.9 , 5.1, 11.2, 11.6 , 12. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
||
|
0., 0., 0., 11. , 11.2, 11.4, 23.8, 24.2 , 24.6, 0. , 0. , 0. , 12.8, 13. , 13.2, 27.4, 27.8 , 28.2, 0. , 0. , 0. , 31. , 31.4 , 31.8, 65.6, 66.39999, 67.2,
|
||
|
0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
||
|
0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 10.9, 11. , 11.1, 22.7, 22.9 , 23.1,
|
||
|
0., 0., 0., 11.8, 11.9, 12. , 24.5, 24.7 , 24.9, 0. , 0. , 0. , 26.3, 26.5 , 26.7, 54.4, 54.8 , 55.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
||
|
0., 0., 0., 32.6, 32.8, 33. , 67. , 67.4 , 67.8, 0. , 0. , 0. , 34.4, 34.6 , 34.8, 70.6, 71. , 71.4, 0. , 0. , 0. , 74.2, 74.6 , 75. ,152. , 152.8 ,153.6});
|
||
|
input.linspace(1.);
|
||
|
gradO.linspace(0.1, 0.1);
|
||
|
|
||
|
nd4j::ops::maxpool3dnew_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, maxpool3d_bp_test4) {
|
||
|
|
||
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
||
|
int oD=3,oH=4,oW=3;
|
||
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
||
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
|
||
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0.2, 0.3, 1.1, 1.3, 1.5, 0, 0, 0, 5.7, 6, 6.3,
|
||
|
14.1, 14.7, 15.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 11.2, 11.4, 23.8, 24.2,
|
||
|
24.6, 0, 0, 0, 43.8, 44.4, 45, 93, 94.2, 95.4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||
|
10.9, 11, 11.1, 22.7, 22.9, 23.1, 0, 0, 0, 38.1, 38.4, 38.7, 78.9, 79.5, 80.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32.6, 32.8, 33, 67, 67.4, 67.8, 0, 0, 0, 108.6, 109.2, 109.8, 222.6, 223.8, 225,});
|
||
|
input.linspace(1.);
|
||
|
gradO.linspace(0.1, 0.1);
|
||
|
|
||
|
nd4j::ops::maxpool3dnew_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TYPED_TEST(TypedDeclarableOpsTests4, maxpool2d_test1) {
|
||
|
|
||
|
int bS = 3; // batch size (number of samples)
|
||
|
int iC = 3; // input channels
|
||
|
int iH = 28, iW = 28; // input height/width
|
||
|
int kH = 2, kW = 2; // kernel (filter) height/width
|
||
|
int sH = 1, sW = 1; // stride height/width
|
||
|
int pH = 0, pW = 0; // padding height/width
|
||
|
int dH = 1, dW = 1; // dilation height/width
|
||
|
|
||
|
int oH = 27, oW = 27; // output height/width
|
||
|
|
||
|
int isSameMode = 0; // 1-SAME, 0-VALID
|
||
|
|
||
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
|
||
|
|
||
|
nd4j::ops::maxpool2d op;
|
||
|
auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, 1, 0});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
ASSERT_TRUE(output->isSameShape({bS, iC, oH, oW}));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, tri_test1) {
|
||
|
|
||
|
const int rows = 3;
|
||
|
const int cols = 5;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0});
|
||
|
|
||
|
nd4j::ops::tri op;
|
||
|
auto results = op.execute({}, {}, {rows, cols});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
// output->printIndexedBuffer();
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, tri_test2) {
|
||
|
|
||
|
const int rows = 3;
|
||
|
const int cols = 5;
|
||
|
const int diag = 2;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1});
|
||
|
|
||
|
nd4j::ops::tri op;
|
||
|
auto results = op.execute({}, {}, {rows, cols, diag});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, tri_test3) {
|
||
|
|
||
|
const int rows = 3;
|
||
|
const int cols = 5;
|
||
|
const int diag = -1;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0});
|
||
|
|
||
|
nd4j::ops::tri op;
|
||
|
auto results = op.execute({}, {}, {rows, cols, diag});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, tri_test4) {
|
||
|
|
||
|
const int rows = 3;
|
||
|
const int cols = 5;
|
||
|
const int diag = -2;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0});
|
||
|
|
||
|
nd4j::ops::tri op;
|
||
|
auto results = op.execute({}, {}, {rows, cols, diag});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, tri_test5) {
|
||
|
|
||
|
const int rows = 5;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<float>('c', {rows, rows}, {1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1});
|
||
|
|
||
|
nd4j::ops::tri op;
|
||
|
auto results = op.execute({}, {}, {rows});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, tri_test6) {
|
||
|
|
||
|
const int rows = 3;
|
||
|
const int cols = 5;
|
||
|
const int diag = -20;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||
|
|
||
|
nd4j::ops::tri op;
|
||
|
auto results = op.execute({}, {}, {rows, cols, diag});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, tri_test7) {
|
||
|
|
||
|
const int rows = 3;
|
||
|
const int cols = 5;
|
||
|
const int diag = 20;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||
|
|
||
|
nd4j::ops::tri op;
|
||
|
auto results = op.execute({}, {}, {rows, cols, diag});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test1) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 0, 5, 6, 0, 0, 9, 0, 0, 0});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test2) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3,4, 5, 6,0, 8, 9,0, 0, 12});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {-1});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test3) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2,3, 4,0, 6,7, 8,9,10,0,12});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {-1});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test4) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2,0, 4,0, 0,7, 8,0, 10,0, 0});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test5) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {0, 2,0, 0,0, 0,0, 8,0, 0,0, 0});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {1});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test6) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {0, 0,0, 0,0, 0,0, 0,0, 0,0, 0});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {10});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test7) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {-10});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test8) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {6}, {1, 2, 3, 4, 5, 6});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {6, 6}, {1, 2, 3, 4, 5, 6,0, 2, 3, 4, 5, 6,0, 0, 3, 4, 5, 6,0, 0, 0, 4, 5, 6,0, 0, 0, 0, 5, 6,0, 0, 0, 0, 0, 6});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test9) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {6}, {1, 2, 3, 4, 5, 6});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 0, 2, 3, 4, 5, 6, 0, 0, 3, 4, 5, 6});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {-3});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test10) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {6}, {1, 2, 3, 4, 5, 6});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {6, 6}, {0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {3});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_test11) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {6}, {1, 2, 3, 4, 5, 6});
|
||
|
auto expected = NDArrayFactory::create<double>('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6});
|
||
|
|
||
|
nd4j::ops::triu op;
|
||
|
auto results = op.execute({&input}, {}, {-58});
|
||
|
auto output = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_bp_test1) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto gradO = NDArrayFactory::create<double>('c', {2, 3, 2});
|
||
|
gradO = 0.5;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {0.,0.5,0.,0. ,0.,0. ,0.,0.5,0.,0. ,0.,0.});
|
||
|
|
||
|
nd4j::ops::triu_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {1});
|
||
|
auto gradI = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(gradI));
|
||
|
ASSERT_TRUE(expected.equalsTo(gradI));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_bp_test2) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||
|
auto gradO = NDArrayFactory::create<double>('c', {2, 3, 2});
|
||
|
gradO = 0.5;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {0.5,0.5,0. ,0.5,0. ,0. ,0.5,0.5,0. ,0.5,0. ,0.});
|
||
|
|
||
|
nd4j::ops::triu_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {});
|
||
|
auto gradI = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(gradI));
|
||
|
ASSERT_TRUE(expected.equalsTo(gradI));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_bp_test3) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {6}, {1, 2, 3, 4, 5, 6});
|
||
|
auto gradO = NDArrayFactory::create<double>('c', {6,6});
|
||
|
gradO = 0.5;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {6,6}, {0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0. , 0.5, 0.5, 0.5, 0.5,0. , 0. , 0. , 0.5, 0.5, 0.5});
|
||
|
|
||
|
nd4j::ops::triu_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {-2});
|
||
|
auto gradI = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(gradI));
|
||
|
ASSERT_TRUE(expected.equalsTo(gradI));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////
|
||
|
TEST_F(DeclarableOpsTests4, triu_bp_test4) {
|
||
|
|
||
|
auto input = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 3, 4, 5, 6});
|
||
|
auto gradO = NDArrayFactory::create<double>('c', {2,3});
|
||
|
gradO = 0.5;
|
||
|
|
||
|
auto expected = NDArrayFactory::create<double>('c', {2,3}, {0., 0., 0., 0., 0., 0.});
|
||
|
|
||
|
nd4j::ops::triu_bp op;
|
||
|
auto results = op.execute({&input, &gradO}, {}, {10});
|
||
|
auto gradI = results->at(0);
|
||
|
|
||
|
ASSERT_EQ(Status::OK(), results->status());
|
||
|
|
||
|
ASSERT_TRUE(expected.isSameShape(gradI));
|
||
|
ASSERT_TRUE(expected.equalsTo(gradI));
|
||
|
|
||
|
delete results;
|
||
|
}
|
||
|
|