[WIP] ops fixes (#168)

* - correct layer_norm

Signed-off-by: Yurii <yurii@skymind.io>

* - further fix of layer norm

Signed-off-by: Yurii <yurii@skymind.io>

* - correct scatter_upd op

Signed-off-by: Yurii <yurii@skymind.io>

* - correct cuda kernel for histogram_fixed_width op

Signed-off-by: Yurii <yurii@skymind.io>

* - delete comments

Signed-off-by: Yurii <yurii@skymind.io>

* enabled one ignored test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-26 19:37:05 +03:00 committed by GitHub
parent b417ca21bf
commit bb5fc36e5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 240 additions and 109 deletions

View File

@ -35,6 +35,9 @@ namespace ops {
std::vector<int> axis = *block.getIArguments(); std::vector<int> axis = *block.getIArguments();
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC
const int dimC = isNCHW ? 1 : input->rankOf() - 1;
NDArray* bias = nullptr; NDArray* bias = nullptr;
if (block.width() > 2) if (block.width() > 2)
bias = INPUT_VARIABLE(2); bias = INPUT_VARIABLE(2);
@ -48,9 +51,12 @@ namespace ops {
std::vector<bool> bargs = {}; std::vector<bool> bargs = {};
standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); standardizeOp.execute(inputs, outputs, targs, longAxis, bargs);
output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, output); // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, output);
if(bias != nullptr) output->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain);
output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output); if(bias != nullptr) {
// output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output);
output->applyBroadcast(nd4j::broadcast::Add, {dimC}, bias);
}
return Status::OK(); return Status::OK();
} }
@ -71,12 +77,17 @@ namespace ops {
auto dLdg = OUTPUT_VARIABLE(1); auto dLdg = OUTPUT_VARIABLE(1);
auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr; auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr;
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC
const int dimC = isNCHW ? 1 : input->rankOf() - 1;
std::vector<int> axis = *block.getIArguments(); std::vector<int> axis = *block.getIArguments();
std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis); std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis);
if(bias != nullptr) if(bias != nullptr) {
eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true); // eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true);
eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true);
}
NDArray standardized(input->shapeInfo(), false, block.launchContext()); NDArray standardized(input->shapeInfo(), false, block.launchContext());
@ -88,10 +99,11 @@ namespace ops {
standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); standardizeOp.execute(inputs, outputs, targs, longAxis, bargs);
standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr); standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr);
standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, {0}, true); standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true);
nd4j::ops::standardize_bp standardizeBp; nd4j::ops::standardize_bp standardizeBp;
eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx); // eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx);
eps->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain, dLdx);
auto dLdx_tmp = dLdx->dup(); auto dLdx_tmp = dLdx->dup();
std::vector<NDArray *> standardizeBpArgs = {input, dLdx_tmp}; std::vector<NDArray *> standardizeBpArgs = {input, dLdx_tmp};

View File

@ -31,10 +31,7 @@ void histogramFixedWidth_(const NDArray& input, const NDArray& range, NDArray& o
const int nbins = output.lengthOf(); const int nbins = output.lengthOf();
// firstly initialize output with zeros // firstly initialize output with zeros
if(output.ews() == 1) output.nullify();
memset(output.buffer(), 0, nbins * output.sizeOfT());
else
output = 0;
const T leftEdge = range.e<double>(0); const T leftEdge = range.e<double>(0);
const T rightEdge = range.e<double>(1); const T rightEdge = range.e<double>(1);

View File

@ -54,6 +54,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
std::vector<int> dimsToExcludeUpd(sizeOfDims); std::vector<int> dimsToExcludeUpd(sizeOfDims);
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
shape::printIntArray(dimsToExcludeUpd.data(),dimsToExcludeUpd.size());
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug ! // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) { for(Nd4jLong i = 0; i < indLen; ++i) {

View File

@ -20,110 +20,181 @@
#include <ops/declarable/helpers/histogramFixedWidth.h> #include <ops/declarable/helpers/histogramFixedWidth.h>
#include <cuda_exception.h> #include <cuda_exception.h>
#include <PointersManager.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> ///////////////////////////////////////////////////////////////////
__global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong* sourceShape, Nd4jLong bufferLength) { template<typename T>
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; __global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong* xShapeInfo,
const auto step = gridDim.x * blockDim.x; void* vz, const Nd4jLong* zShapeInfo,
for (int t = tid; t < bufferLength; t += step) { const T leftEdge, const T rightEdge) {
destination[t] = reinterpret_cast<T const*>(source)[shape::getIndexOffset(t, sourceShape, bufferLength)];
}
}
template <typename T> const T* x = reinterpret_cast<const T*>(vx);
__global__ static void returnBuffers(void* destination, Nd4jLong const* source, Nd4jLong* destinationShape, Nd4jLong bufferLength) { Nd4jLong* z = reinterpret_cast<Nd4jLong*>(vz);
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x;
for (int t = tid; t < bufferLength; t += step) {
reinterpret_cast<T*>(destination)[shape::getIndexOffset(t, destinationShape, bufferLength)] = source[t];
}
}
template <typename T> __shared__ Nd4jLong xLen, zLen, totalThreads, nbins;
static __global__ void histogramFixedWidthKernel(void* outputBuffer, Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, Nd4jLong inputLength, double const leftEdge, double binWidth, double secondEdge, double lastButOneEdge) { __shared__ T binWidth, secondEdge, lastButOneEdge;
__shared__ T const* x;
__shared__ Nd4jLong* z; // output buffer
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
z = reinterpret_cast<Nd4jLong*>(outputBuffer);
x = reinterpret_cast<T const*>(inputBuffer); xLen = shape::length(xShapeInfo);
nbins = shape::length(zShapeInfo); // nbins = zLen
totalThreads = gridDim.x * blockDim.x;
binWidth = (rightEdge - leftEdge ) / nbins;
secondEdge = leftEdge + binWidth;
lastButOneEdge = rightEdge - binWidth;
} }
__syncthreads(); __syncthreads();
auto tid = blockIdx.x * gridDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for(auto i = tid; i < inputLength; i += step) { const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const T value = x[shape::getIndexOffset(i, inputShape, inputLength)]; for (Nd4jLong i = tid; i < xLen; i += totalThreads) {
Nd4jLong currInd = static_cast<Nd4jLong>((value - leftEdge) / binWidth);
const T value = x[shape::getIndexOffset(i, xShapeInfo, xLen)];
Nd4jLong zIndex;
if(value < secondEdge) if(value < secondEdge)
currInd = 0; zIndex = 0;
else if(value >= lastButOneEdge) else if(value >= lastButOneEdge)
currInd = outputLength - 1; zIndex = nbins - 1;
nd4j::math::atomics::nd4j_atomicAdd(&z[currInd], 1LL); else
} zIndex = static_cast<Nd4jLong>((value - leftEdge) / binWidth);
}
nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getIndexOffset(zIndex, zShapeInfo, nbins)], 1LL);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t *stream, const NDArray& input, const NDArray& range, NDArray& output) {
const T leftEdge = range.e<T>(0);
const T rightEdge = range.e<T>(1);
histogramFixedWidthCuda<T><<<512, MAX_NUM_THREADS / 2, 512, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge);
}
////////////////////////////////////////////////////////////////////////
void histogramFixedWidth(nd4j::LaunchContext* context, const NDArray& input, const NDArray& range, NDArray& output) {
template <typename T>
void histogramFixedWidth_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) {
const int nbins = output.lengthOf();
auto stream = context->getCudaStream();
// firstly initialize output with zeros // firstly initialize output with zeros
//if(output.ews() == 1) output.nullify();
// memset(output.buffer(), 0, nbins * output.sizeOfT());
//else
output.assign(0);
if (!input.isActualOnDeviceSide())
input.syncToDevice();
const double leftEdge = range.e<double>(0); PointersManager manager(context, "histogramFixedWidth");
const double rightEdge = range.e<double>(1);
const double binWidth = (rightEdge - leftEdge ) / nbins; NDArray::prepareSpecialUse({&output}, {&input});
const double secondEdge = leftEdge + binWidth; BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES);
double lastButOneEdge = rightEdge - binWidth; NDArray::registerSpecialUse({&output}, {&input});
Nd4jLong* outputBuffer;
cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * sizeof(Nd4jLong)); manager.synchronize();
if (err != 0) }
throw cuda_exception::build("helpers::histogramFixedWidth: Cannot allocate memory for output", err);
copyBuffers<Nd4jLong ><<<256, 512, 8192, *stream>>>(outputBuffer, output.getSpecialBuffer(), output.getSpecialShapeInfo(), output.lengthOf());
histogramFixedWidthKernel<T><<<256, 512, 8192, *stream>>>(outputBuffer, output.lengthOf(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), input.lengthOf(), leftEdge, binWidth, secondEdge, lastButOneEdge); // template <typename T>
returnBuffers<Nd4jLong><<<256, 512, 8192, *stream>>>(output.specialBuffer(), outputBuffer, output.specialShapeInfo(), output.lengthOf()); // __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong* sourceShape, Nd4jLong bufferLength) {
//cudaSyncStream(*stream); // const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
err = cudaFree(outputBuffer); // const auto step = gridDim.x * blockDim.x;
if (err != 0) // for (int t = tid; t < bufferLength; t += step) {
throw cuda_exception::build("helpers::histogramFixedWidth: Cannot deallocate memory for output buffer", err); // destination[t] = reinterpret_cast<T const*>(source)[shape::getIndexOffset(t, sourceShape, bufferLength)];
output.tickWriteDevice(); // }
//#pragma omp parallel for schedule(guided) // }
// for(Nd4jLong i = 0; i < input.lengthOf(); ++i) {
// // template <typename T>
// const T value = input.e<T>(i); // __global__ static void returnBuffers(void* destination, Nd4jLong const* source, Nd4jLong* destinationShape, Nd4jLong bufferLength) {
// // const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
// if(value < secondEdge) // const auto step = gridDim.x * blockDim.x;
//#pragma omp critical // for (int t = tid; t < bufferLength; t += step) {
// output.p<Nd4jLong>(0, output.e<Nd4jLong>(0) + 1); // reinterpret_cast<T*>(destination)[shape::getIndexOffset(t, destinationShape, bufferLength)] = source[t];
// else if(value >= lastButOneEdge) // }
//#pragma omp critical // }
// output.p<Nd4jLong>(nbins-1, output.e<Nd4jLong>(nbins-1) + 1);
// else { // template <typename T>
// static __global__ void histogramFixedWidthKernel(void* outputBuffer, Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, Nd4jLong inputLength, double const leftEdge, double binWidth, double secondEdge, double lastButOneEdge) {
// __shared__ T const* x;
// __shared__ Nd4jLong* z; // output buffer
// if (threadIdx.x == 0) {
// z = reinterpret_cast<Nd4jLong*>(outputBuffer);
// x = reinterpret_cast<T const*>(inputBuffer);
// }
// __syncthreads();
// auto tid = blockIdx.x * gridDim.x + threadIdx.x;
// auto step = blockDim.x * gridDim.x;
// for(auto i = tid; i < inputLength; i += step) {
// const T value = x[shape::getIndexOffset(i, inputShape, inputLength)];
// Nd4jLong currInd = static_cast<Nd4jLong>((value - leftEdge) / binWidth); // Nd4jLong currInd = static_cast<Nd4jLong>((value - leftEdge) / binWidth);
//#pragma omp critical
// output.p<Nd4jLong>(currInd, output.e<Nd4jLong>(currInd) + 1);
// }
// }
}
void histogramFixedWidth(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { // if(value < secondEdge)
BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (context, input, range, output), LIBND4J_TYPES); // currInd = 0;
} // else if(value >= lastButOneEdge)
BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES); // currInd = outputLength - 1;
// nd4j::math::atomics::nd4j_atomicAdd(&z[currInd], 1LL);
// }
// }
// template <typename T>
// void histogramFixedWidth_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) {
// const int nbins = output.lengthOf();
// auto stream = context->getCudaStream();
// // firstly initialize output with zeros
// //if(output.ews() == 1)
// // memset(output.buffer(), 0, nbins * output.sizeOfT());
// //else
// output.assign(0);
// if (!input.isActualOnDeviceSide())
// input.syncToDevice();
// const double leftEdge = range.e<double>(0);
// const double rightEdge = range.e<double>(1);
// const double binWidth = (rightEdge - leftEdge ) / nbins;
// const double secondEdge = leftEdge + binWidth;
// double lastButOneEdge = rightEdge - binWidth;
// Nd4jLong* outputBuffer;
// cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * sizeof(Nd4jLong));
// if (err != 0)
// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot allocate memory for output", err);
// copyBuffers<Nd4jLong ><<<256, 512, 8192, *stream>>>(outputBuffer, output.getSpecialBuffer(), output.getSpecialShapeInfo(), output.lengthOf());
// histogramFixedWidthKernel<T><<<256, 512, 8192, *stream>>>(outputBuffer, output.lengthOf(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), input.lengthOf(), leftEdge, binWidth, secondEdge, lastButOneEdge);
// returnBuffers<Nd4jLong><<<256, 512, 8192, *stream>>>(output.specialBuffer(), outputBuffer, output.specialShapeInfo(), output.lengthOf());
// //cudaSyncStream(*stream);
// err = cudaFree(outputBuffer);
// if (err != 0)
// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot deallocate memory for output buffer", err);
// output.tickWriteDevice();
// //#pragma omp parallel for schedule(guided)
// // for(Nd4jLong i = 0; i < input.lengthOf(); ++i) {
// //
// // const T value = input.e<T>(i);
// //
// // if(value < secondEdge)
// //#pragma omp critical
// // output.p<Nd4jLong>(0, output.e<Nd4jLong>(0) + 1);
// // else if(value >= lastButOneEdge)
// //#pragma omp critical
// // output.p<Nd4jLong>(nbins-1, output.e<Nd4jLong>(nbins-1) + 1);
// // else {
// // Nd4jLong currInd = static_cast<Nd4jLong>((value - leftEdge) / binWidth);
// //#pragma omp critical
// // output.p<Nd4jLong>(currInd, output.e<Nd4jLong>(currInd) + 1);
// // }
// // }
// }
// void histogramFixedWidth(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) {
// BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (context, input, range, output), LIBND4J_TYPES);
// }
// BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES);
} }
} }

View File

@ -398,10 +398,15 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind
const int xRank = indices.rankOf(); const int xRank = indices.rankOf();
std::vector<int> zTadDims = ShapeUtils::evalDimsToExclude(output.rankOf(), {0}); std::vector<int> zTadDims = ShapeUtils::evalDimsToExclude(output.rankOf(), {0});
std::vector<int> yTadDims(xRank);
std::iota(yTadDims.begin(), yTadDims.end(), xRank == 1 ? 0 : xRank);
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), yTadDims); int sizeOfUpdDims = xRank;
if(output.rankOf() == updates.rankOf() && indices.isVector())
sizeOfUpdDims = 1;
std::vector<int> yTadDims(sizeOfUpdDims);
std::iota(yTadDims.begin(), yTadDims.end(), 0);
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), ShapeUtils::evalDimsToExclude(updates.rankOf(), yTadDims));
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), zTadDims); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), zTadDims);
const Nd4jLong zTadLen = shape::length(packZ.primaryShapeInfo()); const Nd4jLong zTadLen = shape::length(packZ.primaryShapeInfo());

View File

@ -910,7 +910,31 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) {
auto *out = results->at(0); auto *out = results->at(0);
ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.isSameShape(out));
out->printBuffer("5HIST"); // out->printBuffer("5HIST");
ASSERT_TRUE(exp.equalsTo(out));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, histogram_fixed_width_test6) {
auto input = NDArrayFactory::create<double>('c', {7},{0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9});
auto range = NDArrayFactory::create<double>('c', {2}, {0, 1});
auto bins = NDArrayFactory::create<int>(5);
auto exp = NDArrayFactory::create<Nd4jLong>('c', {5}, {3, 1, 2, 0, 1});
nd4j::ops::histogram_fixed_width op;
auto results = op.execute({&input, &range, &bins}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto out = results->at(0);
// out->printShapeInfo();
// out->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(out));
ASSERT_TRUE(exp.equalsTo(out)); ASSERT_TRUE(exp.equalsTo(out));
delete results; delete results;

View File

@ -249,7 +249,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
auto b = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.}); auto b = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
nd4j::ops::layer_norm op; nd4j::ops::layer_norm op;
auto result = op.execute({&x, &g, &b}, {}, {0}, {}); auto result = op.execute({&x, &g, &b}, {}, {0}, {false});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
} }
@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0., 0., 0., 0., 0.}); auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0., 0., 0., 0., 0.});
nd4j::ops::layer_norm_bp op; nd4j::ops::layer_norm_bp op;
auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {}); auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {false});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
delete result; delete result;
} }

View File

@ -39,7 +39,7 @@ public:
} }
}; };
TEST_F(DeclarableOpsTests16, test_scatter_update_119) { TEST_F(DeclarableOpsTests16, scatter_upd_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1}); auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
auto y = NDArrayFactory::create<int>(0); auto y = NDArrayFactory::create<int>(0);
auto w = NDArrayFactory::create<float>(3.0f); auto w = NDArrayFactory::create<float>(3.0f);
@ -56,6 +56,27 @@ TEST_F(DeclarableOpsTests16, test_scatter_update_119) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests16, scatter_upd_2) {
NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32);
NDArray indices('c', {2}, {2,5}, nd4j::DataType::INT32);
NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32);
NDArray e('c', {10, 3}, {1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30}, nd4j::DataType::FLOAT32);
x.linspace(1);
nd4j::ops::scatter_upd op;
auto result = op.execute({&x, &indices, &updates}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests16, test_size_dtype_1) { TEST_F(DeclarableOpsTests16, test_size_dtype_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1}); auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
auto z = NDArrayFactory::create<float>(0.0f); auto z = NDArrayFactory::create<float>(0.0f);

View File

@ -1297,7 +1297,6 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
@Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testLayerNormMixedOrders(){ public void testLayerNormMixedOrders(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');