[WIP] tests fixes (#130)

* no openmp for ClipByGlobalNorm

Signed-off-by: raver119 <raver119@gmail.com>

* one more bfloat16 rng test

Signed-off-by: raver119 <raver119@gmail.com>

* assertion fix

Signed-off-by: raver119 <raver119@gmail.com>

* - legacy IsMax gone
- linear IsMax gets shapeInfo argument

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of legacy IsMax tests

Signed-off-by: raver119 <raver119@gmail.com>

* IsMax is custom op now

Signed-off-by: raver119 <raver119@gmail.com>

* more blocks for ismax

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

*  - sqrt test
 - some legacy code removed from CudaExecutioner
 - Transforms.asin tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* - TransformFloat fix

Signed-off-by: raver119 <raver119@gmail.com>

* - ismax fix
- SpaceToBatchND/BatchToSpaceND wrappers
- couple of legacy tests removed

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-19 11:33:15 +03:00 committed by GitHub
parent bb80fe4f94
commit aceb915557
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 368 additions and 467 deletions

View File

@ -785,48 +785,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
auto xType = ArrayOptions::dataType(hXShapeInfo); auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo);
switch (opNum) { dim3 launchDims(512, 512, 2048);
case transform::IsMax: {
bool scalarCheat = false;
if (extraParams == nullptr) {
scalarCheat = true;
}
void* special = lc->getAllocationPointer();
if (scalarCheat) {
auto scalarShape = nd4j::ConstantShapeHelper::getInstance()->bufferForShapeInfo(ShapeDescriptor::scalarDescriptor(nd4j::DataType::INT64)); //ShapeBuilders::createScalarShapeInfo(nd4j::DataType::INT64);
/**
* In case of vector-input for IsMax, it just turns into IndexReduce call + further filler call
*/
execIndexReduceScalar(lc, indexreduce::IndexMax, nullptr, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, scalarShape.primaryAsT<Nd4jLong>(), special, scalarShape.specialAsT<Nd4jLong>());
Nd4jLong maxIdx = -119;
nd4j::DebugHelper::checkErrorCode(stream, "IsMax: execIndexReduce(...) failed");
cudaMemcpyAsync(&maxIdx, special, sizeof(Nd4jLong), cudaMemcpyDeviceToHost, *stream);
nd4j::DebugHelper::checkErrorCode(stream, "IsMax: cudaMemcpyAsync(...) failed");
int targetIdx = 0;
if (shape::order(hXShapeInfo) == 'c' || shape::order(hXShapeInfo) == 'f' && maxIdx * shape::stride(hXShapeInfo)[shape::rank(hXShapeInfo) - 1] >= shape::length(hXShapeInfo))
targetIdx = maxIdx;
else
targetIdx = maxIdx * shape::stride(hXShapeInfo)[shape::rank(hXShapeInfo) - 1];
dim3 launchDims(1, 512, 1024);
BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, dZ, shape::length(hZShapeInfo), targetIdx), LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed");
//delete[] scalarShape;
}
}
break;
default: {
dim3 launchDims(512, 512, 16384);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
// TODO: remove after the release // TODO: remove after the release
auto res = cudaStreamSynchronize(*stream); auto res = cudaStreamSynchronize(*stream);
@ -884,7 +845,7 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
if (!DataTypeUtils::isR(zType)) if (!DataTypeUtils::isR(zType))
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType); throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
dim3 launchDims(512, 512, 16384); dim3 launchDims(512, 512, 2048);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
// TODO: remove after the release // TODO: remove after the release

View File

@ -653,37 +653,8 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum,
auto streamSpecial = reinterpret_cast<cudaStream_t&>(extraPointers[4]); auto streamSpecial = reinterpret_cast<cudaStream_t&>(extraPointers[4]);
LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast<int*>(extraPointers[6])); LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast<int*>(extraPointers[6]));
// FIXME: remove this once all operations are enabled
if (opNum == nd4j::transform::IsMax && extraParams != nullptr) {
auto hostYShapeInfo = reinterpret_cast<Nd4jLong *>(extraPointers[7]);
auto hostTShapeInfo = reinterpret_cast<Nd4jLong *>(extraPointers[19]);
auto tadMaxShapeInfo = reinterpret_cast<Nd4jLong *> (extraPointers[10]);
auto tadMaxOffsets = reinterpret_cast<Nd4jLong *> (extraPointers[11]);
int *dimension = reinterpret_cast<int *> (extraPointers[15]);
int *hDimension = reinterpret_cast<int *> (extraPointers[16]);
int dimensionLength = getDeviceId(extraPointers[18]);
auto special = reinterpret_cast<double *>(extraPointers[17]);
auto cshape = ShapeBuilders::createVectorShapeInfo(nd4j::DataType::INT32, dimensionLength);
// we call for IMax on specified dimension
execIndexReduce(extraPointers, indexreduce::IndexMax, nullptr, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, hostTShapeInfo, special, hostYShapeInfo, hDimension, cshape, dimension, nullptr);
DEBUG_KERNEL(stream, opNum);
dim3 launchDims(256, 256, 16384);
auto zType = ArrayOptions::dataType(hZShapeInfo);
// at this point, all IMax indexes are gathered, and we execute filler
BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, special, dZ, dZShapeInfo, tadMaxShapeInfo, dimension, dimensionLength, tadMaxOffsets), LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed");
delete[] cshape;
} else {
NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, nullptr); NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, nullptr);
} }
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void execTransformStrict(Nd4jPointer *extraPointers,int opNum, void execTransformStrict(Nd4jPointer *extraPointers,int opNum,
@ -712,7 +683,7 @@ void execTransformFloat(Nd4jPointer *extraPointers,int opNum,
auto tadOffsets = reinterpret_cast<Nd4jLong *>(extraPointers != nullptr ? extraPointers[11] : nullptr); auto tadOffsets = reinterpret_cast<Nd4jLong *>(extraPointers != nullptr ? extraPointers[11] : nullptr);
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dZ, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
} }

View File

@ -25,21 +25,21 @@ namespace nd4j {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
__global__ void execFillIsMax(void *vdZ, Nd4jLong length, long idx) { __global__ void execFillIsMax(void *vdZ, Nd4jLong *xShapeInfo, Nd4jLong length, long idx) {
auto dz = reinterpret_cast<T*>(vdZ); auto dz = reinterpret_cast<T*>(vdZ);
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x)
dz[i] = (i == idx ? (T) 1 : (T) 0); dz[shape::getIndexOffset(i, xShapeInfo, length)] = (i == idx ? (T) 1 : (T) 0);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
__host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong length, long idx) { __host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong *xShapeInfo, Nd4jLong length, long idx) {
execFillIsMax<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(dx, length, idx); execFillIsMax<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(dx, xShapeInfo, length, idx);
nd4j::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed"); nd4j::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed");
} }
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong length, long idx), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong *zShapeInfo, Nd4jLong length, long idx), LIBND4J_TYPES);
} }

View File

@ -99,18 +99,18 @@ namespace functions {
if(xEws > 0 && zEws > 0 && xOrder == zOrder) { if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
for (int i = tid; i < length; i += totalThreads) for (Nd4jLong i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params); z[i * zEws] = OpType::op(x[i * xEws], params);
} }
else { else {
if(vx == vz) { if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
z[xOffset] = OpType::op(x[xOffset], params); z[xOffset] = OpType::op(x[xOffset], params);
} }
} }
else { else {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
z[zOffset] = OpType::op(x[xOffset], params); z[zOffset] = OpType::op(x[xOffset], params);

View File

@ -92,8 +92,7 @@
(21, Copy) (21, Copy)
#define TRANSFORM_ANY_OPS \ #define TRANSFORM_ANY_OPS \
(0, Assign) , \ (0, Assign)
(1, IsMax)
// these ops return bool // these ops return bool
#define TRANSFORM_BOOL_OPS \ #define TRANSFORM_BOOL_OPS \

View File

@ -36,7 +36,7 @@
namespace nd4j { namespace nd4j {
template <typename T> template <typename T>
_CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong length, long idx); _CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong *xShapeInfo, Nd4jLong length, long idx);
template <typename T> template <typename T>
_CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dX, void *dZ, Nd4jLong *zShapeInfo, Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOffsets); _CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dX, void *dZ, Nd4jLong *zShapeInfo, Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOffsets);

View File

@ -28,7 +28,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -1) { CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -2) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -260,7 +260,7 @@ namespace nd4j {
* 0: axis * 0: axis
*/ */
#if NOT_EXCLUDED(OP_ismax) #if NOT_EXCLUDED(OP_ismax)
DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -1); DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -2);
#endif #endif
/** /**

View File

@ -34,11 +34,6 @@ namespace helpers {
template <typename T> template <typename T>
static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>& dimensions) { static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>& dimensions) {
void* extraParams = nullptr;
bool scalarCheat = false;
if (extraParams == nullptr) {
scalarCheat = true;
}
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
auto xRank = input->rankOf(); auto xRank = input->rankOf();
@ -49,29 +44,16 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
Nd4jLong* special = nullptr; Nd4jLong* special = nullptr;
PointersManager manager(context, "IsMaxHelper"); PointersManager manager(context, "IsMaxHelper");
if (dimensions.size() == 0) { if (dimensions.size() == 0) {
// auto scalarShape = ShapeBuilders::createScalarShapeInfo(nd4j::DataType::INT64);
/** /**
* In case of vector-input for IsMax, it just turns into IndexReduce call + further filler call * In case of vector-input for IsMax, it just turns into IndexReduce call + subsequent filler call
*/ */
auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions); auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
//NativeOpExecutioner::execIndexReduceScalar(context, indexreduce::IndexMax, nullptr, input->getShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), extraParams, nullptr, scalarShape, special, nullptr); auto targetIdx = indexMax->e<Nd4jLong>(0);
//Nd4jLong maxIdx = -119;
//checkCudaErrors(cudaStreamSynchronize(*stream));
//cudaMemcpyAsync(&maxIdx, special, sizeof(Nd4jLong), cudaMemcpyDeviceToHost, *stream);
//checkCudaErrors(cudaStreamSynchronize(*stream));
int targetIdx = 0;
if (input->ordering() == 'c' || input->ordering() == 'f' && indexMax->e<Nd4jLong>(0) * shape::stride(input->getShapeInfo())[input->rankOf() - 1] >= input->lengthOf()) dim3 launchDims(128, 512, 1024);
targetIdx = indexMax->e<Nd4jLong>(0); BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf(), targetIdx), LIBND4J_TYPES);
else manager.synchronize();
targetIdx = indexMax->e<Nd4jLong>(0) * shape::stride(input->getShapeInfo())[input->rankOf() - 1];
dim3 launchDims(1, 512, 1024);
BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->lengthOf(), targetIdx), LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed");
//delete[] scalarShape;
delete indexMax; delete indexMax;
} else { } else {
Nd4jLong* hostYShapeInfo = nullptr; Nd4jLong* hostYShapeInfo = nullptr;
@ -82,13 +64,7 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), copy.data(), copy.size()); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), copy.data(), copy.size());
auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions); auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
//indexMaxArr->printIndexedBuffer("Index max!!!");
// we call for IMax on specified dimension
//NativeOpExecutioner::execIndexReduce(context, indexreduce::IndexMax, nullptr, input->getShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), extraParams, nullptr, hostTShapeInfo, special, hostYShapeInfo, const_cast<int*>(dimensions.data()), (int)dimensions.size(), nullptr, nullptr);
//DEBUG_KERNEL(stream, opNum);
dim3 launchDims(256, 256, 16384); dim3 launchDims(256, 256, 16384);
dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)); dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int));
@ -103,7 +79,11 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions) { void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions) {
NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (context, input, output, dimensions), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (context, input, output, dimensions), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {input});
} }
BUILD_SINGLE_TEMPLATE(template void ismax_, (nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void ismax_, (nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions), LIBND4J_TYPES);

View File

@ -113,6 +113,14 @@ TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) {
ASSERT_TRUE(x.sumNumber().e<float>(0) > 0); ASSERT_TRUE(x.sumNumber().e<float>(0) > 0);
} }
TEST_F(DataTypesValidationTests, test_bfloat16_rand_2) {
auto x = NDArrayFactory::create<bfloat16>('c', {5, 10});
RandomGenerator gen(119, 120);
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1);
ASSERT_TRUE(x.sumNumber().e<float>(0) > 0);
}
TEST_F(DataTypesValidationTests, cast_1) { TEST_F(DataTypesValidationTests, cast_1) {
float16 x = static_cast<float16>(1.f); float16 x = static_cast<float16>(1.f);

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <array>
using namespace nd4j;
class DeclarableOpsTests16 : public testing::Test {
public:
DeclarableOpsTests16() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests16, test_repeat_119) {
auto x = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 4, 5, 6});
auto e = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
nd4j::ops::repeat op;
auto result = op.execute({&x}, {}, {2, 0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}

View File

@ -975,69 +975,6 @@ TEST_F(JavaInteropTests, zeta_test10) {
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
} }
TEST_F(JavaInteropTests, Test_Is_Max_1) {
auto arrayX = NDArrayFactory::create<float>({1, 2, 1, 1});
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
Nd4jPointer* extraPointers = nullptr;
#ifdef __CUDABLAS__
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
#endif
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
execTransformAny(extraPointers, transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
nullptr);
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
ASSERT_EQ(arrayE, arrayZ);
delete []extraPointers;
}
TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
auto arrayX = NDArrayFactory::create<float>({1, 2, 1, 1});
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
Nd4jPointer* extraPointers = nullptr;
#ifdef __CUDABLAS__
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
#endif
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
execTransformAny(extraPointers, transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
nullptr);
//arrayZ.printIndexedBuffer("JAVA ISMAX1");
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
ASSERT_EQ(arrayE, arrayZ);
delete []extraPointers;
}
TEST_F(JavaInteropTests, Test_Is_Max_2) {
auto arrayX = NDArrayFactory::create<float>('c', {3, 2, 3}, {1, 10, 2, 3, 4, 5, -10, -9, -8, -7, -6, -5, 4, 3, 2, 1, 0, -1});
auto arrayZ = NDArrayFactory::create<bool>('c', {3, 2, 3});
Nd4jLong tad[] = {2, 2, 3, 3, 1, 524288, -1, 99};
Nd4jLong off[] = {0, 6, 12};
Nd4jLong *ex[] = {tad, off};
float ea[] = {2, 1, 2};
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
ea);
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
}
TEST_F(JavaInteropTests, Test_IAMax_1) { TEST_F(JavaInteropTests, Test_IAMax_1) {
auto arrayX = NDArrayFactory::create<float>({-0.24f, -0.26f, -0.07f, -0.01f}); auto arrayX = NDArrayFactory::create<float>({-0.24f, -0.26f, -0.07f, -0.01f});
auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr); auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr);

View File

@ -367,49 +367,6 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
delete result; delete result;
} }
TEST_F(LegacyOpsTests, Test_IsMax_1) {
if (!Environment::getInstance()->isCPU())
return;
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 2, 2, 2});
auto z = NDArrayFactory::create<double>('c', {2, 2, 2, 2, 2, 2});
x.linspace(1.0);
z.assign(-589);
double extra[] = {1.0, 0.0};
NativeOpExecutioner::execTransformAny(nullptr, transform::IsMax, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), extra, nullptr, nullptr);
// z.printIndexedBuffer("z");
for (Nd4jLong e = 0; e < z.lengthOf(); e++) {
ASSERT_TRUE(z.e<double>(e) >= 0);
}
}
TEST_F(LegacyOpsTests, Test_IsMax_2) {
if (!Environment::getInstance()->isCPU())
return;
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 2, 2, 2});
auto z = NDArrayFactory::create<bool>('c', {2, 2, 2, 2, 2, 2});
x.linspace(1.0);
z.assign(false);
double extra[] = {1.0, 0.0};
NativeOpExecutioner::execTransformAny(nullptr, transform::IsMax, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), extra, nullptr, nullptr);
// z.printIndexedBuffer("z");
for (Nd4jLong e = 0; e < z.lengthOf(); e++) {
if (e >= z.lengthOf() / 2)
ASSERT_TRUE(z.e<bool>(e));
else
ASSERT_FALSE(z.e<bool>(e));
}
}
TEST_F(LegacyOpsTests, BroadcastingTests_1) { TEST_F(LegacyOpsTests, BroadcastingTests_1) {
auto x = NDArrayFactory::create<double>('c', {5, 5}); auto x = NDArrayFactory::create<double>('c', {5, 5});
x.assign(0.0f); x.assign(0.0f);

View File

@ -1236,7 +1236,7 @@ public class DifferentialFunctionFactory {
} }
public SDVariable isMax(SDVariable ix) { public SDVariable isMax(SDVariable ix) {
return new IsMax(sameDiff(), ix, false).outputVariable(); return new IsMax(sameDiff(), ix).outputVariable();
} }
public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) { public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) {

View File

@ -262,7 +262,7 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
labelCountsEachClass.addi(labels2d.sum(0).castTo(labelCountsEachClass.dataType())); labelCountsEachClass.addi(labels2d.sum(0).castTo(labelCountsEachClass.dataType()));
//For prediction counts: do an IsMax op, but we need to take masking into account... //For prediction counts: do an IsMax op, but we need to take masking into account...
INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p.dup(), 1)); INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p, p.ulike(), 1))[0];
if (maskArray != null) { if (maskArray != null) {
LossUtil.applyMask(isPredictedClass, maskArray); LossUtil.applyMask(isPredictedClass, maskArray);
} }

View File

@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp; import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -34,47 +35,29 @@ import java.util.List;
* [1, 2, 3, 1] -> [0, 0, 1, 0] * [1, 2, 3, 1] -> [0, 0, 1, 0]
* @author Adam Gibson * @author Adam Gibson
*/ */
public class IsMax extends BaseTransformAnyOp { public class IsMax extends DynamicCustomOp {
public IsMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { public IsMax(SameDiff sameDiff, SDVariable i_v) {
super(sameDiff, i_v, inPlace); super(sameDiff, i_v);
} }
public IsMax(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) {
super(sameDiff, i_v, shape, inPlace, extraArgs);
}
public IsMax(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs) {
super(sameDiff, i_v, extraArgs);
}
public IsMax(INDArray x, INDArray z) { public IsMax(INDArray x, INDArray z) {
super(x, z); super(new INDArray[]{x}, new INDArray[]{z});
} }
public IsMax() {} public IsMax() {}
public IsMax(INDArray x) { public IsMax(INDArray x) {
super(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering())); this(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering()));
} }
public IsMax(INDArray x, INDArray z, int... dimensions) { public IsMax(INDArray x, INDArray z, int... dimensions) {
super(x, z); this(x, z);
this.extraArgs = new Object[dimensions.length + 1]; this.addIArgument(dimensions);
this.extraArgs[0] = dimensions.length;
for (int i = 0; i < dimensions.length; i++)
this.extraArgs[i + 1] = dimensions[i];
} }
public IsMax(INDArray x, int... dimensions) { public IsMax(INDArray x, int... dimensions) {
super(x, Nd4j.createUninitialized(x.dataType(), x.shape(), x.ordering())); this(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering()), dimensions);
this.extraArgs = new Object[dimensions.length + 1];
this.extraArgs[0] = dimensions.length;
for (int i = 0; i < dimensions.length; i++)
this.extraArgs[i + 1] = dimensions[i];
}
@Override
public int opNum() {
return 1;
} }
@Override @Override
@ -82,7 +65,6 @@ public class IsMax extends BaseTransformAnyOp {
return "ismax"; return "ismax";
} }
@Override @Override
public String onnxName() { public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No onnx op opName found for " + opName());
@ -93,14 +75,6 @@ public class IsMax extends BaseTransformAnyOp {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
} }
@Override
public DataBuffer extraArgsDataBuff(DataType dtype) {
if (Nd4j.getExecutioner().type() == OpExecutioner.ExecutionerType.CUDA)
return this.extraArgs == null ? null : Nd4j.createBuffer(DataType.LONG, 1, false);
else
return super.extraArgsDataBuff(dtype);
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().zerosLike(arg())); return Collections.singletonList(f().zerosLike(arg()));

View File

@ -77,7 +77,7 @@ public class BatchToSpace extends DynamicCustomOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
return "BatchToSpaceND"; return "BatchToSpace";
} }
@Override @Override

View File

@ -0,0 +1,93 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* N-dimensional batch to space operation. Transforms data from a tensor from batch dimension into M spatial dimensions
* according to the "blocks" specified (a vector of length M). Afterwards the spatial dimensions are optionally cropped,
* as specified in "crops", a tensor of dim (M, 2), denoting the crop range.
* <p>
* Example:
* input: [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
* input shape: [4, 1, 1, 1]
* blocks: [2, 2]
* crops: [[0, 0], [0, 0]]
* <p>
* output: [[[[1], [2]], [[3], [4]]]]
* output shape: [1, 2, 2, 1]
*
* @author Max Pumperla
*/
public class BatchToSpaceND extends DynamicCustomOp {
private int[] blocks;
private int[][] crops;
public BatchToSpaceND() {
}
public BatchToSpaceND(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) {
super(null, sameDiff, args, inPlace);
this.blocks = blocks;
this.crops = crops;
for (val b : blocks)
addIArgument(b);
for (int e = 0; e < crops.length; e++)
addIArgument(crops[e][0], crops[e][1]);
}
@Override
public String opName() {
return "batch_to_space_nd";
}
@Override
public String onnxName() {
return "batch_to_space_nd";
}
@Override
public String tensorflowName() {
return "BatchToSpaceND";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// Inverse of batch to space is space to batch with same blocks and padding as crops
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops));
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -77,7 +77,7 @@ public class SpaceToBatch extends DynamicCustomOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
return "SpaceToBatchND"; return "SpaceToBatch";
} }
@Override @Override

View File

@ -0,0 +1,95 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* N-dimensional space to batch operation. Transforms data from a tensor from M spatial dimensions into batch dimension
* according to the "blocks" specified (a vector of length M). Afterwards the spatial dimensions are optionally padded,
* as specified in "padding", a tensor of dim (M, 2), denoting the padding range.
* <p>
* Example:
* input: [[[[1], [2]], [[3], [4]]]]
* input shape: [1, 2, 2, 1]
* blocks: [2, 2]
* padding: [[0, 0], [0, 0]]
* <p>
* output: [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
* output shape: [4, 1, 1, 1]
* *
*
* @author Max Pumperla
*/
public class SpaceToBatchND extends DynamicCustomOp {
protected int[] blocks;
protected int[][] padding;
public SpaceToBatchND() {
}
public SpaceToBatchND(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) {
super(null, sameDiff, args, inPlace);
this.blocks = blocks;
this.padding = padding;
for (val b : blocks)
addIArgument(b);
for (int e = 0; e < padding.length; e++)
addIArgument(padding[e][0], padding[e][1]);
}
@Override
public String opName() {
return "space_to_batch_nd";
}
@Override
public String onnxName() {
return "space_to_batch_nd";
}
@Override
public String tensorflowName() {
return "SpaceToBatchND";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// Inverse of space to batch is batch to space with same blocks and crops as padding
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding));
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -378,7 +378,7 @@ public class Transforms {
public static INDArray asin(INDArray in, boolean copy) { public static INDArray asin(INDArray in, boolean copy) {
return Nd4j.getExecutioner().exec(new ASin(((copy ? in.dup() : in)))); return Nd4j.getExecutioner().exec(new ASin(in, (copy ? in.ulike() : in)));
} }
public static INDArray atan(INDArray arr) { public static INDArray atan(INDArray arr) {
@ -999,7 +999,8 @@ public class Transforms {
} }
public static INDArray isMax(INDArray input, INDArray output) { public static INDArray isMax(INDArray input, INDArray output) {
return Nd4j.getExecutioner().exec(new IsMax(input, output)); Nd4j.getExecutioner().exec(new IsMax(input, output));
return output;
} }
@ -1035,7 +1036,7 @@ public class Transforms {
* @return * @return
*/ */
public static INDArray sqrt(INDArray ndArray, boolean dup) { public static INDArray sqrt(INDArray ndArray, boolean dup) {
return exec(dup ? new Sqrt(ndArray, ndArray.ulike()) : new Sqrt(ndArray)); return exec(dup ? new Sqrt(ndArray, ndArray.ulike()) : new Sqrt(ndArray, ndArray));
} }
/** /**

View File

@ -1308,40 +1308,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
var hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); var hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
// IsMax
if (op.getOpType() == Op.Type.TRANSFORM_ANY && op.opNum() == 1 && op.extraArgs() != null && op.extraArgs().length > 0) {
// for IsMax along dimension we need special temporary buffer
dimension = new int[(int) op.extraArgs()[0]];
for (int i = 0; i < dimension.length; i++) {
dimension[i] = (int) op.extraArgs()[i + 1];
}
for (int i = 0; i < dimension.length; i++) {
if (dimension[i] < 0)
dimension[i] += op.x().rank();
}
//do op along all dimensions
if (dimension.length == op.x().rank())
dimension = new int[] {Integer.MAX_VALUE};
long[] retShape = Shape.wholeArrayDimension(dimension) ? new long[] {}
: ArrayUtil.removeIndex(op.x().shape(), dimension);
ret = Nd4j.createUninitialized(DataType.LONG, retShape);
// FIXME: this maybe misleading use of this particular pointer
hostYShapeInfo = allocator.getPointer(ret.shapeInfoDataBuffer(), context);
retHostShape = allocator.getHostPointer(ret.shapeInfoDataBuffer());
//dimensionPointer = AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context);
DataBuffer dimensionBuffer = allocator.getConstantBuffer(dimension);
dimensionDevPointer = allocator.getPointer(dimensionBuffer, context);
dimensionHostPointer = allocator.getHostPointer(dimensionBuffer);
retPointer = allocator.getPointer(ret, context);
}
if (op.z() == null) { if (op.z() == null) {
ret = Nd4j.createUninitialized(op.resultType(), op.x().shape(), op.x().ordering()); ret = Nd4j.createUninitialized(op.resultType(), op.x().shape(), op.x().ordering());
@ -1365,37 +1331,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
op.validateDataTypes(experimentalMode.get()); op.validateDataTypes(experimentalMode.get());
// SoftMax, LogSoftMax, SoftMaxDerivative
if (op.getOpType() == Op.Type.TRANSFORM_STRICT && (op.opNum() >= 0 && op.opNum() <= 2)) {
tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), new int[] {0});
tadMaxBuffers = tadManager.getTADOnlyShapeInfo(op.x().rank() == 1 ? op.x().reshape(1, -1) : op.x(), new int[] {1});
hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
devTadShapeInfo = allocator.getPointer(tadBuffers.getFirst(), context);
hostMaxTadShapeInfo = AddressRetriever.retrieveHostPointer(tadMaxBuffers.getFirst());
devMaxTadShapeInfo = allocator.getPointer(tadMaxBuffers.getFirst(), context);
DataBuffer offsets = tadBuffers.getSecond();
devTadOffsets = offsets == null ? null : allocator.getPointer(offsets, context);
DataBuffer maxOffsets = tadMaxBuffers.getSecond();
devMaxTadOffsets = maxOffsets == null ? null : allocator.getPointer(maxOffsets, context);
} else if (op.getOpType() == Op.Type.TRANSFORM_ANY && op.opNum() == 1) { // IsMax
tadBuffers = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
DataBuffer offsets = tadBuffers.getSecond();
devTadOffsets = offsets == null ? null : allocator.getPointer(offsets, context);
if (retPointer == null)
retPointer = context.getBufferReduction();
}
Pointer z = allocator.getPointer(op.z(), context); Pointer z = allocator.getPointer(op.z(), context);
Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context); Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context);
@ -1462,7 +1397,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
case TRANSFORM_FLOAT: case TRANSFORM_FLOAT:
nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(),
null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo,
op.z().data().addressPointer(), (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo,
extraArgs); extraArgs);
break; break;
case TRANSFORM_BOOL: case TRANSFORM_BOOL:

View File

@ -1516,7 +1516,7 @@ public class SameDiffTests extends BaseNd4jTest {
//then dL/dIn = 1 if in_i == min(in) or 0 otherwise //then dL/dIn = 1 if in_i == min(in) or 0 otherwise
//Note that we don't have an "IsMin" op, so use IsMax(neg(in)) which is equivalent //Note that we don't have an "IsMin" op, so use IsMax(neg(in)) which is equivalent
INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.neg())).castTo(Nd4j.defaultFloatingPointType()); INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.neg()))[0].castTo(Nd4j.defaultFloatingPointType());
assertEquals(exp, dLdIn); assertEquals(exp, dLdIn);
} }
@ -1540,7 +1540,7 @@ public class SameDiffTests extends BaseNd4jTest {
//If L = max(in) //If L = max(in)
//then dL/dIn = 1 if in_i == max(in) or 0 otherwise //then dL/dIn = 1 if in_i == max(in) or 0 otherwise
INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.dup())).castTo(DataType.DOUBLE); INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.dup()))[0].castTo(DataType.DOUBLE);
assertEquals(exp, dLdIn); assertEquals(exp, dLdIn);
} }

View File

@ -72,6 +72,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps; import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace; import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
@ -261,7 +262,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
public void testIsMaxVectorCase() { public void testIsMaxVectorCase() {
INDArray arr = Nd4j.create(new double[] {1, 2, 4, 3}, new long[] {2, 2}); INDArray arr = Nd4j.create(new double[] {1, 2, 4, 3}, new long[] {2, 2});
INDArray assertion = Nd4j.create(new boolean[] {false, false, true, false}, new long[] {2, 2}, DataType.BOOL); INDArray assertion = Nd4j.create(new boolean[] {false, false, true, false}, new long[] {2, 2}, DataType.BOOL);
INDArray test = Nd4j.getExecutioner().exec(new IsMax(arr)); INDArray test = Nd4j.getExecutioner().exec(new IsMax(arr))[0];
assertEquals(assertion, test); assertEquals(assertion, test);
} }
@ -719,7 +720,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
//Tests: full buffer... //Tests: full buffer...
//1d //1d
INDArray arr1 = Nd4j.create(new double[] {1, 2, 3, 1}); INDArray arr1 = Nd4j.create(new double[] {1, 2, 3, 1});
val res1 = Nd4j.getExecutioner().exec(new IsMax(arr1)); val res1 = Nd4j.getExecutioner().exec(new IsMax(arr1))[0];
INDArray exp1 = Nd4j.create(new boolean[] {false, false, true, false}); INDArray exp1 = Nd4j.create(new boolean[] {false, false, true, false});
assertEquals(exp1, res1); assertEquals(exp1, res1);
@ -736,8 +737,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray exp2d = Nd4j.create(new boolean[][] {{false, false, false}, {false, true, false}}); INDArray exp2d = Nd4j.create(new boolean[][] {{false, false, false}, {false, true, false}});
INDArray f = arr2d.dup('f'); INDArray f = arr2d.dup('f');
INDArray out2dc = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('c'))); INDArray out2dc = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('c')))[0];
INDArray out2df = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('f'))); INDArray out2df = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('f')))[0];
assertEquals(exp2d, out2dc); assertEquals(exp2d, out2dc);
assertEquals(exp2d, out2df); assertEquals(exp2d, out2df);
} }
@ -803,16 +804,48 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test @Test
public void testIsMaxEqualValues_2() { public void testIsMaxEqualValues_2() {
//[0 2] [0 1] //[0 2] [0 1]
//[2 1] -> [0 0] //[2 1] -> [0 0]bg
INDArray orig = Nd4j.create(new double[][] {{0, 2}, {2, 1}}); INDArray orig = Nd4j.create(new double[][] {{0, 3}, {2, 1}});
INDArray exp = Nd4j.create(new double[][] {{0, 1}, {0, 0}}); INDArray exp = Nd4j.create(new double[][] {{0, 1}, {0, 0}});
INDArray outc = Transforms.isMax(orig.dup('c')); INDArray outc = Transforms.isMax(orig.dup('c'));
assertEquals(exp, outc); assertEquals(exp, outc);
INDArray outf = Transforms.isMax(orig.dup('f')); log.info("Orig: {}", orig.dup('f').data().asFloat());
INDArray outf = Transforms.isMax(orig.dup('f'), orig.dup('f').ulike());
log.info("OutF: {}", outf.data().asFloat());
assertEquals(exp, outf); assertEquals(exp, outf);
} }
@Test
public void testIsMaxEqualValues_3() {
//[0 2] [0 1]
//[2 1] -> [0 0]
INDArray orig = Nd4j.create(new double[][] {{0, 2}, {3, 1}});
INDArray exp = Nd4j.create(new double[][] {{0, 0}, {1, 0}});
INDArray outc = Transforms.isMax(orig.dup('c'));
assertEquals(exp, outc);
INDArray outf = Transforms.isMax(orig.dup('f'), orig.dup('f').ulike());
assertEquals(exp, outf);
}
@Test
public void testSqrt_1() {
val x = Nd4j.createFromArray(9.0, 9.0, 9.0, 9.0);
val x2 = Nd4j.createFromArray(9.0, 9.0, 9.0, 9.0);
val e = Nd4j.createFromArray(3.0, 3.0, 3.0, 3.0);
val z1 = Transforms.sqrt(x, true);
val z2 = Transforms.sqrt(x2, false);
assertEquals(e, z2);
assertEquals(e, x2);
assertEquals(e, z1);
}
@Test @Test
public void testAssign_CF() { public void testAssign_CF() {
val orig = Nd4j.create(new double[][] {{0, 2}, {2, 1}}); val orig = Nd4j.create(new double[][] {{0, 2}, {2, 1}});
@ -828,8 +861,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
//1d: row vector //1d: row vector
INDArray orig = Nd4j.create(new double[] {1, 2, 3, 1}).reshape(1,4 ); INDArray orig = Nd4j.create(new double[] {1, 2, 3, 1}).reshape(1,4 );
INDArray alongDim0 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 0)); INDArray alongDim0 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 0))[0];
INDArray alongDim1 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 1)); INDArray alongDim1 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 1))[0];
INDArray expAlong0 = Nd4j.create(new boolean[]{true, true, true, true}).reshape(1,4); INDArray expAlong0 = Nd4j.create(new boolean[]{true, true, true, true}).reshape(1,4);
INDArray expAlong1 = Nd4j.create(new boolean[] {false, false, true, false}).reshape(1,4); INDArray expAlong1 = Nd4j.create(new boolean[] {false, false, true, false}).reshape(1,4);
@ -841,8 +874,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
//1d: col vector //1d: col vector
System.out.println("----------------------------------"); System.out.println("----------------------------------");
INDArray col = Nd4j.create(new double[] {1, 2, 3, 1}, new long[] {4, 1}); INDArray col = Nd4j.create(new double[] {1, 2, 3, 1}, new long[] {4, 1});
INDArray alongDim0col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()), 0)); INDArray alongDim0col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()), 0))[0];
INDArray alongDim1col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()),1)); INDArray alongDim1col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()),1))[0];
INDArray expAlong0col = Nd4j.create(new boolean[] {false, false, true, false}).reshape(4,1); INDArray expAlong0col = Nd4j.create(new boolean[] {false, false, true, false}).reshape(4,1);
INDArray expAlong1col = Nd4j.create(new boolean[] {true, true, true, true}).reshape(4,1); INDArray expAlong1col = Nd4j.create(new boolean[] {true, true, true, true}).reshape(4,1);
@ -877,10 +910,10 @@ public class Nd4jTestsC extends BaseNd4jTest {
//[0 1 0] //[0 1 0]
System.out.println("---------------------"); System.out.println("---------------------");
INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}}); INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}});
INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0)); INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0))[0];
INDArray alongDim0f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 0)); INDArray alongDim0f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 0))[0];
INDArray alongDim1c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 1)); INDArray alongDim1c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 1))[0];
INDArray alongDim1f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 1)); INDArray alongDim1f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 1))[0];
INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}}); INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}});
INDArray expAlong1_2d = Nd4j.create(new boolean[][] {{false, false, true}, {false, true, false}}); INDArray expAlong1_2d = Nd4j.create(new boolean[][] {{false, false, true}, {false, true, false}});
@ -904,7 +937,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test @Test
public void testIsMaxSingleDim1() { public void testIsMaxSingleDim1() {
INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}}); INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}});
INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0)); INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0))[0];
INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}}); INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}});
System.out.println("Original shapeInfo: " + orig2d.dup('c').shapeInfoDataBuffer()); System.out.println("Original shapeInfo: " + orig2d.dup('c').shapeInfoDataBuffer());
@ -1056,8 +1089,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
+ Arrays.toString(shape) + ")"); + Arrays.toString(shape) + ")");
INDArray arrC = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape); INDArray arrC = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape);
INDArray arrF = arrC.dup('f'); INDArray arrF = arrC.dup('f');
val resC = Nd4j.getExecutioner().exec(new IsMax(arrC, alongDimension)); val resC = Nd4j.getExecutioner().exec(new IsMax(arrC, alongDimension))[0];
val resF = Nd4j.getExecutioner().exec(new IsMax(arrF, alongDimension)); val resF = Nd4j.getExecutioner().exec(new IsMax(arrF, alongDimension))[0];
double[] cBuffer = resC.data().asDouble(); double[] cBuffer = resC.data().asDouble();
@ -3932,7 +3965,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
v.assign(t); v.assign(t);
} }
val result = Nd4j.getExecutioner().exec(new IsMax(arr, Nd4j.createUninitialized(DataType.BOOL, arr.shape(), arr.ordering()), 1, 2)); val result = Nd4j.getExecutioner().exec(new IsMax(arr, Nd4j.createUninitialized(DataType.BOOL, arr.shape(), arr.ordering()), 1, 2))[0];
assertEquals(expected, result); assertEquals(expected, result);
} }
@ -3971,8 +4004,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
} }
} }
INDArray actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), Nd4j.createUninitialized(DataType.BOOL, arr.shape()),0, 1)); INDArray actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), Nd4j.createUninitialized(DataType.BOOL, arr.shape()),0, 1))[0];
INDArray actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), Nd4j.createUninitialized(DataType.BOOL, arr.shape(), 'f'), 0, 1)); INDArray actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), Nd4j.createUninitialized(DataType.BOOL, arr.shape(), 'f'), 0, 1))[0];
assertEquals(exp, actC); assertEquals(exp, actC);
assertEquals(exp, actF); assertEquals(exp, actF);
@ -4006,8 +4039,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
} }
} }
actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), 2, 3)); actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), arr.dup('c').ulike(), 2, 3))[0];
actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), 2, 3)); actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), arr.dup('f').ulike(), 2, 3))[0];
assertEquals(exp, actC); assertEquals(exp, actC);
assertEquals(exp, actF); assertEquals(exp, actF);
@ -6527,7 +6560,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertTrue(x.sumNumber().floatValue() > 0); assertTrue(x.sumNumber().floatValue() > 0);
x = Nd4j.randn(DataType.BFLOAT16 , 10); x = Nd4j.randn(DataType.BFLOAT16 , 10);
assertTrue(x.sumNumber().floatValue() > 0); assertTrue(x.sumNumber().floatValue() != 0.0);
} }
@Test @Test
@ -7962,7 +7995,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
public void testBatchToSpace(){ public void testBatchToSpace(){
INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5); INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5);
DynamicCustomOp c = new BatchToSpace(); DynamicCustomOp c = new BatchToSpaceND();
c.addInputArgument( c.addInputArgument(
Nd4j.rand(DataType.FLOAT, new int[]{4, 4, 3}), Nd4j.rand(DataType.FLOAT, new int[]{4, 4, 3}),

View File

@ -106,115 +106,6 @@ public class CudaTests extends BaseNd4jTest {
assertEquals(exp, arrayA); assertEquals(exp, arrayA);
} }
@Test(timeout = 40000L)
public void testContextSpam() throws Exception {
if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA)
return;
val success = new AtomicInteger(0);
val iterations = 101;
val threads = new ArrayList<Thread>();
for (int e = 0; e < iterations; e++) {
val f = e;
val t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.create(1);
if (f % 50 == 0)
log.info("Context {} created", f);
Nd4j.getMemoryManager().releaseCurrentContext();
success.incrementAndGet();
try {
Thread.sleep(1000L);
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
});
t.start();
threads.add(t);
}
for (val t: threads)
t.join();
assertEquals(iterations, success.get());
}
@Ignore
@Test(timeout = 180000L)
public void testContextSpam_2() throws Exception {
if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA)
return;
val success = new AtomicInteger(0);
val iterations = 101;
val threads = new ArrayList<Thread>();
for (int e = 0; e < iterations; e++) {
val f = e;
val t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.create(1);
if (f % 50 == 0)
log.info("Context {} created", f);
//Nd4j.getMemoryManager().releaseCurrentContext();
success.incrementAndGet();
try {
Thread.sleep(1000L);
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
});
t.start();
threads.add(t);
}
for (val t: threads)
t.join();
assertEquals(iterations, success.get());
}
@Test
public void testSequentialReleaseAndReacquire() throws Exception {
if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA)
return;
Nd4j.create(128);
Nd4j.getMemoryManager().releaseCurrentContext();
val array = Nd4j.create(128);
array.addi(1.0f);
}
@Test
@Ignore
public void test(){
if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA)
return;
val SD = SameDiff.create();
val in = SD.one("test", 5, 8, 3, 4);
SDVariable out = in.reshape(-1, 4);
SDVariable out1 = out.reshape(4, 15, -1);
SDVariable out2 = SD.dot(out1, out1, 2);
SDVariable out3 = out2.reshape(-1, 4); // <---- error here
System.out.println(Arrays.toString(out3.eval().toFloatMatrix()));
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';

View File

@ -27,6 +27,18 @@
<packaging>jar</packaging> <packaging>jar</packaging>
<name>nd4j-parameter-server-node</name> <name>nd4j-parameter-server-node</name>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
</plugin>
</plugins>
</build>
<dependencies> <dependencies>
<dependency> <dependency>