[WIP] tests fixes (#130)
* no openmp for ClipByGlobalNorm Signed-off-by: raver119 <raver119@gmail.com> * one more bfloat16 rng test Signed-off-by: raver119 <raver119@gmail.com> * assertion fix Signed-off-by: raver119 <raver119@gmail.com> * - legacy IsMax gone - linear IsMax gets shapeInfo argument Signed-off-by: raver119 <raver119@gmail.com> * get rid of legacy IsMax tests Signed-off-by: raver119 <raver119@gmail.com> * IsMax is custom op now Signed-off-by: raver119 <raver119@gmail.com> * more blocks for ismax Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * - sqrt test - some legacy code removed from CudaExecutioner - Transforms.asin tweaks Signed-off-by: raver119 <raver119@gmail.com> * - TransformFloat fix Signed-off-by: raver119 <raver119@gmail.com> * - ismax fix - SpaceToBatchND/BatchToSpaceND wrappers - couple of legacy tests removed Signed-off-by: raver119 <raver119@gmail.com>master
parent
bb80fe4f94
commit
aceb915557
|
@ -785,48 +785,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
|
||||||
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
switch (opNum) {
|
dim3 launchDims(512, 512, 2048);
|
||||||
case transform::IsMax: {
|
|
||||||
bool scalarCheat = false;
|
|
||||||
if (extraParams == nullptr) {
|
|
||||||
scalarCheat = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void* special = lc->getAllocationPointer();
|
|
||||||
|
|
||||||
if (scalarCheat) {
|
|
||||||
auto scalarShape = nd4j::ConstantShapeHelper::getInstance()->bufferForShapeInfo(ShapeDescriptor::scalarDescriptor(nd4j::DataType::INT64)); //ShapeBuilders::createScalarShapeInfo(nd4j::DataType::INT64);
|
|
||||||
/**
|
|
||||||
* In case of vector-input for IsMax, it just turns into IndexReduce call + further filler call
|
|
||||||
*/
|
|
||||||
execIndexReduceScalar(lc, indexreduce::IndexMax, nullptr, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, scalarShape.primaryAsT<Nd4jLong>(), special, scalarShape.specialAsT<Nd4jLong>());
|
|
||||||
Nd4jLong maxIdx = -119;
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "IsMax: execIndexReduce(...) failed");
|
|
||||||
|
|
||||||
cudaMemcpyAsync(&maxIdx, special, sizeof(Nd4jLong), cudaMemcpyDeviceToHost, *stream);
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "IsMax: cudaMemcpyAsync(...) failed");
|
|
||||||
int targetIdx = 0;
|
|
||||||
|
|
||||||
if (shape::order(hXShapeInfo) == 'c' || shape::order(hXShapeInfo) == 'f' && maxIdx * shape::stride(hXShapeInfo)[shape::rank(hXShapeInfo) - 1] >= shape::length(hXShapeInfo))
|
|
||||||
targetIdx = maxIdx;
|
|
||||||
else
|
|
||||||
targetIdx = maxIdx * shape::stride(hXShapeInfo)[shape::rank(hXShapeInfo) - 1];
|
|
||||||
|
|
||||||
dim3 launchDims(1, 512, 1024);
|
|
||||||
BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, dZ, shape::length(hZShapeInfo), targetIdx), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed");
|
|
||||||
|
|
||||||
//delete[] scalarShape;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default: {
|
|
||||||
dim3 launchDims(512, 512, 16384);
|
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
auto res = cudaStreamSynchronize(*stream);
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
@ -884,7 +845,7 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
|
||||||
if (!DataTypeUtils::isR(zType))
|
if (!DataTypeUtils::isR(zType))
|
||||||
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
|
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
|
||||||
|
|
||||||
dim3 launchDims(512, 512, 16384);
|
dim3 launchDims(512, 512, 2048);
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
|
|
|
@ -653,36 +653,7 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum,
|
||||||
auto streamSpecial = reinterpret_cast<cudaStream_t&>(extraPointers[4]);
|
auto streamSpecial = reinterpret_cast<cudaStream_t&>(extraPointers[4]);
|
||||||
LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast<int*>(extraPointers[6]));
|
LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast<int*>(extraPointers[6]));
|
||||||
|
|
||||||
// FIXME: remove this once all operations are enabled
|
|
||||||
if (opNum == nd4j::transform::IsMax && extraParams != nullptr) {
|
|
||||||
auto hostYShapeInfo = reinterpret_cast<Nd4jLong *>(extraPointers[7]);
|
|
||||||
auto hostTShapeInfo = reinterpret_cast<Nd4jLong *>(extraPointers[19]);
|
|
||||||
auto tadMaxShapeInfo = reinterpret_cast<Nd4jLong *> (extraPointers[10]);
|
|
||||||
auto tadMaxOffsets = reinterpret_cast<Nd4jLong *> (extraPointers[11]);
|
|
||||||
int *dimension = reinterpret_cast<int *> (extraPointers[15]);
|
|
||||||
int *hDimension = reinterpret_cast<int *> (extraPointers[16]);
|
|
||||||
int dimensionLength = getDeviceId(extraPointers[18]);
|
|
||||||
auto special = reinterpret_cast<double *>(extraPointers[17]);
|
|
||||||
|
|
||||||
auto cshape = ShapeBuilders::createVectorShapeInfo(nd4j::DataType::INT32, dimensionLength);
|
|
||||||
|
|
||||||
// we call for IMax on specified dimension
|
|
||||||
execIndexReduce(extraPointers, indexreduce::IndexMax, nullptr, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, hostTShapeInfo, special, hostYShapeInfo, hDimension, cshape, dimension, nullptr);
|
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
|
||||||
|
|
||||||
dim3 launchDims(256, 256, 16384);
|
|
||||||
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
|
||||||
|
|
||||||
// at this point, all IMax indexes are gathered, and we execute filler
|
|
||||||
BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, special, dZ, dZShapeInfo, tadMaxShapeInfo, dimension, dimensionLength, tadMaxOffsets), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed");
|
|
||||||
|
|
||||||
delete[] cshape;
|
|
||||||
} else {
|
|
||||||
NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, nullptr);
|
NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, nullptr);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -712,7 +683,7 @@ void execTransformFloat(Nd4jPointer *extraPointers,int opNum,
|
||||||
auto tadOffsets = reinterpret_cast<Nd4jLong *>(extraPointers != nullptr ? extraPointers[11] : nullptr);
|
auto tadOffsets = reinterpret_cast<Nd4jLong *>(extraPointers != nullptr ? extraPointers[11] : nullptr);
|
||||||
|
|
||||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||||
NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dZ, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,21 +25,21 @@ namespace nd4j {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void execFillIsMax(void *vdZ, Nd4jLong length, long idx) {
|
__global__ void execFillIsMax(void *vdZ, Nd4jLong *xShapeInfo, Nd4jLong length, long idx) {
|
||||||
auto dz = reinterpret_cast<T*>(vdZ);
|
auto dz = reinterpret_cast<T*>(vdZ);
|
||||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x)
|
for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x)
|
||||||
dz[i] = (i == idx ? (T) 1 : (T) 0);
|
dz[shape::getIndexOffset(i, xShapeInfo, length)] = (i == idx ? (T) 1 : (T) 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong length, long idx) {
|
__host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong *xShapeInfo, Nd4jLong length, long idx) {
|
||||||
execFillIsMax<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(dx, length, idx);
|
execFillIsMax<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(dx, xShapeInfo, length, idx);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong length, long idx), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong *zShapeInfo, Nd4jLong length, long idx), LIBND4J_TYPES);
|
||||||
}
|
}
|
|
@ -99,18 +99,18 @@ namespace functions {
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||||
|
|
||||||
for (int i = tid; i < length; i += totalThreads)
|
for (Nd4jLong i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if(vx == vz) {
|
if(vx == vz) {
|
||||||
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
|
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
|
||||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
||||||
z[xOffset] = OpType::op(x[xOffset], params);
|
z[xOffset] = OpType::op(x[xOffset], params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
|
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
|
||||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
||||||
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
|
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
|
||||||
z[zOffset] = OpType::op(x[xOffset], params);
|
z[zOffset] = OpType::op(x[xOffset], params);
|
||||||
|
|
|
@ -92,8 +92,7 @@
|
||||||
(21, Copy)
|
(21, Copy)
|
||||||
|
|
||||||
#define TRANSFORM_ANY_OPS \
|
#define TRANSFORM_ANY_OPS \
|
||||||
(0, Assign) , \
|
(0, Assign)
|
||||||
(1, IsMax)
|
|
||||||
|
|
||||||
// these ops return bool
|
// these ops return bool
|
||||||
#define TRANSFORM_BOOL_OPS \
|
#define TRANSFORM_BOOL_OPS \
|
||||||
|
|
|
@ -36,7 +36,7 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
_CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong length, long idx);
|
_CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong *xShapeInfo, Nd4jLong length, long idx);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
_CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dX, void *dZ, Nd4jLong *zShapeInfo, Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOffsets);
|
_CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dX, void *dZ, Nd4jLong *zShapeInfo, Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOffsets);
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -1) {
|
CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -2) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
|
@ -260,7 +260,7 @@ namespace nd4j {
|
||||||
* 0: axis
|
* 0: axis
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_ismax)
|
#if NOT_EXCLUDED(OP_ismax)
|
||||||
DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -1);
|
DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -2);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -34,11 +34,6 @@ namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>& dimensions) {
|
static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>& dimensions) {
|
||||||
void* extraParams = nullptr;
|
|
||||||
bool scalarCheat = false;
|
|
||||||
if (extraParams == nullptr) {
|
|
||||||
scalarCheat = true;
|
|
||||||
}
|
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
auto xRank = input->rankOf();
|
auto xRank = input->rankOf();
|
||||||
|
@ -49,29 +44,16 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
|
||||||
Nd4jLong* special = nullptr;
|
Nd4jLong* special = nullptr;
|
||||||
PointersManager manager(context, "IsMaxHelper");
|
PointersManager manager(context, "IsMaxHelper");
|
||||||
if (dimensions.size() == 0) {
|
if (dimensions.size() == 0) {
|
||||||
// auto scalarShape = ShapeBuilders::createScalarShapeInfo(nd4j::DataType::INT64);
|
|
||||||
/**
|
/**
|
||||||
* In case of vector-input for IsMax, it just turns into IndexReduce call + further filler call
|
* In case of vector-input for IsMax, it just turns into IndexReduce call + subsequent filler call
|
||||||
*/
|
*/
|
||||||
auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
|
auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
|
||||||
//NativeOpExecutioner::execIndexReduceScalar(context, indexreduce::IndexMax, nullptr, input->getShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), extraParams, nullptr, scalarShape, special, nullptr);
|
auto targetIdx = indexMax->e<Nd4jLong>(0);
|
||||||
//Nd4jLong maxIdx = -119;
|
|
||||||
//checkCudaErrors(cudaStreamSynchronize(*stream));
|
|
||||||
//cudaMemcpyAsync(&maxIdx, special, sizeof(Nd4jLong), cudaMemcpyDeviceToHost, *stream);
|
|
||||||
//checkCudaErrors(cudaStreamSynchronize(*stream));
|
|
||||||
int targetIdx = 0;
|
|
||||||
|
|
||||||
if (input->ordering() == 'c' || input->ordering() == 'f' && indexMax->e<Nd4jLong>(0) * shape::stride(input->getShapeInfo())[input->rankOf() - 1] >= input->lengthOf())
|
dim3 launchDims(128, 512, 1024);
|
||||||
targetIdx = indexMax->e<Nd4jLong>(0);
|
BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf(), targetIdx), LIBND4J_TYPES);
|
||||||
else
|
manager.synchronize();
|
||||||
targetIdx = indexMax->e<Nd4jLong>(0) * shape::stride(input->getShapeInfo())[input->rankOf() - 1];
|
|
||||||
|
|
||||||
dim3 launchDims(1, 512, 1024);
|
|
||||||
BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->lengthOf(), targetIdx), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed");
|
|
||||||
|
|
||||||
//delete[] scalarShape;
|
|
||||||
delete indexMax;
|
delete indexMax;
|
||||||
} else {
|
} else {
|
||||||
Nd4jLong* hostYShapeInfo = nullptr;
|
Nd4jLong* hostYShapeInfo = nullptr;
|
||||||
|
@ -82,13 +64,7 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
|
||||||
|
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), copy.data(), copy.size());
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), copy.data(), copy.size());
|
||||||
|
|
||||||
|
|
||||||
auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
|
auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
|
||||||
//indexMaxArr->printIndexedBuffer("Index max!!!");
|
|
||||||
// we call for IMax on specified dimension
|
|
||||||
//NativeOpExecutioner::execIndexReduce(context, indexreduce::IndexMax, nullptr, input->getShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), extraParams, nullptr, hostTShapeInfo, special, hostYShapeInfo, const_cast<int*>(dimensions.data()), (int)dimensions.size(), nullptr, nullptr);
|
|
||||||
|
|
||||||
//DEBUG_KERNEL(stream, opNum);
|
|
||||||
|
|
||||||
dim3 launchDims(256, 256, 16384);
|
dim3 launchDims(256, 256, 16384);
|
||||||
dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int));
|
dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int));
|
||||||
|
@ -103,7 +79,11 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
|
||||||
|
|
||||||
|
|
||||||
void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions) {
|
void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (context, input, output, dimensions), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (context, input, output, dimensions), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void ismax_, (nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void ismax_, (nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions), LIBND4J_TYPES);
|
||||||
|
|
|
@ -113,6 +113,14 @@ TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) {
|
||||||
ASSERT_TRUE(x.sumNumber().e<float>(0) > 0);
|
ASSERT_TRUE(x.sumNumber().e<float>(0) > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, test_bfloat16_rand_2) {
|
||||||
|
auto x = NDArrayFactory::create<bfloat16>('c', {5, 10});
|
||||||
|
RandomGenerator gen(119, 120);
|
||||||
|
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.sumNumber().e<float>(0) > 0);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DataTypesValidationTests, cast_1) {
|
TEST_F(DataTypesValidationTests, cast_1) {
|
||||||
|
|
||||||
float16 x = static_cast<float16>(1.f);
|
float16 x = static_cast<float16>(1.f);
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <GradCheck.h>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace nd4j;
|
||||||
|
|
||||||
|
|
||||||
|
class DeclarableOpsTests16 : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
DeclarableOpsTests16() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_repeat_119) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
|
||||||
|
|
||||||
|
nd4j::ops::repeat op;
|
||||||
|
auto result = op.execute({&x}, {}, {2, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
|
@ -975,69 +975,6 @@ TEST_F(JavaInteropTests, zeta_test10) {
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_Is_Max_1) {
|
|
||||||
auto arrayX = NDArrayFactory::create<float>({1, 2, 1, 1});
|
|
||||||
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
|
|
||||||
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
|
|
||||||
|
|
||||||
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
|
||||||
|
|
||||||
Nd4jPointer* extraPointers = nullptr;
|
|
||||||
#ifdef __CUDABLAS__
|
|
||||||
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
|
|
||||||
execTransformAny(extraPointers, transform::IsMax,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
|
||||||
nullptr);
|
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
|
|
||||||
ASSERT_EQ(arrayE, arrayZ);
|
|
||||||
|
|
||||||
delete []extraPointers;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
|
|
||||||
auto arrayX = NDArrayFactory::create<float>({1, 2, 1, 1});
|
|
||||||
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
|
|
||||||
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
|
|
||||||
|
|
||||||
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
|
||||||
|
|
||||||
Nd4jPointer* extraPointers = nullptr;
|
|
||||||
#ifdef __CUDABLAS__
|
|
||||||
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
|
|
||||||
execTransformAny(extraPointers, transform::IsMax,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
|
||||||
nullptr);
|
|
||||||
//arrayZ.printIndexedBuffer("JAVA ISMAX1");
|
|
||||||
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
|
|
||||||
ASSERT_EQ(arrayE, arrayZ);
|
|
||||||
delete []extraPointers;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_Is_Max_2) {
|
|
||||||
auto arrayX = NDArrayFactory::create<float>('c', {3, 2, 3}, {1, 10, 2, 3, 4, 5, -10, -9, -8, -7, -6, -5, 4, 3, 2, 1, 0, -1});
|
|
||||||
auto arrayZ = NDArrayFactory::create<bool>('c', {3, 2, 3});
|
|
||||||
Nd4jLong tad[] = {2, 2, 3, 3, 1, 524288, -1, 99};
|
|
||||||
Nd4jLong off[] = {0, 6, 12};
|
|
||||||
Nd4jLong *ex[] = {tad, off};
|
|
||||||
float ea[] = {2, 1, 2};
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
|
|
||||||
execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
|
||||||
ea);
|
|
||||||
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_IAMax_1) {
|
TEST_F(JavaInteropTests, Test_IAMax_1) {
|
||||||
auto arrayX = NDArrayFactory::create<float>({-0.24f, -0.26f, -0.07f, -0.01f});
|
auto arrayX = NDArrayFactory::create<float>({-0.24f, -0.26f, -0.07f, -0.01f});
|
||||||
auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr);
|
auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr);
|
||||||
|
|
|
@ -367,49 +367,6 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, Test_IsMax_1) {
|
|
||||||
if (!Environment::getInstance()->isCPU())
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 2, 2, 2});
|
|
||||||
auto z = NDArrayFactory::create<double>('c', {2, 2, 2, 2, 2, 2});
|
|
||||||
x.linspace(1.0);
|
|
||||||
z.assign(-589);
|
|
||||||
|
|
||||||
double extra[] = {1.0, 0.0};
|
|
||||||
|
|
||||||
NativeOpExecutioner::execTransformAny(nullptr, transform::IsMax, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), extra, nullptr, nullptr);
|
|
||||||
|
|
||||||
// z.printIndexedBuffer("z");
|
|
||||||
for (Nd4jLong e = 0; e < z.lengthOf(); e++) {
|
|
||||||
ASSERT_TRUE(z.e<double>(e) >= 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, Test_IsMax_2) {
|
|
||||||
if (!Environment::getInstance()->isCPU())
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 2, 2, 2});
|
|
||||||
auto z = NDArrayFactory::create<bool>('c', {2, 2, 2, 2, 2, 2});
|
|
||||||
x.linspace(1.0);
|
|
||||||
z.assign(false);
|
|
||||||
|
|
||||||
double extra[] = {1.0, 0.0};
|
|
||||||
|
|
||||||
NativeOpExecutioner::execTransformAny(nullptr, transform::IsMax, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), extra, nullptr, nullptr);
|
|
||||||
|
|
||||||
// z.printIndexedBuffer("z");
|
|
||||||
for (Nd4jLong e = 0; e < z.lengthOf(); e++) {
|
|
||||||
if (e >= z.lengthOf() / 2)
|
|
||||||
ASSERT_TRUE(z.e<bool>(e));
|
|
||||||
else
|
|
||||||
ASSERT_FALSE(z.e<bool>(e));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, BroadcastingTests_1) {
|
TEST_F(LegacyOpsTests, BroadcastingTests_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {5, 5});
|
auto x = NDArrayFactory::create<double>('c', {5, 5});
|
||||||
x.assign(0.0f);
|
x.assign(0.0f);
|
||||||
|
|
|
@ -1236,7 +1236,7 @@ public class DifferentialFunctionFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable isMax(SDVariable ix) {
|
public SDVariable isMax(SDVariable ix) {
|
||||||
return new IsMax(sameDiff(), ix, false).outputVariable();
|
return new IsMax(sameDiff(), ix).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) {
|
public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) {
|
||||||
|
|
|
@ -262,7 +262,7 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
|
||||||
|
|
||||||
labelCountsEachClass.addi(labels2d.sum(0).castTo(labelCountsEachClass.dataType()));
|
labelCountsEachClass.addi(labels2d.sum(0).castTo(labelCountsEachClass.dataType()));
|
||||||
//For prediction counts: do an IsMax op, but we need to take masking into account...
|
//For prediction counts: do an IsMax op, but we need to take masking into account...
|
||||||
INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p.dup(), 1));
|
INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p, p.ulike(), 1))[0];
|
||||||
if (maskArray != null) {
|
if (maskArray != null) {
|
||||||
LossUtil.applyMask(isPredictedClass, maskArray);
|
LossUtil.applyMask(isPredictedClass, maskArray);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -34,47 +35,29 @@ import java.util.List;
|
||||||
* [1, 2, 3, 1] -> [0, 0, 1, 0]
|
* [1, 2, 3, 1] -> [0, 0, 1, 0]
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
public class IsMax extends BaseTransformAnyOp {
|
public class IsMax extends DynamicCustomOp {
|
||||||
public IsMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
public IsMax(SameDiff sameDiff, SDVariable i_v) {
|
||||||
super(sameDiff, i_v, inPlace);
|
super(sameDiff, i_v);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IsMax(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) {
|
|
||||||
super(sameDiff, i_v, shape, inPlace, extraArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IsMax(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs) {
|
|
||||||
super(sameDiff, i_v, extraArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IsMax(INDArray x, INDArray z) {
|
public IsMax(INDArray x, INDArray z) {
|
||||||
super(x, z);
|
super(new INDArray[]{x}, new INDArray[]{z});
|
||||||
}
|
}
|
||||||
|
|
||||||
public IsMax() {}
|
public IsMax() {}
|
||||||
|
|
||||||
public IsMax(INDArray x) {
|
public IsMax(INDArray x) {
|
||||||
super(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering()));
|
this(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public IsMax(INDArray x, INDArray z, int... dimensions) {
|
public IsMax(INDArray x, INDArray z, int... dimensions) {
|
||||||
super(x, z);
|
this(x, z);
|
||||||
this.extraArgs = new Object[dimensions.length + 1];
|
this.addIArgument(dimensions);
|
||||||
this.extraArgs[0] = dimensions.length;
|
|
||||||
for (int i = 0; i < dimensions.length; i++)
|
|
||||||
this.extraArgs[i + 1] = dimensions[i];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public IsMax(INDArray x, int... dimensions) {
|
public IsMax(INDArray x, int... dimensions) {
|
||||||
super(x, Nd4j.createUninitialized(x.dataType(), x.shape(), x.ordering()));
|
this(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering()), dimensions);
|
||||||
this.extraArgs = new Object[dimensions.length + 1];
|
|
||||||
this.extraArgs[0] = dimensions.length;
|
|
||||||
for (int i = 0; i < dimensions.length; i++)
|
|
||||||
this.extraArgs[i + 1] = dimensions[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -82,7 +65,6 @@ public class IsMax extends BaseTransformAnyOp {
|
||||||
return "ismax";
|
return "ismax";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||||
|
@ -93,14 +75,6 @@ public class IsMax extends BaseTransformAnyOp {
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataBuffer extraArgsDataBuff(DataType dtype) {
|
|
||||||
if (Nd4j.getExecutioner().type() == OpExecutioner.ExecutionerType.CUDA)
|
|
||||||
return this.extraArgs == null ? null : Nd4j.createBuffer(DataType.LONG, 1, false);
|
|
||||||
else
|
|
||||||
return super.extraArgsDataBuff(dtype);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return Collections.singletonList(f().zerosLike(arg()));
|
return Collections.singletonList(f().zerosLike(arg()));
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class BatchToSpace extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "BatchToSpaceND";
|
return "BatchToSpace";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* N-dimensional batch to space operation. Transforms data from a tensor from batch dimension into M spatial dimensions
|
||||||
|
* according to the "blocks" specified (a vector of length M). Afterwards the spatial dimensions are optionally cropped,
|
||||||
|
* as specified in "crops", a tensor of dim (M, 2), denoting the crop range.
|
||||||
|
* <p>
|
||||||
|
* Example:
|
||||||
|
* input: [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
|
||||||
|
* input shape: [4, 1, 1, 1]
|
||||||
|
* blocks: [2, 2]
|
||||||
|
* crops: [[0, 0], [0, 0]]
|
||||||
|
* <p>
|
||||||
|
* output: [[[[1], [2]], [[3], [4]]]]
|
||||||
|
* output shape: [1, 2, 2, 1]
|
||||||
|
*
|
||||||
|
* @author Max Pumperla
|
||||||
|
*/
|
||||||
|
public class BatchToSpaceND extends DynamicCustomOp {
|
||||||
|
|
||||||
|
private int[] blocks;
|
||||||
|
private int[][] crops;
|
||||||
|
|
||||||
|
public BatchToSpaceND() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public BatchToSpaceND(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) {
|
||||||
|
super(null, sameDiff, args, inPlace);
|
||||||
|
|
||||||
|
this.blocks = blocks;
|
||||||
|
this.crops = crops;
|
||||||
|
|
||||||
|
for (val b : blocks)
|
||||||
|
addIArgument(b);
|
||||||
|
|
||||||
|
for (int e = 0; e < crops.length; e++)
|
||||||
|
addIArgument(crops[e][0], crops[e][1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "batch_to_space_nd";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String onnxName() {
|
||||||
|
return "batch_to_space_nd";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "BatchToSpaceND";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
// Inverse of batch to space is space to batch with same blocks and padding as crops
|
||||||
|
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
|
||||||
|
return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -77,7 +77,7 @@ public class SpaceToBatch extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "SpaceToBatchND";
|
return "SpaceToBatch";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* N-dimensional space to batch operation. Transforms data from a tensor from M spatial dimensions into batch dimension
|
||||||
|
* according to the "blocks" specified (a vector of length M). Afterwards the spatial dimensions are optionally padded,
|
||||||
|
* as specified in "padding", a tensor of dim (M, 2), denoting the padding range.
|
||||||
|
* <p>
|
||||||
|
* Example:
|
||||||
|
* input: [[[[1], [2]], [[3], [4]]]]
|
||||||
|
* input shape: [1, 2, 2, 1]
|
||||||
|
* blocks: [2, 2]
|
||||||
|
* padding: [[0, 0], [0, 0]]
|
||||||
|
* <p>
|
||||||
|
* output: [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
|
||||||
|
* output shape: [4, 1, 1, 1]
|
||||||
|
* *
|
||||||
|
*
|
||||||
|
* @author Max Pumperla
|
||||||
|
*/
|
||||||
|
public class SpaceToBatchND extends DynamicCustomOp {
|
||||||
|
|
||||||
|
protected int[] blocks;
|
||||||
|
protected int[][] padding;
|
||||||
|
|
||||||
|
public SpaceToBatchND() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public SpaceToBatchND(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) {
|
||||||
|
super(null, sameDiff, args, inPlace);
|
||||||
|
|
||||||
|
this.blocks = blocks;
|
||||||
|
this.padding = padding;
|
||||||
|
|
||||||
|
for (val b : blocks)
|
||||||
|
addIArgument(b);
|
||||||
|
|
||||||
|
for (int e = 0; e < padding.length; e++)
|
||||||
|
addIArgument(padding[e][0], padding[e][1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "space_to_batch_nd";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String onnxName() {
|
||||||
|
return "space_to_batch_nd";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "SpaceToBatchND";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
// Inverse of space to batch is batch to space with same blocks and crops as padding
|
||||||
|
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
|
||||||
|
return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -378,7 +378,7 @@ public class Transforms {
|
||||||
|
|
||||||
|
|
||||||
public static INDArray asin(INDArray in, boolean copy) {
|
public static INDArray asin(INDArray in, boolean copy) {
|
||||||
return Nd4j.getExecutioner().exec(new ASin(((copy ? in.dup() : in))));
|
return Nd4j.getExecutioner().exec(new ASin(in, (copy ? in.ulike() : in)));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray atan(INDArray arr) {
|
public static INDArray atan(INDArray arr) {
|
||||||
|
@ -999,7 +999,8 @@ public class Transforms {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray isMax(INDArray input, INDArray output) {
|
public static INDArray isMax(INDArray input, INDArray output) {
|
||||||
return Nd4j.getExecutioner().exec(new IsMax(input, output));
|
Nd4j.getExecutioner().exec(new IsMax(input, output));
|
||||||
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1035,7 +1036,7 @@ public class Transforms {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray sqrt(INDArray ndArray, boolean dup) {
|
public static INDArray sqrt(INDArray ndArray, boolean dup) {
|
||||||
return exec(dup ? new Sqrt(ndArray, ndArray.ulike()) : new Sqrt(ndArray));
|
return exec(dup ? new Sqrt(ndArray, ndArray.ulike()) : new Sqrt(ndArray, ndArray));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1308,40 +1308,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
||||||
var hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
|
var hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
|
||||||
|
|
||||||
// IsMax
|
|
||||||
if (op.getOpType() == Op.Type.TRANSFORM_ANY && op.opNum() == 1 && op.extraArgs() != null && op.extraArgs().length > 0) {
|
|
||||||
// for IsMax along dimension we need special temporary buffer
|
|
||||||
dimension = new int[(int) op.extraArgs()[0]];
|
|
||||||
|
|
||||||
for (int i = 0; i < dimension.length; i++) {
|
|
||||||
dimension[i] = (int) op.extraArgs()[i + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < dimension.length; i++) {
|
|
||||||
if (dimension[i] < 0)
|
|
||||||
dimension[i] += op.x().rank();
|
|
||||||
}
|
|
||||||
//do op along all dimensions
|
|
||||||
if (dimension.length == op.x().rank())
|
|
||||||
dimension = new int[] {Integer.MAX_VALUE};
|
|
||||||
|
|
||||||
long[] retShape = Shape.wholeArrayDimension(dimension) ? new long[] {}
|
|
||||||
: ArrayUtil.removeIndex(op.x().shape(), dimension);
|
|
||||||
|
|
||||||
ret = Nd4j.createUninitialized(DataType.LONG, retShape);
|
|
||||||
|
|
||||||
// FIXME: this maybe misleading use of this particular pointer
|
|
||||||
hostYShapeInfo = allocator.getPointer(ret.shapeInfoDataBuffer(), context);
|
|
||||||
retHostShape = allocator.getHostPointer(ret.shapeInfoDataBuffer());
|
|
||||||
|
|
||||||
//dimensionPointer = AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context);
|
|
||||||
DataBuffer dimensionBuffer = allocator.getConstantBuffer(dimension);
|
|
||||||
dimensionDevPointer = allocator.getPointer(dimensionBuffer, context);
|
|
||||||
dimensionHostPointer = allocator.getHostPointer(dimensionBuffer);
|
|
||||||
|
|
||||||
retPointer = allocator.getPointer(ret, context);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (op.z() == null) {
|
if (op.z() == null) {
|
||||||
ret = Nd4j.createUninitialized(op.resultType(), op.x().shape(), op.x().ordering());
|
ret = Nd4j.createUninitialized(op.resultType(), op.x().shape(), op.x().ordering());
|
||||||
|
@ -1365,37 +1331,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
op.validateDataTypes(experimentalMode.get());
|
op.validateDataTypes(experimentalMode.get());
|
||||||
|
|
||||||
// SoftMax, LogSoftMax, SoftMaxDerivative
|
|
||||||
if (op.getOpType() == Op.Type.TRANSFORM_STRICT && (op.opNum() >= 0 && op.opNum() <= 2)) {
|
|
||||||
tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), new int[] {0});
|
|
||||||
tadMaxBuffers = tadManager.getTADOnlyShapeInfo(op.x().rank() == 1 ? op.x().reshape(1, -1) : op.x(), new int[] {1});
|
|
||||||
|
|
||||||
hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
|
|
||||||
devTadShapeInfo = allocator.getPointer(tadBuffers.getFirst(), context);
|
|
||||||
|
|
||||||
hostMaxTadShapeInfo = AddressRetriever.retrieveHostPointer(tadMaxBuffers.getFirst());
|
|
||||||
devMaxTadShapeInfo = allocator.getPointer(tadMaxBuffers.getFirst(), context);
|
|
||||||
|
|
||||||
DataBuffer offsets = tadBuffers.getSecond();
|
|
||||||
devTadOffsets = offsets == null ? null : allocator.getPointer(offsets, context);
|
|
||||||
|
|
||||||
DataBuffer maxOffsets = tadMaxBuffers.getSecond();
|
|
||||||
devMaxTadOffsets = maxOffsets == null ? null : allocator.getPointer(maxOffsets, context);
|
|
||||||
} else if (op.getOpType() == Op.Type.TRANSFORM_ANY && op.opNum() == 1) { // IsMax
|
|
||||||
tadBuffers = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
|
|
||||||
|
|
||||||
hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
|
|
||||||
devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
|
|
||||||
|
|
||||||
DataBuffer offsets = tadBuffers.getSecond();
|
|
||||||
devTadOffsets = offsets == null ? null : allocator.getPointer(offsets, context);
|
|
||||||
|
|
||||||
if (retPointer == null)
|
|
||||||
retPointer = context.getBufferReduction();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Pointer z = allocator.getPointer(op.z(), context);
|
Pointer z = allocator.getPointer(op.z(), context);
|
||||||
Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context);
|
Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
|
@ -1462,7 +1397,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
case TRANSFORM_FLOAT:
|
case TRANSFORM_FLOAT:
|
||||||
nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(),
|
nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(),
|
||||||
null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo,
|
null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo,
|
||||||
op.z().data().addressPointer(), (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo,
|
null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo,
|
||||||
extraArgs);
|
extraArgs);
|
||||||
break;
|
break;
|
||||||
case TRANSFORM_BOOL:
|
case TRANSFORM_BOOL:
|
||||||
|
|
|
@ -1516,7 +1516,7 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
//then dL/dIn = 1 if in_i == min(in) or 0 otherwise
|
//then dL/dIn = 1 if in_i == min(in) or 0 otherwise
|
||||||
|
|
||||||
//Note that we don't have an "IsMin" op, so use IsMax(neg(in)) which is equivalent
|
//Note that we don't have an "IsMin" op, so use IsMax(neg(in)) which is equivalent
|
||||||
INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.neg())).castTo(Nd4j.defaultFloatingPointType());
|
INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.neg()))[0].castTo(Nd4j.defaultFloatingPointType());
|
||||||
|
|
||||||
assertEquals(exp, dLdIn);
|
assertEquals(exp, dLdIn);
|
||||||
}
|
}
|
||||||
|
@ -1540,7 +1540,7 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
//If L = max(in)
|
//If L = max(in)
|
||||||
//then dL/dIn = 1 if in_i == max(in) or 0 otherwise
|
//then dL/dIn = 1 if in_i == max(in) or 0 otherwise
|
||||||
|
|
||||||
INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.dup())).castTo(DataType.DOUBLE);
|
INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.dup()))[0].castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
assertEquals(exp, dLdIn);
|
assertEquals(exp, dLdIn);
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,6 +72,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
|
||||||
|
@ -261,7 +262,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
public void testIsMaxVectorCase() {
|
public void testIsMaxVectorCase() {
|
||||||
INDArray arr = Nd4j.create(new double[] {1, 2, 4, 3}, new long[] {2, 2});
|
INDArray arr = Nd4j.create(new double[] {1, 2, 4, 3}, new long[] {2, 2});
|
||||||
INDArray assertion = Nd4j.create(new boolean[] {false, false, true, false}, new long[] {2, 2}, DataType.BOOL);
|
INDArray assertion = Nd4j.create(new boolean[] {false, false, true, false}, new long[] {2, 2}, DataType.BOOL);
|
||||||
INDArray test = Nd4j.getExecutioner().exec(new IsMax(arr));
|
INDArray test = Nd4j.getExecutioner().exec(new IsMax(arr))[0];
|
||||||
assertEquals(assertion, test);
|
assertEquals(assertion, test);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -719,7 +720,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
//Tests: full buffer...
|
//Tests: full buffer...
|
||||||
//1d
|
//1d
|
||||||
INDArray arr1 = Nd4j.create(new double[] {1, 2, 3, 1});
|
INDArray arr1 = Nd4j.create(new double[] {1, 2, 3, 1});
|
||||||
val res1 = Nd4j.getExecutioner().exec(new IsMax(arr1));
|
val res1 = Nd4j.getExecutioner().exec(new IsMax(arr1))[0];
|
||||||
INDArray exp1 = Nd4j.create(new boolean[] {false, false, true, false});
|
INDArray exp1 = Nd4j.create(new boolean[] {false, false, true, false});
|
||||||
|
|
||||||
assertEquals(exp1, res1);
|
assertEquals(exp1, res1);
|
||||||
|
@ -736,8 +737,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
INDArray exp2d = Nd4j.create(new boolean[][] {{false, false, false}, {false, true, false}});
|
INDArray exp2d = Nd4j.create(new boolean[][] {{false, false, false}, {false, true, false}});
|
||||||
|
|
||||||
INDArray f = arr2d.dup('f');
|
INDArray f = arr2d.dup('f');
|
||||||
INDArray out2dc = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('c')));
|
INDArray out2dc = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('c')))[0];
|
||||||
INDArray out2df = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('f')));
|
INDArray out2df = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('f')))[0];
|
||||||
assertEquals(exp2d, out2dc);
|
assertEquals(exp2d, out2dc);
|
||||||
assertEquals(exp2d, out2df);
|
assertEquals(exp2d, out2df);
|
||||||
}
|
}
|
||||||
|
@ -803,16 +804,48 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testIsMaxEqualValues_2() {
|
public void testIsMaxEqualValues_2() {
|
||||||
//[0 2] [0 1]
|
//[0 2] [0 1]
|
||||||
//[2 1] -> [0 0]
|
//[2 1] -> [0 0]bg
|
||||||
INDArray orig = Nd4j.create(new double[][] {{0, 2}, {2, 1}});
|
INDArray orig = Nd4j.create(new double[][] {{0, 3}, {2, 1}});
|
||||||
INDArray exp = Nd4j.create(new double[][] {{0, 1}, {0, 0}});
|
INDArray exp = Nd4j.create(new double[][] {{0, 1}, {0, 0}});
|
||||||
INDArray outc = Transforms.isMax(orig.dup('c'));
|
INDArray outc = Transforms.isMax(orig.dup('c'));
|
||||||
assertEquals(exp, outc);
|
assertEquals(exp, outc);
|
||||||
|
|
||||||
INDArray outf = Transforms.isMax(orig.dup('f'));
|
log.info("Orig: {}", orig.dup('f').data().asFloat());
|
||||||
|
|
||||||
|
INDArray outf = Transforms.isMax(orig.dup('f'), orig.dup('f').ulike());
|
||||||
|
log.info("OutF: {}", outf.data().asFloat());
|
||||||
assertEquals(exp, outf);
|
assertEquals(exp, outf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIsMaxEqualValues_3() {
|
||||||
|
//[0 2] [0 1]
|
||||||
|
//[2 1] -> [0 0]
|
||||||
|
INDArray orig = Nd4j.create(new double[][] {{0, 2}, {3, 1}});
|
||||||
|
INDArray exp = Nd4j.create(new double[][] {{0, 0}, {1, 0}});
|
||||||
|
INDArray outc = Transforms.isMax(orig.dup('c'));
|
||||||
|
assertEquals(exp, outc);
|
||||||
|
|
||||||
|
INDArray outf = Transforms.isMax(orig.dup('f'), orig.dup('f').ulike());
|
||||||
|
assertEquals(exp, outf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSqrt_1() {
|
||||||
|
val x = Nd4j.createFromArray(9.0, 9.0, 9.0, 9.0);
|
||||||
|
val x2 = Nd4j.createFromArray(9.0, 9.0, 9.0, 9.0);
|
||||||
|
val e = Nd4j.createFromArray(3.0, 3.0, 3.0, 3.0);
|
||||||
|
|
||||||
|
val z1 = Transforms.sqrt(x, true);
|
||||||
|
val z2 = Transforms.sqrt(x2, false);
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(e, z2);
|
||||||
|
assertEquals(e, x2);
|
||||||
|
assertEquals(e, z1);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAssign_CF() {
|
public void testAssign_CF() {
|
||||||
val orig = Nd4j.create(new double[][] {{0, 2}, {2, 1}});
|
val orig = Nd4j.create(new double[][] {{0, 2}, {2, 1}});
|
||||||
|
@ -828,8 +861,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
//1d: row vector
|
//1d: row vector
|
||||||
INDArray orig = Nd4j.create(new double[] {1, 2, 3, 1}).reshape(1,4 );
|
INDArray orig = Nd4j.create(new double[] {1, 2, 3, 1}).reshape(1,4 );
|
||||||
|
|
||||||
INDArray alongDim0 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 0));
|
INDArray alongDim0 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 0))[0];
|
||||||
INDArray alongDim1 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 1));
|
INDArray alongDim1 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 1))[0];
|
||||||
|
|
||||||
INDArray expAlong0 = Nd4j.create(new boolean[]{true, true, true, true}).reshape(1,4);
|
INDArray expAlong0 = Nd4j.create(new boolean[]{true, true, true, true}).reshape(1,4);
|
||||||
INDArray expAlong1 = Nd4j.create(new boolean[] {false, false, true, false}).reshape(1,4);
|
INDArray expAlong1 = Nd4j.create(new boolean[] {false, false, true, false}).reshape(1,4);
|
||||||
|
@ -841,8 +874,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
//1d: col vector
|
//1d: col vector
|
||||||
System.out.println("----------------------------------");
|
System.out.println("----------------------------------");
|
||||||
INDArray col = Nd4j.create(new double[] {1, 2, 3, 1}, new long[] {4, 1});
|
INDArray col = Nd4j.create(new double[] {1, 2, 3, 1}, new long[] {4, 1});
|
||||||
INDArray alongDim0col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()), 0));
|
INDArray alongDim0col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()), 0))[0];
|
||||||
INDArray alongDim1col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()),1));
|
INDArray alongDim1col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()),1))[0];
|
||||||
|
|
||||||
INDArray expAlong0col = Nd4j.create(new boolean[] {false, false, true, false}).reshape(4,1);
|
INDArray expAlong0col = Nd4j.create(new boolean[] {false, false, true, false}).reshape(4,1);
|
||||||
INDArray expAlong1col = Nd4j.create(new boolean[] {true, true, true, true}).reshape(4,1);
|
INDArray expAlong1col = Nd4j.create(new boolean[] {true, true, true, true}).reshape(4,1);
|
||||||
|
@ -877,10 +910,10 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
//[0 1 0]
|
//[0 1 0]
|
||||||
System.out.println("---------------------");
|
System.out.println("---------------------");
|
||||||
INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}});
|
INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}});
|
||||||
INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0));
|
INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0))[0];
|
||||||
INDArray alongDim0f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 0));
|
INDArray alongDim0f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 0))[0];
|
||||||
INDArray alongDim1c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 1));
|
INDArray alongDim1c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 1))[0];
|
||||||
INDArray alongDim1f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 1));
|
INDArray alongDim1f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 1))[0];
|
||||||
|
|
||||||
INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}});
|
INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}});
|
||||||
INDArray expAlong1_2d = Nd4j.create(new boolean[][] {{false, false, true}, {false, true, false}});
|
INDArray expAlong1_2d = Nd4j.create(new boolean[][] {{false, false, true}, {false, true, false}});
|
||||||
|
@ -904,7 +937,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testIsMaxSingleDim1() {
|
public void testIsMaxSingleDim1() {
|
||||||
INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}});
|
INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}});
|
||||||
INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0));
|
INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0))[0];
|
||||||
INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}});
|
INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}});
|
||||||
|
|
||||||
System.out.println("Original shapeInfo: " + orig2d.dup('c').shapeInfoDataBuffer());
|
System.out.println("Original shapeInfo: " + orig2d.dup('c').shapeInfoDataBuffer());
|
||||||
|
@ -1056,8 +1089,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
+ Arrays.toString(shape) + ")");
|
+ Arrays.toString(shape) + ")");
|
||||||
INDArray arrC = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape);
|
INDArray arrC = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape);
|
||||||
INDArray arrF = arrC.dup('f');
|
INDArray arrF = arrC.dup('f');
|
||||||
val resC = Nd4j.getExecutioner().exec(new IsMax(arrC, alongDimension));
|
val resC = Nd4j.getExecutioner().exec(new IsMax(arrC, alongDimension))[0];
|
||||||
val resF = Nd4j.getExecutioner().exec(new IsMax(arrF, alongDimension));
|
val resF = Nd4j.getExecutioner().exec(new IsMax(arrF, alongDimension))[0];
|
||||||
|
|
||||||
|
|
||||||
double[] cBuffer = resC.data().asDouble();
|
double[] cBuffer = resC.data().asDouble();
|
||||||
|
@ -3932,7 +3965,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
v.assign(t);
|
v.assign(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
val result = Nd4j.getExecutioner().exec(new IsMax(arr, Nd4j.createUninitialized(DataType.BOOL, arr.shape(), arr.ordering()), 1, 2));
|
val result = Nd4j.getExecutioner().exec(new IsMax(arr, Nd4j.createUninitialized(DataType.BOOL, arr.shape(), arr.ordering()), 1, 2))[0];
|
||||||
|
|
||||||
assertEquals(expected, result);
|
assertEquals(expected, result);
|
||||||
}
|
}
|
||||||
|
@ -3971,8 +4004,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), Nd4j.createUninitialized(DataType.BOOL, arr.shape()),0, 1));
|
INDArray actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), Nd4j.createUninitialized(DataType.BOOL, arr.shape()),0, 1))[0];
|
||||||
INDArray actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), Nd4j.createUninitialized(DataType.BOOL, arr.shape(), 'f'), 0, 1));
|
INDArray actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), Nd4j.createUninitialized(DataType.BOOL, arr.shape(), 'f'), 0, 1))[0];
|
||||||
|
|
||||||
assertEquals(exp, actC);
|
assertEquals(exp, actC);
|
||||||
assertEquals(exp, actF);
|
assertEquals(exp, actF);
|
||||||
|
@ -4006,8 +4039,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), 2, 3));
|
actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), arr.dup('c').ulike(), 2, 3))[0];
|
||||||
actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), 2, 3));
|
actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), arr.dup('f').ulike(), 2, 3))[0];
|
||||||
|
|
||||||
assertEquals(exp, actC);
|
assertEquals(exp, actC);
|
||||||
assertEquals(exp, actF);
|
assertEquals(exp, actF);
|
||||||
|
@ -6527,7 +6560,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertTrue(x.sumNumber().floatValue() > 0);
|
assertTrue(x.sumNumber().floatValue() > 0);
|
||||||
|
|
||||||
x = Nd4j.randn(DataType.BFLOAT16 , 10);
|
x = Nd4j.randn(DataType.BFLOAT16 , 10);
|
||||||
assertTrue(x.sumNumber().floatValue() > 0);
|
assertTrue(x.sumNumber().floatValue() != 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -7962,7 +7995,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
public void testBatchToSpace(){
|
public void testBatchToSpace(){
|
||||||
|
|
||||||
INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5);
|
INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5);
|
||||||
DynamicCustomOp c = new BatchToSpace();
|
DynamicCustomOp c = new BatchToSpaceND();
|
||||||
|
|
||||||
c.addInputArgument(
|
c.addInputArgument(
|
||||||
Nd4j.rand(DataType.FLOAT, new int[]{4, 4, 3}),
|
Nd4j.rand(DataType.FLOAT, new int[]{4, 4, 3}),
|
||||||
|
|
|
@ -106,115 +106,6 @@ public class CudaTests extends BaseNd4jTest {
|
||||||
assertEquals(exp, arrayA);
|
assertEquals(exp, arrayA);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 40000L)
|
|
||||||
public void testContextSpam() throws Exception {
|
|
||||||
if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA)
|
|
||||||
return;
|
|
||||||
|
|
||||||
val success = new AtomicInteger(0);
|
|
||||||
val iterations = 101;
|
|
||||||
|
|
||||||
val threads = new ArrayList<Thread>();
|
|
||||||
for (int e = 0; e < iterations; e++) {
|
|
||||||
val f = e;
|
|
||||||
val t = new Thread(new Runnable() {
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
Nd4j.create(1);
|
|
||||||
if (f % 50 == 0)
|
|
||||||
log.info("Context {} created", f);
|
|
||||||
|
|
||||||
Nd4j.getMemoryManager().releaseCurrentContext();
|
|
||||||
success.incrementAndGet();
|
|
||||||
try {
|
|
||||||
Thread.sleep(1000L);
|
|
||||||
} catch (InterruptedException ex) {
|
|
||||||
ex.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
t.start();
|
|
||||||
threads.add(t);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (val t: threads)
|
|
||||||
t.join();
|
|
||||||
|
|
||||||
assertEquals(iterations, success.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test(timeout = 180000L)
|
|
||||||
public void testContextSpam_2() throws Exception {
|
|
||||||
if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA)
|
|
||||||
return;
|
|
||||||
|
|
||||||
val success = new AtomicInteger(0);
|
|
||||||
val iterations = 101;
|
|
||||||
|
|
||||||
val threads = new ArrayList<Thread>();
|
|
||||||
for (int e = 0; e < iterations; e++) {
|
|
||||||
val f = e;
|
|
||||||
val t = new Thread(new Runnable() {
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
Nd4j.create(1);
|
|
||||||
if (f % 50 == 0)
|
|
||||||
log.info("Context {} created", f);
|
|
||||||
|
|
||||||
//Nd4j.getMemoryManager().releaseCurrentContext();
|
|
||||||
success.incrementAndGet();
|
|
||||||
try {
|
|
||||||
Thread.sleep(1000L);
|
|
||||||
} catch (InterruptedException ex) {
|
|
||||||
ex.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
t.start();
|
|
||||||
threads.add(t);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (val t: threads)
|
|
||||||
t.join();
|
|
||||||
|
|
||||||
assertEquals(iterations, success.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSequentialReleaseAndReacquire() throws Exception {
|
|
||||||
if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA)
|
|
||||||
return;
|
|
||||||
|
|
||||||
Nd4j.create(128);
|
|
||||||
|
|
||||||
Nd4j.getMemoryManager().releaseCurrentContext();
|
|
||||||
|
|
||||||
val array = Nd4j.create(128);
|
|
||||||
array.addi(1.0f);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@Ignore
|
|
||||||
public void test(){
|
|
||||||
if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA)
|
|
||||||
return;
|
|
||||||
|
|
||||||
val SD = SameDiff.create();
|
|
||||||
val in = SD.one("test", 5, 8, 3, 4);
|
|
||||||
SDVariable out = in.reshape(-1, 4);
|
|
||||||
SDVariable out1 = out.reshape(4, 15, -1);
|
|
||||||
SDVariable out2 = SD.dot(out1, out1, 2);
|
|
||||||
|
|
||||||
SDVariable out3 = out2.reshape(-1, 4); // <---- error here
|
|
||||||
|
|
||||||
System.out.println(Arrays.toString(out3.eval().toFloatMatrix()));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
|
@ -27,6 +27,18 @@
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
<name>nd4j-parameter-server-node</name>
|
<name>nd4j-parameter-server-node</name>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-compiler-plugin</artifactId>
|
||||||
|
<configuration>
|
||||||
|
<source>8</source>
|
||||||
|
<target>8</target>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|
Loading…
Reference in New Issue