2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author raver119@gmail.com
|
|
|
|
//
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <legacy/NativeOps.h>
|
|
|
|
#include <array/NDArray.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
#include <ops/declarable/CustomOperations.h>
|
|
|
|
#include <ops/declarable/OpRegistrator.h>
|
|
|
|
#include <graph/GraphHolder.h>
|
|
|
|
#include <graph/FlatUtils.h>
|
|
|
|
#include "testlayers.h"
|
|
|
|
#include <array>
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
using namespace sd;
|
|
|
|
using namespace sd::ops;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
class JavaInteropTests : public testing::Test {
|
|
|
|
public:
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestShapeExposure1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
|
|
|
|
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::conv2d op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
std::vector<double> tArgs({});
|
|
|
|
std::vector<Nd4jLong> iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1});
|
|
|
|
|
|
|
|
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(1, shapeList->size());
|
|
|
|
|
|
|
|
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
|
|
|
|
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
|
|
|
|
ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]);
|
|
|
|
ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]);
|
|
|
|
ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]);
|
|
|
|
|
|
|
|
//int *ptr = (int *) shapeList[0];
|
|
|
|
//delete[] ptr;
|
|
|
|
//delete shapeList;
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
deleteShapeList((Nd4jPointer) shapeList);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestShapeExposure2) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::shape_of op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
std::vector<double> tArgs({});
|
|
|
|
std::vector<Nd4jLong> iArgs({});
|
|
|
|
|
|
|
|
|
|
|
|
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(1, shapeList->size());
|
|
|
|
|
|
|
|
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
|
|
|
|
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
deleteShapeList((Nd4jPointer) shapeList);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestShapeExposure3) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {5, 30});
|
|
|
|
auto sizes = NDArrayFactory::create<int>('c', {3}, {4, 15, 11});
|
|
|
|
|
|
|
|
std::vector<Nd4jLong> list0 = {0,0, 0,4};
|
|
|
|
std::vector<Nd4jLong> list1 = {0,0, 4,19};
|
|
|
|
std::vector<Nd4jLong> list2 = {0,0, 19,30};
|
|
|
|
|
|
|
|
auto sub0 = x(list0, true);
|
|
|
|
auto sub1 = x(list1, true);
|
|
|
|
auto sub2 = x(list2, true);
|
|
|
|
|
|
|
|
sub0.assign(0.0f);
|
|
|
|
sub1.assign(1.0f);
|
|
|
|
sub2.assign(2.0f);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer(), x.getSpecialBuffer(), sizes.getSpecialBuffer()};
|
|
|
|
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo(), x.getSpecialShapeInfo(), sizes.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::split_v op;
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jLong iArgs[] = {1};
|
|
|
|
auto hash = op.getOpHash();
|
|
|
|
|
2020-01-30 16:46:12 +01:00
|
|
|
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(3, shapeList->size());
|
|
|
|
|
|
|
|
ASSERT_TRUE(shape::equalsSoft(sub0.shapeInfo(), shapeList->at(0)));
|
|
|
|
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
|
|
|
|
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
deleteShapeList((Nd4jPointer) shapeList);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
2019-11-30 14:02:07 +01:00
|
|
|
auto x = NDArrayFactory::create<float>('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
2019-06-06 14:21:15 +02:00
|
|
|
auto z = NDArrayFactory::create<float>('c', {6});
|
2019-11-30 14:02:07 +01:00
|
|
|
auto e = NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::squeeze op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_RDiv_1) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3}, {2, 2, 2});
|
|
|
|
auto y = NDArrayFactory::create<double>('c', {3}, {4, 6, 8});
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {3});
|
|
|
|
auto e = NDArrayFactory::create<double>('c', {3}, {2, 3, 4});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::reversedivide op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestSconv2d_1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
|
|
|
|
auto weightsD = NDArrayFactory::create<float>('c', {1, 3, 1, 1});
|
|
|
|
auto weightsP = NDArrayFactory::create<float>('c', {2, 3, 1, 1});
|
2019-09-11 19:12:09 +02:00
|
|
|
auto bias = NDArrayFactory::create<float>('c', {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
auto output = NDArrayFactory::create<float>('c', {3, 2, 8, 8});
|
|
|
|
output.assign(0.0);
|
|
|
|
|
|
|
|
input.linspace(1);
|
|
|
|
weightsD.linspace(1);
|
|
|
|
weightsP.linspace(1);
|
|
|
|
bias.linspace(1);
|
|
|
|
weightsD.permutei({2,3,1,0});
|
|
|
|
weightsP.permutei({2,3,1,0});
|
|
|
|
|
|
|
|
auto expOutput = NDArrayFactory::create<float>('c', {3, 2, 8, 8});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sconv2d op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer(), (Nd4jPointer) weightsP.getBuffer(), (Nd4jPointer) bias.getBuffer(), (Nd4jPointer) input.getSpecialBuffer(), (Nd4jPointer) weightsD.getSpecialBuffer(), (Nd4jPointer) weightsP.getSpecialBuffer(), (Nd4jPointer) bias.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo(), (Nd4jPointer) weightsP.getShapeInfo(), (Nd4jPointer) bias.getShapeInfo(), (Nd4jPointer) input.getSpecialShapeInfo(), (Nd4jPointer) weightsD.getSpecialShapeInfo(), (Nd4jPointer) weightsP.getSpecialShapeInfo(), (Nd4jPointer) bias.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), (Nd4jPointer) output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), (Nd4jPointer) output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
|
2019-06-06 14:21:15 +02:00
|
|
|
nullptr, 0, exp, 9, nullptr, 0, false);
|
|
|
|
|
|
|
|
//output.printBuffer("output");
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_NEAR(1423, output.e<float>(0), 1e-5);
|
|
|
|
//nd4j_printf("Iter %i passed...\n", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestSconv2d_2) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
|
|
|
|
auto weightsD = NDArrayFactory::create<float>('c', {1, 3, 1, 1});
|
|
|
|
auto output = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
|
|
|
|
output.assign(0.0);
|
|
|
|
|
|
|
|
input.linspace(1);
|
|
|
|
weightsD.linspace(1);
|
|
|
|
weightsD.permutei({2,3,1,0});
|
|
|
|
|
|
|
|
auto expOutput = NDArrayFactory::create<float>('c', {3, 3, 8, 8});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sconv2d op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input, &weightsD});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer(), input.getSpecialBuffer(), weightsD.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo(), input.getSpecialShapeInfo(), weightsD.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&output}, {&input, &weightsD});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_NEAR(1, output.e<float>(0), 1e-5);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestMaxPooling2d_1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
|
|
|
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
|
|
|
input.linspace(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
std::vector<Nd4jLong> iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::maxpool2d op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&output}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
|
|
|
|
}
|
|
|
|
TEST_F(JavaInteropTests, TestCol2Im_1) {
|
|
|
|
/*
|
|
|
|
o.d.n.l.c.ConvolutionLayer - eps shape: [6, 1, 2, 2, 2, 4, 5, 160, 4, 2, 1, 40, 8, 0, -1, 99]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - epsNext shape: [4, 1, 2, 4, 5, 20, 20, 5, 1, 0, 1, 99]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - Strides: [1, 1]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - Padding: [0, 0]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - Input: [4,5]
|
|
|
|
o.d.n.l.c.ConvolutionLayer - Dilation: [1, 1]
|
|
|
|
*/
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 2, 2, 4, 5});
|
|
|
|
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
|
|
|
input.linspace(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::col2im op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&output}, {&input});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestPNorm_1) {
|
|
|
|
/*
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - input: [4, 1, 3, 4, 4, 16, 16, 4, 1, 0, 1, 99]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - output: [4, 1, 3, 3, 3, 27, 9, 3, 1, 0, 1, 99]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Kernel: [2, 2]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Strides: [1, 1]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Pad: [0, 0]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Dilation: [1, 1]
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - Same: false
|
|
|
|
o.d.n.l.c.s.SubsamplingLayer - pnorm: 2
|
|
|
|
*/
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 3, 4, 4});
|
|
|
|
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
|
|
|
input.linspace(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::pnormpool2d op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&output}, {&input});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, TestInplace_1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {10, 10});
|
|
|
|
//auto exp('c', {10, 10});
|
|
|
|
input.linspace(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({}, {&input});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::clipbyvalue op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
double extras[] = {-1.0f, 1.0f};
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({}, {&input});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
|
|
|
|
|
|
|
ASSERT_NEAR(1.0, input.meanNumber().e<float>(0), 1e-5);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Synonyms_1) {
|
|
|
|
auto op = OpRegistrator::getInstance()->getOperation("RDiv");
|
|
|
|
auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide");
|
|
|
|
std::string nameExp("reversedivide");
|
|
|
|
|
|
|
|
ASSERT_TRUE(op != nullptr);
|
|
|
|
ASSERT_TRUE(opRef != nullptr);
|
|
|
|
|
|
|
|
std::string name = *(op->getOpName());
|
|
|
|
std::string nameRef = *(opRef->getOpName());
|
|
|
|
|
|
|
|
ASSERT_EQ(nameExp, nameRef);
|
|
|
|
ASSERT_EQ(nameRef, name);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Synonyms_2) {
|
|
|
|
auto op = OpRegistrator::getInstance()->getOperation("RDiv");
|
|
|
|
auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide");
|
|
|
|
std::string nameExp("reversedivide");
|
|
|
|
|
|
|
|
ASSERT_TRUE(op != nullptr);
|
|
|
|
ASSERT_TRUE(opRef != nullptr);
|
|
|
|
|
|
|
|
std::string name = *(op->getOpName());
|
|
|
|
std::string nameRef = *(opRef->getOpName());
|
|
|
|
|
|
|
|
ASSERT_EQ(nameExp, nameRef);
|
|
|
|
ASSERT_EQ(nameRef, name);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Synonyms_3) {
|
|
|
|
auto op = OpRegistrator::getInstance()->getOperation("RDiv");
|
|
|
|
auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide");
|
|
|
|
std::string nameExp("reversedivide");
|
|
|
|
|
|
|
|
ASSERT_TRUE(op != nullptr);
|
|
|
|
ASSERT_TRUE(opRef != nullptr);
|
|
|
|
|
|
|
|
std::string name = *(op->getOpName());
|
|
|
|
std::string nameRef = *(opRef->getOpName());
|
|
|
|
|
|
|
|
ASSERT_EQ(nameExp, nameRef);
|
|
|
|
ASSERT_EQ(nameRef, name);
|
|
|
|
}
|
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
TEST_F(JavaInteropTests, Test_FastPath_Validation_1) {
|
|
|
|
auto x = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
|
|
|
auto z = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
|
|
|
|
|
|
|
Context ctx(1);
|
|
|
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax op;
|
2019-08-26 18:57:51 +02:00
|
|
|
auto status = op.execute(&ctx);
|
|
|
|
ASSERT_NE(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_FastPath_Validation_2) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
|
|
|
|
auto z = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
|
|
|
|
|
|
|
Context ctx(1);
|
|
|
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax op;
|
2019-08-26 18:57:51 +02:00
|
|
|
auto status = op.execute(&ctx);
|
|
|
|
ASSERT_NE(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
2019-11-19 10:53:52 +01:00
|
|
|
TEST_F(JavaInteropTests, Test_FastPath_Validation_3) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {3, 5}, { 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
|
|
|
|
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
|
|
|
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f});
|
|
|
|
|
|
|
|
auto min = NDArrayFactory::create<float>({ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
|
|
|
|
auto max = NDArrayFactory::create<float>({ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
|
|
|
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {3, 5});
|
|
|
|
|
|
|
|
Context ctx(1);
|
|
|
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
|
|
|
ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), min.specialShapeInfo());
|
|
|
|
ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), max.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::fake_quant_with_min_max_vars_per_channel op;
|
2019-11-19 10:53:52 +01:00
|
|
|
ASSERT_ANY_THROW(op.execute(&ctx));
|
|
|
|
}
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
TEST_F(JavaInteropTests, Test_empty_cast_1) {
|
|
|
|
auto x = NDArrayFactory::create<bool>('c', {1, 0, 2});
|
|
|
|
auto z = NDArrayFactory::create<Nd4jLong>('c', {1, 0, 2});
|
|
|
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {1, 0, 2});
|
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {10};
|
|
|
|
|
|
|
|
Context ctx(1);
|
|
|
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
ctx.setIArguments(iArgs, 1);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::cast op;
|
2019-11-13 15:15:18 +01:00
|
|
|
auto result = op.execute(&ctx);
|
|
|
|
ASSERT_EQ(Status::OK(), result);
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
/*
|
2019-06-06 14:21:15 +02:00
|
|
|
TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
|
|
|
int inOutH = 35;
|
|
|
|
int inOutW = 35;
|
|
|
|
int inOutC = 192;
|
|
|
|
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
|
|
|
|
x.linspace(1.0);
|
|
|
|
z.linspace(1.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::avgpool2d op;
|
2019-06-06 14:21:15 +02:00
|
|
|
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
|
|
|
|
|
|
|
Nd4jLong exp[] = {3,3, 1,1, 0,0, 1,1, 1, 0, 1};
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&x});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), result);
|
|
|
|
|
|
|
|
int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH;
|
|
|
|
int padTop = totalPadHeight / 2;
|
|
|
|
int padBottom = totalPadHeight - totalPadHeight / 2;
|
|
|
|
|
|
|
|
int k = 3;
|
|
|
|
|
|
|
|
auto m = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
|
|
|
|
auto c = NDArrayFactory::create<float>('c', {1, inOutH, inOutW, inOutC});
|
|
|
|
|
|
|
|
for (int h = 0; h < inOutH; h++) {
|
|
|
|
for (int w = 0; w < inOutW; w++) {
|
|
|
|
int hFrom = h - padTop;
|
|
|
|
int wFrom = w - padBottom;
|
|
|
|
|
|
|
|
int hTo = hFrom + k;
|
|
|
|
int wTo = wFrom + k;
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
hFrom = sd::math::nd4j_max<int>(0, hFrom);
|
|
|
|
wFrom = sd::math::nd4j_max<int>(0, wFrom);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
hTo = sd::math::nd4j_min<int>(inOutH, hTo);
|
|
|
|
wTo = sd::math::nd4j_min<int>(inOutW, wTo);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
int idxOut[4];
|
|
|
|
int idxIn[4];
|
|
|
|
for (int ch = 0; ch < inOutC; ch++) {
|
|
|
|
idxOut[1] = h;
|
|
|
|
idxOut[2] = w;
|
|
|
|
idxOut[3] = ch;
|
|
|
|
idxIn[3] = ch;
|
|
|
|
|
|
|
|
for (int kh = hFrom; kh < hTo; kh++) {
|
|
|
|
for (int kw = wFrom; kw < wTo; kw++) {
|
|
|
|
idxIn[1] = kh;
|
|
|
|
idxIn[2] = kw;
|
|
|
|
|
|
|
|
auto inVal = x.e<float>(0, kh, kw, ch);
|
|
|
|
m.p(0, h, w, ch, inVal + m.e<float>(0, h, w, ch));
|
|
|
|
c.p(0, h, w, ch, 1 + c.e<int>(0, h, w, ch));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
m /= c;
|
|
|
|
|
|
|
|
//z.printIndexedBuffer("z buffer", 100);
|
|
|
|
//m.printIndexedBuffer("m buffer", 100);
|
|
|
|
int cnt = 0;
|
|
|
|
int lim = 10;
|
|
|
|
for (int e = 0; e < z.lengthOf() && cnt < lim; e++) {
|
|
|
|
auto _m = m.e<float>(e);
|
|
|
|
auto _z = z.e<float>(e);
|
2020-03-02 10:49:41 +01:00
|
|
|
auto eq = sd::math::nd4j_eq<float>(_m, _z, 1e-5);
|
2019-06-06 14:21:15 +02:00
|
|
|
if (!eq) {
|
|
|
|
nd4j_printf("Difference at element e [%i]: <%f> vs <%f>\n", e, _m, _z);
|
|
|
|
cnt++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ASSERT_EQ(m, z);
|
|
|
|
}
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
|
2020-03-02 10:49:41 +01:00
|
|
|
uint8_t* data = sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
unregisterGraph(nullptr, 119);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
|
|
|
|
|
|
|
delete[] data;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
|
|
|
//Environment::getInstance()->setDebug(true);
|
|
|
|
//Environment::getInstance()->setVerbose(true);
|
|
|
|
|
|
|
|
auto exp0 = NDArrayFactory::create<float>('c', {3}, {3, 3, 3});
|
|
|
|
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
|
|
|
|
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
|
|
|
|
|
|
|
|
// we load graph from file, because we're not in java here, and dont have buffer ready
|
2020-03-02 10:49:41 +01:00
|
|
|
uint8_t* data = sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// we ensure that there's no such a graph stored earlier
|
|
|
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
|
|
|
// register the graph, to call for it later
|
2019-07-22 13:34:08 +02:00
|
|
|
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// and ensure we're ok
|
|
|
|
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// run stuff
|
|
|
|
|
|
|
|
auto input_0 = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
input_0.assign(1.0f);
|
|
|
|
|
|
|
|
int idx[] = {1};
|
|
|
|
|
|
|
|
Nd4jPointer inputs_0[] = {(Nd4jPointer) input_0.buffer()};
|
|
|
|
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
|
|
|
|
|
|
|
|
// now we're executing stored graph and providing replacement for input variable
|
2019-07-22 13:34:08 +02:00
|
|
|
auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
|
|
|
|
ASSERT_EQ(1, res_0->size());
|
|
|
|
|
|
|
|
auto z0 = res_0->at(0)->getNDArray();
|
|
|
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
|
|
|
|
|
|
|
|
|
|
|
auto input_1 = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
input_1.assign(2.0f);
|
|
|
|
|
|
|
|
Nd4jPointer inputs_1[] = {(Nd4jPointer) input_1.buffer()};
|
|
|
|
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
|
|
|
|
|
|
|
|
// doing it again
|
2019-07-22 13:34:08 +02:00
|
|
|
auto res_1 = executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
|
|
|
ASSERT_EQ(1, res_1->size());
|
|
|
|
|
|
|
|
auto z1 = res_1->at(0)->getNDArray();
|
|
|
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
|
|
|
|
|
|
|
|
|
|
|
auto input_2 = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
input_2.assign(3.0f);
|
|
|
|
|
|
|
|
Nd4jPointer inputs_2[] = {(Nd4jPointer) input_2.buffer()};
|
|
|
|
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
|
|
|
|
|
|
|
|
// and again
|
2019-07-22 13:34:08 +02:00
|
|
|
auto res_2 = executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
|
|
|
ASSERT_EQ(1, res_2->size());
|
|
|
|
|
|
|
|
auto z2 = res_2->at(0)->getNDArray();
|
|
|
|
ASSERT_TRUE(exp2.isSameShape(z2));
|
|
|
|
|
|
|
|
|
|
|
|
//////// clean out
|
2019-07-22 13:34:08 +02:00
|
|
|
unregisterGraph(nullptr, 119);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
|
|
|
|
|
|
|
|
|
|
|
delete[] data;
|
|
|
|
delete res_0;
|
|
|
|
delete res_1;
|
|
|
|
delete res_2;
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Greater_1) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 1, 2});
|
|
|
|
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 0, 0});
|
|
|
|
// auto o = NDArrayFactory::create<float>('c', {2, 2}, {3, 3, 3, 3});
|
2019-12-06 09:10:44 +01:00
|
|
|
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-12-06 09:10:44 +01:00
|
|
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::greater op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&o}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(exp.equalsTo(&o));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Greater_2) {
|
2019-11-30 14:02:07 +01:00
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 1.f, 2.f});
|
|
|
|
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 0.f, 0.f});
|
2019-12-06 09:10:44 +01:00
|
|
|
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-12-06 09:10:44 +01:00
|
|
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::greater op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&o}, {&x, &y});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(exp.equalsTo(&o));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::is_non_decreasing op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-30 14:02:07 +01:00
|
|
|
auto x = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
|
2019-06-06 14:21:15 +02:00
|
|
|
auto o = NDArrayFactory::create<bool>(false);
|
|
|
|
auto exp = NDArrayFactory::create<bool>(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&o}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&o}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(&o));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::test_output_reshape op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto y = NDArrayFactory::create<float>(2.0f);
|
|
|
|
auto z = NDArrayFactory::create<float>('f', {2, 3});
|
|
|
|
auto e = NDArrayFactory::create<float>('c', {2, 3}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
|
|
|
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::add op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_TRUE(e.isSameShape(z));
|
|
|
|
ASSERT_TRUE(e.equalsTo(z));
|
|
|
|
ASSERT_FALSE(e.ordering() == z.ordering());
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) {
|
2019-11-30 14:02:07 +01:00
|
|
|
auto input = NDArrayFactory::create<double>('c', {2, 3, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24});
|
2019-06-06 14:21:15 +02:00
|
|
|
auto indices = NDArrayFactory::create<Nd4jLong>('c', {1, 6}, {0,1, 2,2, 1,2});
|
2019-11-30 14:02:07 +01:00
|
|
|
auto output = NDArrayFactory::create<double>('f', {2, 1, 6, 4});
|
|
|
|
auto e = NDArrayFactory::create<double>('c', {2, 1, 6, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::gather op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&output}, {&input, &indices});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) indices.getBuffer(), input.getSpecialBuffer(), indices.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) indices.getShapeInfo(), input.getSpecialShapeInfo(), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {1};
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&output}, {&input, &indices});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_TRUE(e.isSameShape(output));
|
|
|
|
ASSERT_TRUE(e.equalsTo(output));
|
|
|
|
ASSERT_FALSE(e.ordering() == output.ordering());
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {3, 4, 5});
|
|
|
|
auto y = NDArrayFactory::create<float>('c', {3, 4, 5});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {5});
|
|
|
|
|
|
|
|
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
2019-08-02 19:01:03 +02:00
|
|
|
dims.syncToHost();
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext* context = sd::LaunchContext::defaultContext();
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
Nd4jPointer* extraPointers = nullptr;
|
|
|
|
#ifdef __CUDABLAS__
|
|
|
|
extraPointers = new Nd4jPointer[6] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer()};
|
|
|
|
#endif
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), {0,1});
|
|
|
|
auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1});
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dims});
|
2020-01-04 11:27:50 +01:00
|
|
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
|
|
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
|
|
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
|
|
|
OpaqueDataBuffer dimBuf(dims.dataBuffer());
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
2019-08-02 19:01:03 +02:00
|
|
|
nullptr,
|
2020-01-04 11:27:50 +01:00
|
|
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
|
|
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
|
|
|
&dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(),
|
|
|
|
packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dims});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
delete []extraPointers;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
/*
|
|
|
|
TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
|
|
|
|
Environment::getInstance()->setDebug(true);
|
|
|
|
Environment::getInstance()->setVerbose(false);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
auto pl = sd::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
|
2019-07-22 13:34:08 +02:00
|
|
|
auto ptr = executeFlatGraph(nullptr, pl);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Environment::getInstance()->setDebug(false);
|
|
|
|
Environment::getInstance()->setVerbose(false);
|
|
|
|
|
|
|
|
delete[] pl;
|
|
|
|
delete ptr;
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) {
|
|
|
|
|
2019-11-30 14:02:07 +01:00
|
|
|
auto input = NDArrayFactory::create<double>('c', {4, 10, 10, 3}, {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, 5.461874
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {4, 4, 4, 3});
|
|
|
|
auto exp = NDArrayFactory::create<double>('c', {4, 4, 4, 3}, {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, 4.4945225
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::avgpool2d op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&input});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&input});
|
|
|
|
ASSERT_EQ(Status::OK(), status);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {1, 1, 4, 5});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {1, 1, 4, 5});
|
|
|
|
|
|
|
|
input.linspace(1.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {2,2, 1,1, 1,1, 2,2,1, 0,0};
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::maxpool2d op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&input});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Unstack_1) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {5, 5});
|
|
|
|
x.linspace(1.0);
|
|
|
|
auto z0 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
auto z1 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
auto z2 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
auto z3 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
auto z4 = NDArrayFactory::create<double>('c',{5});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer()), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo()), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), z2.buffer(), z3.buffer(), z4.buffer(), z0.getSpecialBuffer(), z1.getSpecialBuffer(), z2.getSpecialBuffer(), z3.getSpecialBuffer(), z4.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {z0.shapeInfo(), z1.shapeInfo(), z2.shapeInfo(), z3.shapeInfo(), z4.shapeInfo(), z0.getSpecialShapeInfo(), z1.getSpecialShapeInfo(), z2.getSpecialShapeInfo(), z3.getSpecialShapeInfo(), z4.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {0};
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::unstack op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) {
|
|
|
|
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f,2.91434479f,5.43639755f,-2.10573769f, 4.08528662f,5.86908436f,-4.46203756f,2.21057916f,5.35849190f,0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {4, 4, 4, 3});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f,
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::avgpool2d op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&input});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
|
|
|
|
|
|
|
auto hash = op.getOpHash();
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&input});
|
|
|
|
ASSERT_EQ(Status::OK(), status);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
|
|
|
|
if (!Environment::getInstance()->isExperimentalBuild())
|
|
|
|
return;
|
|
|
|
|
|
|
|
auto arrayX = NDArrayFactory::create<int>({1, 2, 3, 4});
|
|
|
|
auto arrayY = NDArrayFactory::create<double>({1, 2, 3, 4});
|
|
|
|
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
|
|
|
|
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
OpaqueDataBuffer xBuf(arrayX.dataBuffer());
|
|
|
|
OpaqueDataBuffer yBuf(arrayY.dataBuffer());
|
|
|
|
OpaqueDataBuffer zBuf(arrayZ.dataBuffer());
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
execPairwiseTransform(nullptr, pairwise::Add,
|
2020-01-04 11:27:50 +01:00
|
|
|
&xBuf, arrayX.shapeInfo(), arrayX.getSpecialShapeInfo(),
|
|
|
|
&yBuf, arrayY.shapeInfo(), arrayY.getSpecialShapeInfo(),
|
|
|
|
&zBuf, arrayZ.shapeInfo(), arrayZ.getSpecialShapeInfo(),
|
2019-06-06 14:21:15 +02:00
|
|
|
nullptr);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(arrayE, arrayZ);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Add_1) {
|
|
|
|
auto x = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
|
|
|
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
|
|
|
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&x}, {&x, &y});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::add op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo(),};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&x}, {&x, &y});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(e, x);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, zeta_test10) {
|
|
|
|
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3, 4}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12});
|
|
|
|
auto q = NDArrayFactory::create<double>('c', {3, 4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12});
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {3, 4});
|
|
|
|
|
|
|
|
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::zeta op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x, &q});
|
|
|
|
|
|
|
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer(), x.getSpecialBuffer(), q.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), q.getShapeInfo(), x.specialShapeInfo(), q.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&x, &q});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_IAMax_1) {
|
|
|
|
auto arrayX = NDArrayFactory::create<float>({-0.24f, -0.26f, -0.07f, -0.01f});
|
|
|
|
auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr);
|
|
|
|
auto exp = NDArrayFactory::create<Nd4jLong>(1);
|
|
|
|
|
|
|
|
ASSERT_EQ(exp, arrayZ);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
|
|
|
|
auto arrayX = NDArrayFactory::create<double>('c', {10, 10});
|
|
|
|
auto arrayY = NDArrayFactory::create<double>('c', {10, 10});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(arrayX.buffer()), reinterpret_cast<Nd4jPointer>(arrayY.buffer()), arrayX.getSpecialBuffer(), arrayY.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(arrayX.shapeInfo()), reinterpret_cast<Nd4jPointer>(arrayY.shapeInfo()), arrayX.getSpecialShapeInfo(), arrayY.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({}, {&arrayX, &arrayY});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::greater_equal op;
|
2020-01-30 16:46:12 +01:00
|
|
|
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0);
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({}, {&arrayX, &arrayY});
|
2019-06-06 14:21:15 +02:00
|
|
|
delete shapeList;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_L2_Loss_3) {
|
|
|
|
auto x = NDArrayFactory::create<double>(0.7787855863571167);
|
|
|
|
auto e = NDArrayFactory::create<double>(0.303254);
|
|
|
|
auto z = NDArrayFactory::create<double>(0.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer()), x.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo()), x.getSpecialShapeInfo()};
|
|
|
|
|
|
|
|
Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
|
|
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::l2_loss op;
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&x});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_3) {
|
|
|
|
auto array0 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto array1 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {3, 2});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
|
|
|
Context ctx(1);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&array0, &array1});
|
|
|
|
|
|
|
|
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.getSpecialBuffer(), array0.getSpecialShapeInfo());
|
|
|
|
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.getSpecialBuffer(), array1.getSpecialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(2, ctx.width());
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::add op;
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {&array0, &array1});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(exp, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_4) {
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1,1,1,0,0, 1,1,1,1,0, 1,1,1,1,1});
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {3, 5});
|
|
|
|
Nd4jLong iArgs[] = {3, 5, 2};
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({&z}, {});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Context ctx(1);
|
|
|
|
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
ctx.setIArguments(iArgs, 3);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::tri op;
|
2019-07-22 13:34:08 +02:00
|
|
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&z}, {});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(exp, z);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_5) {
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
auto c = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
a.linspace(1.0);
|
|
|
|
b.linspace(1.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&c}, {&b, &c});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Context ctx(1);
|
|
|
|
|
|
|
|
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
|
|
|
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::matmul op;
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&c}, {&b, &c});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_6) {
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2, 3});
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 4});
|
|
|
|
auto gI = NDArrayFactory::create<float>('c', {2, 4});
|
|
|
|
|
|
|
|
auto gA = NDArrayFactory::create<float>('c', {2, 3});
|
|
|
|
auto gB = NDArrayFactory::create<float>('c', {3, 4});
|
|
|
|
a.linspace(1.0);
|
|
|
|
b.linspace(1.0);
|
|
|
|
gI.linspace(1.0);
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&gA, &gB}, {&a, &b, &gI});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Context ctx(1);
|
|
|
|
Nd4jLong iArgs[] = {0L, 0L, 0L};
|
|
|
|
|
|
|
|
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
|
|
|
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
|
|
|
ctx.setInputArray(2, gI.buffer(), gI.shapeInfo(), gI.specialBuffer(), gI.specialShapeInfo());
|
|
|
|
|
|
|
|
ctx.setOutputArray(0, gA.buffer(), gA.shapeInfo(), gA.specialBuffer(), gA.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(1, gB.buffer(), gB.shapeInfo(), gB.specialBuffer(), gB.specialShapeInfo());
|
|
|
|
|
|
|
|
ctx.setIArguments(iArgs, 3);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::matmul_bp op;
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::registerSpecialUse({&gA, &gB}, {&a, &b, &gI});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
|
|
|
auto b = NDArrayFactory::create<float>(3.f);
|
|
|
|
auto z = NDArrayFactory::create<float>('c', {3});
|
|
|
|
auto e = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
NDArray::prepareSpecialUse({&z}, {&a, &b});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Context ctx(1);
|
|
|
|
Nd4jLong iArgs[] = {0L, 0L, 0L};
|
|
|
|
|
|
|
|
ctx.setIArguments(iArgs, 1);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::concat op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
|
|
|
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
|
|
|
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
2019-08-02 19:01:03 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&z}, {&a, &b});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
2019-08-14 20:51:42 +02:00
|
|
|
TEST_F(JavaInteropTests, test_bfloat16_rng) {
|
|
|
|
if (!Environment::getInstance()->isCPU())
|
|
|
|
return;
|
|
|
|
|
|
|
|
auto z = NDArrayFactory::create<bfloat16>('c', {10});
|
|
|
|
RandomGenerator rng(119, 323841120L);
|
|
|
|
bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f};
|
2020-01-04 11:27:50 +01:00
|
|
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
2020-03-02 10:49:41 +01:00
|
|
|
execRandom(nullptr, sd::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args);
|
2019-11-13 15:15:18 +01:00
|
|
|
|
|
|
|
//z.printIndexedBuffer("z");
|
2019-08-14 20:51:42 +02:00
|
|
|
ASSERT_TRUE(z.sumNumber().e<float>(0) > 0);
|
|
|
|
}
|
|
|
|
|
2019-08-21 14:05:47 +02:00
|
|
|
TEST_F(JavaInteropTests, test_ismax_view) {
|
|
|
|
auto original = NDArrayFactory::create<double>('c', {2, 3, 40});
|
|
|
|
auto v = original.subarray({NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)});
|
2019-12-20 20:35:39 +01:00
|
|
|
v.assign(1.0);
|
2019-08-21 14:05:47 +02:00
|
|
|
|
2019-12-20 20:35:39 +01:00
|
|
|
auto e = v.like();
|
2020-03-03 05:32:37 +01:00
|
|
|
auto t = e(0, {2});
|
2019-12-20 20:35:39 +01:00
|
|
|
t.assign(1.0);
|
2019-08-21 14:05:47 +02:00
|
|
|
|
2019-12-20 20:35:39 +01:00
|
|
|
auto z = v.ulike();
|
2019-08-21 14:05:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
Nd4jLong iArgs[] = {2L, 0L};
|
|
|
|
Context ctx(1);
|
2019-12-20 20:35:39 +01:00
|
|
|
ctx.setInputArray(0, v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo());
|
2019-08-21 14:05:47 +02:00
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
ctx.setIArguments(iArgs, 1);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::ismax op;
|
2019-08-21 14:05:47 +02:00
|
|
|
op.execute(&ctx);
|
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
2019-08-23 11:31:12 +02:00
|
|
|
TEST_F(JavaInteropTests, test_size_dtype_1) {
|
2019-11-30 14:02:07 +01:00
|
|
|
auto x = NDArrayFactory::create<float>('c', {3}, {1.f, 1.f, 1.f});
|
2019-08-23 11:31:12 +02:00
|
|
|
auto z = NDArrayFactory::create<float>(0.0f);
|
|
|
|
auto e = NDArrayFactory::create<float>(3.0f);
|
|
|
|
|
|
|
|
Context ctx(1);
|
|
|
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::size op;
|
2019-08-23 11:31:12 +02:00
|
|
|
auto status = op.execute(&ctx);
|
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_EQ(e, z);
|
|
|
|
}
|
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
TEST_F(JavaInteropTests, test_expandable_array_op_1) {
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
auto x = NDArrayFactory::string( {2}, {"first string", "second"});
|
2020-03-02 10:49:41 +01:00
|
|
|
auto d = NDArrayFactory::string(" ", sd::DataType::UTF8);
|
2020-01-04 11:27:50 +01:00
|
|
|
|
|
|
|
auto z0 = NDArrayFactory::create<Nd4jLong>('c', {6});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
auto z1 = NDArrayFactory::string( {3}, {"", "", ""});
|
2020-01-04 11:27:50 +01:00
|
|
|
|
|
|
|
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"});
|
2020-01-04 11:27:50 +01:00
|
|
|
|
|
|
|
InteropDataBuffer iz0(z0.dataBuffer());
|
|
|
|
InteropDataBuffer iz1(z1.dataBuffer());
|
|
|
|
|
|
|
|
Context ctx(1);
|
|
|
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
|
|
|
ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo());
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::compat_string_split op;
|
2020-01-04 11:27:50 +01:00
|
|
|
auto status = op.execute(&ctx);
|
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
|
|
|
|
ASSERT_EQ(exp0, z0);
|
|
|
|
ASSERT_EQ(exp1, z1);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) {
|
|
|
|
if (!Environment::getInstance()->isCPU())
|
|
|
|
return;
|
|
|
|
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
|
|
|
|
auto y = NDArrayFactory::create<double>('c', {4, 3, 3, 3});
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
|
|
|
|
|
|
|
|
double buffer[2048];
|
|
|
|
|
|
|
|
InteropDataBuffer ix(0, DataType::DOUBLE, false);
|
|
|
|
InteropDataBuffer iy(0, DataType::DOUBLE, false);
|
|
|
|
InteropDataBuffer iz(0, DataType::DOUBLE, false);
|
|
|
|
|
|
|
|
// we're imitating workspace-managed array here
|
|
|
|
ix.setPrimary(buffer + 64, x.lengthOf());
|
|
|
|
iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf());
|
|
|
|
iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf());
|
|
|
|
|
|
|
|
Context ctx(1);
|
|
|
|
ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo());
|
|
|
|
ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo());
|
|
|
|
ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo());
|
|
|
|
|
|
|
|
ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::maxpool2d_bp op;
|
2020-01-04 11:27:50 +01:00
|
|
|
auto status = op.execute(&ctx);
|
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
}
|
|
|
|
|
2020-04-16 13:53:56 +02:00
|
|
|
TEST_F(JavaInteropTests, test_linspace_shape_1) {
|
|
|
|
if (!Environment::getInstance()->isCPU())
|
|
|
|
return;
|
|
|
|
|
|
|
|
sd::ops::lin_space op;
|
|
|
|
double tArgs[2] = {1.0, 10.0};
|
|
|
|
Nd4jLong iArgs = 10L;
|
|
|
|
int dArg = (int) sd::DataType::FLOAT32;
|
|
|
|
auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1);
|
|
|
|
|
|
|
|
ASSERT_EQ(1, result->size());
|
|
|
|
delete result;
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
/*
|
|
|
|
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto pl = sd::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
2019-07-22 13:34:08 +02:00
|
|
|
auto ptr = executeFlatGraph(nullptr, pl);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// at this point we have FlatResults
|
|
|
|
auto flatResult = GetFlatResult(ptr->pointer());
|
|
|
|
auto size = flatResult->variables()->size();
|
|
|
|
|
|
|
|
// we know exact number of outputs in this graph in given mode
|
|
|
|
ASSERT_EQ(184, size);
|
|
|
|
|
|
|
|
|
|
|
|
// now we're rolling through all variables and restore them one by one
|
|
|
|
for (int e = 0; e < size; e++) {
|
|
|
|
auto flatVar = flatResult->variables()->Get(e);
|
|
|
|
auto flatArray = flatVar->ndarray();
|
|
|
|
|
|
|
|
// checking var part first
|
|
|
|
// we just want to ensure we're not experiencing overruns here
|
|
|
|
auto name = flatVar->name()->str();
|
|
|
|
|
|
|
|
// checking array part now
|
|
|
|
auto shape = flatArray->shape();
|
|
|
|
auto rank = shape->Get(0);
|
|
|
|
|
|
|
|
ASSERT_TRUE(shape->size() > 0 && rank >= 0 && rank < MAX_RANK);
|
|
|
|
|
|
|
|
// building regular NDArray out of this FlatArray
|
2020-03-02 10:49:41 +01:00
|
|
|
auto ndarray = sd::graph::FlatUtils::fromFlatArray(flatArray);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// rank should match FlatArray
|
|
|
|
ASSERT_EQ(rank, ndarray->rankOf());
|
|
|
|
|
|
|
|
// array shouldn't have any NaN/Inf values
|
|
|
|
ASSERT_TRUE(ndarray->isFinite());
|
|
|
|
|
|
|
|
// array should be assignable
|
|
|
|
ndarray->assign(123.f);
|
|
|
|
|
|
|
|
// and safely removable after
|
|
|
|
delete ndarray;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
delete[] pl;
|
|
|
|
delete ptr;
|
|
|
|
|
|
|
|
// and we should have 0 leaks reported after this line :)
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
|
|
|
|
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
|
|
|
|
// std::array<float, 60> syn1;
|
|
|
|
// std::array<float, 100000> exp;
|
|
|
|
|
|
|
|
// for (int e = 0; e < syn1.size(); e++)
|
|
|
|
// syn1[e] = 0.0f;
|
|
|
|
|
|
|
|
// for (int e = 0; e < exp.size(); e++) {
|
|
|
|
// auto f = static_cast<double>(e);
|
2020-03-02 10:49:41 +01:00
|
|
|
// auto tmp = sd::math::nd4j_exp<double, double>((f / 100000.0 * 2.0 - 1.0) * 6.0);
|
2019-06-06 14:21:15 +02:00
|
|
|
// exp[e] = static_cast<float>(tmp / (tmp + 1.0));
|
|
|
|
// }
|
|
|
|
|
|
|
|
// auto maxTypes = 5;
|
|
|
|
// auto numAggregates = 1;
|
|
|
|
// auto opNum = 3;
|
|
|
|
// auto maxArgs = 6;
|
|
|
|
// auto maxShapes = 0;
|
|
|
|
// auto maxIntArrays = 2;
|
|
|
|
// auto maxIntArraySize = 40;
|
|
|
|
// auto maxIndexArguments = 10;
|
|
|
|
// auto maxRealArguments = 2;
|
|
|
|
|
|
|
|
// std::array<int, 100000> pointer;
|
|
|
|
|
|
|
|
// auto batchLimit = 512;
|
|
|
|
|
|
|
|
// int indexPos = maxTypes * batchLimit;
|
|
|
|
// int intArraysPos = indexPos + (maxIndexArguments * batchLimit);
|
|
|
|
// int realPos = (intArraysPos + (maxIntArrays * maxIntArraySize * batchLimit));
|
|
|
|
// int argsPos = (realPos + ((maxRealArguments * batchLimit))) / 2;
|
|
|
|
// int shapesPos = argsPos + (maxArgs * batchLimit);
|
|
|
|
|
|
|
|
// std::vector<int> intArray0({0, 0, 0, 0, 0});
|
|
|
|
// std::vector<int> intArray1({1, 0, 0, 0, 0});
|
|
|
|
|
|
|
|
// std::vector<int> indexingArgs0({1, 20, 5, 0, 100000, 3, 0, 0, 0});
|
|
|
|
// std::vector<int> indexingArgs1({0, 20, 5, 0, 100000, 3, 1, 0, 0});
|
|
|
|
|
|
|
|
// std::vector<float> realArgs0({0.024964055335354007f, 3.0768702268737162E18f});
|
|
|
|
|
|
|
|
// int argSize = 6;
|
|
|
|
// int shapesSize = 0;
|
|
|
|
// int indexingSize = 9;
|
|
|
|
// int realArgsSize = 2;
|
|
|
|
// int intArraysSize = 2;
|
|
|
|
|
|
|
|
// int e = 0;
|
|
|
|
|
|
|
|
// auto idx = e * maxTypes;
|
|
|
|
|
|
|
|
// // numbers of arguments
|
|
|
|
// pointer[idx] = 6; // arguments size
|
|
|
|
// pointer[idx+1] = 0; // shapes size
|
|
|
|
// pointer[idx+2] = 9; // indexing arguments size
|
|
|
|
// pointer[idx+3] = 2; // real args size
|
|
|
|
// pointer[idx+4] = 2; // intArray args size
|
|
|
|
|
|
|
|
// // indexing args
|
|
|
|
// auto idxArgs = e == 0 ? indexingArgs0 : indexingArgs1;
|
|
|
|
// for (int f = 0; f < idxArgs.size(); f++) {
|
|
|
|
// idx = indexPos + e * maxIndexArguments;
|
|
|
|
// pointer[idx + f] = idxArgs[f];
|
|
|
|
// }
|
|
|
|
|
|
|
|
// // int array values
|
|
|
|
// int bsize = maxIntArrays * maxIntArraySize;
|
|
|
|
// for (int f = 0; f < intArraysSize; f++) {
|
|
|
|
// int step = (e * bsize) + (f * maxIntArraySize);
|
|
|
|
// auto intArr = f == 0 ? intArray0 : intArray1;
|
|
|
|
// for (int x = 0; x < intArr.size(); x++) {
|
|
|
|
// idx = intArraysPos + step + x;
|
|
|
|
// pointer[idx] = intArr[x];
|
|
|
|
// }
|
|
|
|
// }
|
|
|
|
|
|
|
|
// // real args
|
|
|
|
// auto ptr = reinterpret_cast<float *>(pointer.data());
|
|
|
|
// for (int f = 0; f < realArgsSize; f++) {
|
|
|
|
// idx = realPos + (e * maxRealArguments);
|
|
|
|
// ptr[idx + f] = realArgs0[f];
|
|
|
|
// }
|
|
|
|
|
|
|
|
// //
|
|
|
|
// auto ptrptr = reinterpret_cast<void **>(pointer.data());
|
|
|
|
// idx = argsPos + e * maxArgs;
|
|
|
|
// ptrptr[idx] = reinterpret_cast<void*>(syn0.data());
|
|
|
|
// ptrptr[idx+1] = reinterpret_cast<void*>(syn1.data());
|
|
|
|
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
|
|
|
|
// }
|