cavis/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp
Oleh d52e67209e
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor more corrections to support convertors

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading

* StringUtils for utf convertor #8613 tests added some corrections to optimize build

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 some corrections and code clean up

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion

* StringUtils for utf convertor #8613 some fixes and tets

* StringUtils for utf convertor #8613 some more staff to support different unicode

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 fix linking bug

* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed

* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing

* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)

* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 merge master and sync with server

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up

* StringUtils for utf convertor #8613 fixed cast and copy issues

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 fixed cuda and update tests

* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc

* - avoid ambiguity of NDArray ctrs overloading in some tests

Signed-off-by: Yurii <iuriish@yahoo.com>

* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613  void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 several more fixes, improvements and updates

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 minor fixes string element size define

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 revert last changes as mistake

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 16:30:49 +03:00

1482 lines
125 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <NativeOps.h>
#include <NDArray.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/OpRegistrator.h>
#include <graph/GraphHolder.h>
#include <graph/FlatUtils.h>
#include "testlayers.h"
#include <array>
using namespace nd4j;
using namespace nd4j::ops;
class JavaInteropTests : public testing::Test {
public:
};
TEST_F(JavaInteropTests, TestShapeExposure1) {
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
nd4j::ops::conv2d op;
std::vector<double> tArgs({});
std::vector<Nd4jLong> iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1});
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
ASSERT_EQ(1, shapeList->size());
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]);
ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]);
ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]);
//int *ptr = (int *) shapeList[0];
//delete[] ptr;
//delete shapeList;
deleteShapeList((Nd4jPointer) shapeList);
}
TEST_F(JavaInteropTests, TestShapeExposure2) {
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
nd4j::ops::shape_of op;
std::vector<double> tArgs({});
std::vector<Nd4jLong> iArgs({});
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
ASSERT_EQ(1, shapeList->size());
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
deleteShapeList((Nd4jPointer) shapeList);
}
TEST_F(JavaInteropTests, TestShapeExposure3) {
auto x = NDArrayFactory::create<float>('c', {5, 30});
auto sizes = NDArrayFactory::create<int>('c', {3}, {4, 15, 11});
std::vector<Nd4jLong> list0 = {0,0, 0,4};
std::vector<Nd4jLong> list1 = {0,0, 4,19};
std::vector<Nd4jLong> list2 = {0,0, 19,30};
auto sub0 = x(list0, true);
auto sub1 = x(list1, true);
auto sub2 = x(list2, true);
sub0.assign(0.0f);
sub1.assign(1.0f);
sub2.assign(2.0f);
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer(), x.getSpecialBuffer(), sizes.getSpecialBuffer()};
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo(), x.getSpecialShapeInfo(), sizes.getSpecialShapeInfo()};
nd4j::ops::split_v op;
Nd4jLong iArgs[] = {1};
auto hash = op.getOpHash();
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0);
ASSERT_EQ(3, shapeList->size());
ASSERT_TRUE(shape::equalsSoft(sub0.shapeInfo(), shapeList->at(0)));
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
deleteShapeList((Nd4jPointer) shapeList);
}
TEST_F(JavaInteropTests, Test_Squeeze_1) {
auto x = NDArrayFactory::create<float>('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
auto z = NDArrayFactory::create<float>('c', {6});
auto e = NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
nd4j::ops::squeeze op;
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, Test_RDiv_1) {
auto x = NDArrayFactory::create<double>('c', {3}, {2, 2, 2});
auto y = NDArrayFactory::create<double>('c', {3}, {4, 6, 8});
auto z = NDArrayFactory::create<double>('c', {3});
auto e = NDArrayFactory::create<double>('c', {3}, {2, 3, 4});
NDArray::prepareSpecialUse({&z}, {&x, &y});
nd4j::ops::reversedivide op;
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NDArray::registerSpecialUse({&z}, {&x, &y});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, TestSconv2d_1) {
auto input = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
auto weightsD = NDArrayFactory::create<float>('c', {1, 3, 1, 1});
auto weightsP = NDArrayFactory::create<float>('c', {2, 3, 1, 1});
auto bias = NDArrayFactory::create<float>('c', {2});
auto output = NDArrayFactory::create<float>('c', {3, 2, 8, 8});
output.assign(0.0);
input.linspace(1);
weightsD.linspace(1);
weightsP.linspace(1);
bias.linspace(1);
weightsD.permutei({2,3,1,0});
weightsP.permutei({2,3,1,0});
auto expOutput = NDArrayFactory::create<float>('c', {3, 2, 8, 8});
nd4j::ops::sconv2d op;
NDArray::prepareSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer(), (Nd4jPointer) weightsP.getBuffer(), (Nd4jPointer) bias.getBuffer(), (Nd4jPointer) input.getSpecialBuffer(), (Nd4jPointer) weightsD.getSpecialBuffer(), (Nd4jPointer) weightsP.getSpecialBuffer(), (Nd4jPointer) bias.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo(), (Nd4jPointer) weightsP.getShapeInfo(), (Nd4jPointer) bias.getShapeInfo(), (Nd4jPointer) input.getSpecialShapeInfo(), (Nd4jPointer) weightsD.getSpecialShapeInfo(), (Nd4jPointer) weightsP.getSpecialShapeInfo(), (Nd4jPointer) bias.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), (Nd4jPointer) output.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), (Nd4jPointer) output.getSpecialShapeInfo()};
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
nullptr, 0, exp, 9, nullptr, 0, false);
//output.printBuffer("output");
NDArray::registerSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias});
ASSERT_NEAR(1423, output.e<float>(0), 1e-5);
//nd4j_printf("Iter %i passed...\n", e);
}
TEST_F(JavaInteropTests, TestSconv2d_2) {
auto input = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
auto weightsD = NDArrayFactory::create<float>('c', {1, 3, 1, 1});
auto output = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
output.assign(0.0);
input.linspace(1);
weightsD.linspace(1);
weightsD.permutei({2,3,1,0});
auto expOutput = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
nd4j::ops::sconv2d op;
NDArray::prepareSpecialUse({&output}, {&input, &weightsD});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer(), input.getSpecialBuffer(), weightsD.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo(), input.getSpecialShapeInfo(), weightsD.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
NDArray::registerSpecialUse({&output}, {&input, &weightsD});
ASSERT_NEAR(1, output.e<float>(0), 1e-5);
}
TEST_F(JavaInteropTests, TestMaxPooling2d_1) {
auto input = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
input.linspace(1);
NDArray::prepareSpecialUse({&output}, {&input});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
std::vector<Nd4jLong> iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1});
nd4j::ops::maxpool2d op;
Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
NDArray::registerSpecialUse({&output}, {&input});
ASSERT_EQ(ND4J_STATUS_OK, status);
}
TEST_F(JavaInteropTests, TestCol2Im_1) {
/*
o.d.n.l.c.ConvolutionLayer - eps shape: [6, 1, 2, 2, 2, 4, 5, 160, 4, 2, 1, 40, 8, 0, -1, 99]
o.d.n.l.c.ConvolutionLayer - epsNext shape: [4, 1, 2, 4, 5, 20, 20, 5, 1, 0, 1, 99]
o.d.n.l.c.ConvolutionLayer - Strides: [1, 1]
o.d.n.l.c.ConvolutionLayer - Padding: [0, 0]
o.d.n.l.c.ConvolutionLayer - Input: [4,5]
o.d.n.l.c.ConvolutionLayer - Dilation: [1, 1]
*/
auto input = NDArrayFactory::create<float>('c', {1, 2, 2, 2, 4, 5});
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
input.linspace(1);
NDArray::prepareSpecialUse({&output}, {&input});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
nd4j::ops::col2im op;
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
auto hash = op.getOpHash();
execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
NDArray::registerSpecialUse({&output}, {&input});
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
}
TEST_F(JavaInteropTests, TestPNorm_1) {
/*
o.d.n.l.c.s.SubsamplingLayer - input: [4, 1, 3, 4, 4, 16, 16, 4, 1, 0, 1, 99]
o.d.n.l.c.s.SubsamplingLayer - output: [4, 1, 3, 3, 3, 27, 9, 3, 1, 0, 1, 99]
o.d.n.l.c.s.SubsamplingLayer - Kernel: [2, 2]
o.d.n.l.c.s.SubsamplingLayer - Strides: [1, 1]
o.d.n.l.c.s.SubsamplingLayer - Pad: [0, 0]
o.d.n.l.c.s.SubsamplingLayer - Dilation: [1, 1]
o.d.n.l.c.s.SubsamplingLayer - Same: false
o.d.n.l.c.s.SubsamplingLayer - pnorm: 2
*/
auto input = NDArrayFactory::create<float>('c', {1, 3, 4, 4});
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
input.linspace(1);
NDArray::prepareSpecialUse({&output}, {&input});
nd4j::ops::pnormpool2d op;
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
NDArray::registerSpecialUse({&output}, {&input});
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
}
TEST_F(JavaInteropTests, TestInplace_1) {
auto input = NDArrayFactory::create<float>('c', {10, 10});
//auto exp('c', {10, 10});
input.linspace(1);
NDArray::prepareSpecialUse({}, {&input});
nd4j::ops::clipbyvalue op;
double extras[] = {-1.0f, 1.0f};
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
NDArray::registerSpecialUse({}, {&input});
ASSERT_EQ(ND4J_STATUS_OK, result);
ASSERT_NEAR(1.0, input.meanNumber().e<float>(0), 1e-5);
}
TEST_F(JavaInteropTests, Test_Synonyms_1) {
auto op = OpRegistrator::getInstance()->getOperation("RDiv");
auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide");
std::string nameExp("reversedivide");
ASSERT_TRUE(op != nullptr);
ASSERT_TRUE(opRef != nullptr);
std::string name = *(op->getOpName());
std::string nameRef = *(opRef->getOpName());
ASSERT_EQ(nameExp, nameRef);
ASSERT_EQ(nameRef, name);
}
TEST_F(JavaInteropTests, Test_Synonyms_2) {
auto op = OpRegistrator::getInstance()->getOperation("RDiv");
auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide");
std::string nameExp("reversedivide");
ASSERT_TRUE(op != nullptr);
ASSERT_TRUE(opRef != nullptr);
std::string name = *(op->getOpName());
std::string nameRef = *(opRef->getOpName());
ASSERT_EQ(nameExp, nameRef);
ASSERT_EQ(nameRef, name);
}
TEST_F(JavaInteropTests, Test_Synonyms_3) {
auto op = OpRegistrator::getInstance()->getOperation("RDiv");
auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide");
std::string nameExp("reversedivide");
ASSERT_TRUE(op != nullptr);
ASSERT_TRUE(opRef != nullptr);
std::string name = *(op->getOpName());
std::string nameRef = *(opRef->getOpName());
ASSERT_EQ(nameExp, nameRef);
ASSERT_EQ(nameRef, name);
}
TEST_F(JavaInteropTests, Test_FastPath_Validation_1) {
auto x = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
auto z = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
nd4j::ops::softmax op;
auto status = op.execute(&ctx);
ASSERT_NE(Status::OK(), status);
}
TEST_F(JavaInteropTests, Test_FastPath_Validation_2) {
auto x = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
auto z = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
nd4j::ops::softmax op;
auto status = op.execute(&ctx);
ASSERT_NE(Status::OK(), status);
}
TEST_F(JavaInteropTests, Test_FastPath_Validation_3) {
auto x = NDArrayFactory::create<float>('c', {3, 5}, { 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f});
auto min = NDArrayFactory::create<float>({ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
auto max = NDArrayFactory::create<float>({ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
auto z = NDArrayFactory::create<double>('c', {3, 5});
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), min.specialShapeInfo());
ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), max.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
ASSERT_ANY_THROW(op.execute(&ctx));
}
TEST_F(JavaInteropTests, Test_empty_cast_1) {
auto x = NDArrayFactory::create<bool>('c', {1, 0, 2});
auto z = NDArrayFactory::create<Nd4jLong>('c', {1, 0, 2});
auto e = NDArrayFactory::create<Nd4jLong>('c', {1, 0, 2});
Nd4jLong iArgs[] = {10};
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
ctx.setIArguments(iArgs, 1);
nd4j::ops::cast op;
auto result = op.execute(&ctx);
ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(e, z);
}
/*
TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
int inOutH = 35;
int inOutW = 35;
int inOutC = 192;
auto x = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
auto z = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
x.linspace(1.0);
z.linspace(1.0);
NDArray::prepareSpecialUse({&z}, {&x});
nd4j::ops::avgpool2d op;
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
Nd4jLong exp[] = {3,3, 1,1, 0,0, 1,1, 1, 0, 1};
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
NDArray::registerSpecialUse({&z}, {&x});
ASSERT_EQ(Status::OK(), result);
int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH;
int padTop = totalPadHeight / 2;
int padBottom = totalPadHeight - totalPadHeight / 2;
int k = 3;
auto m = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
auto c = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
for (int h = 0; h < inOutH; h++) {
for (int w = 0; w < inOutW; w++) {
int hFrom = h - padTop;
int wFrom = w - padBottom;
int hTo = hFrom + k;
int wTo = wFrom + k;
hFrom = nd4j::math::nd4j_max<int>(0, hFrom);
wFrom = nd4j::math::nd4j_max<int>(0, wFrom);
hTo = nd4j::math::nd4j_min<int>(inOutH, hTo);
wTo = nd4j::math::nd4j_min<int>(inOutW, wTo);
int idxOut[4];
int idxIn[4];
for (int ch = 0; ch < inOutC; ch++) {
idxOut[1] = h;
idxOut[2] = w;
idxOut[3] = ch;
idxIn[3] = ch;
for (int kh = hFrom; kh < hTo; kh++) {
for (int kw = wFrom; kw < wTo; kw++) {
idxIn[1] = kh;
idxIn[2] = kw;
auto inVal = x.e<float>(0, kh, kw, ch);
m.p(0, h, w, ch, inVal + m.e<float>(0, h, w, ch));
c.p(0, h, w, ch, 1 + c.e<int>(0, h, w, ch));
}
}
}
}
}
m /= c;
//z.printIndexedBuffer("z buffer", 100);
//m.printIndexedBuffer("m buffer", 100);
int cnt = 0;
int lim = 10;
for (int e = 0; e < z.lengthOf() && cnt < lim; e++) {
auto _m = m.e<float>(e);
auto _z = z.e<float>(e);
auto eq = nd4j::math::nd4j_eq<float>(_m, _z, 1e-5);
if (!eq) {
nd4j_printf("Difference at element e [%i]: <%f> vs <%f>\n", e, _m, _z);
cnt++;
}
}
ASSERT_EQ(m, z);
}
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
registerGraph(nullptr, 119, (Nd4jPointer) data);
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
unregisterGraph(nullptr, 119);
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
delete[] data;
}
TEST_F(JavaInteropTests, Test_GraphReuse_2) {
//Environment::getInstance()->setDebug(true);
//Environment::getInstance()->setVerbose(true);
auto exp0 = NDArrayFactory::create<float>('c', {3}, {3, 3, 3});
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
// we load graph from file, because we're not in java here, and dont have buffer ready
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
// we ensure that there's no such a graph stored earlier
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
// register the graph, to call for it later
registerGraph(nullptr, 119, (Nd4jPointer) data);
// and ensure we're ok
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
// run stuff
auto input_0 = NDArrayFactory::create<float>('c', {3, 3});
input_0.assign(1.0f);
int idx[] = {1};
Nd4jPointer inputs_0[] = {(Nd4jPointer) input_0.buffer()};
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
// now we're executing stored graph and providing replacement for input variable
auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
ASSERT_EQ(1, res_0->size());
auto z0 = res_0->at(0)->getNDArray();
ASSERT_TRUE(exp0.isSameShape(z0));
auto input_1 = NDArrayFactory::create<float>('c', {3, 3});
input_1.assign(2.0f);
Nd4jPointer inputs_1[] = {(Nd4jPointer) input_1.buffer()};
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
// doing it again
auto res_1 = executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
ASSERT_EQ(1, res_1->size());
auto z1 = res_1->at(0)->getNDArray();
ASSERT_TRUE(exp1.isSameShape(z1));
auto input_2 = NDArrayFactory::create<float>('c', {3, 3});
input_2.assign(3.0f);
Nd4jPointer inputs_2[] = {(Nd4jPointer) input_2.buffer()};
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
// and again
auto res_2 = executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
ASSERT_EQ(1, res_2->size());
auto z2 = res_2->at(0)->getNDArray();
ASSERT_TRUE(exp2.isSameShape(z2));
//////// clean out
unregisterGraph(nullptr, 119);
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
delete[] data;
delete res_0;
delete res_1;
delete res_2;
}
*/
TEST_F(JavaInteropTests, Test_Greater_1) {
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 1, 2});
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 0, 0});
// auto o = NDArrayFactory::create<float>('c', {2, 2}, {3, 3, 3, 3});
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
NDArray::prepareSpecialUse({&o}, {&x, &y});
nd4j::ops::greater op;
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NDArray::registerSpecialUse({&o}, {&x, &y});
ASSERT_TRUE(exp.equalsTo(&o));
}
TEST_F(JavaInteropTests, Test_Greater_2) {
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 1.f, 2.f});
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 0.f, 0.f});
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
nd4j::ops::greater op;
NDArray::prepareSpecialUse({&o}, {&x, &y});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NDArray::registerSpecialUse({&o}, {&x, &y});
ASSERT_TRUE(exp.equalsTo(&o));
}
TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
nd4j::ops::is_non_decreasing op;
auto x = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
auto o = NDArrayFactory::create<bool>(false);
auto exp = NDArrayFactory::create<bool>(1);
NDArray::prepareSpecialUse({&o}, {&x});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
auto hash = op.getOpHash();
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NDArray::registerSpecialUse({&o}, {&x});
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.equalsTo(&o));
}
TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) {
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
auto z = NDArrayFactory::create<float>('c', {2, 3});
nd4j::ops::test_output_reshape op;
NDArray::prepareSpecialUse({&z}, {&x});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
auto hash = op.getOpHash();
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NDArray::registerSpecialUse({&z}, {&x});
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) {
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
auto y = NDArrayFactory::create<float>(2.0f);
auto z = NDArrayFactory::create<float>('f', {2, 3});
auto e = NDArrayFactory::create<float>('c', {2, 3}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
nd4j::ops::add op;
NDArray::prepareSpecialUse({&z}, {&x, &y});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
auto hash = op.getOpHash();
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NDArray::prepareSpecialUse({&z}, {&x, &y});
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z));
ASSERT_FALSE(e.ordering() == z.ordering());
}
TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) {
auto input = NDArrayFactory::create<double>('c', {2, 3, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24});
auto indices = NDArrayFactory::create<Nd4jLong>('c', {1, 6}, {0,1, 2,2, 1,2});
auto output = NDArrayFactory::create<double>('f', {2, 1, 6, 4});
auto e = NDArrayFactory::create<double>('c', {2, 1, 6, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24});
nd4j::ops::gather op;
NDArray::prepareSpecialUse({&output}, {&input, &indices});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) indices.getBuffer(), input.getSpecialBuffer(), indices.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) indices.getShapeInfo(), input.getSpecialShapeInfo(), input.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
Nd4jLong iArgs[] = {1};
auto hash = op.getOpHash();
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
NDArray::registerSpecialUse({&output}, {&input, &indices});
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(e.isSameShape(output));
ASSERT_TRUE(e.equalsTo(output));
ASSERT_FALSE(e.ordering() == output.ordering());
}
TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
auto y = NDArrayFactory::create<float>('c', {3, 4, 5});
auto z = NDArrayFactory::create<float>('c', {5});
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
dims.syncToHost();
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
Nd4jPointer* extraPointers = nullptr;
#ifdef __CUDABLAS__
extraPointers = new Nd4jPointer[6] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer()};
#endif
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), {0,1});
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1});
NDArray::prepareSpecialUse({&z}, {&x, &y, &dims});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dims.dataBuffer());
execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
&dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(),
packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
NDArray::registerSpecialUse({&z}, {&x, &y, &dims});
delete []extraPointers;
}
/*
TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
Environment::getInstance()->setDebug(true);
Environment::getInstance()->setVerbose(false);
auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
auto ptr = executeFlatGraph(nullptr, pl);
Environment::getInstance()->setDebug(false);
Environment::getInstance()->setVerbose(false);
delete[] pl;
delete ptr;
}
*/
TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) {
auto input = NDArrayFactory::create<double>('c', {4, 10, 10, 3}, {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, 5.46187401, -2.57214499, 2.48484039, 4.04043484, -2.07137156, -1.42709637, 9.25487137, -0.12605135, -2.66949964, 2.89412403, 0.74451172, -2.96250391, 3.99258423, 0.27084303, 0.32213116, 5.42332172, -0.44414216, 1.70881832, 6.69346905, 0.53058422, -4.73146200, 4.22051668, 2.24834967, 0.66996074, 4.30173683, 0.11849818, -4.07520294, 8.27318478, -2.54398274, -2.86705542, 10.11775303, -0.99382895, 0.65881538, 7.93556786, -1.27934420, -1.69343162, 9.68042564, -1.02609646, -1.18189347, 5.75370646, -1.67888868, -4.48871994, 4.79537392, -0.79212248, -0.19855022, 6.15060997, -0.01081491, 3.64454579, 10.82562447, 1.58859253, -2.65847278, 8.60093212, -1.59196103, 0.07635692, 11.76175690, -1.17453325, 0.10122013, 6.86458445, -2.18891335, -2.74004745, 8.07066154, 0.71818852, -2.03035975, 6.31053686, 0.51509416, 1.39789927, 9.43515587, 2.04256630, 0.13985133, 4.65010691, 2.40911126, -0.36255789, -3.06867862, -0.45225358, -1.56778407, 6.05917358, -1.09891272, 1.77184200, 6.46248102, 0.96042323, -0.24346280, 4.63436460, -4.69907761, 1.25187206, 11.46173859, -2.21917558, 1.28007793, 6.92173195, 2.11268163, -3.47389889, 5.08722782, -3.03950930, -4.17154264, 11.30568314, 0.80361372, 2.53214502, 7.18707085, -4.49114513, 2.85449266, 10.14906883, -0.31974933, -0.84472644, -0.52459574, 0.12921631, -1.81390119, 2.76170087, 1.03982210, 2.91744232, -0.29048753, 5.87453508, -1.53684759, 1.85800636, -0.91404629, 1.28954852, 5.11354685, -2.47475505, -1.33179152, 2.58552408, 1.37316465, -3.32339454, 1.54122913, 3.24953628, -0.29758382, 2.82391763, -1.51142192, -1.22699404, 6.75745535, 0.65452754, -3.29385471, 2.06008053, 2.53172946, -4.23532820, -1.53909743, -0.07010663, -1.42173731, 7.29031610, -0.18448229, 4.59496164, 6.73027277, 0.73441899, 0.14426160, 4.14915276, -2.97010231, 6.05851364, 4.95218086, -2.39145470, 2.40494704, 2.10288811, 0.53503096, 1.44511235, 6.66344261, -3.05803776, 7.21418667, 3.30303526, -0.24163735, 3.47409391, 3.64520788, 2.15189481, -3.11243272, 3.62310791, 0.37379482, 0.40865007, -0.83132005, -4.78246069, 2.07030797, 6.51765442, 3.16178989, 5.06180477, 3.78434467, -0.96689719, 0.35965276, 5.89967585, 1.40294051, 1.11952639, 10.59778214, 0.26739889, -1.61297631, 6.24801159, -0.93914318, -0.57812452, 9.92604542, -0.73025000, -3.38530874, 2.45646000, -2.47949195, 0.51638460, 10.65636063, 1.97816694, -3.00407791, 2.66914415, -0.81951088, -0.23316640, 2.40737987, -2.70007610, 1.51531935, 4.08860207, -0.27552786, -1.31721711, 7.11568260, -3.33498216, -4.02545023, 7.22675610, -0.81690705, -2.52689576, 1.04016697, -0.79291463, -0.34875512, 10.00498390, -4.24167728, 1.46162593, 11.82569408, -1.70359993, -0.30161047, 16.44085884, -0.82253462, -0.09435523, 6.13080597, -0.20259480, 0.68308711, 6.15663004, -6.61776876, 0.33295766, 2.55449438, -0.17819691, -1.14892209, 5.56776142, 1.99279118, 1.33035934, 4.45823956, 3.34916544, -2.59905386, 6.16164446, -2.03881931, -2.45273542, 12.46793365, -2.22743297, 2.83738565, 8.48628139, -1.39347959, -1.30867767, 11.08041477, -4.00363779, 2.09183025, 11.30395889, -2.20504737, 1.37426853, 8.98735619, 1.04676604, -0.72757077, 8.28050232, -6.70741081, -0.65798020, 5.68592072, -0.60760021, 0.35854483, 6.26852131, 1.94100165, 1.32112014, 0.80987954, -1.74617672, -0.25434083, 7.16045523, 1.58884013, -2.64847064, 13.14820385, 1.21393633, -2.47258949, 9.41650105, -0.79384226, 2.48954105, 10.95629311, 0.47723705, 4.02126694, 8.02593136, -2.20726371, -1.18794477, 1.50836647, 0.93118095, -1.73513174, 8.85493565, -2.99670315, -0.79055870, 2.39473820, 2.05046916, -2.38055134, 11.82299423, 0.15609655, 0.68744308, 5.66401434, -0.69281673, 2.09855556, 7.74626589, -0.34283102, 1.00542057, 9.95838642, 0.80161905, 2.33455157, 9.80057335, -0.93561798, 2.56991577, 8.29711342, 0.94213426, 0.44209945, 11.70259857, 0.92710167, 2.60957146, 0.24971688, -0.86529571, 3.78628922, 6.80884457, -0.68178189, 2.21103406, 3.18895817, 0.60283208, -2.92716241, 6.72060776, -1.06625068, 2.56543374, 9.97404480, 3.58080721, -0.94936347, 10.16736984, -1.38464379, 1.18191063, 6.66179037, -3.56115270, 0.32329530, 10.90870762, 2.20638227, 0.19653285, 7.34650040, -3.63859272, -1.03027737, 5.98829985, -3.66606474, -3.89746714, 8.63469028, 1.22569811, 1.63240814, 3.74385309, 0.58243257, -0.56981975, 3.69260955, 1.00979900, -1.44030499, 8.57058144, -1.10648811, 1.20474911, 5.43133020, -2.14822555, -0.07928789, 11.25825310, 0.19645604, -5.49546146, 10.41917038, -0.68178523, -2.99639869, 6.50054455, 0.46488351, -5.42328453, 9.09500027, -2.82107449, 0.05601966, 15.34610748, -0.06820253, 3.86699796, 10.73316956, -3.04795432, -0.14702171, 5.64813185, 1.44028485, -2.47596145, 0.07280898, -3.03187990, -1.35183525, 9.35835648, 2.72966957, 1.88199532, 10.36187744, -0.22834805, -3.26738238, 6.92025137, -2.34061313, 4.77379704, 5.28559113, -2.96323752, -1.76186585, 5.94436455, 0.38647744, -5.73869514, 6.76849556, 1.40892124, -1.19068217, 5.37919092, -6.65328646, 3.62782669, 12.34744644, 2.44762444, -4.19242620, 6.14906216, 0.08121119, 0.61355996, 2.69666457, -1.88962626, -0.55314136, 1.84937525, 1.56048691, 1.17460012, 3.75674725, 1.06198275, -5.74625874, 5.41645575, -1.28946674, -1.51689398, 4.32400894, -0.05222082, -4.83948946, 1.80747867, 1.63144708, -2.73887825, 1.63975775, -2.02163982, -0.16210437, 2.93518686, 1.14427686, -2.83246303, 4.79283667, 2.69697428, -3.12678456, -1.19225168, -2.37022972, -3.09429741, 1.94225383, -1.13747168, -2.55048585, 5.40242243, 1.12777328, 3.43713188, 3.62658787, -2.16878843, 0.30164462, 2.97407579, -0.07275413, -1.31149673, 4.70066261, -2.01323795, 4.85255766, 4.59128904, 1.68084168, 1.60336494, 6.58138466, -1.04759812, 2.69906545, 3.55769277, -0.74327278, 2.65819693, 5.39528131, 2.11248922, -1.06446671, 5.24546766, -2.43146014, 4.58907509, 0.06521678, -2.24503994, 2.45722699, 6.94863081, 0.35258654, 2.83396196, 9.92525196, -1.12225175, -0.34365177, 7.19116688, -4.39813757, 0.46517885, 13.22028065, -2.57483673, -6.37226963, 7.58046293, -2.74600363, 0.42231262, 8.04881668, 0.17289802, -0.53447008, 16.55157471, -5.63614368, 0.39288223, 3.37079263, 1.26484549, -0.12820500, 8.46440125, -4.39304399, 2.97676420, 0.65650189, 0.83158541, -1.11556435, 6.32885838, -0.36087769, 2.80724382, 9.90292645, 1.15936041, 0.20947981, 6.91249275, -2.67404819, 2.93782163, 6.65656614, -2.30828357, 2.98214006, 6.80611229, -4.93821478, -7.66555262, 7.59763002, -0.54159302, 3.87403512, 12.42607784, 2.59284401, -0.23375344, 8.95293331, -0.71807784, 0.61873478, 8.66713524, 1.24289191, -2.37835455, 2.08071637, -0.88315344, -3.41891551, 6.85245323, 1.73007369, 1.02169311, 7.69170332, -2.85411978, 2.69790673, 8.12906551, -1.19351399, -2.26442742, 12.26104450, -0.75579089, -1.73274946, 10.68729019, 2.20655656, -0.90522075, 12.42165184, -1.67929137, 2.44851565, 9.31565762, -0.06645700, 1.52762020, 6.18427515, -1.68882596, 3.70261097, 3.02252960, -3.44125366, -1.31575799, 2.84617424, -0.96849400, -4.52356243, 9.95027161, 0.19966406, -0.78874779, 8.18595028, -4.08300209, 1.75126517, 0.96418417, -4.04913044, -0.95200396, 12.03637886, -0.03041124, 0.41642749, 8.88267422, -3.24985337, -2.24919462, 7.32566118, 0.16964148, -2.74123430, 7.05264473, -3.30191112, 0.17163286, 4.81851053, -1.64463484, -0.85933101, 7.29276276, 2.34066939, -2.14860010, 3.46148157, -0.01782012, 1.51504040, 4.79304934, 1.85281146, -1.70663762, 6.93470192, -4.15440845, -1.25983095, 10.52491760, 0.42930329, -1.85146868, 11.70042324, -0.41704914, 3.83796859, 9.21148491, -2.79719448, 0.79470479, 6.26926661, -5.85230207, 3.95105338, 7.84790897, -1.38680744, -1.78099084, 11.95235348, -2.99841452, -1.34507811, 6.15714645, -1.07552516, -2.81228638, 1.66234732, -4.55166149, -1.92601109, 8.64634514, -0.48158705, 3.31595659, 7.67371941, 2.56964207, 0.12107098, 4.56467867, -0.93541539, 1.39432955, 11.99714088, 1.05353570, -2.13099813, 3.67617917, 3.45895386, 1.37365830, 8.74344158, -4.17585802, 1.43908918, 6.28764772, 3.97346330, -0.69144285, 9.07983303, -0.41635889, -0.14965028, 8.85469818, 1.11306190, 2.59440994, 5.38982344, -1.07948279, 1.37252975, 10.26984596, -0.09318046, 2.73104119, 12.45902252, -1.55446684, -2.76124811, 12.19395065, -0.51846564, 1.02764034, 11.42673588, -0.95940983, -0.04781032, 8.78379822, -4.88957930, 0.32534006, 11.97696400, -3.35108662, 1.95104563, 4.46915388, -2.32061648, 3.45230985, 8.29983711, 2.81034684, -2.35529327, 6.07801294, -0.98105043, -0.05359888, 2.52291036, -0.01986909, -2.35321999, 10.51954269, 2.11145401, 3.53506470, 7.29093266, 0.03721160, -1.13496494, 7.43886709, -5.84201956, 2.50796294, 12.14647675, 2.77490377, -2.18896222, 6.05641937, 5.32617044, 1.04221284, 10.79106712, -2.95749092, -2.75414610, 11.30037117, -3.40654182, -2.24673963, 7.49126101, 0.70811015, -6.18003702, 13.83951187, -1.01204085, 1.36298490, -1.04451632, 2.42435336, -0.02346706, -0.85528886, 1.04731262, 0.22192979, 4.15708160, 0.34933877, 0.04814529, 2.24107265, 0.49676740, -1.47752666, 0.45040059, -0.70471478, -1.19759345, 0.21711677, 0.88461423, -2.76830935, 5.52066898, 1.97664857, -1.75381601, 3.45877838, 1.52617192, -1.61350942, 0.85337949, 1.97610760, -3.40310287, 3.40319014, -3.38691044, -0.71319139, 1.65463758, -0.60680127, -1.80700517, 8.02592373, 2.59627104, 2.65895891, 5.93043184, -4.48425817, 3.92670918, 4.19496679, -2.28286791, 6.41634607, 5.72330523, 1.16269672, -0.28753027, 2.46342492, 0.36693189, 0.26712441, 6.37652683, -2.50139046, 2.43923736, 5.56310415, 0.98065847, 1.04267502, 4.16403675, -0.04966142, 4.40897894, 3.72905660, -3.46129870, 3.59962773, 1.34830284, -1.76661730, 0.47943926, 5.29946661, -1.12711561, 1.26970029, 15.17655945, -1.50971997, 5.81345224, 8.48562050, -4.36049604, 2.48144460, 8.23780441, -3.46030426, -0.84656560, 5.94946814, 1.12747943, -2.65683913, 8.69085693, 1.31309867, -2.79958344, 8.76840591, -1.56444156, 1.62710834, 2.41177034, -0.72804940, 5.70619011, 4.67169666, -0.86167198, -1.83803177, 2.96346045, 2.82692933, -2.81557131, 7.11113358, -1.90071094, 2.54244423, 11.19284058, -0.06298946, -1.71517313, 12.98388577, 0.84510714, 3.00816894, 2.57200313, 0.03899818, -1.49330592, 9.60099125, -3.59513044, -1.30045319, 7.09241819, -0.65233821, -2.33627677, 8.81366920, 0.84154201, 1.03312039, 9.85289097, 0.19351870, 1.78496623, 7.34631205, -2.16530800, -0.65016162, 2.46842360, 0.24016285, -1.24308395, 4.78175163, -0.97682536, 2.20942235, 6.68382788, 3.76786447, -1.44454038, 6.26453733, -3.23575711, -2.30137897, 9.53092670, -5.55222607, 3.25999236, 9.37559509, 1.86339056, -0.23551451, 10.23400211, 3.93031883, -0.52629089, 7.85724449, -2.91549587, 4.46612740, 5.66530371, -2.70820427, 4.81359577, 10.31247330, 1.92230141, 2.53931546, 0.74986327, 1.70303428, 0.48063779, 5.31099129, -0.78976244, 3.75864220, 4.23051405, 2.34042454, -7.98193836, 9.83987141, -1.46722627, 3.54497814, 10.36455154, -4.51249075, 0.77715248, 7.78694630, -4.59989023, -2.49585629, 9.90296268, 1.38535416, 1.17441154, 10.10452843, -0.98628229, 0.60194463, 9.12639141, -3.90754628, 2.88526392, 7.24123430, -0.15283313, -0.75728363, -1.15116858, -2.53791571, 0.77229571, 6.44114161, 0.02646767, 4.95463037, 7.21066380, 1.79384065, 0.73250306, 8.04447937, 0.32576546, -0.79447043, 10.12717724, 2.33392906, 1.30716443, 12.36073112, -0.36694977, -1.20438910, 7.03105593, 0.59557682, 0.69267452, 10.18113136, 2.49944925, -0.42229167, 8.83143330, -1.18805945, -2.87509322, 4.53596449, 4.09732771, -3.39088297, -1.02536607, 0.82119560, -3.47302604, 9.29991817, 0.21001509, 4.97036457, 9.50018406, 1.04420102, 1.96560478, 10.74769592, -6.22709799, 3.11690164, 5.06759691, -1.23724771, -3.05831861, 8.12925529, -1.93435478, -1.10151744, 9.32263088, -0.04249470, -5.98547363, 10.49398136, 0.26400441, -0.78915191, 13.28219604, 2.99276900, 0.74853164, 2.49364305, -3.43529654, 4.05278301, 2.13498688, -2.35444307, -0.79900265, 4.66968822, -0.31095147, 3.60674143, 12.37222099, -0.07855003, -3.30292702, 12.15215874, 0.60886210, 2.87075138, 7.75271845, 0.38044083, 3.34402204, 6.40583277, -0.87888050, 0.67438459, 6.91080809, 1.98332930, -0.08303714, 8.08630371, -0.16772588, -2.74058914, 7.17253590, -2.69122696, 1.48173678, 8.99470139, -1.43302310, -0.88651133, 2.66944790, -0.29186964, 2.00838661, 5.09587479, -0.76676071, -2.88322186, 8.31110573, -0.14550979, -1.37726915, 10.28355122, -1.60575438, -0.04118848, 9.97510815, 0.14440438, -3.24632120, 9.00034523, 4.14319563, -1.31023729, 7.16950464, -0.70428526, 2.01559544, 7.26155043, 2.40816474, 2.09847403, 7.31264496, -0.75401551, 2.13392544, 7.03648758, 1.04036045, -1.15636516, 1.09634531, -0.06340861, -0.58107805, -0.65623116, 1.18972754, -0.80717683, 1.40118241, -0.61932516, -3.60596156, 1.59904599, -2.23774099, -1.13721037, 3.89620137, -0.09115922, -7.51356888, 2.36975193, -1.42520905, -2.34173775, 3.33830214, -2.74016523, -3.04115510, 6.00119495, -1.36084354, -2.45065260, 4.56992292, -3.02825928,-3.74182844,5.11069250,-0.91531068,-2.31385994,1.83399653,3.39370203,-3.60886002});
auto z = NDArrayFactory::create<double>('c', {4, 4, 4, 3});
auto exp = NDArrayFactory::create<double>('c', {4, 4, 4, 3}, {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, 4.49452257, -0.16509305, 0.19028664, 8.24897003, 0.44750381, 2.15448594, 8.97640514, -0.77728152, 0.57272542, 9.03467560, 0.47173575, -1.10807717, 3.30056310, -0.43268481, -0.41470885, 3.53798294, -0.08546703, -2.16840744, 6.18733406, -0.17871059, -2.59837723, 5.94218683, -1.02990067, -0.49760687, 3.76938033, 0.86383581, -1.91504073});
nd4j::ops::avgpool2d op;
NDArray::prepareSpecialUse({&z}, {&input});
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
auto hash = op.getOpHash();
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
NDArray::registerSpecialUse({&z}, {&input});
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
auto input = NDArrayFactory::create<float>('c', {1, 1, 4, 5});
auto z = NDArrayFactory::create<float>('c', {1, 1, 4, 5});
input.linspace(1.0);
NDArray::prepareSpecialUse({&z}, {&input});
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
Nd4jLong iArgs[] = {2,2, 1,1, 1,1, 2,2,1, 0,0};
nd4j::ops::maxpool2d op;
auto hash = op.getOpHash();
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
NDArray::registerSpecialUse({&z}, {&input});
ASSERT_EQ(Status::OK(), status);
}
TEST_F(JavaInteropTests, Test_Unstack_1) {
auto x = NDArrayFactory::create<double>('c', {5, 5});
x.linspace(1.0);
auto z0 = NDArrayFactory::create<double>('c',{5});
auto z1 = NDArrayFactory::create<double>('c',{5});
auto z2 = NDArrayFactory::create<double>('c',{5});
auto z3 = NDArrayFactory::create<double>('c',{5});
auto z4 = NDArrayFactory::create<double>('c',{5});
NDArray::prepareSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x});
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer()), x.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo()), x.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), z2.buffer(), z3.buffer(), z4.buffer(), z0.getSpecialBuffer(), z1.getSpecialBuffer(), z2.getSpecialBuffer(), z3.getSpecialBuffer(), z4.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {z0.shapeInfo(), z1.shapeInfo(), z2.shapeInfo(), z3.shapeInfo(), z4.shapeInfo(), z0.getSpecialShapeInfo(), z1.getSpecialShapeInfo(), z2.getSpecialShapeInfo(), z3.getSpecialShapeInfo(), z4.getSpecialShapeInfo()};
Nd4jLong iArgs[] = {0};
nd4j::ops::unstack op;
auto hash = op.getOpHash();
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
NDArray::registerSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x});
ASSERT_EQ(Status::OK(), status);
}
TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) {
auto input = NDArrayFactory::create<float>('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f,2.91434479f,5.43639755f,-2.10573769f, 4.08528662f,5.86908436f,-4.46203756f,2.21057916f,5.35849190f,0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f});
auto z = NDArrayFactory::create<float>('c', {4, 4, 4, 3});
auto exp = NDArrayFactory::create<float>('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f});
nd4j::ops::avgpool2d op;
NDArray::prepareSpecialUse({&z}, {&input});
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
auto hash = op.getOpHash();
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
NDArray::registerSpecialUse({&z}, {&input});
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
if (!Environment::getInstance()->isExperimentalBuild())
return;
auto arrayX = NDArrayFactory::create<int>({1, 2, 3, 4});
auto arrayY = NDArrayFactory::create<double>({1, 2, 3, 4});
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY});
OpaqueDataBuffer xBuf(arrayX.dataBuffer());
OpaqueDataBuffer yBuf(arrayY.dataBuffer());
OpaqueDataBuffer zBuf(arrayZ.dataBuffer());
execPairwiseTransform(nullptr, pairwise::Add,
&xBuf, arrayX.shapeInfo(), arrayX.getSpecialShapeInfo(),
&yBuf, arrayY.shapeInfo(), arrayY.getSpecialShapeInfo(),
&zBuf, arrayZ.shapeInfo(), arrayZ.getSpecialShapeInfo(),
nullptr);
NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY});
ASSERT_EQ(arrayE, arrayZ);
}
TEST_F(JavaInteropTests, Test_Add_1) {
auto x = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
NDArray::prepareSpecialUse({&x}, {&x, &y});
nd4j::ops::add op;
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo(),};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NDArray::registerSpecialUse({&x}, {&x, &y});
ASSERT_EQ(e, x);
}
TEST_F(JavaInteropTests, zeta_test10) {
auto x = NDArrayFactory::create<double>('c', {3, 4}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12});
auto q = NDArrayFactory::create<double>('c', {3, 4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12});
auto z = NDArrayFactory::create<double>('c', {3, 4});
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
nd4j::ops::zeta op;
NDArray::prepareSpecialUse({&z}, {&x, &q});
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer(), x.getSpecialBuffer(), q.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), q.getShapeInfo(), x.specialShapeInfo(), q.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NDArray::registerSpecialUse({&z}, {&x, &q});
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, Test_IAMax_1) {
auto arrayX = NDArrayFactory::create<float>({-0.24f, -0.26f, -0.07f, -0.01f});
auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr);
auto exp = NDArrayFactory::create<Nd4jLong>(1);
ASSERT_EQ(exp, arrayZ);
}
TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
auto arrayX = NDArrayFactory::create<double>('c', {10, 10});
auto arrayY = NDArrayFactory::create<double>('c', {10, 10});
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(arrayX.buffer()), reinterpret_cast<Nd4jPointer>(arrayY.buffer()), arrayX.getSpecialBuffer(), arrayY.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(arrayX.shapeInfo()), reinterpret_cast<Nd4jPointer>(arrayY.shapeInfo()), arrayX.getSpecialShapeInfo(), arrayY.getSpecialShapeInfo()};
NDArray::prepareSpecialUse({}, {&arrayX, &arrayY});
nd4j::ops::greater_equal op;
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0);
NDArray::registerSpecialUse({}, {&arrayX, &arrayY});
delete shapeList;
}
TEST_F(JavaInteropTests, Test_L2_Loss_3) {
auto x = NDArrayFactory::create<double>(0.7787855863571167);
auto e = NDArrayFactory::create<double>(0.303254);
auto z = NDArrayFactory::create<double>(0.0);
NDArray::prepareSpecialUse({&z}, {&x});
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer()), x.getSpecialBuffer()};
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo()), x.getSpecialShapeInfo()};
Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
nd4j::ops::l2_loss op;
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status);
NDArray::registerSpecialUse({&z}, {&x});
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, Test_Fastpath_3) {
auto array0 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
auto array1 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
auto z = NDArrayFactory::create<float>('c', {3, 2});
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
Context ctx(1);
NDArray::prepareSpecialUse({&z}, {&array0, &array1});
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.getSpecialBuffer(), array0.getSpecialShapeInfo());
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.getSpecialBuffer(), array1.getSpecialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo());
ASSERT_EQ(2, ctx.width());
nd4j::ops::add op;
execCustomOp2(nullptr, op.getOpHash(), &ctx);
NDArray::registerSpecialUse({&z}, {&array0, &array1});
ASSERT_EQ(exp, z);
}
TEST_F(JavaInteropTests, Test_Fastpath_4) {
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1,1,1,0,0, 1,1,1,1,0, 1,1,1,1,1});
auto z = NDArrayFactory::create<double>('c', {3, 5});
Nd4jLong iArgs[] = {3, 5, 2};
NDArray::prepareSpecialUse({&z}, {});
Context ctx(1);
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
ctx.setIArguments(iArgs, 3);
nd4j::ops::tri op;
execCustomOp2(nullptr, op.getOpHash(), &ctx);
NDArray::registerSpecialUse({&z}, {});
ASSERT_EQ(exp, z);
}
TEST_F(JavaInteropTests, Test_Fastpath_5) {
auto a = NDArrayFactory::create<float>('c', {3, 3});
auto b = NDArrayFactory::create<float>('c', {3, 3});
auto c = NDArrayFactory::create<float>('c', {3, 3});
a.linspace(1.0);
b.linspace(1.0);
NDArray::prepareSpecialUse({&c}, {&b, &c});
Context ctx(1);
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
nd4j::ops::matmul op;
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
NDArray::registerSpecialUse({&c}, {&b, &c});
ASSERT_EQ(Status::OK(), status);
}
TEST_F(JavaInteropTests, Test_Fastpath_6) {
auto a = NDArrayFactory::create<float>('c', {2, 3});
auto b = NDArrayFactory::create<float>('c', {3, 4});
auto gI = NDArrayFactory::create<float>('c', {2, 4});
auto gA = NDArrayFactory::create<float>('c', {2, 3});
auto gB = NDArrayFactory::create<float>('c', {3, 4});
a.linspace(1.0);
b.linspace(1.0);
gI.linspace(1.0);
NDArray::prepareSpecialUse({&gA, &gB}, {&a, &b, &gI});
Context ctx(1);
Nd4jLong iArgs[] = {0L, 0L, 0L};
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
ctx.setInputArray(2, gI.buffer(), gI.shapeInfo(), gI.specialBuffer(), gI.specialShapeInfo());
ctx.setOutputArray(0, gA.buffer(), gA.shapeInfo(), gA.specialBuffer(), gA.specialShapeInfo());
ctx.setOutputArray(1, gB.buffer(), gB.shapeInfo(), gB.specialBuffer(), gB.specialShapeInfo());
ctx.setIArguments(iArgs, 3);
nd4j::ops::matmul_bp op;
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
NDArray::registerSpecialUse({&gA, &gB}, {&a, &b, &gI});
ASSERT_EQ(Status::OK(), status);
}
TEST_F(JavaInteropTests, Test_Fastpath_7) {
auto a = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
auto b = NDArrayFactory::create<float>(3.f);
auto z = NDArrayFactory::create<float>('c', {3});
auto e = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
NDArray::prepareSpecialUse({&z}, {&a, &b});
Context ctx(1);
Nd4jLong iArgs[] = {0L, 0L, 0L};
ctx.setIArguments(iArgs, 1);
nd4j::ops::concat op;
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
NDArray::registerSpecialUse({&z}, {&a, &b});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, test_bfloat16_rng) {
if (!Environment::getInstance()->isCPU())
return;
auto z = NDArrayFactory::create<bfloat16>('c', {10});
RandomGenerator rng(119, 323841120L);
bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f};
OpaqueDataBuffer zBuf(z.dataBuffer());
execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args);
//z.printIndexedBuffer("z");
ASSERT_TRUE(z.sumNumber().e<float>(0) > 0);
}
TEST_F(JavaInteropTests, test_ismax_view) {
auto original = NDArrayFactory::create<double>('c', {2, 3, 40});
auto v = original.subarray({NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)});
v.assign(1.0);
auto e = v.like();
auto t = e.tensorAlongDimension(0, {0, 1});
t.assign(1.0);
auto z = v.ulike();
Nd4jLong iArgs[] = {2L, 0L};
Context ctx(1);
ctx.setInputArray(0, v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
ctx.setIArguments(iArgs, 1);
nd4j::ops::ismax op;
op.execute(&ctx);
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, test_size_dtype_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1.f, 1.f, 1.f});
auto z = NDArrayFactory::create<float>(0.0f);
auto e = NDArrayFactory::create<float>(3.0f);
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
nd4j::ops::size op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, test_expandable_array_op_1) {
auto x = NDArrayFactory::string( {2}, {"first string", "second"});
auto d = NDArrayFactory::string(" ", nd4j::DataType::UTF8);
auto z0 = NDArrayFactory::create<Nd4jLong>('c', {6});
auto z1 = NDArrayFactory::string( {3}, {"", "", ""});
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"});
InteropDataBuffer iz0(z0.dataBuffer());
InteropDataBuffer iz1(z1.dataBuffer());
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo());
ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo());
ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo());
nd4j::ops::compat_string_split op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(exp0, z0);
ASSERT_EQ(exp1, z1);
}
TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) {
if (!Environment::getInstance()->isCPU())
return;
auto x = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
auto y = NDArrayFactory::create<double>('c', {4, 3, 3, 3});
auto z = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
double buffer[2048];
InteropDataBuffer ix(0, DataType::DOUBLE, false);
InteropDataBuffer iy(0, DataType::DOUBLE, false);
InteropDataBuffer iz(0, DataType::DOUBLE, false);
// we're imitating workspace-managed array here
ix.setPrimary(buffer + 64, x.lengthOf());
iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf());
iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf());
Context ctx(1);
ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo());
ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo());
ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo());
ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0});
nd4j::ops::maxpool2d_bp op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
}
/*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
auto ptr = executeFlatGraph(nullptr, pl);
// at this point we have FlatResults
auto flatResult = GetFlatResult(ptr->pointer());
auto size = flatResult->variables()->size();
// we know exact number of outputs in this graph in given mode
ASSERT_EQ(184, size);
// now we're rolling through all variables and restore them one by one
for (int e = 0; e < size; e++) {
auto flatVar = flatResult->variables()->Get(e);
auto flatArray = flatVar->ndarray();
// checking var part first
// we just want to ensure we're not experiencing overruns here
auto name = flatVar->name()->str();
// checking array part now
auto shape = flatArray->shape();
auto rank = shape->Get(0);
ASSERT_TRUE(shape->size() > 0 && rank >= 0 && rank < MAX_RANK);
// building regular NDArray out of this FlatArray
auto ndarray = nd4j::graph::FlatUtils::fromFlatArray(flatArray);
// rank should match FlatArray
ASSERT_EQ(rank, ndarray->rankOf());
// array shouldn't have any NaN/Inf values
ASSERT_TRUE(ndarray->isFinite());
// array should be assignable
ndarray->assign(123.f);
// and safely removable after
delete ndarray;
}
delete[] pl;
delete ptr;
// and we should have 0 leaks reported after this line :)
}
*/
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
// std::array<float, 60> syn1;
// std::array<float, 100000> exp;
// for (int e = 0; e < syn1.size(); e++)
// syn1[e] = 0.0f;
// for (int e = 0; e < exp.size(); e++) {
// auto f = static_cast<double>(e);
// auto tmp = nd4j::math::nd4j_exp<double, double>((f / 100000.0 * 2.0 - 1.0) * 6.0);
// exp[e] = static_cast<float>(tmp / (tmp + 1.0));
// }
// auto maxTypes = 5;
// auto numAggregates = 1;
// auto opNum = 3;
// auto maxArgs = 6;
// auto maxShapes = 0;
// auto maxIntArrays = 2;
// auto maxIntArraySize = 40;
// auto maxIndexArguments = 10;
// auto maxRealArguments = 2;
// std::array<int, 100000> pointer;
// auto batchLimit = 512;
// int indexPos = maxTypes * batchLimit;
// int intArraysPos = indexPos + (maxIndexArguments * batchLimit);
// int realPos = (intArraysPos + (maxIntArrays * maxIntArraySize * batchLimit));
// int argsPos = (realPos + ((maxRealArguments * batchLimit))) / 2;
// int shapesPos = argsPos + (maxArgs * batchLimit);
// std::vector<int> intArray0({0, 0, 0, 0, 0});
// std::vector<int> intArray1({1, 0, 0, 0, 0});
// std::vector<int> indexingArgs0({1, 20, 5, 0, 100000, 3, 0, 0, 0});
// std::vector<int> indexingArgs1({0, 20, 5, 0, 100000, 3, 1, 0, 0});
// std::vector<float> realArgs0({0.024964055335354007f, 3.0768702268737162E18f});
// int argSize = 6;
// int shapesSize = 0;
// int indexingSize = 9;
// int realArgsSize = 2;
// int intArraysSize = 2;
// int e = 0;
// auto idx = e * maxTypes;
// // numbers of arguments
// pointer[idx] = 6; // arguments size
// pointer[idx+1] = 0; // shapes size
// pointer[idx+2] = 9; // indexing arguments size
// pointer[idx+3] = 2; // real args size
// pointer[idx+4] = 2; // intArray args size
// // indexing args
// auto idxArgs = e == 0 ? indexingArgs0 : indexingArgs1;
// for (int f = 0; f < idxArgs.size(); f++) {
// idx = indexPos + e * maxIndexArguments;
// pointer[idx + f] = idxArgs[f];
// }
// // int array values
// int bsize = maxIntArrays * maxIntArraySize;
// for (int f = 0; f < intArraysSize; f++) {
// int step = (e * bsize) + (f * maxIntArraySize);
// auto intArr = f == 0 ? intArray0 : intArray1;
// for (int x = 0; x < intArr.size(); x++) {
// idx = intArraysPos + step + x;
// pointer[idx] = intArr[x];
// }
// }
// // real args
// auto ptr = reinterpret_cast<float *>(pointer.data());
// for (int f = 0; f < realArgsSize; f++) {
// idx = realPos + (e * maxRealArguments);
// ptr[idx + f] = realArgs0[f];
// }
// //
// auto ptrptr = reinterpret_cast<void **>(pointer.data());
// idx = argsPos + e * maxArgs;
// ptrptr[idx] = reinterpret_cast<void*>(syn0.data());
// ptrptr[idx+1] = reinterpret_cast<void*>(syn1.data());
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
// }