2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author raver119@gmail.com
|
|
|
|
//
|
|
|
|
|
|
|
|
#include <NativeOps.h>
|
|
|
|
#include <NDArray.h>
|
|
|
|
#include <ops/declarable/CustomOperations.h>
|
|
|
|
#include <ops/declarable/OpRegistrator.h>
|
|
|
|
#include <graph/GraphHolder.h>
|
|
|
|
#include <graph/FlatUtils.h>
|
|
|
|
#include "testlayers.h"
|
|
|
|
#include <array>
|
|
|
|
|
|
|
|
using namespace nd4j;
|
|
|
|
using namespace nd4j::ops;
|
|
|
|
|
|
|
|
class JavaInteropTests : public testing::Test {
|
|
|
|
public:
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestShapeExposure1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
|
|
|
|
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
|
|
|
|
|
|
|
|
nd4j::ops::conv2d op;
|
|
|
|
|
|
|
|
std::vector<double> tArgs({});
|
|
|
|
std::vector<Nd4jLong> iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1});
|
|
|
|
|
|
|
|
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(1, shapeList->size());
|
|
|
|
|
|
|
|
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
|
|
|
|
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
|
|
|
|
ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]);
|
|
|
|
ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]);
|
|
|
|
ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]);
|
|
|
|
|
|
|
|
//int *ptr = (int *) shapeList[0];
|
|
|
|
//delete[] ptr;
|
|
|
|
//delete shapeList;
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
deleteShapeList((Nd4jPointer) shapeList);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestShapeExposure2) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
|
|
|
|
|
|
|
|
nd4j::ops::shape_of op;
|
|
|
|
|
|
|
|
std::vector<double> tArgs({});
|
|
|
|
std::vector<Nd4jLong> iArgs({});
|
|
|
|
|
|
|
|
|
|
|
|
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(1, shapeList->size());
|
|
|
|
|
|
|
|
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
|
|
|
|
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
deleteShapeList((Nd4jPointer) shapeList);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestShapeExposure3) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {5, 30});
|
|
|
|
auto sizes = NDArrayFactory::create<int>('c', {3}, {4, 15, 11});
|
|
|
|
|
|
|
|
std::vector<Nd4jLong> list0 = {0,0, 0,4};
|
|
|
|
std::vector<Nd4jLong> list1 = {0,0, 4,19};
|
|
|
|
std::vector<Nd4jLong> list2 = {0,0, 19,30};
|
|
|
|
|
|
|
|
auto sub0 = x(list0, true);
|
|
|
|
auto sub1 = x(list1, true);
|
|
|
|
auto sub2 = x(list2, true);
|
|
|
|
|
|
|
|
sub0.assign(0.0f);
|
|
|
|
sub1.assign(1.0f);
|
|
|
|
sub2.assign(2.0f);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer(), x.getSpecialBuffer(), sizes.getSpecialBuffer()};
|
|
|
|
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo(), x.getSpecialShapeInfo(), sizes.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
nd4j::ops::split_v op;
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jLong iArgs[] = {1};
|
|
|
|
auto hash = op.getOpHash();
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(3, shapeList->size());
|
|
|
|
|
|
|
|
ASSERT_TRUE(shape::equalsSoft(sub0.shapeInfo(), shapeList->at(0)));
|
|
|
|
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
|
|
|
|
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
deleteShapeList((Nd4jPointer) shapeList);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {1, 6}, {1, 2, 3, 4, 5, 6});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {6});
|
|
|
|
auto e = NDArrayFactory::create<float>('c', {6}, {1, 2, 3, 4, 5, 6});
|
|
|
|
|
|
|
|
nd4j::ops::squeeze op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_RDiv_1) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3}, {2, 2, 2});
|
|
|
|
auto y = NDArrayFactory::create<double>('c', {3}, {4, 6, 8});
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {3});
|
|
|
|
auto e = NDArrayFactory::create<double>('c', {3}, {2, 3, 4});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
nd4j::ops::reversedivide op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestSconv2d_1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
|
|
|
|
auto weightsD = NDArrayFactory::create<float>('c', {1, 3, 1, 1});
|
|
|
|
auto weightsP = NDArrayFactory::create<float>('c', {2, 3, 1, 1});
|
|
|
|
auto bias = NDArrayFactory::create<float>('c', {1, 2});
|
|
|
|
auto output = NDArrayFactory::create<float>('c', {3, 2, 8, 8});
|
|
|
|
output.assign(0.0);
|
|
|
|
|
|
|
|
input.linspace(1);
|
|
|
|
weightsD.linspace(1);
|
|
|
|
weightsP.linspace(1);
|
|
|
|
bias.linspace(1);
|
|
|
|
weightsD.permutei({2,3,1,0});
|
|
|
|
weightsP.permutei({2,3,1,0});
|
|
|
|
|
|
|
|
auto expOutput = NDArrayFactory::create<float>('c', {3, 2, 8, 8});
|
|
|
|
|
|
|
|
nd4j::ops::sconv2d op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer(), (Nd4jPointer) weightsP.getBuffer(), (Nd4jPointer) bias.getBuffer(), (Nd4jPointer) input.getSpecialBuffer(), (Nd4jPointer) weightsD.getSpecialBuffer(), (Nd4jPointer) weightsP.getSpecialBuffer(), (Nd4jPointer) bias.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo(), (Nd4jPointer) weightsP.getShapeInfo(), (Nd4jPointer) bias.getShapeInfo(), (Nd4jPointer) input.getSpecialShapeInfo(), (Nd4jPointer) weightsD.getSpecialShapeInfo(), (Nd4jPointer) weightsP.getSpecialShapeInfo(), (Nd4jPointer) bias.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), (Nd4jPointer) output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), (Nd4jPointer) output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
|
2019-06-06 14:21:15 +02:00
|
|
|
nullptr, 0, exp, 9, nullptr, 0, false);
|
|
|
|
|
|
|
|
//output.printBuffer("output");
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_NEAR(1423, output.e<float>(0), 1e-5);
|
|
|
|
//nd4j_printf("Iter %i passed...\n", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestSconv2d_2) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
|
|
|
|
auto weightsD = NDArrayFactory::create<float>('c', {1, 3, 1, 1});
|
|
|
|
auto output = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
|
|
|
|
output.assign(0.0);
|
|
|
|
|
|
|
|
input.linspace(1);
|
|
|
|
weightsD.linspace(1);
|
|
|
|
weightsD.permutei({2,3,1,0});
|
|
|
|
|
|
|
|
auto expOutput = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
|
|
|
|
|
|
|
|
nd4j::ops::sconv2d op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input, &weightsD});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer(), input.getSpecialBuffer(), weightsD.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo(), input.getSpecialShapeInfo(), weightsD.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&output}, {&input, &weightsD});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_NEAR(1, output.e<float>(0), 1e-5);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestMaxPooling2d_1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
|
|
|
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
|
|
|
input.linspace(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
std::vector<Nd4jLong> iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1});
|
|
|
|
|
|
|
|
nd4j::ops::maxpool2d op;
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&output}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
|
|
|
|
}
|
|
|
|
TEST_F(JavaInteropTests, TestCol2Im_1) {
|
|
|
|
/*
|
|
|
|
o.d.n.l.c.ConvolutionLayer - eps shape: [6, 1, 2, 2, 2, 4, 5, 160, 4, 2, 1, 40, 8, 0, -1, 99]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - epsNext shape: [4, 1, 2, 4, 5, 20, 20, 5, 1, 0, 1, 99]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - Strides: [1, 1]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - Padding: [0, 0]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - Input: [4,5]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - Dilation: [1, 1]
|
|
|
|
*/
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 2, 2, 4, 5});
|
|
|
|
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
|
|
|
input.linspace(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
nd4j::ops::col2im op;
|
|
|
|
|
|
|
|
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&output}, {&input});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestPNorm_1) {
|
|
|
|
/*
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - input: [4, 1, 3, 4, 4, 16, 16, 4, 1, 0, 1, 99]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - output: [4, 1, 3, 3, 3, 27, 9, 3, 1, 0, 1, 99]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Kernel: [2, 2]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Strides: [1, 1]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Pad: [0, 0]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Dilation: [1, 1]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Same: false
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - pnorm: 2
|
|
|
|
*/
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 3, 4, 4});
|
|
|
|
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
|
|
|
input.linspace(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
nd4j::ops::pnormpool2d op;
|
|
|
|
|
|
|
|
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&output}, {&input});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestInplace_1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {10, 10});
|
|
|
|
//auto exp('c', {10, 10});
|
|
|
|
input.linspace(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({}, {&input});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
nd4j::ops::clipbyvalue op;
|
|
|
|
|
|
|
|
double extras[] = {-1.0f, 1.0f};
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({}, {&input});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
|
|
|
|
|
|
|
ASSERT_NEAR(1.0, input.meanNumber().e<float>(0), 1e-5);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Synonyms_1) {
|
|
|
|
auto op = OpRegistrator::getInstance()->getOperation("RDiv");
|
|
|
|
auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide");
|
|
|
|
std::string nameExp("reversedivide");
|
|
|
|
|
|
|
|
ASSERT_TRUE(op != nullptr);
|
|
|
|
ASSERT_TRUE(opRef != nullptr);
|
|
|
|
|
|
|
|
std::string name = *(op->getOpName());
|
|
|
|
std::string nameRef = *(opRef->getOpName());
|
|
|
|
|
|
|
|
ASSERT_EQ(nameExp, nameRef);
|
|
|
|
ASSERT_EQ(nameRef, name);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Synonyms_2) {
|
|
|
|
auto op = OpRegistrator::getInstance()->getOperation("RDiv");
|
|
|
|
auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide");
|
|
|
|
std::string nameExp("reversedivide");
|
|
|
|
|
|
|
|
ASSERT_TRUE(op != nullptr);
|
|
|
|
ASSERT_TRUE(opRef != nullptr);
|
|
|
|
|
|
|
|
std::string name = *(op->getOpName());
|
|
|
|
std::string nameRef = *(opRef->getOpName());
|
|
|
|
|
|
|
|
ASSERT_EQ(nameExp, nameRef);
|
|
|
|
ASSERT_EQ(nameRef, name);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Synonyms_3) {
|
|
|
|
auto op = OpRegistrator::getInstance()->getOperation("RDiv");
|
|
|
|
auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide");
|
|
|
|
std::string nameExp("reversedivide");
|
|
|
|
|
|
|
|
ASSERT_TRUE(op != nullptr);
|
|
|
|
ASSERT_TRUE(opRef != nullptr);
|
|
|
|
|
|
|
|
std::string name = *(op->getOpName());
|
|
|
|
std::string nameRef = *(opRef->getOpName());
|
|
|
|
|
|
|
|
ASSERT_EQ(nameExp, nameRef);
|
|
|
|
ASSERT_EQ(nameRef, name);
|
|
|
|
}
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
/*
|
2019-06-06 14:21:15 +02:00
|
|
|
TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
|
|
|
int inOutH = 35;
|
|
|
|
int inOutW = 35;
|
|
|
|
int inOutC = 192;
|
|
|
|
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
|
|
|
|
x.linspace(1.0);
|
|
|
|
z.linspace(1.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
nd4j::ops::avgpool2d op;
|
|
|
|
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
|
|
|
|
|
|
|
Nd4jLong exp[] = {3,3, 1,1, 0,0, 1,1, 1, 0, 1};
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&x});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), result);
|
|
|
|
|
|
|
|
int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH;
|
|
|
|
int padTop = totalPadHeight / 2;
|
|
|
|
int padBottom = totalPadHeight - totalPadHeight / 2;
|
|
|
|
|
|
|
|
int k = 3;
|
|
|
|
|
|
|
|
auto m = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
|
|
|
|
auto c = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
|
|
|
|
|
|
|
|
for (int h = 0; h < inOutH; h++) {
|
|
|
|
for (int w = 0; w < inOutW; w++) {
|
|
|
|
int hFrom = h - padTop;
|
|
|
|
int wFrom = w - padBottom;
|
|
|
|
|
|
|
|
int hTo = hFrom + k;
|
|
|
|
int wTo = wFrom + k;
|
|
|
|
|
|
|
|
hFrom = nd4j::math::nd4j_max<int>(0, hFrom);
|
|
|
|
wFrom = nd4j::math::nd4j_max<int>(0, wFrom);
|
|
|
|
|
|
|
|
hTo = nd4j::math::nd4j_min<int>(inOutH, hTo);
|
|
|
|
wTo = nd4j::math::nd4j_min<int>(inOutW, wTo);
|
|
|
|
|
|
|
|
int idxOut[4];
|
|
|
|
int idxIn[4];
|
|
|
|
for (int ch = 0; ch < inOutC; ch++) {
|
|
|
|
idxOut[1] = h;
|
|
|
|
idxOut[2] = w;
|
|
|
|
idxOut[3] = ch;
|
|
|
|
idxIn[3] = ch;
|
|
|
|
|
|
|
|
for (int kh = hFrom; kh < hTo; kh++) {
|
|
|
|
for (int kw = wFrom; kw < wTo; kw++) {
|
|
|
|
idxIn[1] = kh;
|
|
|
|
idxIn[2] = kw;
|
|
|
|
|
|
|
|
auto inVal = x.e<float>(0, kh, kw, ch);
|
|
|
|
m.p(0, h, w, ch, inVal + m.e<float>(0, h, w, ch));
|
|
|
|
c.p(0, h, w, ch, 1 + c.e<int>(0, h, w, ch));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
m /= c;
|
|
|
|
|
|
|
|
//z.printIndexedBuffer("z buffer", 100);
|
|
|
|
//m.printIndexedBuffer("m buffer", 100);
|
|
|
|
int cnt = 0;
|
|
|
|
int lim = 10;
|
|
|
|
for (int e = 0; e < z.lengthOf() && cnt < lim; e++) {
|
|
|
|
auto _m = m.e<float>(e);
|
|
|
|
auto _z = z.e<float>(e);
|
|
|
|
auto eq = nd4j::math::nd4j_eq<float>(_m, _z, 1e-5);
|
|
|
|
if (!eq) {
|
|
|
|
nd4j_printf("Difference at element e [%i]: <%f> vs <%f>\n", e, _m, _z);
|
|
|
|
cnt++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ASSERT_EQ(m, z);
|
|
|
|
}
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
|
|
|
|
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
unregisterGraph(nullptr, 119);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
|
|
|
|
|
|
|
delete[] data;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
|
|
|
//Environment::getInstance()->setDebug(true);
|
|
|
|
//Environment::getInstance()->setVerbose(true);
|
|
|
|
|
|
|
|
auto exp0 = NDArrayFactory::create<float>('c', {3}, {3, 3, 3});
|
|
|
|
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
|
|
|
|
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
|
|
|
|
|
|
|
|
// we load graph from file, because we're not in java here, and dont have buffer ready
|
|
|
|
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
|
|
|
|
|
|
|
// we ensure that there's no such a graph stored earlier
|
|
|
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
|
|
|
// register the graph, to call for it later
|
2019-07-22 13:34:08 +02:00
|
|
|
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// and ensure we're ok
|
|
|
|
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// run stuff
|
|
|
|
|
|
|
|
auto input_0 = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
input_0.assign(1.0f);
|
|
|
|
|
|
|
|
int idx[] = {1};
|
|
|
|
|
|
|
|
Nd4jPointer inputs_0[] = {(Nd4jPointer) input_0.buffer()};
|
|
|
|
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
|
|
|
|
|
|
|
|
// now we're executing stored graph and providing replacement for input variable
|
2019-07-22 13:34:08 +02:00
|
|
|
auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
|
|
|
|
ASSERT_EQ(1, res_0->size());
|
|
|
|
|
|
|
|
auto z0 = res_0->at(0)->getNDArray();
|
|
|
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
|
|
|
|
|
|
|
|
|
|
|
auto input_1 = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
input_1.assign(2.0f);
|
|
|
|
|
|
|
|
Nd4jPointer inputs_1[] = {(Nd4jPointer) input_1.buffer()};
|
|
|
|
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
|
|
|
|
|
|
|
|
// doing it again
|
2019-07-22 13:34:08 +02:00
|
|
|
auto res_1 = executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
|
|
|
ASSERT_EQ(1, res_1->size());
|
|
|
|
|
|
|
|
auto z1 = res_1->at(0)->getNDArray();
|
|
|
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
|
|
|
|
|
|
|
|
|
|
|
auto input_2 = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
input_2.assign(3.0f);
|
|
|
|
|
|
|
|
Nd4jPointer inputs_2[] = {(Nd4jPointer) input_2.buffer()};
|
|
|
|
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
|
|
|
|
|
|
|
|
// and again
|
2019-07-22 13:34:08 +02:00
|
|
|
auto res_2 = executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
|
|
|
ASSERT_EQ(1, res_2->size());
|
|
|
|
|
|
|
|
auto z2 = res_2->at(0)->getNDArray();
|
|
|
|
ASSERT_TRUE(exp2.isSameShape(z2));
|
|
|
|
|
|
|
|
|
|
|
|
//////// clean out
|
2019-07-22 13:34:08 +02:00
|
|
|
unregisterGraph(nullptr, 119);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
|
|
|
|
|
|
|
delete[] data;
|
|
|
|
delete res_0;
|
|
|
|
delete res_1;
|
|
|
|
delete res_2;
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Greater_1) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 1, 2});
|
|
|
|
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 0, 0});
|
|
|
|
// auto o = NDArrayFactory::create<float>('c', {2, 2}, {3, 3, 3, 3});
|
|
|
|
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {1, 1, 1, 1});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
nd4j::ops::greater op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&o}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(exp.equalsTo(&o));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Greater_2) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 1, 2});
|
|
|
|
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 0, 0});
|
|
|
|
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {1, 1, 1, 1});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
|
|
|
|
|
|
|
|
nd4j::ops::greater op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&o}, {&x, &y});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(exp.equalsTo(&o));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
nd4j::ops::is_non_decreasing op;
|
|
|
|
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
|
|
|
|
auto o = NDArrayFactory::create<bool>(false);
|
|
|
|
auto exp = NDArrayFactory::create<bool>(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&o}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&o}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(&o));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
|
|
|
|
|
|
|
nd4j::ops::test_output_reshape op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto y = NDArrayFactory::create<float>(2.0f);
|
|
|
|
auto z = NDArrayFactory::create<float>('f', {2, 3});
|
|
|
|
auto e = NDArrayFactory::create<float>('c', {2, 3}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
|
|
|
|
|
|
|
|
|
|
|
nd4j::ops::add op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_TRUE(e.isSameShape(z));
|
|
|
|
ASSERT_TRUE(e.equalsTo(z));
|
|
|
|
ASSERT_FALSE(e.ordering() == z.ordering());
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {2, 3, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24});
|
|
|
|
auto indices = NDArrayFactory::create<Nd4jLong>('c', {1, 6}, {0,1, 2,2, 1,2});
|
|
|
|
auto output = NDArrayFactory::create<float>('f', {2, 1, 6, 4});
|
|
|
|
auto e = NDArrayFactory::create<float>('c', {2, 1, 6, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24});
|
|
|
|
|
|
|
|
nd4j::ops::gather op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input, &indices});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) indices.getBuffer(), input.getSpecialBuffer(), indices.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) indices.getShapeInfo(), input.getSpecialShapeInfo(), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {1};
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&output}, {&input, &indices});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_TRUE(e.isSameShape(output));
|
|
|
|
ASSERT_TRUE(e.equalsTo(output));
|
|
|
|
ASSERT_FALSE(e.ordering() == output.ordering());
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
|
|
|
|
auto y = NDArrayFactory::create<float>('c', {3, 4, 5});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {5});
|
|
|
|
|
|
|
|
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
2019-08-02 19:01:03 +02:00
|
|
|
dims.syncToHost();
|
|
|
|
|
|
|
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
|
|
|
|
|
|
|
Nd4jPointer* extraPointers = nullptr;
|
|
|
|
#ifdef __CUDABLAS__
|
|
|
|
extraPointers = new Nd4jPointer[6] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer()};
|
|
|
|
#endif
|
|
|
|
|
|
|
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), {0,1});
|
|
|
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1});
|
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dims});
|
|
|
|
|
|
|
|
execReduce3Tad(extraPointers, 2, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
|
|
|
nullptr,
|
|
|
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
|
|
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
|
|
|
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dims});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
delete []extraPointers;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
/*
|
|
|
|
TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
|
|
|
|
Environment::getInstance()->setDebug(true);
|
|
|
|
Environment::getInstance()->setVerbose(false);
|
|
|
|
|
|
|
|
auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
|
2019-07-22 13:34:08 +02:00
|
|
|
auto ptr = executeFlatGraph(nullptr, pl);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Environment::getInstance()->setDebug(false);
|
|
|
|
Environment::getInstance()->setVerbose(false);
|
|
|
|
|
|
|
|
delete[] pl;
|
|
|
|
delete ptr;
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) {
|
|
|
|
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {4, 10, 10, 3}, {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, 5.4618740
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {4, 4, 4, 3});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {4, 4, 4, 3}, {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, 4.49452257
|
|
|
|
|
|
|
|
nd4j::ops::avgpool2d op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&input});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&input});
|
|
|
|
ASSERT_EQ(Status::OK(), status);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 1, 4, 5});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {1, 1, 4, 5});
|
|
|
|
|
|
|
|
input.linspace(1.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {2,2, 1,1, 1,1, 2,2,1, 0,0};
|
|
|
|
|
|
|
|
nd4j::ops::maxpool2d op;
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Unstack_1) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {5, 5});
|
|
|
|
x.linspace(1.0);
|
|
|
|
auto z0 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
auto z1 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
auto z2 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
auto z3 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
auto z4 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer()), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo()), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), z2.buffer(), z3.buffer(), z4.buffer(), z0.getSpecialBuffer(), z1.getSpecialBuffer(), z2.getSpecialBuffer(), z3.getSpecialBuffer(), z4.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {z0.shapeInfo(), z1.shapeInfo(), z2.shapeInfo(), z3.shapeInfo(), z4.shapeInfo(), z0.getSpecialShapeInfo(), z1.getSpecialShapeInfo(), z2.getSpecialShapeInfo(), z3.getSpecialShapeInfo(), z4.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {0};
|
|
|
|
|
|
|
|
nd4j::ops::unstack op;
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) {
|
|
|
|
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f,2.91434479f,5.43639755f,-2.10573769f, 4.08528662f,5.86908436f,-4.46203756f,2.21057916f,5.35849190f,0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {4, 4, 4, 3});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f,
|
|
|
|
|
|
|
|
nd4j::ops::avgpool2d op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&input});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&input});
|
|
|
|
ASSERT_EQ(Status::OK(), status);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
|
|
|
|
if (!Environment::getInstance()->isExperimentalBuild())
|
|
|
|
return;
|
|
|
|
|
|
|
|
auto arrayX = NDArrayFactory::create<int>({1, 2, 3, 4});
|
|
|
|
auto arrayY = NDArrayFactory::create<double>({1, 2, 3, 4});
|
|
|
|
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
|
|
|
|
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
|
|
|
|
|
|
|
execPairwiseTransform(nullptr, pairwise::Add,
|
|
|
|
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
|
|
|
arrayY.buffer(), arrayY.shapeInfo(), arrayY.getSpecialBuffer(), arrayY.getSpecialShapeInfo(),
|
|
|
|
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
2019-06-06 14:21:15 +02:00
|
|
|
nullptr);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(arrayE, arrayZ);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Add_1) {
|
|
|
|
auto x = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
|
|
|
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
|
|
|
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&x}, {&x, &y});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
nd4j::ops::add op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo(),};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&x}, {&x, &y});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(e, x);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, zeta_test10) {
|
|
|
|
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3, 4}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12});
|
|
|
|
auto q = NDArrayFactory::create<double>('c', {3, 4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12});
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {3, 4});
|
|
|
|
|
|
|
|
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
|
|
|
|
|
|
|
nd4j::ops::zeta op;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &q});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer(), x.getSpecialBuffer(), q.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), q.getShapeInfo(), x.specialShapeInfo(), q.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&x, &q});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Is_Max_1) {
|
|
|
|
auto arrayX = NDArrayFactory::create<float>({1, 2, 1, 1});
|
|
|
|
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
|
|
|
|
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
|
|
|
|
|
|
|
Nd4jPointer* extraPointers = nullptr;
|
|
|
|
#ifdef __CUDABLAS__
|
|
|
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
|
|
|
#endif
|
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
|
|
|
|
execTransformAny(extraPointers, transform::IsMax,
|
|
|
|
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
|
|
|
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
2019-06-06 14:21:15 +02:00
|
|
|
nullptr);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(arrayE, arrayZ);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
delete []extraPointers;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
|
|
|
|
auto arrayX = NDArrayFactory::create<float>({1, 2, 1, 1});
|
|
|
|
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
|
|
|
|
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
|
|
|
|
|
|
|
Nd4jPointer* extraPointers = nullptr;
|
|
|
|
#ifdef __CUDABLAS__
|
|
|
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
|
|
|
#endif
|
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
|
|
|
|
execTransformAny(extraPointers, transform::IsMax,
|
|
|
|
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
|
|
|
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
2019-06-06 14:21:15 +02:00
|
|
|
nullptr);
|
|
|
|
//arrayZ.printIndexedBuffer("JAVA ISMAX1");
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(arrayE, arrayZ);
|
2019-08-02 19:01:03 +02:00
|
|
|
delete []extraPointers;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Is_Max_2) {
|
|
|
|
auto arrayX = NDArrayFactory::create<float>('c', {3, 2, 3}, {1, 10, 2, 3, 4, 5, -10, -9, -8, -7, -6, -5, 4, 3, 2, 1, 0, -1});
|
|
|
|
auto arrayZ = NDArrayFactory::create<bool>('c', {3, 2, 3});
|
|
|
|
Nd4jLong tad[] = {2, 2, 3, 3, 1, 524288, -1, 99};
|
|
|
|
Nd4jLong off[] = {0, 6, 12};
|
|
|
|
Nd4jLong *ex[] = {tad, off};
|
|
|
|
float ea[] = {2, 1, 2};
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
|
|
|
|
execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
|
|
|
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
|
|
|
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
2019-06-06 14:21:15 +02:00
|
|
|
ea);
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_IAMax_1) {
|
|
|
|
auto arrayX = NDArrayFactory::create<float>({-0.24f, -0.26f, -0.07f, -0.01f});
|
|
|
|
auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr);
|
|
|
|
auto exp = NDArrayFactory::create<Nd4jLong>(1);
|
|
|
|
|
|
|
|
ASSERT_EQ(exp, arrayZ);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
|
|
|
|
auto arrayX = NDArrayFactory::create<double>('c', {10, 10});
|
|
|
|
auto arrayY = NDArrayFactory::create<double>('c', {10, 10});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(arrayX.buffer()), reinterpret_cast<Nd4jPointer>(arrayY.buffer()), arrayX.getSpecialBuffer(), arrayY.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(arrayX.shapeInfo()), reinterpret_cast<Nd4jPointer>(arrayY.shapeInfo()), arrayX.getSpecialShapeInfo(), arrayY.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({}, {&arrayX, &arrayY});
|
2019-06-06 14:21:15 +02:00
|
|
|
nd4j::ops::greater_equal op;
|
2019-07-22 13:34:08 +02:00
|
|
|
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({}, {&arrayX, &arrayY});
|
2019-06-06 14:21:15 +02:00
|
|
|
delete shapeList;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_L2_Loss_3) {
|
|
|
|
auto x = NDArrayFactory::create<double>(0.7787855863571167);
|
|
|
|
auto e = NDArrayFactory::create<double>(0.303254);
|
|
|
|
auto z = NDArrayFactory::create<double>(0.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer()), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo()), x.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
nd4j::ops::l2_loss op;
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_3) {
|
|
|
|
auto array0 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto array1 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {3, 2});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
|
|
|
Context ctx(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&array0, &array1});
|
|
|
|
|
|
|
|
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.getSpecialBuffer(), array0.getSpecialShapeInfo());
|
|
|
|
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.getSpecialBuffer(), array1.getSpecialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(2, ctx.width());
|
|
|
|
|
|
|
|
nd4j::ops::add op;
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&array0, &array1});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(exp, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_4) {
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1,1,1,0,0, 1,1,1,1,0, 1,1,1,1,1});
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {3, 5});
|
|
|
|
Nd4jLong iArgs[] = {3, 5, 2};
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({&z}, {});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Context ctx(1);
|
|
|
|
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
ctx.setIArguments(iArgs, 3);
|
|
|
|
|
|
|
|
nd4j::ops::tri op;
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(exp, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_5) {
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
auto c = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
a.linspace(1.0);
|
|
|
|
b.linspace(1.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&c}, {&b, &c});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Context ctx(1);
|
|
|
|
|
|
|
|
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
|
|
|
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
|
|
|
|
|
|
|
|
nd4j::ops::matmul op;
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&c}, {&b, &c});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_6) {
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2, 3});
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 4});
|
|
|
|
auto gI = NDArrayFactory::create<float>('c', {2, 4});
|
|
|
|
|
|
|
|
auto gA = NDArrayFactory::create<float>('c', {2, 3});
|
|
|
|
auto gB = NDArrayFactory::create<float>('c', {3, 4});
|
|
|
|
a.linspace(1.0);
|
|
|
|
b.linspace(1.0);
|
|
|
|
gI.linspace(1.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&gA, &gB}, {&a, &b, &gI});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Context ctx(1);
|
|
|
|
Nd4jLong iArgs[] = {0L, 0L, 0L};
|
|
|
|
|
|
|
|
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
|
|
|
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
|
|
|
ctx.setInputArray(2, gI.buffer(), gI.shapeInfo(), gI.specialBuffer(), gI.specialShapeInfo());
|
|
|
|
|
|
|
|
ctx.setOutputArray(0, gA.buffer(), gA.shapeInfo(), gA.specialBuffer(), gA.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(1, gB.buffer(), gB.shapeInfo(), gB.specialBuffer(), gB.specialShapeInfo());
|
|
|
|
|
|
|
|
ctx.setIArguments(iArgs, 3);
|
|
|
|
|
|
|
|
nd4j::ops::matmul_bp op;
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&gA, &gB}, {&a, &b, &gI});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
|
|
|
auto b = NDArrayFactory::create<float>(3.f);
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {3});
|
|
|
|
auto e = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&a, &b});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Context ctx(1);
|
|
|
|
Nd4jLong iArgs[] = {0L, 0L, 0L};
|
|
|
|
|
|
|
|
ctx.setIArguments(iArgs, 1);
|
|
|
|
|
|
|
|
nd4j::ops::concat op;
|
|
|
|
|
|
|
|
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
|
|
|
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
|
|
|
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&a, &b});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
|
|
|
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
2019-07-22 13:34:08 +02:00
|
|
|
auto ptr = executeFlatGraph(nullptr, pl);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// at this point we have FlatResults
|
|
|
|
auto flatResult = GetFlatResult(ptr->pointer());
|
|
|
|
auto size = flatResult->variables()->size();
|
|
|
|
|
|
|
|
// we know exact number of outputs in this graph in given mode
|
|
|
|
ASSERT_EQ(184, size);
|
|
|
|
|
|
|
|
|
|
|
|
// now we're rolling through all variables and restore them one by one
|
|
|
|
for (int e = 0; e < size; e++) {
|
|
|
|
auto flatVar = flatResult->variables()->Get(e);
|
|
|
|
auto flatArray = flatVar->ndarray();
|
|
|
|
|
|
|
|
// checking var part first
|
|
|
|
// we just want to ensure we're not experiencing overruns here
|
|
|
|
auto name = flatVar->name()->str();
|
|
|
|
|
|
|
|
// checking array part now
|
|
|
|
auto shape = flatArray->shape();
|
|
|
|
auto rank = shape->Get(0);
|
|
|
|
|
|
|
|
ASSERT_TRUE(shape->size() > 0 && rank >= 0 && rank < MAX_RANK);
|
|
|
|
|
|
|
|
// building regular NDArray out of this FlatArray
|
|
|
|
auto ndarray = nd4j::graph::FlatUtils::fromFlatArray(flatArray);
|
|
|
|
|
|
|
|
// rank should match FlatArray
|
|
|
|
ASSERT_EQ(rank, ndarray->rankOf());
|
|
|
|
|
|
|
|
// array shouldn't have any NaN/Inf values
|
|
|
|
ASSERT_TRUE(ndarray->isFinite());
|
|
|
|
|
|
|
|
// array should be assignable
|
|
|
|
ndarray->assign(123.f);
|
|
|
|
|
|
|
|
// and safely removable after
|
|
|
|
delete ndarray;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
delete[] pl;
|
|
|
|
delete ptr;
|
|
|
|
|
|
|
|
// and we should have 0 leaks reported after this line :)
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
|
|
|
|
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
|
|
|
|
// std::array<float, 60> syn1;
|
|
|
|
// std::array<float, 100000> exp;
|
|
|
|
|
|
|
|
// for (int e = 0; e < syn1.size(); e++)
|
|
|
|
// syn1[e] = 0.0f;
|
|
|
|
|
|
|
|
// for (int e = 0; e < exp.size(); e++) {
|
|
|
|
// auto f = static_cast<double>(e);
|
|
|
|
// auto tmp = nd4j::math::nd4j_exp<double, double>((f / 100000.0 * 2.0 - 1.0) * 6.0);
|
|
|
|
// exp[e] = static_cast<float>(tmp / (tmp + 1.0));
|
|
|
|
// }
|
|
|
|
|
|
|
|
// auto maxTypes = 5;
|
|
|
|
// auto numAggregates = 1;
|
|
|
|
// auto opNum = 3;
|
|
|
|
// auto maxArgs = 6;
|
|
|
|
// auto maxShapes = 0;
|
|
|
|
// auto maxIntArrays = 2;
|
|
|
|
// auto maxIntArraySize = 40;
|
|
|
|
// auto maxIndexArguments = 10;
|
|
|
|
// auto maxRealArguments = 2;
|
|
|
|
|
|
|
|
// std::array<int, 100000> pointer;
|
|
|
|
|
|
|
|
// auto batchLimit = 512;
|
|
|
|
|
|
|
|
// int indexPos = maxTypes * batchLimit;
|
|
|
|
// int intArraysPos = indexPos + (maxIndexArguments * batchLimit);
|
|
|
|
// int realPos = (intArraysPos + (maxIntArrays * maxIntArraySize * batchLimit));
|
|
|
|
// int argsPos = (realPos + ((maxRealArguments * batchLimit))) / 2;
|
|
|
|
// int shapesPos = argsPos + (maxArgs * batchLimit);
|
|
|
|
|
|
|
|
// std::vector<int> intArray0({0, 0, 0, 0, 0});
|
|
|
|
// std::vector<int> intArray1({1, 0, 0, 0, 0});
|
|
|
|
|
|
|
|
// std::vector<int> indexingArgs0({1, 20, 5, 0, 100000, 3, 0, 0, 0});
|
|
|
|
// std::vector<int> indexingArgs1({0, 20, 5, 0, 100000, 3, 1, 0, 0});
|
|
|
|
|
|
|
|
// std::vector<float> realArgs0({0.024964055335354007f, 3.0768702268737162E18f});
|
|
|
|
|
|
|
|
// int argSize = 6;
|
|
|
|
// int shapesSize = 0;
|
|
|
|
// int indexingSize = 9;
|
|
|
|
// int realArgsSize = 2;
|
|
|
|
// int intArraysSize = 2;
|
|
|
|
|
|
|
|
// int e = 0;
|
|
|
|
|
|
|
|
// auto idx = e * maxTypes;
|
|
|
|
|
|
|
|
// // numbers of arguments
|
|
|
|
// pointer[idx] = 6; // arguments size
|
|
|
|
// pointer[idx+1] = 0; // shapes size
|
|
|
|
// pointer[idx+2] = 9; // indexing arguments size
|
|
|
|
// pointer[idx+3] = 2; // real args size
|
|
|
|
// pointer[idx+4] = 2; // intArray args size
|
|
|
|
|
|
|
|
// // indexing args
|
|
|
|
// auto idxArgs = e == 0 ? indexingArgs0 : indexingArgs1;
|
|
|
|
// for (int f = 0; f < idxArgs.size(); f++) {
|
|
|
|
// idx = indexPos + e * maxIndexArguments;
|
|
|
|
// pointer[idx + f] = idxArgs[f];
|
|
|
|
// }
|
|
|
|
|
|
|
|
// // int array values
|
|
|
|
// int bsize = maxIntArrays * maxIntArraySize;
|
|
|
|
// for (int f = 0; f < intArraysSize; f++) {
|
|
|
|
// int step = (e * bsize) + (f * maxIntArraySize);
|
|
|
|
// auto intArr = f == 0 ? intArray0 : intArray1;
|
|
|
|
// for (int x = 0; x < intArr.size(); x++) {
|
|
|
|
// idx = intArraysPos + step + x;
|
|
|
|
// pointer[idx] = intArr[x];
|
|
|
|
// }
|
|
|
|
// }
|
|
|
|
|
|
|
|
// // real args
|
|
|
|
// auto ptr = reinterpret_cast<float *>(pointer.data());
|
|
|
|
// for (int f = 0; f < realArgsSize; f++) {
|
|
|
|
// idx = realPos + (e * maxRealArguments);
|
|
|
|
// ptr[idx + f] = realArgs0[f];
|
|
|
|
// }
|
|
|
|
|
|
|
|
// //
|
|
|
|
// auto ptrptr = reinterpret_cast<void **>(pointer.data());
|
|
|
|
// idx = argsPos + e * maxArgs;
|
|
|
|
// ptrptr[idx] = reinterpret_cast<void*>(syn0.data());
|
|
|
|
// ptrptr[idx+1] = reinterpret_cast<void*>(syn1.data());
|
|
|
|
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
|
|
|
|
// }
|