From bb5fc36e5e4d446d1f60fa949c2780e9cba7f75d Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 26 Aug 2019 19:37:05 +0300 Subject: [PATCH] [WIP] ops fixes (#168) * - correct layer_norm Signed-off-by: Yurii * - further fix of layer norm Signed-off-by: Yurii * - correct scatter_upd op Signed-off-by: Yurii * - correct cuda kernel for histogram_fixed_width op Signed-off-by: Yurii * - delete comments Signed-off-by: Yurii * enabled one ignored test Signed-off-by: raver119 --- .../generic/transforms/layer_norm.cpp | 28 +- .../helpers/cpu/histogramFixedWidth.cpp | 9 +- .../ops/declarable/helpers/cpu/scatter.cpp | 2 + .../helpers/cuda/histogramFixedWidth.cu | 245 +++++++++++------- .../ops/declarable/helpers/cuda/scatter.cu | 11 +- .../layers_tests/DeclarableOpsTests10.cpp | 26 +- .../layers_tests/DeclarableOpsTests15.cpp | 4 +- .../layers_tests/DeclarableOpsTests16.cpp | 23 +- .../opvalidation/LayerOpValidation.java | 1 - 9 files changed, 240 insertions(+), 109 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp index 4e612565e..684d98d6d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp @@ -32,9 +32,12 @@ namespace ops { auto input = INPUT_VARIABLE(0); auto gain = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); - + std::vector axis = *block.getIArguments(); + const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const int dimC = isNCHW ? 1 : input->rankOf() - 1; + NDArray* bias = nullptr; if (block.width() > 2) bias = INPUT_VARIABLE(2); @@ -48,9 +51,12 @@ namespace ops { std::vector bargs = {}; standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); - output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, output); - if(bias != nullptr) - output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output); + // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, output); + output->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain); + if(bias != nullptr) { + // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output); + output->applyBroadcast(nd4j::broadcast::Add, {dimC}, bias); + } return Status::OK(); } @@ -71,12 +77,17 @@ namespace ops { auto dLdg = OUTPUT_VARIABLE(1); auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr; + const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const int dimC = isNCHW ? 1 : input->rankOf() - 1; + std::vector axis = *block.getIArguments(); std::vector longAxis = ArrayUtils::toLongVector(axis); - if(bias != nullptr) - eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true); + if(bias != nullptr) { + // eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true); + eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true); + } NDArray standardized(input->shapeInfo(), false, block.launchContext()); @@ -88,10 +99,11 @@ namespace ops { standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr); - standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, {0}, true); + standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true); nd4j::ops::standardize_bp standardizeBp; - eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx); + // eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx); + eps->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain, dLdx); auto dLdx_tmp = dLdx->dup(); std::vector standardizeBpArgs = {input, dLdx_tmp}; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp index a3f3e2f9e..349d0381a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp @@ -28,13 +28,10 @@ namespace helpers { template void histogramFixedWidth_(const NDArray& input, const NDArray& range, NDArray& output) { - const int nbins = output.lengthOf(); + const int nbins = output.lengthOf(); - // firstly initialize output with zeros - if(output.ews() == 1) - memset(output.buffer(), 0, nbins * output.sizeOfT()); - else - output = 0; + // firstly initialize output with zeros + output.nullify(); const T leftEdge = range.e(0); const T rightEdge = range.e(1); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp index c1d01930c..0b16ac989 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp @@ -54,6 +54,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) std::vector dimsToExcludeUpd(sizeOfDims); std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); + shape::printIntArray(dimsToExcludeUpd.data(),dimsToExcludeUpd.size()); + // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug ! PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) for(Nd4jLong i = 0; i < indLen; ++i) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu index 2c46210cf..ebde4909c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu @@ -20,110 +20,181 @@ #include #include +#include -namespace nd4j { -namespace ops { +namespace nd4j { +namespace ops { namespace helpers { - template - __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong* sourceShape, Nd4jLong bufferLength) { - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - destination[t] = reinterpret_cast(source)[shape::getIndexOffset(t, sourceShape, bufferLength)]; - } +/////////////////////////////////////////////////////////////////// +template +__global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const T leftEdge, const T rightEdge) { + + const T* x = reinterpret_cast(vx); + Nd4jLong* z = reinterpret_cast(vz); + + __shared__ Nd4jLong xLen, zLen, totalThreads, nbins; + __shared__ T binWidth, secondEdge, lastButOneEdge; + + if (threadIdx.x == 0) { + + xLen = shape::length(xShapeInfo); + nbins = shape::length(zShapeInfo); // nbins = zLen + totalThreads = gridDim.x * blockDim.x; + + binWidth = (rightEdge - leftEdge ) / nbins; + secondEdge = leftEdge + binWidth; + lastButOneEdge = rightEdge - binWidth; } - template - __global__ static void returnBuffers(void* destination, Nd4jLong const* source, Nd4jLong* destinationShape, Nd4jLong bufferLength) { - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - reinterpret_cast(destination)[shape::getIndexOffset(t, destinationShape, bufferLength)] = source[t]; - } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + + const T value = x[shape::getIndexOffset(i, xShapeInfo, xLen)]; + + Nd4jLong zIndex; + + if(value < secondEdge) + zIndex = 0; + else if(value >= lastButOneEdge) + zIndex = nbins - 1; + else + zIndex = static_cast((value - leftEdge) / binWidth); + + nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getIndexOffset(zIndex, zShapeInfo, nbins)], 1LL); } +} - template - static __global__ void histogramFixedWidthKernel(void* outputBuffer, Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, Nd4jLong inputLength, double const leftEdge, double binWidth, double secondEdge, double lastButOneEdge) { +/////////////////////////////////////////////////////////////////// +template +__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t *stream, const NDArray& input, const NDArray& range, NDArray& output) { - __shared__ T const* x; - __shared__ Nd4jLong* z; // output buffer + const T leftEdge = range.e(0); + const T rightEdge = range.e(1); - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(inputBuffer); - } - __syncthreads(); - auto tid = blockIdx.x * gridDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; + histogramFixedWidthCuda<<<512, MAX_NUM_THREADS / 2, 512, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge); +} - for(auto i = tid; i < inputLength; i += step) { +//////////////////////////////////////////////////////////////////////// +void histogramFixedWidth(nd4j::LaunchContext* context, const NDArray& input, const NDArray& range, NDArray& output) { - const T value = x[shape::getIndexOffset(i, inputShape, inputLength)]; - Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); + // firstly initialize output with zeros + output.nullify(); - if(value < secondEdge) - currInd = 0; - else if(value >= lastButOneEdge) - currInd = outputLength - 1; - nd4j::math::atomics::nd4j_atomicAdd(&z[currInd], 1LL); - } - } + PointersManager manager(context, "histogramFixedWidth"); + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); +} - template - void histogramFixedWidth_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { - const int nbins = output.lengthOf(); - auto stream = context->getCudaStream(); - // firstly initialize output with zeros - //if(output.ews() == 1) - // memset(output.buffer(), 0, nbins * output.sizeOfT()); - //else - output.assign(0); - if (!input.isActualOnDeviceSide()) - input.syncToDevice(); +// template +// __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong* sourceShape, Nd4jLong bufferLength) { +// const auto tid = blockIdx.x * gridDim.x + threadIdx.x; +// const auto step = gridDim.x * blockDim.x; +// for (int t = tid; t < bufferLength; t += step) { +// destination[t] = reinterpret_cast(source)[shape::getIndexOffset(t, sourceShape, bufferLength)]; +// } +// } - const double leftEdge = range.e(0); - const double rightEdge = range.e(1); +// template +// __global__ static void returnBuffers(void* destination, Nd4jLong const* source, Nd4jLong* destinationShape, Nd4jLong bufferLength) { +// const auto tid = blockIdx.x * gridDim.x + threadIdx.x; +// const auto step = gridDim.x * blockDim.x; +// for (int t = tid; t < bufferLength; t += step) { +// reinterpret_cast(destination)[shape::getIndexOffset(t, destinationShape, bufferLength)] = source[t]; +// } +// } - const double binWidth = (rightEdge - leftEdge ) / nbins; - const double secondEdge = leftEdge + binWidth; - double lastButOneEdge = rightEdge - binWidth; - Nd4jLong* outputBuffer; - cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * sizeof(Nd4jLong)); - if (err != 0) - throw cuda_exception::build("helpers::histogramFixedWidth: Cannot allocate memory for output", err); - copyBuffers<<<256, 512, 8192, *stream>>>(outputBuffer, output.getSpecialBuffer(), output.getSpecialShapeInfo(), output.lengthOf()); - histogramFixedWidthKernel<<<256, 512, 8192, *stream>>>(outputBuffer, output.lengthOf(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), input.lengthOf(), leftEdge, binWidth, secondEdge, lastButOneEdge); - returnBuffers<<<256, 512, 8192, *stream>>>(output.specialBuffer(), outputBuffer, output.specialShapeInfo(), output.lengthOf()); - //cudaSyncStream(*stream); - err = cudaFree(outputBuffer); - if (err != 0) - throw cuda_exception::build("helpers::histogramFixedWidth: Cannot deallocate memory for output buffer", err); - output.tickWriteDevice(); -//#pragma omp parallel for schedule(guided) -// for(Nd4jLong i = 0; i < input.lengthOf(); ++i) { -// -// const T value = input.e(i); -// -// if(value < secondEdge) -//#pragma omp critical -// output.p(0, output.e(0) + 1); -// else if(value >= lastButOneEdge) -//#pragma omp critical -// output.p(nbins-1, output.e(nbins-1) + 1); -// else { -// Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); -//#pragma omp critical -// output.p(currInd, output.e(currInd) + 1); -// } -// } - } +// template +// static __global__ void histogramFixedWidthKernel(void* outputBuffer, Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, Nd4jLong inputLength, double const leftEdge, double binWidth, double secondEdge, double lastButOneEdge) { - void histogramFixedWidth(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (context, input, range, output), LIBND4J_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES); +// __shared__ T const* x; +// __shared__ Nd4jLong* z; // output buffer + +// if (threadIdx.x == 0) { +// z = reinterpret_cast(outputBuffer); +// x = reinterpret_cast(inputBuffer); +// } +// __syncthreads(); +// auto tid = blockIdx.x * gridDim.x + threadIdx.x; +// auto step = blockDim.x * gridDim.x; + +// for(auto i = tid; i < inputLength; i += step) { + +// const T value = x[shape::getIndexOffset(i, inputShape, inputLength)]; +// Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); + +// if(value < secondEdge) +// currInd = 0; +// else if(value >= lastButOneEdge) +// currInd = outputLength - 1; +// nd4j::math::atomics::nd4j_atomicAdd(&z[currInd], 1LL); +// } +// } + + +// template +// void histogramFixedWidth_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { +// const int nbins = output.lengthOf(); +// auto stream = context->getCudaStream(); +// // firstly initialize output with zeros +// //if(output.ews() == 1) +// // memset(output.buffer(), 0, nbins * output.sizeOfT()); +// //else +// output.assign(0); +// if (!input.isActualOnDeviceSide()) +// input.syncToDevice(); + +// const double leftEdge = range.e(0); +// const double rightEdge = range.e(1); + +// const double binWidth = (rightEdge - leftEdge ) / nbins; +// const double secondEdge = leftEdge + binWidth; +// double lastButOneEdge = rightEdge - binWidth; +// Nd4jLong* outputBuffer; +// cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * sizeof(Nd4jLong)); +// if (err != 0) +// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot allocate memory for output", err); +// copyBuffers<<<256, 512, 8192, *stream>>>(outputBuffer, output.getSpecialBuffer(), output.getSpecialShapeInfo(), output.lengthOf()); +// histogramFixedWidthKernel<<<256, 512, 8192, *stream>>>(outputBuffer, output.lengthOf(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), input.lengthOf(), leftEdge, binWidth, secondEdge, lastButOneEdge); +// returnBuffers<<<256, 512, 8192, *stream>>>(output.specialBuffer(), outputBuffer, output.specialShapeInfo(), output.lengthOf()); +// //cudaSyncStream(*stream); +// err = cudaFree(outputBuffer); +// if (err != 0) +// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot deallocate memory for output buffer", err); +// output.tickWriteDevice(); +// //#pragma omp parallel for schedule(guided) +// // for(Nd4jLong i = 0; i < input.lengthOf(); ++i) { +// // +// // const T value = input.e(i); +// // +// // if(value < secondEdge) +// //#pragma omp critical +// // output.p(0, output.e(0) + 1); +// // else if(value >= lastButOneEdge) +// //#pragma omp critical +// // output.p(nbins-1, output.e(nbins-1) + 1); +// // else { +// // Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); +// //#pragma omp critical +// // output.p(currInd, output.e(currInd) + 1); +// // } +// // } +// } + +// void histogramFixedWidth(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { +// BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (context, input, range, output), LIBND4J_TYPES); +// } +// BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu index ec0d304df..54d350f47 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu @@ -398,10 +398,15 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind const int xRank = indices.rankOf(); std::vector zTadDims = ShapeUtils::evalDimsToExclude(output.rankOf(), {0}); - std::vector yTadDims(xRank); - std::iota(yTadDims.begin(), yTadDims.end(), xRank == 1 ? 0 : xRank); - auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), yTadDims); + int sizeOfUpdDims = xRank; + if(output.rankOf() == updates.rankOf() && indices.isVector()) + sizeOfUpdDims = 1; + + std::vector yTadDims(sizeOfUpdDims); + std::iota(yTadDims.begin(), yTadDims.end(), 0); + + auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), ShapeUtils::evalDimsToExclude(updates.rankOf(), yTadDims)); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), zTadDims); const Nd4jLong zTadLen = shape::length(packZ.primaryShapeInfo()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 1d4bf7338..82ed21709 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -910,7 +910,31 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) { auto *out = results->at(0); ASSERT_TRUE(exp.isSameShape(out)); - out->printBuffer("5HIST"); + // out->printBuffer("5HIST"); + ASSERT_TRUE(exp.equalsTo(out)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, histogram_fixed_width_test6) { + + auto input = NDArrayFactory::create('c', {7},{0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9}); + auto range = NDArrayFactory::create('c', {2}, {0, 1}); + auto bins = NDArrayFactory::create(5); + + auto exp = NDArrayFactory::create('c', {5}, {3, 1, 2, 0, 1}); + + nd4j::ops::histogram_fixed_width op; + auto results = op.execute({&input, &range, &bins}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto out = results->at(0); + // out->printShapeInfo(); + // out->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index df1421d71..21a0381e9 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -249,7 +249,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { auto b = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); nd4j::ops::layer_norm op; - auto result = op.execute({&x, &g, &b}, {}, {0}, {}); + auto result = op.execute({&x, &g, &b}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result->status()); delete result; } @@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { auto eps = NDArrayFactory::create('c', {1, 5}, {0., 0., 0., 0., 0.}); nd4j::ops::layer_norm_bp op; - auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {}); + auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result->status()); delete result; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index a23d5421e..992b21c0f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -39,7 +39,7 @@ public: } }; -TEST_F(DeclarableOpsTests16, test_scatter_update_119) { +TEST_F(DeclarableOpsTests16, scatter_upd_1) { auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); auto y = NDArrayFactory::create(0); auto w = NDArrayFactory::create(3.0f); @@ -56,6 +56,27 @@ TEST_F(DeclarableOpsTests16, test_scatter_update_119) { delete result; } +TEST_F(DeclarableOpsTests16, scatter_upd_2) { + + NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32); + NDArray indices('c', {2}, {2,5}, nd4j::DataType::INT32); + NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32); + NDArray e('c', {10, 3}, {1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30}, nd4j::DataType::FLOAT32); + + x.linspace(1); + + nd4j::ops::scatter_upd op; + auto result = op.execute({&x, &indices, &updates}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + + TEST_F(DeclarableOpsTests16, test_size_dtype_1) { auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); auto z = NDArrayFactory::create(0.0f); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 9437ad7b2..84bd96ad6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -1297,7 +1297,6 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - @Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testLayerNormMixedOrders(){ Nd4j.getRandom().setSeed(12345); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');