From bf0ddbc06c3e02fd65f23a18ac6e08be8db1d7ad Mon Sep 17 00:00:00 2001 From: Oleh Date: Mon, 30 Mar 2020 16:33:51 +0300 Subject: [PATCH 01/19] libnd4j fixes for context sync in operation execution (#350) Signed-off-by: Oleg --- libnd4j/include/array/impl/NDArrayList.cpp | 38 ++++++---- .../generic/broadcastable/floormod.cpp | 32 ++++----- .../generic/compat/compat_string_split.cpp | 18 ++--- .../ops/declarable/generic/loss/hingeLoss.cpp | 2 + .../ops/declarable/generic/loss/huberLoss.cpp | 2 + .../ops/declarable/generic/loss/logLoss.cpp | 2 + .../generic/loss/log_poisson_loss.cpp | 2 + .../generic/loss/meanPairWsSqErr.cpp | 2 + .../ops/declarable/generic/loss/meanSqErr.cpp | 2 + .../generic/loss/sigmCrossEntropy.cpp | 2 + .../parity_ops/compare_and_bitpack.cpp | 20 +++--- .../generic/parity_ops/normalize_moments.cpp | 14 ++-- .../ops/declarable/generic/random/uniform.cpp | 4 +- .../declarable/generic/tsne/symmetrized.cpp | 70 +++++++++---------- .../ops/declarable/helpers/cpu/lup.cpp | 2 +- .../include/ops/declarable/helpers/cpu/qr.cpp | 4 +- .../ops/declarable/helpers/cpu/top_k.cpp | 8 +-- .../ops/declarable/helpers/cuda/histogram.cu | 2 +- .../helpers/cuda/image_draw_bounding_boxes.cu | 6 +- .../helpers/cuda/image_suppression.cu | 8 +-- .../ops/declarable/helpers/cuda/lup.cu | 6 +- .../include/ops/declarable/helpers/cuda/qr.cu | 3 +- .../declarable/helpers/cuda/segment_max.cu | 8 +-- .../declarable/helpers/cuda/segment_mean.cu | 16 ++--- .../declarable/helpers/cuda/segment_min.cu | 8 +-- .../declarable/helpers/cuda/segment_prod.cu | 8 +-- .../declarable/helpers/cuda/segment_sqrtn.cu | 8 +-- .../declarable/helpers/cuda/segment_sum.cu | 8 +-- .../declarable/helpers/impl/multiUnique.cpp | 2 +- 29 files changed, 167 insertions(+), 140 deletions(-) diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index deb7828d8..ecd4bcaca 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -23,6 +23,7 @@ #include #include #include +#include namespace sd { NDArrayList::NDArrayList(int height, bool expandable) { @@ -144,25 +145,38 @@ namespace sd { NDArray* NDArrayList::stack() { // FIXME: this is bad for perf, but ok as poc - sd::ops::stack op; - std::vector inputs; - std::vector targs; - std::vector iargs({0}); - std::vector bargs; + int numElements = _elements.load(); - + std::vector inputs(numElements); for (int e = 0; e < numElements; e++) { _chunks[e]->syncToDevice(); - inputs.emplace_back(_chunks[e]); + inputs[e] = _chunks[e]; } - iargs.push_back(_axis); + auto inShapeInfo = inputs[0]->getShapeInfo(); + int rank = shape::rank(inShapeInfo); + NDArray* array = nullptr; - auto result = op.evaluate(inputs); + if (shape::isEmpty(inShapeInfo)) { + switch (rank) { + case 0: { + if (numElements == 1) { + array = new NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); + } else { + array = new NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ; + } + } + } + } + else{ + std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); + outShape.insert(outShape.begin(), (Nd4jLong) numElements); + array = new NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); + } + + ops::helpers::stack(inputs[0]->getContext(), inputs, *array, 0); - auto array = new NDArray(result.at(0)->dup()); - - return array; + return array; } std::pair& NDArrayList::id() { diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp index 52377739b..fac209905 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp @@ -14,10 +14,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// @author raver119@gmail.com -// modified by sgazeos@gmail.com with backprop implementation. -// + // + // @author raver119@gmail.com + // modified by sgazeos@gmail.com with backprop implementation. + // #include #if NOT_EXCLUDED(OP_floormod) @@ -31,7 +31,7 @@ namespace sd { auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!"); auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z); @@ -46,15 +46,15 @@ namespace sd { DECLARE_TYPES(floormod) { getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); } DECLARE_TYPES(floormod_bp) { getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ ALL_FLOATS }); } CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) { @@ -66,11 +66,11 @@ namespace sd { auto gradY = OUTPUT_VARIABLE(1); gradX->assign(epsNext); - sd::ops::floormod op; - auto tmpResult(op.evaluate({x, y})); + NDArray temp(*epsNext); + BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp); if (gradY->rankOf() == gradX->rankOf()) - epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY); + epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY); else // epsNext is greater than gradY { std::vector dims(epsNext->rankOf() * 2); @@ -78,7 +78,7 @@ namespace sd { for (Nd4jLong d = 0; d < gap; d++) { dims[d * 2 + 1] = 1; } - auto tempIn((*tmpResult.at(0))(dims)); + auto tempIn((temp)(dims)); (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); } return Status::OK(); @@ -92,8 +92,8 @@ namespace sd { // eps always has shape of x // grad always has shape of y - Nd4jLong *shapeE; - Nd4jLong *shapeG; + Nd4jLong* shapeE; + Nd4jLong* shapeG; COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp index f88710904..40e080a8f 100644 --- a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp +++ b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// @author raver119@gmail.com -// + // + // @author raver119@gmail.com + // #include #if NOT_EXCLUDED(OP_split_string) @@ -60,7 +60,7 @@ namespace sd { // filling output indices for (uint64_t f = 0; f < cnt; f++) { - for (auto v: icoords) + for (auto v : icoords) indices->p(ic++, v); // last index @@ -75,12 +75,12 @@ namespace sd { for (auto e = 0L; e < input->lengthOf(); e++) { auto split = StringUtils::split(input->e(e), d); - for (const auto &s:split) + for (const auto& s : split) strings.emplace_back(s); } // now once we have all strings in single vector time to fill - auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings); + auto tmp = NDArrayFactory::string({ (Nd4jLong)strings.size() }, strings, input->dataType(), block.launchContext()); auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); // for CUDA mostly @@ -129,9 +129,9 @@ namespace sd { DECLARE_TYPES(compat_string_split) { getOpDescriptor() - ->setAllowedInputTypes({ALL_STRINGS}) - ->setAllowedOutputTypes(0, {ALL_INDICES}) - ->setAllowedOutputTypes(1, {ALL_STRINGS}); + ->setAllowedInputTypes({ ALL_STRINGS }) + ->setAllowedOutputTypes(0, { ALL_INDICES }) + ->setAllowedOutputTypes(1, { ALL_STRINGS }); } } } diff --git a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp index 244083a03..7d8eeec3a 100644 --- a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp @@ -68,6 +68,7 @@ namespace sd { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = *weights * E.lengthOf(); else @@ -201,6 +202,7 @@ namespace sd { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp index 0c05de0ba..a29bd1cf2 100644 --- a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp @@ -73,6 +73,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = *weights * E.lengthOf(); else @@ -216,6 +217,7 @@ DECLARE_SHAPE_FN(huber_loss) { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp index dc889d5c9..99140a394 100644 --- a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp @@ -70,6 +70,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = *weights * E.lengthOf(); else @@ -206,6 +207,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp index 9a00b4eb4..20e03e92b 100644 --- a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp @@ -74,6 +74,7 @@ namespace ops { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = *weights * E.lengthOf(); else @@ -209,6 +210,7 @@ namespace ops { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp index 312a32674..f8006a3ed 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp @@ -143,6 +143,7 @@ namespace sd { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else @@ -282,6 +283,7 @@ namespace sd { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp index c5925fe90..b0ccf968b 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp @@ -67,6 +67,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else @@ -200,6 +201,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp index 4d3c5749c..28d66bc93 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp @@ -78,6 +78,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else @@ -219,6 +220,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp index 43a7c7c75..1decc65f0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// @author sgazeos@gmail.com -// + // + // @author sgazeos@gmail.com + // #include #include @@ -29,24 +29,24 @@ namespace sd { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector()); + auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector(), block.launchContext()); BROADCAST_CHECK_EMPTY(x, y, (&z0)); - + auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); bitcast res; - auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false); + auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false); if (tZ != &z0) { delete tZ; } - + return status; } DECLARE_TYPES(compare_and_bitpack) { getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::UINT8); + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::UINT8); } DECLARE_SHAPE_FN(compare_and_bitpack) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp index d3ccff82a..f8a4c5c6e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// Created by george@skymind.io on 26.01.2018. -// + // + // Created by george@skymind.io on 26.01.2018. + // #include #if NOT_EXCLUDED(OP_normalize_moments) @@ -34,7 +34,7 @@ namespace sd { auto resVariances = OUTPUT_VARIABLE(1); // FIXME: double? - NDArray shift = NDArrayFactory::create(0.); + NDArray shift = NDArrayFactory::create(0., block.launchContext()); if (block.getTArguments()->size() > 0) { shift.assign(T_ARG(0)); @@ -47,7 +47,7 @@ namespace sd { squareMeans.applyTransform(transform::Square, squareMeans, nullptr); variances->applyScalarArr(scalar::Divide, *counts, tempVariances); -// tempVariances.printIndexedBuffer("varianced divided by count"); + // tempVariances.printIndexedBuffer("varianced divided by count"); tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances); if (shift.e(0) != 0) { @@ -75,8 +75,8 @@ namespace sd { DECLARE_TYPES(normalize_moments) { getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ ALL_FLOATS }); } } diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index 6dec96739..94df6b32d 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -49,8 +49,8 @@ namespace sd { bool disposable = false; if (min == nullptr && max == nullptr && block.numT() >= 2) { - min = NDArrayFactory::create_(dtype); - max = NDArrayFactory::create_(dtype); + min = NDArrayFactory::create_(dtype, block.launchContext()); + max = NDArrayFactory::create_(dtype, block.launchContext()); min->p(0, T_ARG(0)); max->p(0, T_ARG(1)); disposable = true; diff --git a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp index 82dd8c36e..cf3675122 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// @author George A. Shulinok , created on 4/18/2019. -// + // + // @author George A. Shulinok , created on 4/18/2019. + // #include #if NOT_EXCLUDED(OP_barnes_symmetrized) @@ -25,20 +25,20 @@ #include namespace sd { -namespace ops { - NDArray* rowCountsPtr = nullptr; + namespace ops { + NDArray* rowCountsPtr = nullptr; - CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); - auto valP = INPUT_VARIABLE(2); + CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); + auto valP = INPUT_VARIABLE(2); auto N = rowP->lengthOf() - 1; - auto outputRows = OUTPUT_VARIABLE(0); + auto outputRows = OUTPUT_VARIABLE(0); auto outputCols = OUTPUT_VARIABLE(1); auto outputVals = OUTPUT_VARIABLE(2); - if (block.getIArguments()->size() > 0) - N = INT_ARG(0); + if (block.getIArguments()->size() > 0) + N = INT_ARG(0); if (rowCountsPtr) { helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr); @@ -46,33 +46,33 @@ namespace ops { return Status::OK(); } return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data."); - } + } - DECLARE_TYPES(barnes_symmetrized) { - getOpDescriptor() - ->setAllowedInputTypes(0, {DataType::INT32}) - ->setAllowedInputTypes(1, {DataType::INT32}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(1, {DataType::INT32}) - ->setAllowedOutputTypes(1, {DataType::INT32}) - ->setAllowedOutputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setSameMode(false); - } + DECLARE_TYPES(barnes_symmetrized) { + getOpDescriptor() + ->setAllowedInputTypes(0, { DataType::INT32 }) + ->setAllowedInputTypes(1, { DataType::INT32 }) + ->setAllowedInputTypes(2, { ALL_INTS, ALL_FLOATS }) + ->setAllowedOutputTypes(1, { DataType::INT32 }) + ->setAllowedOutputTypes(1, { DataType::INT32 }) + ->setAllowedOutputTypes(2, { ALL_INTS, ALL_FLOATS }) + ->setSameMode(false); + } - DECLARE_SHAPE_FN(barnes_symmetrized) { - auto valPShapeInfo = inputShape->at(2); + DECLARE_SHAPE_FN(barnes_symmetrized) { + auto valPShapeInfo = inputShape->at(2); Nd4jLong* outShapeInfo; - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); auto N = rowP->lengthOf() - 1; if (block.getIArguments()->size() > 0) N = INT_ARG(0); auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0)); - NDArray* rowCounts = NDArrayFactory::create_('c', {N}); //rowP->dup(); + NDArray* rowCounts = NDArrayFactory::create_('c', { N }, block.launchContext()); //rowP->dup(); //srowCounts->assign(0); Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts); rowCounts->syncToHost(); -// rowCounts->printBuffer("Row Counts"); + // rowCounts->printBuffer("Row Counts"); if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len."); rowCountsPtr = rowCounts; //ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); @@ -80,13 +80,13 @@ namespace ops { // outShapeInfo[2] = len; // ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c'); //outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace()); - outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', {1, len}, block.getWorkspace()); - auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, len}, block.getWorkspace()); - auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, N + 1}, block.getWorkspace()); - return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo)); - } + outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', { 1, len }, block.getWorkspace()); + auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, len }, block.getWorkspace()); + auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, N + 1 }, block.getWorkspace()); + return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo)); + } -} + } } #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 8466631da..8938a98f9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -142,7 +142,7 @@ namespace helpers { const int rowNum = input->rows(); const int columnNum = input->columns(); - NDArray determinant = NDArrayFactory::create(1.f); + NDArray determinant = NDArrayFactory::create(1.f, context); NDArray compoundMatrix = *input; // copy NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides permutationMatrix.setIdentity(); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp index 2ea18a79d..1f980e553 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp @@ -39,7 +39,7 @@ namespace helpers { template NDArray vmul(NDArray const& v, int n) { - NDArray res('c', {n,n}, v.dataType()); // x = matrix_new(n, n); + NDArray res('c', {n,n}, v.dataType(), v.getContext()); // x = matrix_new(n, n); T const* vBuf = v.getDataBuffer()->primaryAsT(); T* resBuf = res.dataBuffer()->primaryAsT(); auto interloop = PRAGMA_THREADS_FOR_2D { @@ -61,7 +61,7 @@ namespace helpers { std::vector q(M); NDArray z = *matrix; - NDArray e('c', {M}, DataTypeUtils::fromT()); // two internal buffers and scalar for squared norm + NDArray e('c', {M}, DataTypeUtils::fromT(), Q->getContext()); // two internal buffers and scalar for squared norm for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number e.nullify(); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp b/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp index 7e0b07da0..78b06d71e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp @@ -69,9 +69,9 @@ namespace helpers { auto trial = (*input)(e, dimsToExclude); // fill up the first k elements - NDArray topValues = NDArrayFactory::create('c', {k}); - NDArray sortedVals = NDArrayFactory::create('c', {k}); - NDArray topIndices = NDArrayFactory::create('c', {k}); + NDArray topValues = NDArrayFactory::create('c', {k}, input->getContext()); + NDArray sortedVals = NDArrayFactory::create('c', {k}, input->getContext()); + NDArray topIndices = NDArrayFactory::create('c', {k}, input->getContext()); for (uint pos = 0; pos < k; ++pos) { topIndices.t(pos) = pos; topValues.t(pos) = trial.t(pos); @@ -144,7 +144,7 @@ namespace helpers { for (int i = 0; i < input->rankOf() - 1; i++) shapeI[i] = input->sizeAt(i); shapeI[input->rankOf() - 1] = k; - std::unique_ptr indices(NDArrayFactory::create_(input->ordering(), shapeI)); + std::unique_ptr indices(NDArrayFactory::create_(input->ordering(), shapeI, context)); NDArray* values = nullptr; int status = topKFunctor(context, input, values, indices.get(), k, true); result->assign(0); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu index 51af14fc4..c6123d6da 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -112,7 +112,7 @@ namespace sd { int numThreads = 256; int numBlocks = sd::math::nd4j_max(256, sd::math::nd4j_min(1, shape::length(xShapeInfo) / numThreads)); int workspaceSize = numBlocks * numBins; - auto tmp = NDArrayFactory::create('c', {workspaceSize}); + auto tmp = NDArrayFactory::create('c', {workspaceSize}, context); histogramKernel<<getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast(min_val), reinterpret_cast(max_val)); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu index 9817471bb..6d6ec95ed 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu @@ -25,7 +25,7 @@ namespace ops { namespace helpers { typedef NDArray ColorTable_t; - static NDArray DefaultColorTable(int depth) { + static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) { //std::vector> colorTable; const Nd4jLong kDefaultTableLength = 10; const Nd4jLong kDefaultChannelLength = 4; @@ -40,7 +40,7 @@ namespace helpers { 0, 0, 0.5, 1, // 7: navy blue 0, 1, 1, 1, // 8: aqua 1, 0, 1, 1 // 9: fuchsia - }, DataType::FLOAT32); + }, DataType::FLOAT32, context); if (depth == 1) { colorTable.assign(1.f); // all to white when black and white colors @@ -144,7 +144,7 @@ namespace helpers { auto channels = images->sizeAt(3); auto stream = context->getCudaStream(); auto boxSize = boxes->sizeAt(1); - NDArray colorsTable = DefaultColorTable(channels); + NDArray colorsTable = DefaultColorTable(channels, context); if ((colors != nullptr && colors->lengthOf() > 0)) { colorsTable = *colors; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index 5ed534cb6..e6d9a27b1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -188,7 +188,7 @@ namespace helpers { static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {boxes, scales}); - std::unique_ptr indices(NDArrayFactory::create_('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext()); + std::unique_ptr indices(NDArrayFactory::create_('c', {scales->lengthOf()}, context)); // - 1, scales->lengthOf()); //, scales->getContext()); NDArray scores(*scales); Nd4jPointer extras[2] = {nullptr, stream}; @@ -198,7 +198,7 @@ namespace helpers { indices->tickWriteDevice(); sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); indices->tickWriteDevice(); - NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}); + NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}, context); int numSelected = 0; int numBoxes = boxes->sizeAt(0); auto boxesBuf = reinterpret_cast(boxes->specialBuffer()); @@ -347,8 +347,8 @@ namespace helpers { scores->syncToDevice(); } - NDArray indices = NDArrayFactory::create('c', {scores->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext()); - NDArray startPositions = NDArrayFactory::create('c', {scores->lengthOf()}); + NDArray indices = NDArrayFactory::create('c', {scores->lengthOf()}, context); // - 1, scales->lengthOf()); //, scales->getContext()); + NDArray startPositions = NDArrayFactory::create('c', {scores->lengthOf()}, context); NDArray selectedScores(*scores); Nd4jPointer extras[2] = {nullptr, stream}; auto indexBuf = indices.dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 7630694e1..2ca731912 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -598,7 +598,7 @@ namespace helpers { static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { auto n = input->sizeAt(-1); auto stream = context->getCudaStream(); - NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // ('c', {n}); + NDArray iota('c', {n}, permutationVectors->dataType(), context);// = NDArrayFactory::create(); // ('c', {n}); iota.linspace(0); iota.syncToDevice(); output->assign(input); // fill up output tensor with zeros @@ -631,7 +631,7 @@ namespace helpers { // if (dtype != DataType::DOUBLE) // dtype = DataType::FLOAT32; auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); - auto det = NDArrayFactory::create(1); + auto det = NDArrayFactory::create(1, context); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); dim3 launchDims(256, 256, 1024); @@ -677,7 +677,7 @@ namespace helpers { dtype = DataType::FLOAT32; auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace()); - auto det = NDArrayFactory::create(1); + auto det = NDArrayFactory::create(1, context); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); dim3 launchDims(256, 256, 1024); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu index d0d5fddd5..394840376 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu @@ -110,7 +110,7 @@ namespace helpers { auto resR = fullMatricies?R->ulike():matrix->ulike(); std::vector q(M); NDArray z = *matrix; - NDArray e('c', {M}, DataTypeUtils::fromT()); // two internal buffers and scalar for squared norm + NDArray e('c', {M}, DataTypeUtils::fromT(), context); // two internal buffers and scalar for squared norm for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number e.nullify(); z = matrixMinor(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix) @@ -177,4 +177,3 @@ namespace helpers { } } } - diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index 208bf764b..e7baf2370 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -167,8 +167,8 @@ namespace sd { auto stream = context->getCudaStream(); indices->syncToHost(); Nd4jLong numOfClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -209,8 +209,8 @@ namespace sd { // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); output->assign(DataTypeUtils::infOrMax()); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes); classesRangesBegs.assign(indices->lengthOf()); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu index fa8882190..76036a5e6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu @@ -158,8 +158,8 @@ namespace helpers { static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { auto stream = context->getCudaStream(); Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -198,8 +198,8 @@ namespace helpers { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); classesRangesBegs.assign(indices->lengthOf()); @@ -314,8 +314,8 @@ namespace helpers { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -367,8 +367,8 @@ namespace helpers { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu index b83d37567..0133b3b11 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu @@ -161,8 +161,8 @@ namespace helpers { static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { auto stream = context->getCudaStream(); Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); output->assign(DataTypeUtils::infOrMax()); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -202,8 +202,8 @@ namespace helpers { static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); output->assign(DataTypeUtils::infOrMax()); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu index baec75b9e..d08f79817 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu @@ -122,8 +122,8 @@ namespace helpers { static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { auto stream = context->getCudaStream(); Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); output->assign(1); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -160,8 +160,8 @@ namespace helpers { static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); classesRangesBegs.assign(indices->lengthOf()); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu index 7d85a0ea6..f9b6eaad0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu @@ -86,8 +86,8 @@ namespace helpers { static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); classesRangesBegs.assign(indices->lengthOf()); @@ -207,8 +207,8 @@ namespace helpers { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu index 6e1e3fca8..56d53710f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu @@ -162,8 +162,8 @@ namespace helpers { static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { auto stream = context->getCudaStream(); Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -201,8 +201,8 @@ namespace helpers { static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); classesRangesBegs.assign(indices->lengthOf()); diff --git a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp index ff9ae3144..5989f5246 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp @@ -41,7 +41,7 @@ namespace helpers { length += array->lengthOf(); pos++; } - NDArray arrayFull('c', {length}, sd::DataType::INT32); + NDArray arrayFull('c', {length}, sd::DataType::INT32, inputList[0]->getContext()); cContext.setOutputArray(0, &arrayFull); cContext.setIArguments(&axis, 1); From 29e61579c1ead6379f7afa01b8d036299cb8b00a Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Tue, 31 Mar 2020 07:41:16 +0300 Subject: [PATCH 02/19] Shyrma reshape empty (#338) * - start working on reshape op which operates with empty shapes Signed-off-by: Yurii * - correct reshaping for empty arrays Signed-off-by: Yurii * - remove unnecessary check in Loopkind Signed-off-by: Yurii --- libnd4j/include/array/NDArray.hXX | 2 +- libnd4j/include/helpers/LoopKind.h | 24 +- .../ops/declarable/generic/shape/reshape.cpp | 282 +++------- .../layers_tests/ArrayOptionsTests.cpp | 7 +- .../layers_tests/DeclarableOpsTests1.cpp | 301 +++------- .../layers_tests/DeclarableOpsTests14.cpp | 526 ++++++++++++++---- .../layers_tests/DeclarableOpsTests15.cpp | 92 ++- .../layers_tests/DeclarableOpsTests4.cpp | 35 -- libnd4j/tests_cpu/layers_tests/EmptyTests.cpp | 66 --- .../tests_cpu/layers_tests/ParityOpsTests.cpp | 174 +++--- .../tests_cpu/layers_tests/ScalarTests.cpp | 30 +- .../tests_cpu/layers_tests/SingleDimTests.cpp | 45 +- 12 files changed, 730 insertions(+), 854 deletions(-) diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 43c6fe2ad..1caae85a4 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -1944,7 +1944,7 @@ void NDArray::tilei(const std::vector& reps) { Nd4jLong NDArray::sizeAt(const int dim) const { if (dim >= this->rankOf() || dim < -this->rankOf()) - throw std::runtime_error("Bad size index requested"); + throw std::runtime_error("NDArray::sizeAt: bad size index requested"); if (dim >= 0) return shape::shapeOf(_shapeInfo)[dim]; diff --git a/libnd4j/include/helpers/LoopKind.h b/libnd4j/include/helpers/LoopKind.h index 95e9238ad..e3ca932b3 100644 --- a/libnd4j/include/helpers/LoopKind.h +++ b/libnd4j/include/helpers/LoopKind.h @@ -35,16 +35,16 @@ namespace sd { class ND4J_EXPORT LoopKind { - + public: enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D }; static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo); static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); - + }; ////////////////////////////////////////////////////////////////////////////// @@ -59,8 +59,8 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd int temp; const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; - const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo); + const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; + const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo); if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c')) return EWS1; @@ -160,7 +160,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const N const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c'; const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; - const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); + const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c')) return EWS1; @@ -206,7 +206,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c'; const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';; - if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && shape::rank(xShapeInfo) == 2 && xEws == 1 && xOrder == 'c' && xRank == 2 && + if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 && tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) return SMALLARR2DX; if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) @@ -233,18 +233,18 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const ////////////////////////////////////////////////////////////////////////////// LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) { - // both tad shapes are the same, but strides and ews may be different + // both tad shapes are the same, but strides and ews may be different const int tadRank = shape::rank(xTadShapeInfo); const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo); - const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); const char xTadOrder = shape::order(xTadShapeInfo); const char yTadOrder = shape::order(xTadShapeInfo); const char zOrder = shape::order(zShapeInfo); - + int position; const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c'; const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c'; @@ -265,7 +265,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, c return RANK4; if(tadRank == 5 && zEws > 0 && zVectorOrC) return RANK5; - return COMMON; + return COMMON; } diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index ace58a0b8..5ac7686e2 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -35,111 +35,19 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { auto z = OUTPUT_VARIABLE(0); //Special case: empty.reshape() -> return empty - if (x->isEmpty()) { - REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); - return Status::OK(); //No op - } - - if (block.width() == 1) { - - auto arguments = block.getIArguments(); - int argsSize = arguments->size(); - - - - int e = 1; - char order = (char) -(*arguments)[0]; - if (order != 'c' && order != 'f') { - order = 'c'; //x->ordering(); - e = 0; - } - - REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension"); - - std::vector shapeNew; - int e2 = e; - for (; e < (int) arguments->size(); e++) { - if (arguments->at(e) == -1){ - Nd4jLong shapeLength = 1; - for(; e2 < e; e2++){ - shapeLength *= arguments->at(e2); - } - for(e2 = e + 1; e2 < arguments->size(); e2++){ - shapeLength *= arguments->at(e2); - } - Nd4jLong realShape = x->lengthOf() / shapeLength; - shapeNew.push_back(realShape); - } - else{ - shapeNew.push_back(arguments->at(e)); - } - - } - - auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); - REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); - - if (Environment::getInstance()->isDebugAndVerbose()) { - nd4j_printv("Reshape: new shape", shapeNew); - } - - auto xr = x->reshape(order, shapeNew); - z->assign(xr); - STORE_RESULT(*z); - - return Status::OK(); - - } else if (block.width() == 2) { - - auto s = INPUT_VARIABLE(1); - - char order = 'c'; - if (block.numI() > 0) - order = (char) -INT_ARG(0); - - std::vector shapeNew(s->lengthOf()); - - for (int e = 0; e < (int) s->lengthOf(); e++) { - auto dim = s->e(e); - if (dim == -1){ - Nd4jLong shapeLength = 1; - for(int e2 = 0; e2 < e; e2++){ - shapeLength *= s->e(e2); - } - for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){ - REQUIRE_TRUE(s->e(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - shapeLength *= s->e(e2); - } - Nd4jLong realShape = x->lengthOf() / shapeLength; - shapeNew[e] = realShape; - } - else{ - shapeNew[e] = dim; - } - } - - if (Environment::getInstance()->isDebugAndVerbose()) { - nd4j_printv("Reshape: new shape", shapeNew); - } - - if (s->isScalar()) { - // just a scalar - z->assign(x); - } else { - // in some cases we might go away with simple memcpy call instead of assign call - if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) { - z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset()); - } else { - auto xr = x->reshape(order, shapeNew); - z->assign(xr); - } - } - - return Status::OK(); - + if (x->isEmpty()) { + REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); + return Status::OK(); //No op } - return ND4J_STATUS_BAD_INPUT; + REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf()); + + if (Environment::getInstance()->isDebugAndVerbose()) + nd4j_printv("Reshape: new shape", z->getShapeAsVector()); + + z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); + + return Status::OK(); } @@ -151,117 +59,73 @@ DECLARE_TYPES(reshape) { } DECLARE_SHAPE_FN(reshape) { - auto inp = inputShape->at(0); - // we can launch op using Int arguments - if (inputShape->size() == 1) { - REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined"); - std::vector *arguments = block.getIArguments(); + const auto x = INPUT_VARIABLE(0); - int e = 1; - char order = (char) -(*arguments)[0]; - if (order != 'c' && order != 'f') { - order = shape::order(inp); - e = 0; + std::vector reshapeArgs; + std::vector shapeNew; + char orderNew = 'c'; + + if (block.width() == 1) { + reshapeArgs = *block.getIArguments(); + if(!reshapeArgs.empty()) { + orderNew = (char) -reshapeArgs[0]; + if(orderNew == 'c' || orderNew == 'f') + reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case } - - std::vector shapeNew; - - int e2 = e; - for (; e < (int) arguments->size(); e++) { - if ((int) arguments->at(e) == -1){ - - Nd4jLong shapeLength = 1; - for(; e2 < e; e2 ++){ - shapeLength *= arguments->at(e2); - } - for(e2 = e + 1; e2 < arguments->size(); e2++){ - REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - shapeLength *= arguments->at(e2); - } - - if(shapeLength == 0){ - //Edge case for empty: - shapeNew.push_back(0); - } else { - //Standard case - Nd4jLong realShape = shape::length(inp) / shapeLength; - shapeNew.push_back(realShape); - } - } - else{ - shapeNew.push_back(arguments->at(e)); - } - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew))); - } else { - // or, with second input "as shape" - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - // special case here - if (y->isEmpty()) { - REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array"); - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp))); - } - //Special case: empty.reshape(-1) -> return empty - if (x->isEmpty()) { - //REQUIRE_TRUE(y->lengthOf() == 1 && y->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); - auto shapeOf = y->getBufferAsVector(); - Nd4jLong prod = 1; - bool hasNegs = false; - for (auto v:shapeOf) { - if (v < 0) { - hasNegs = true; - v = 0; - } - - prod *= v; - } - - REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well"); - - // if there are -1s - we turn them into zeros - if (hasNegs) { - for (int e = 0; e < shapeOf.size(); e++) - if (shapeOf[e] < 0) - shapeOf[e] = 0; - } - - auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data()); - return SHAPELIST(CONSTANT(newShape)); - } - - std::vector shapeNew(y->lengthOf()); - - for (int e = 0; e < (int) y->lengthOf(); e++) { - auto dim = y->e(e); - if (dim == -1){ - Nd4jLong shapeLength = 1; - for(int e2 = 0; e2 < e; e2++){ - shapeLength *= y->e(e2); - } - for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){ - REQUIRE_TRUE(y->e(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - shapeLength *= y->e(e2); - } - - if(shapeLength == 0){ - //Edge case for empty: - shapeNew[e] = 0; - } else { - Nd4jLong realShape = shape::length(inp) / shapeLength; - shapeNew[e] = realShape; - } - }else { - shapeNew[e] = dim; - } - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew)); } + else { + reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); + orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c'; + } + + REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); + + Nd4jLong xLen = x->lengthOf(); + if(x->isEmpty()) { + xLen = 1; + for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes + if(x->sizeAt(i) != 0) + xLen *= x->sizeAt(i); + } + + for (uint i = 0; i < reshapeArgs.size(); ++i) { + + if (reshapeArgs[i] == -1) { + + uint shapeLength = 1, numOfZeros = 0; + + for(uint j = 0; j < i; ++j) + if(reshapeArgs[j] != 0) + shapeLength *= reshapeArgs[j]; + else + ++numOfZeros; + + for(uint j = i + 1; j < reshapeArgs.size(); ++j) { + REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); + if(reshapeArgs[j] != 0) + shapeLength *= reshapeArgs[j]; + else + ++numOfZeros; + } + + const auto dim = xLen / shapeLength; + + if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) + shapeNew.push_back(0); + else + shapeNew.push_back(dim); + } + else + shapeNew.push_back(reshapeArgs[i]); + } + + auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); + REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew)); } + } } diff --git a/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp b/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp index 18551909c..97893ca5b 100644 --- a/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp @@ -40,16 +40,15 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_0) { TEST_F(ArrayOptionsTests, TestShape_Basic_1) { shape[5] = 2; - + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); } - TEST_F(ArrayOptionsTests, TestShape_Basic_2) { shape[5] = 258; - + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); @@ -58,7 +57,7 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_2) { TEST_F(ArrayOptionsTests, TestShape_Basic_3) { ASSERT_EQ(0, shape::extra(shape)); - + ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 28240cc10..8a03d4abc 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -166,7 +166,7 @@ TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) { auto z = result.at(0); ASSERT_TRUE(z->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -180,7 +180,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) { auto z = result.at(0); ASSERT_TRUE(z->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -198,7 +198,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) { ASSERT_TRUE(z1->equalsTo(exp1)); ASSERT_TRUE(z2->equalsTo(exp2)); - + } ////////////////////////////////////////////////////////////////////// @@ -213,7 +213,7 @@ TEST_F(DeclarableOpsTests1, AXpY_Test_1) { auto z = result.at(0); ASSERT_TRUE(z->equalsTo(exp)); - + } TEST_F(DeclarableOpsTests1, BasicInitialization3) { @@ -258,7 +258,7 @@ TEST_F(DeclarableOpsTests1, TestTensorMmul1) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - + } TEST_F(DeclarableOpsTests1, TestTensorDot2) { @@ -278,7 +278,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot2) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - + } TEST_F(DeclarableOpsTests1, TestTensorDot3) { @@ -298,7 +298,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot3) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - + } TEST_F(DeclarableOpsTests1, TestTensorDot4) { @@ -318,7 +318,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////// @@ -338,7 +338,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot5) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -360,7 +360,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot6) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -381,7 +381,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot7) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot8) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -431,7 +431,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot9) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -452,7 +452,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot10) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -474,7 +474,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot11) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -495,7 +495,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot12) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -516,7 +516,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot13) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -537,7 +537,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot14) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -558,7 +558,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot15) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -579,7 +579,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot16) { ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// @@ -786,7 +786,7 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) { ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } TEST_F(DeclarableOpsTests1, TestRng1) { @@ -1046,7 +1046,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1071,7 +1071,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1093,7 +1093,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) { ASSERT_TRUE(res.at(0)->equalsTo(&exp)); ASSERT_TRUE(exp.equalsTo(&z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1147,7 +1147,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1402,7 +1402,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1421,7 +1421,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1437,7 +1437,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1463,7 +1463,7 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { ASSERT_TRUE(z.equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1676,31 +1676,6 @@ TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) { delete block; } -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, Reshapeas1) { - const std::vector xShape = { 5,4,3 }; - const std::vector yShape = { 3,5,4 }; - - auto x = NDArrayFactory::create_('f', xShape); - auto y = NDArrayFactory::create_('f', yShape); - - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); - - sd::ops::reshapeas reshape; - - reshape.execute(block); - - ASSERT_TRUE(x->isSameShape(y)); - - delete variableSpace; - delete block; -} - TEST_F(DeclarableOpsTests1, Test_Cast_1) { // TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected auto x = NDArrayFactory::create('c', { 5, 5 }); @@ -1715,7 +1690,7 @@ TEST_F(DeclarableOpsTests1, Test_Cast_1) { auto z = result.at(0); ASSERT_TRUE(yExp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1848,113 +1823,6 @@ TEST_F(DeclarableOpsTests1, TestGemv1) { #endif -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, Reshape2) { - const std::vector xShape = { 5,4,3 }; - const std::vector yShape = { 3,5,4 }; - - auto x = NDArrayFactory::create_('c', xShape); - auto y = NDArrayFactory::create_('c', yShape); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1 }); - std::vector* arguments = block->getIArguments(); - arguments->push_back(-y->ordering()); - arguments->push_back(3); - arguments->push_back(5); - arguments->push_back(4); - - sd::ops::reshape reshape; - - Nd4jStatus status = reshape.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - - ASSERT_TRUE(result->isSameShape(y)); - - delete y; - delete block; - delete variableSpace; -} - -TEST_F(DeclarableOpsTests1, Reshape3) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(x.isSameShape(z)); - - -} - -TEST_F(DeclarableOpsTests1, Reshape4) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 3, 4, 5 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(x.isSameShape(z)); - - -} - -TEST_F(DeclarableOpsTests1, Reshape5) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 5, 4, 3 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - -} - -TEST_F(DeclarableOpsTests1, Reshape6) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 4, 15 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 4, -1 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(z->isSameShape(exp)); - - - -} - -TEST_F(DeclarableOpsTests1, Reshape7) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 60 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { -1 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(z->isSameShape(exp)); - - - -} ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Transpose1) { @@ -1983,7 +1851,6 @@ TEST_F(DeclarableOpsTests1, Transpose1) { delete variableSpace; } - ////////////////////////////////////////////////////////////////////// // not-in-place TEST_F(DeclarableOpsTests1, Permute1) { @@ -2259,7 +2126,7 @@ TEST_F(DeclarableOpsTests1, IsMax1) { //res->printIndexedBuffer("IS_MAX"); ASSERT_TRUE(exp.equalsTo(res)); - + } ////////////////////////////////////////////////////////////////////// @@ -2281,7 +2148,7 @@ TEST_F(DeclarableOpsTests1, IsMax2) { //res->printIndexedBuffer("IS_MAX"); ASSERT_TRUE(exp.equalsTo(res)); - + } ////////////////////////////////////////////////////////////////////// @@ -2303,7 +2170,7 @@ TEST_F(DeclarableOpsTests1, IsMax3) { //res->printIndexedBuffer("IS_MAX"); ASSERT_TRUE(exp.equalsTo(res)); - + } ////////////////////////////////////////////////////////////////////// @@ -2352,7 +2219,7 @@ TEST_F(DeclarableOpsTests1, IsMax4) { // ASSERT_TRUE(expState.equalsTo(state)); // ASSERT_TRUE(expOut.equalsTo(output)); -// +// // } ////////////////////////////////////////////////////////////////// @@ -2386,7 +2253,7 @@ TEST_F(DeclarableOpsTests1, sru_test1) { ASSERT_TRUE(expState.equalsTo(state)); ASSERT_TRUE(expOut.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2438,7 +2305,7 @@ TEST_F(DeclarableOpsTests1, sru_bp) { ASSERT_TRUE(expGradB.equalsTo(gradB)); ASSERT_TRUE(expGradInit.equalsTo(gradInit)); - + } ////////////////////////////////////////////////////////////////// @@ -2474,7 +2341,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_1) { ASSERT_TRUE(expState.equalsTo(state)); ASSERT_TRUE(expOut.equalsTo(output)); - + } TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { @@ -2527,7 +2394,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { ASSERT_TRUE(expGradB.equalsTo(gradB)); ASSERT_TRUE(expGradInit.equalsTo(gradInit)); - + } TEST_F(DeclarableOpsTests1, ArgMax1) { @@ -2547,7 +2414,7 @@ TEST_F(DeclarableOpsTests1, ArgMax1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -2568,7 +2435,7 @@ TEST_F(DeclarableOpsTests1, ArgMax2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -2590,7 +2457,7 @@ TEST_F(DeclarableOpsTests1, ArgMax3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, ArgMax4) { @@ -2611,7 +2478,7 @@ TEST_F(DeclarableOpsTests1, ArgMax4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -2633,7 +2500,7 @@ TEST_F(DeclarableOpsTests1, ArgMax5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, ArgMax6) { @@ -2676,7 +2543,7 @@ TEST_F(DeclarableOpsTests1, ArgMin1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -2697,7 +2564,7 @@ TEST_F(DeclarableOpsTests1, SquareTests1) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_1) { @@ -2717,7 +2584,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_2) { @@ -2736,7 +2603,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_2) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_3) { @@ -2756,7 +2623,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_4) { @@ -2775,7 +2642,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_5) { @@ -2796,7 +2663,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_6) { @@ -2809,7 +2676,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_6) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests1, OneHotTests_7) { @@ -2822,7 +2689,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_7) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests1, FillAs_1) { @@ -2840,7 +2707,7 @@ TEST_F(DeclarableOpsTests1, FillAs_1) { ASSERT_NEAR(scalar, result.at(0)->meanNumber().e(0), 1e-5f); - + } ////////////////////////////////////////////////////////////////////// @@ -2866,7 +2733,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) { ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.equalsTo(array)); - + } @@ -2893,7 +2760,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) { ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.equalsTo(array)); - + } @@ -2913,7 +2780,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) { ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.equalsTo(array)); - + } ////////////////////////////////////////////////////////////////////// @@ -2931,7 +2798,7 @@ TEST_F(DeclarableOpsTests1, softmax_test1) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2947,7 +2814,7 @@ TEST_F(DeclarableOpsTests1, softmax_test2) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2963,7 +2830,7 @@ TEST_F(DeclarableOpsTests1, softmax_test3) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2979,7 +2846,7 @@ TEST_F(DeclarableOpsTests1, softmax_test4) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2995,7 +2862,7 @@ TEST_F(DeclarableOpsTests1, softmax_test5) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3011,7 +2878,7 @@ TEST_F(DeclarableOpsTests1, softmax_test6) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3027,7 +2894,7 @@ TEST_F(DeclarableOpsTests1, softmax_test7) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3043,7 +2910,7 @@ TEST_F(DeclarableOpsTests1, softmax_test8) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3059,7 +2926,7 @@ TEST_F(DeclarableOpsTests1, softmax_test9) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test10) { @@ -3074,7 +2941,7 @@ TEST_F(DeclarableOpsTests1, softmax_test10) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test11) { @@ -3089,7 +2956,7 @@ TEST_F(DeclarableOpsTests1, softmax_test11) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3108,7 +2975,7 @@ TEST_F(DeclarableOpsTests1, softmax_test12) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_1) { @@ -3132,7 +2999,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3157,7 +3024,7 @@ TEST_F(DeclarableOpsTests1, Reverse_2) { ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); - + } ////////////////////////////////////////////////////////////////////// @@ -3183,7 +3050,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3209,7 +3076,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3234,7 +3101,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// @@ -3260,7 +3127,7 @@ TEST_F(DeclarableOpsTests1, Reverse_6) { ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); - + } @@ -3288,7 +3155,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -3316,7 +3183,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// @@ -3341,7 +3208,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } TEST_F(DeclarableOpsTests1, Reverse_10) { @@ -3357,7 +3224,7 @@ TEST_F(DeclarableOpsTests1, Reverse_10) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3380,7 +3247,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3402,7 +3269,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3423,7 +3290,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3444,7 +3311,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } TEST_F(DeclarableOpsTests1, Test_Expose_1) { @@ -3463,7 +3330,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) { ASSERT_TRUE(input0.equalsTo(z0)); ASSERT_TRUE(input1.equalsTo(z1)); - + } TEST_F(DeclarableOpsTests1, Test_Expose_2) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index b3a710be9..db49c12f2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -51,23 +51,7 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) { ASSERT_EQ(exp, *z); - -} -TEST_F(DeclarableOpsTests14, Test_Reshape_CF_1) { - auto x = NDArrayFactory::create('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0}); - auto e = NDArrayFactory::create('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); - - auto r = x.reshape('c', {3, 2});; - r.streamline('f'); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {3, 2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - } TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) { @@ -108,7 +92,7 @@ TEST_F(DeclarableOpsTests14, Multiply_test) { ASSERT_EQ(e, r); ASSERT_EQ(e, *f); - + } } @@ -124,7 +108,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) { auto z = result.at(0); ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { @@ -139,7 +123,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { auto z = result.at(0); ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_0) { @@ -193,7 +177,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) { ASSERT_EQ(e, *result.at(0)); - + } TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { @@ -210,7 +194,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { ASSERT_EQ(e, *result.at(0)); - + } TEST_F(DeclarableOpsTests14, test_empty_fill_1) { @@ -224,7 +208,7 @@ TEST_F(DeclarableOpsTests14, test_empty_fill_1) { auto z = result.at(0); ASSERT_EQ(y, *z); - + } TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) { @@ -259,7 +243,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) { auto out = res2.at(0); ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); - + } TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { @@ -271,7 +255,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { auto out = res2.at(0); ASSERT_EQ(out->e(0), -DataTypeUtils::infOrMax()); - + } TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { @@ -286,7 +270,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { ASSERT_EQ(res2.status(), Status::OK()); auto out = res2.at(0); ASSERT_EQ(out->e(0), 0.f); - + } TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { @@ -303,7 +287,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { // out->printShapeInfo("ReduceMean empty shape with keep dims"); // out->printIndexedBuffer("ReduceMean scalar"); ASSERT_TRUE(std::isnan(out->e(0))); - + } TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) { @@ -324,7 +308,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) { ASSERT_TRUE(exp.isSameShape(z)); - + } TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) { @@ -345,7 +329,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) { ASSERT_TRUE(exp.isSameShape(z)); - + } TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { @@ -363,7 +347,7 @@ TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, test_empty_argmax_2) { @@ -391,7 +375,7 @@ TEST_F(DeclarableOpsTests14, test_empty_tanh_5) { ASSERT_TRUE(x.isSameShape(z)); ASSERT_EQ(x, *z); - + } ////////////////////////////////////////////////////////////////////// @@ -409,7 +393,7 @@ TEST_F(DeclarableOpsTests14, repeat_1) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -427,7 +411,7 @@ TEST_F(DeclarableOpsTests14, repeat_2) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -445,7 +429,7 @@ TEST_F(DeclarableOpsTests14, repeat_3) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -463,7 +447,7 @@ TEST_F(DeclarableOpsTests14, repeat_4) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -481,7 +465,7 @@ TEST_F(DeclarableOpsTests14, repeat_5) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) { @@ -502,7 +486,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) { ASSERT_EQ(e, res); - + } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) { @@ -523,7 +507,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) { ASSERT_EQ(e, res); - + } /////////////////////////////////////////////////////////////////////// @@ -639,7 +623,7 @@ TEST_F(DeclarableOpsTests14, matmul_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -661,7 +645,7 @@ TEST_F(DeclarableOpsTests14, matmul_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -682,7 +666,7 @@ TEST_F(DeclarableOpsTests14, matmul_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -704,7 +688,7 @@ TEST_F(DeclarableOpsTests14, matmul_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -726,7 +710,7 @@ TEST_F(DeclarableOpsTests14, matmul_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -747,7 +731,7 @@ TEST_F(DeclarableOpsTests14, matmul_test6) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -770,7 +754,7 @@ TEST_F(DeclarableOpsTests14, matmul_test7) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -795,7 +779,7 @@ TEST_F(DeclarableOpsTests14, matmul_test8) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -820,7 +804,7 @@ TEST_F(DeclarableOpsTests14, matmul_test9) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test10) { @@ -876,7 +860,7 @@ TEST_F(DeclarableOpsTests14, matmul_test11) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test12) { @@ -894,7 +878,7 @@ TEST_F(DeclarableOpsTests14, matmul_test12) { ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -914,7 +898,7 @@ TEST_F(DeclarableOpsTests14, matmul_test13) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test14) { @@ -933,7 +917,7 @@ TEST_F(DeclarableOpsTests14, matmul_test14) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test15) { @@ -952,7 +936,7 @@ TEST_F(DeclarableOpsTests14, matmul_test15) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test16) { @@ -971,7 +955,7 @@ TEST_F(DeclarableOpsTests14, matmul_test16) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test17) { @@ -985,7 +969,7 @@ TEST_F(DeclarableOpsTests14, matmul_test17) { ASSERT_EQ(exp, *result.at(0)); - + } @@ -1007,7 +991,7 @@ TEST_F(DeclarableOpsTests14, matmul_test18) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1027,7 +1011,7 @@ TEST_F(DeclarableOpsTests14, matmul_test19) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1048,7 +1032,7 @@ TEST_F(DeclarableOpsTests14, matmul_test20) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1069,7 +1053,7 @@ TEST_F(DeclarableOpsTests14, matmul_test21) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1090,7 +1074,7 @@ TEST_F(DeclarableOpsTests14, matmul_test22) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1111,7 +1095,7 @@ TEST_F(DeclarableOpsTests14, matmul_test23) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1135,7 +1119,7 @@ TEST_F(DeclarableOpsTests14, matmul_test24) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1156,7 +1140,7 @@ TEST_F(DeclarableOpsTests14, matmul_test25) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1177,7 +1161,7 @@ TEST_F(DeclarableOpsTests14, matmul_test26) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1198,7 +1182,7 @@ TEST_F(DeclarableOpsTests14, matmul_test27) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1220,7 +1204,7 @@ TEST_F(DeclarableOpsTests14, matmul_test28) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1242,7 +1226,7 @@ TEST_F(DeclarableOpsTests14, matmul_test29) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test30) { @@ -1262,7 +1246,7 @@ TEST_F(DeclarableOpsTests14, matmul_test30) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test31) { @@ -1282,7 +1266,7 @@ TEST_F(DeclarableOpsTests14, matmul_test31) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test32) { @@ -1299,7 +1283,7 @@ TEST_F(DeclarableOpsTests14, matmul_test32) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test33) { @@ -1319,7 +1303,7 @@ TEST_F(DeclarableOpsTests14, matmul_test33) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test34) { @@ -1336,7 +1320,7 @@ TEST_F(DeclarableOpsTests14, matmul_test34) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ///////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test35) { @@ -1353,7 +1337,7 @@ TEST_F(DeclarableOpsTests14, matmul_test35) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test36) { @@ -1370,7 +1354,7 @@ TEST_F(DeclarableOpsTests14, matmul_test36) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test37) { @@ -1617,7 +1601,7 @@ TEST_F(DeclarableOpsTests14, Stack_1) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1645,7 +1629,7 @@ TEST_F(DeclarableOpsTests14, Stack_2) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1673,7 +1657,7 @@ TEST_F(DeclarableOpsTests14, Stack_3) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1700,7 +1684,7 @@ TEST_F(DeclarableOpsTests14, Stack_4) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1727,7 +1711,7 @@ TEST_F(DeclarableOpsTests14, Stack_5) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1754,7 +1738,7 @@ TEST_F(DeclarableOpsTests14, Stack_6) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1778,7 +1762,7 @@ TEST_F(DeclarableOpsTests14, Stack_7) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1801,7 +1785,7 @@ TEST_F(DeclarableOpsTests14, Stack_8) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1824,7 +1808,7 @@ TEST_F(DeclarableOpsTests14, Stack_9) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1850,7 +1834,7 @@ TEST_F(DeclarableOpsTests14, Stack_10) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } TEST_F(DeclarableOpsTests14, Stack_11) { @@ -1872,7 +1856,7 @@ TEST_F(DeclarableOpsTests14, Stack_11) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1895,7 +1879,7 @@ TEST_F(DeclarableOpsTests14, Stack_12) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1917,7 +1901,7 @@ TEST_F(DeclarableOpsTests14, Stack_13) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1941,7 +1925,7 @@ TEST_F(DeclarableOpsTests14, Stack_14) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, Stack_15) { @@ -1959,7 +1943,7 @@ TEST_F(DeclarableOpsTests14, Stack_15) { ASSERT_TRUE(exp.isSameShape(z)); - + } @@ -1978,7 +1962,7 @@ TEST_F(DeclarableOpsTests14, Stack_16) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, Stack_17) { @@ -1999,7 +1983,7 @@ TEST_F(DeclarableOpsTests14, Stack_17) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, Stack_18) { @@ -2018,8 +2002,8 @@ TEST_F(DeclarableOpsTests14, Stack_18) { auto out = res2.at(0); ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); - - + + } TEST_F(DeclarableOpsTests14, Stack_19) { @@ -2033,7 +2017,7 @@ TEST_F(DeclarableOpsTests14, Stack_19) { auto z = result.at(0); ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, Stack_20) { @@ -2047,7 +2031,7 @@ TEST_F(DeclarableOpsTests14, Stack_20) { auto z = result.at(0); ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, Stack_21) { @@ -2073,7 +2057,363 @@ TEST_F(DeclarableOpsTests14, Stack_21) { ASSERT_TRUE(outStack->isSameShape(outConcat)); ASSERT_TRUE(outStack->equalsTo(outConcat)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Reshape1) { + const std::vector xShape = { 5,4,3 }; + const std::vector yShape = { 3,5,4 }; + + auto x = NDArrayFactory::create_('f', xShape); + auto y = NDArrayFactory::create_('f', yShape); + + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reshapeas reshape; + + reshape.execute(block); + + ASSERT_TRUE(x->isSameShape(y)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Reshape2) { + const std::vector xShape = { 5,4,3 }; + const std::vector yShape = { 3,5,4 }; + + auto x = NDArrayFactory::create_('c', xShape); + auto y = NDArrayFactory::create_('c', yShape); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, new Variable()); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({ -1 }); + std::vector* arguments = block->getIArguments(); + arguments->push_back(-y->ordering()); + arguments->push_back(3); + arguments->push_back(5); + arguments->push_back(4); + + sd::ops::reshape reshape; + + Nd4jStatus status = reshape.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + + ASSERT_TRUE(result->isSameShape(y)); + + delete y; + delete block; + delete variableSpace; +} + +TEST_F(DeclarableOpsTests14, Reshape3) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x.isSameShape(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape4) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { 3, 4, 5 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x.isSameShape(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape5) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { 5, 4, 3 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); +} + +TEST_F(DeclarableOpsTests14, Reshape6) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto exp = NDArrayFactory::create('c', { 4, 15 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { 4, -1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(exp)); +} + +TEST_F(DeclarableOpsTests14, Reshape7) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto exp = NDArrayFactory::create('c', { 60 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { -1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(exp)); +} + +TEST_F(DeclarableOpsTests14, Reshape8) { + auto x = NDArrayFactory::create('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0}); + auto e = NDArrayFactory::create('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + + auto r = x.reshape('c', {3, 2});; + r.streamline('f'); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {3, 2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +} + +TEST_F(DeclarableOpsTests14, Reshape9) { + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + + sd::ops::reshape op; + auto result = op.evaluate({&array}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); +} + +TEST_F(DeclarableOpsTests14, Reshape10) { + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + auto z = NDArrayFactory::create('c', {1, 1}); + + sd::ops::reshape op; + auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests14, Reshape11) { + auto x = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('c', {4, 3}); + + x.linspace(1); + exp.linspace(1); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {-99, 4, 3}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape12) { + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + auto shape = NDArrayFactory::create('c', {2}, {-1, 2}); + auto exp = NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + + sd::ops::reshape op; + auto result = op.evaluate({&x, &shape}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape13) { + auto vector = NDArrayFactory::create('c', {1}, {119.0f}); + auto exp = NDArrayFactory::create(119.f); + auto empty = NDArrayFactory::empty_(); + + sd::ops::reshape op; + auto result = op.evaluate({&vector, empty}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(exp, *result.at(0)); + + delete empty; +} + +TEST_F(DeclarableOpsTests14, Reshape14) { + auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); + auto y = NDArrayFactory::create('c', {2}, {10, 0}); + auto e = NDArrayFactory::create('c', {10, 0}); + + sd::ops::reshape op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_EQ(e, *z); } +TEST_F(DeclarableOpsTests14, Reshape15) { + auto x0 = NDArrayFactory::create('c', {2, 0}); + auto x1 = NDArrayFactory::create('c', {0, 1, 2}); + + auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); + auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); + + auto e0 = NDArrayFactory::create('c', {2, 0, 0}); + auto e1 = NDArrayFactory::create('c', {0, 1}); + + sd::ops::reshape op; + auto result0 = op.evaluate({&x0, &shape0}, {}, {}); + ASSERT_EQ(Status::OK(), result0.status()); + auto z0 = result0.at(0); + ASSERT_EQ(e0, *z0); + + auto result1 = op.evaluate({&x1, &shape1}, {}, {}); + ASSERT_EQ(Status::OK(), result1.status()); + auto z1 = result1.at(0); + ASSERT_EQ(e1, *z1); +} + +TEST_F(DeclarableOpsTests14, Reshape16) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto shape = NDArrayFactory::create('c', {1, 3}, {1, 2, 2}); + + auto exp = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); + + sd::ops::reshape op; + + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape17) { + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape18) { + auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, Reshape19) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + + +TEST_F(DeclarableOpsTests14, Reshape20) { + + NDArray x1('c', {2,0}, sd::DataType::FLOAT32); + NDArray x2('c', {10,0}, sd::DataType::FLOAT32); + NDArray x3('c', {2,0,0,10}, sd::DataType::FLOAT32); + NDArray x4('c', {0,0,10}, sd::DataType::FLOAT32); + NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32); + NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32); + NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32); + + sd::ops::reshape op; + + auto result = op.evaluate({&x1}, {}, {2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0})); + + result = op.evaluate({&x2}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,5})); + + result = op.evaluate({&x2}, {}, {5, 2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); + + result = op.evaluate({&x2}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); + + result = op.evaluate({&x3}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,10})); + + result = op.evaluate({&x4}, {}, {2, -1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,5,0})); + + result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,0,0,10})); + + result = op.evaluate({&x6}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({5, 2, 0})); + + result = op.evaluate({&x7}, {}, {-1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2, 0})); + + result = op.evaluate({&x7}, {}, {10,0,50,100}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100})); +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index f983d27a3..5fffa73c5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -89,7 +89,7 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { sd::ops::standardize_bp op; auto result = op.evaluate({&x, &eps}, {0}); ASSERT_EQ(Status::OK(), result.status()); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { @@ -108,7 +108,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { auto out = result.at(0); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { @@ -126,7 +126,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { auto out = result.at(0); // out->printIndexedBuffer("Adjusted Constrast"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { @@ -144,7 +144,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { auto out = result.at(0); // out->printIndexedBuffer("Adjusted Constrast"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { @@ -162,7 +162,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { auto out = result.at(0); // out->printIndexedBuffer("Adjusted Constrast"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { @@ -177,7 +177,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { auto out = result.at(0); // out->printIndexedBuffer("Adjusted Constrast"); ASSERT_TRUE(e.equalsTo(out)); - + } /* @@ -308,7 +308,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) { // out->printBuffer("Adjusted Constrast6"); // e.printBuffer("Adjusted Expected 6"); // ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { @@ -415,7 +415,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { auto diff = e - *out; // diff.printBuffer("Adjusted subtract 7"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_1) { @@ -429,7 +429,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_1) { auto out = result.at(0); // out->printIndexedBuffer("Casted result"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_2) { @@ -444,7 +444,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) { auto out = result.at(0); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_3) { @@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) { // e.printIndexedBuffer("Double to int64"); auto res = result.at(0); ASSERT_EQ(*res, e); - + } @@ -508,7 +508,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_5) { // res->printIndexedBuffer("BITCAST5"); ASSERT_TRUE(e.equalsTo(res)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_6) { @@ -528,7 +528,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_6) { // res->printIndexedBuffer("BITCAST6"); ASSERT_TRUE(e.equalsTo(res)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_7) { auto x = NDArrayFactory::create('c', {4, 4}, { @@ -547,7 +547,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_7) { // res->printIndexedBuffer("BITCAST7"); ASSERT_TRUE(e.equalsTo(res)); - + } TEST_F(DeclarableOpsTests15, test_matmul_bp_1) { @@ -637,7 +637,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { sd::ops::layer_norm op; auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result.status()); - + } TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { @@ -649,7 +649,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { sd::ops::layer_norm_bp op; auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result.status()); - + } ////////////////////////////////////////////////////////////////////// @@ -710,30 +710,6 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) { ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); } -TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) { - auto array = NDArrayFactory::create(119.f); - auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); - - sd::ops::reshape op; - auto result = op.evaluate({&array}, {}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_EQ(e, *z); -} - -TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_2) { - auto array = NDArrayFactory::create(119.f); - auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); - auto z = NDArrayFactory::create('c', {1, 1}); - - sd::ops::reshape op; - auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); -} - TEST_F(DeclarableOpsTests15, test_rank_1) { auto array = NDArrayFactory::create('c', {4, 64}); auto e = NDArrayFactory::create('c', {}, {2}); @@ -757,7 +733,7 @@ TEST_F(DeclarableOpsTests15, test_rank_2) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { @@ -800,7 +776,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_2) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - + } TEST_F(DeclarableOpsTests15, test_lstmBlock_3) { @@ -969,7 +945,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) { sd::ops::rgb_to_grs op; auto result = op.evaluate({ &rgbs }, {}, {}); ASSERT_EQ(Status::THROW(), result.status()); - + } catch (std::exception& e) { nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); } @@ -1063,7 +1039,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) { ASSERT_EQ(Status::OK(), result.status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1074,7 +1050,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) { sd::ops::rgb_to_yuv op; auto result = op.evaluate({ &rgbs }, {}, {}); ASSERT_EQ(Status::THROW(), result.status()); - + } catch (std::exception & e) { nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); @@ -1109,7 +1085,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1168,7 +1144,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) { ASSERT_EQ(Status::OK(), result.status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1179,7 +1155,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) { sd::ops::yuv_to_rgb op; auto result = op.evaluate({ &yuv }, {}, {}); ASSERT_EQ(Status::THROW(), result.status()); - + } catch (std::exception & e) { nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); @@ -1423,7 +1399,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test8) { ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); - + } TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { @@ -1515,7 +1491,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { ASSERT_NEAR(dLdyB->e(i), dLdyExpB.e(i), 0.00001); } - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { @@ -1532,10 +1508,10 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); - + ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); @@ -1554,10 +1530,10 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); - + ASSERT_TRUE(B.isSameShape(*dLdAbp)); ASSERT_TRUE(B.equalsTo(*dLdAbp)); @@ -1606,7 +1582,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); @@ -1632,7 +1608,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); @@ -1655,7 +1631,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); @@ -1706,7 +1682,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 69dec8359..1e877ecc6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -789,24 +789,6 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) { ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(DeclarableOpsTests4, Test_Reshape_Again) { - auto x = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('c', {4, 3}); - - x.linspace(1); - exp.linspace(1); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {-99, 4, 3}); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests4, Test_Split_1) { @@ -1209,23 +1191,6 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) { ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) { - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - auto shape = NDArrayFactory::create('c', {2}, {-1, 2}); - auto exp = NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - - sd::ops::reshape op; - auto result = op.evaluate({&x, &shape}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index ab6bad3c4..e6aeb43d4 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -140,37 +140,6 @@ TEST_F(EmptyTests, Test_Concat_4) { ASSERT_EQ(exp, *z); } -TEST_F(EmptyTests, Test_Reshape_1) { - auto vector = NDArrayFactory::create('c', {1}, {119.0f}); - auto exp = NDArrayFactory::create(119.f); - auto empty = NDArrayFactory::empty_(); - - sd::ops::reshape op; - auto result = op.evaluate({&vector, empty}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(exp, *result.at(0)); - - delete empty; -} - -TEST_F(EmptyTests, Test_Reshape_3) { - auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); - auto y = NDArrayFactory::create('c', {2}, {10, 0}); - auto e = NDArrayFactory::create('c', {10, 0}); - - sd::ops::reshape op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_EQ(e, *z); - -} - TEST_F(EmptyTests, Test_dup_1) { auto empty = NDArrayFactory::empty(); auto dup = new NDArray(empty.dup()); @@ -256,41 +225,6 @@ TEST_F(EmptyTests, test_shaped_empty_4) { ASSERT_EQ(shapeOf, array.getShapeAsVector()); } -TEST_F(EmptyTests, test_empty_reshape_1) { - /* - INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); - INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2); - - INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0]; - INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0]; - INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0]; - - assertArrayEquals(new long[]{2, 0, 0}, out0.shape()); - assertArrayEquals(new long[]{0, 1}, out1.shape()); - assertArrayEquals(new long[]{10, 0}, out2.shape()); - */ - auto x0 = NDArrayFactory::create('c', {2, 0}); - auto x1 = NDArrayFactory::create('c', {0, 1, 2}); - - auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); - auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); - - auto e0 = NDArrayFactory::create('c', {2, 0, 0}); - auto e1 = NDArrayFactory::create('c', {0, 1}); - - sd::ops::reshape op; - auto result0 = op.evaluate({&x0, &shape0}, {}, {}); - ASSERT_EQ(Status::OK(), result0.status()); - auto z0 = result0.at(0); - ASSERT_EQ(e0, *z0); - - auto result1 = op.evaluate({&x1, &shape1}, {}, {}); - ASSERT_EQ(Status::OK(), result1.status()); - auto z1 = result1.at(0); - ASSERT_EQ(e1, *z1); - -} - TEST_F(EmptyTests, test_empty_matmul_1) { auto x = NDArrayFactory::create('c', {0, 1}); diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index ce24c8a9b..089b4a92f 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -48,7 +48,7 @@ TEST_F(ParityOpsTests, TestZeroAs1) { ASSERT_TRUE(z->isSameShape(&x)); ASSERT_TRUE(z->equalsTo(&exp)); - + } TEST_F(ParityOpsTests, TestMaximum1) { @@ -66,7 +66,7 @@ TEST_F(ParityOpsTests, TestMaximum1) { ASSERT_TRUE(y.equalsTo(z)); - + } @@ -86,7 +86,7 @@ TEST_F(ParityOpsTests, TestMinimum1) { ASSERT_TRUE(y.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestTear1) { @@ -106,7 +106,7 @@ TEST_F(ParityOpsTests, TestTear1) { for (int e = 0; e < result.size(); e++) ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); - + } TEST_F(ParityOpsTests, TestUnstack1) { @@ -126,7 +126,7 @@ TEST_F(ParityOpsTests, TestUnstack1) { for (int e = 0; e < result.size(); e++) ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); - + } @@ -148,7 +148,7 @@ TEST_F(ParityOpsTests, TestUnstack2) { for (int e = 0; e < result.size(); e++) ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); - + } TEST_F(ParityOpsTests, TestUnstack3) { @@ -166,7 +166,7 @@ TEST_F(ParityOpsTests, TestUnstack3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -185,7 +185,7 @@ TEST_F(ParityOpsTests, TestUnstack4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack5) { @@ -203,7 +203,7 @@ TEST_F(ParityOpsTests, TestUnstack5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack6) { @@ -221,7 +221,7 @@ TEST_F(ParityOpsTests, TestUnstack6) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack7) { @@ -239,7 +239,7 @@ TEST_F(ParityOpsTests, TestUnstack7) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack8) { @@ -257,7 +257,7 @@ TEST_F(ParityOpsTests, TestUnstack8) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack9) { @@ -275,7 +275,7 @@ TEST_F(ParityOpsTests, TestUnstack9) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -293,7 +293,7 @@ TEST_F(ParityOpsTests, TestUnstack10) { ASSERT_TRUE(exp.isSameShape(result.at(1))); ASSERT_TRUE(exp.isSameShape(result.at(2))); - + } //////////////////////////////////////////////////////////////////////// @@ -310,7 +310,7 @@ TEST_F(ParityOpsTests, TestUnstack11) { ASSERT_TRUE(exp.isSameShape(result.at(0))); ASSERT_TRUE(exp.isSameShape(result.at(1))); - + } //////////////////////////////////////////////////////////////////////// @@ -325,7 +325,7 @@ TEST_F(ParityOpsTests, TestUnstack12) { ASSERT_TRUE(result.size() == 0); - + } TEST_F(ParityOpsTests, TestUnstack13) { @@ -361,7 +361,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest1) { ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.equalsTo(z)); - + } @@ -380,7 +380,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest2) { ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.equalsTo(z)); - + } @@ -399,7 +399,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest3) { ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.equalsTo(z)); - + } TEST_F(ParityOpsTests, ExpandDimsTest4) { @@ -417,7 +417,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest4) { ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.equalsTo(z)); - + } @@ -434,7 +434,7 @@ TEST_F(ParityOpsTests, Test_Shape_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -452,7 +452,7 @@ TEST_F(ParityOpsTests, Test_Equals_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -470,7 +470,7 @@ TEST_F(ParityOpsTests, Test_NotEquals_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Less_1) { @@ -487,7 +487,7 @@ TEST_F(ParityOpsTests, Test_Less_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_LessEquals_1) { @@ -504,7 +504,7 @@ TEST_F(ParityOpsTests, Test_LessEquals_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_GreaterEquals_1) { @@ -521,7 +521,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_GreaterEquals_2) { @@ -538,7 +538,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Greater_1) { @@ -555,7 +555,7 @@ TEST_F(ParityOpsTests, Test_Greater_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Where_1) { @@ -575,7 +575,7 @@ TEST_F(ParityOpsTests, Test_Where_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Where_2) { @@ -593,7 +593,7 @@ TEST_F(ParityOpsTests, Test_Where_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -612,7 +612,7 @@ TEST_F(ParityOpsTests, Test_Where_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Select_1) { @@ -630,7 +630,7 @@ TEST_F(ParityOpsTests, Test_Select_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Select_2) { @@ -648,7 +648,7 @@ TEST_F(ParityOpsTests, Test_Select_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Select_3) { @@ -666,25 +666,7 @@ TEST_F(ParityOpsTests, Test_Select_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - -} -TEST_F(ParityOpsTests, Test_Reshape_TF_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto shape = NDArrayFactory::create('c', {1, 3}, {1, 2, 2}); - - auto exp = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); - - sd::ops::reshape op; - - auto result = op.evaluate({&x, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - } TEST_F(ParityOpsTests, Test_Bias_Add_1) { @@ -702,7 +684,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) { for (int e = 0; e < tads.size(); e++) { ASSERT_TRUE(bias.equalsTo(tads.at(e))); } - + } TEST_F(ParityOpsTests, Test_Scatter_Add_1) { @@ -718,7 +700,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_2) { @@ -735,7 +717,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_3) { @@ -751,7 +733,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_4) { @@ -767,7 +749,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_5) { @@ -784,7 +766,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_5) { // z->printBuffer(); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_6) { @@ -800,7 +782,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_7) { @@ -816,7 +798,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_7) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -864,7 +846,7 @@ TEST_F(ParityOpsTests, scatterMax_test1) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test2) { @@ -880,7 +862,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test3) { @@ -897,7 +879,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test4) { @@ -913,7 +895,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test5) { @@ -929,7 +911,7 @@ TEST_F(ParityOpsTests, scatterMax_test5) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test6) { @@ -945,7 +927,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -963,7 +945,7 @@ TEST_F(ParityOpsTests, scatterMin_test1) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMin_test2) { @@ -979,7 +961,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMin_test3) { @@ -995,7 +977,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMin_test4) { @@ -1012,7 +994,7 @@ TEST_F(ParityOpsTests, scatterMin_test4) { // z->printBuffer(); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1044,7 +1026,7 @@ TEST_F(ParityOpsTests, scatterND_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1064,7 +1046,7 @@ TEST_F(ParityOpsTests, scatterND_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1088,7 +1070,7 @@ TEST_F(ParityOpsTests, scatterND_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1107,7 +1089,7 @@ TEST_F(ParityOpsTests, scatterND_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1127,7 +1109,7 @@ TEST_F(ParityOpsTests, scatterND_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1154,7 +1136,7 @@ TEST_F(ParityOpsTests, scatterND_test6) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1181,7 +1163,7 @@ TEST_F(ParityOpsTests, scatterND_test7) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1202,7 +1184,7 @@ TEST_F(ParityOpsTests, scatterND_test8) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1236,7 +1218,7 @@ TEST_F(ParityOpsTests, scatterND_add_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1260,7 +1242,7 @@ TEST_F(ParityOpsTests, scatterND_add_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1283,7 +1265,7 @@ TEST_F(ParityOpsTests, scatterND_add_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1310,7 +1292,7 @@ TEST_F(ParityOpsTests, scatterND_add_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1346,7 +1328,7 @@ TEST_F(ParityOpsTests, scatterND_add_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1379,7 +1361,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1404,7 +1386,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1427,7 +1409,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1454,7 +1436,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1490,7 +1472,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1511,7 +1493,7 @@ TEST_F(ParityOpsTests, scatterND_update_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1535,7 +1517,7 @@ TEST_F(ParityOpsTests, scatterND_update_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1559,7 +1541,7 @@ TEST_F(ParityOpsTests, scatterND_update_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1586,7 +1568,7 @@ TEST_F(ParityOpsTests, scatterND_update_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1622,7 +1604,7 @@ TEST_F(ParityOpsTests, scatterND_update_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1655,7 +1637,7 @@ TEST_F(ParityOpsTests, scatter_update_1) { ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.equalsTo(x)); - + } ////////////////////////////////////////////////////////////////////// @@ -1674,7 +1656,7 @@ TEST_F(ParityOpsTests, scatter_update_2) { ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.equalsTo(x)); - + } ////////////////////////////////////////////////////////////////////// @@ -1693,7 +1675,7 @@ TEST_F(ParityOpsTests, scatter_update_3) { ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.equalsTo(x)); - + } ////////////////////////////////////////////////////////////////////// @@ -1712,5 +1694,5 @@ TEST_F(ParityOpsTests, scatter_update_4) { ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.equalsTo(x)); - + } diff --git a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp index 898af1722..937ca4675 100644 --- a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp @@ -103,7 +103,7 @@ TEST_F(ScalarTests, Test_Concat_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -124,7 +124,7 @@ TEST_F(ScalarTests, Test_Concat_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -146,7 +146,7 @@ TEST_F(ScalarTests, Test_Concat_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ScalarTests, Test_ExpandDims_1) { @@ -163,7 +163,7 @@ TEST_F(ScalarTests, Test_ExpandDims_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ScalarTests, Test_Squeeze_1) { @@ -179,27 +179,9 @@ TEST_F(ScalarTests, Test_Squeeze_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } - -TEST_F(ScalarTests, Test_Reshape_1) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - -} - - TEST_F(ScalarTests, Test_Permute_1) { auto x = NDArrayFactory::create(3.0f); auto exp = NDArrayFactory::create(3.0f); @@ -213,7 +195,7 @@ TEST_F(ScalarTests, Test_Permute_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ScalarTests, Test_Concat_Scalar_1) { diff --git a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp index 636206957..cc13f3529 100644 --- a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp @@ -77,7 +77,7 @@ TEST_F(SingleDimTests, Test_Concat_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(SingleDimTests, Test_Reduce_1) { @@ -111,7 +111,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -129,7 +129,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -149,7 +149,7 @@ TEST_F(SingleDimTests, Test_Squeeze_1) { ASSERT_EQ(exp.rankOf(), z->rankOf()); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(SingleDimTests, Test_Squeeze_2) { @@ -165,42 +165,9 @@ TEST_F(SingleDimTests, Test_Squeeze_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } -TEST_F(SingleDimTests, Test_Reshape_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - -} - -TEST_F(SingleDimTests, Test_Reshape_2) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - -} - - TEST_F(SingleDimTests, Test_Permute_1) { auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); @@ -214,5 +181,5 @@ TEST_F(SingleDimTests, Test_Permute_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } \ No newline at end of file From 1d004b542a017daf4d741accd6f5bdc1ae1c5fab Mon Sep 17 00:00:00 2001 From: Oleh Date: Tue, 31 Mar 2020 13:03:10 +0300 Subject: [PATCH 03/19] xw_plus_b mkldnn implementation (#247) * libnd4j first step of mkldnn for xw_plus_b and test of aurora crash in imageHelper * libnd4j sync folders with master * libnd4j merge master, raw implementation of xw_plus_b on mkldnn, clean up, need testing and adding checks for corresponded input shapes * libnd4j corrections and checks added to xw_plus_b mkl * libnd4j corrected dataType description based on mkl operation description, need more investigation * libnd4j fixe xw_blus_b mkl implementation, need testing Signed-off-by: Oleg * libnd4j two unit tests added Signed-off-by: Oleg * libnd4j fixed check input dimensions bug Signed-off-by: Oleg * libndj4 one more test added to cover different order handling Signed-off-by: Oleg * libnd4j added optional int arg support to define weights format, if arg == 1, mkldnn (do not need transpose in mkldnn implementation), else mmul weights format, corrected check points, added unit test Signed-off-by: Oleg * libnd4j merge master Signed-off-by: Oleg * libnd4j some improvements to avoid NDArray transpose in xw_plus_b operation Signed-off-by: Oleg * libnd4j fixed issues connected with weights rank, also added support of one case based on tf (for mkldnn, cpu, cuda), test case added Signed-off-by: Oleg * libnd4j added proper handling of empty inputs (all implementations) * libnd4j fixed compilation error * libnd4j several more corrections after conflict solve and fixed typos Signed-off-by: Oleg * libnd4j removed unsupported data types Signed-off-by: Oleg * libnd4j merge master and fixed issues Signed-off-by: Oleg * libnd4j added propagation implementation for xw_plus_b, fixed issue connected with mkl weights data format, avoided data copy in transpose mode, test cases added, manually tested with gradCheck Signed-off-by: Oleg * libnd4j one minor fix of double operation declaration Signed-off-by: Oleg * libnd4j code clean up Signed-off-by: Oleg * libnd4j minor tests fixes Signed-off-by: Oleg * libnd4j fixed build problem, integrate helpers changes Signed-off-by: Oleg Co-authored-by: raver119 --- .../ops/declarable/generic/nn/xw_plus_b.cpp | 108 ++++- .../ops/declarable/headers/parity_ops.h | 5 +- .../declarable/platform/mkldnn/mkldnnUtils.h | 4 + .../declarable/platform/mkldnn/xw_plus_b.cpp | 426 ++++++++++++++++++ .../layers_tests/DeclarableOpsTests18.cpp | 205 ++++++++- .../layers_tests/DeclarableOpsTests5.cpp | 130 +++++- .../tests_cpu/layers_tests/MklDnnTests.cpp | 9 +- 7 files changed, 856 insertions(+), 31 deletions(-) create mode 100644 libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp diff --git a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp index ad7a430f4..dbabad395 100644 --- a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp @@ -14,10 +14,11 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// xw_plus_b op. Created by GS 31.01.2018 -// -// + // + // xw_plus_b op. Created by GS 31.01.2018 + // @author Oleg Semeniv + // + // #include #if NOT_EXCLUDED(OP_xw_plus_b) @@ -29,36 +30,115 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(x->rankOf() <= 2 && y->rankOf() <= 2 && z->rankOf() <= 2, 0, "xw_plus_b: Input and Output NDArrays should have rank less or equal to 2"); - REQUIRE_TRUE(b->isVector() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input vector should have proper dimension 1x%i. " - "But %i != %i.", z->sizeAt(-1), b->lengthOf(), z->sizeAt(-1)); + if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty()) + return Status::OK(); + + const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false); + + auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); + + REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); + REQUIRE_TRUE(z->rankOf() == 2, 0, "xw_plus_b: Output array should have rank equal 2, but got instead %i!", z->rankOf()); + + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input bias vector should be 1D and have proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); + // multiply x to y - MmulHelper::mmul(x, y, z, 1.0, 0.0); + MmulHelper::mmul(x, w, z, 1.0, 0.0); // adding b vector z->addiRowVector(*b); + if (bTranspose) + delete w; + return Status::OK(); } DECLARE_SHAPE_FN(xw_plus_b) { - auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), inputShape->at(1), false, false, - ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace()); + + auto weights = INPUT_VARIABLE(1); + + const int nWeightsFormat = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.getWorkspace()) : inputShape->at(1); + + auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), weightsShape, false, false, + ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace()); return SHAPELIST(CONSTANT(outputShape)); } DECLARE_TYPES(xw_plus_b) { getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ ALL_FLOATS }); + } + + + CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) { + + auto x = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdb = OUTPUT_VARIABLE(2); + + if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || dLdz->isEmpty()) + return Status::OK(); + + const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false); + + auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); + + REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); + REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf()); + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, "xw_plus_b BP: Input bias vector should be 1D and have proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1)); + + auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) : OUTPUT_VARIABLE(1); + + // dLdb + dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, { 0 })); + + matmul_bp mmul_bp; + mmul_bp.execute({ x, w, dLdz }, std::vector{dLdx, dLdw}, {}, {}, {}); + + if (bTranspose) { + delete w; + delete dLdw; + } + return Status::OK(); + } + + DECLARE_SHAPE_FN(xw_plus_b_bp) { + + Nd4jLong* xShapeInfo; + Nd4jLong* wShapeInfo; + Nd4jLong* bShapeInfo; + + COPY_SHAPE(inputShape->at(0), xShapeInfo); + COPY_SHAPE(inputShape->at(1), wShapeInfo); + COPY_SHAPE(inputShape->at(2), bShapeInfo); + + return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), CONSTANT(bShapeInfo)); + } + + DECLARE_TYPES(xw_plus_b_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ ALL_FLOATS }); } } } -#endif \ No newline at end of file +#endif diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 81742fa3d..f3131c193 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -867,9 +867,12 @@ namespace sd { * - 2D matrix MxN * - 1D vector with N elements * output value - 2D matrix NxN as multiply of matrixes and add vector + * Int args: + * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul */ #if NOT_EXCLUDED(OP_xw_plus_b) - DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0); + DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0); + DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0); #endif /** diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index dd512a884..514a325c7 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -96,6 +96,10 @@ namespace sd { DECLARE_PLATFORM(tanh_bp, ENGINE_CPU); + DECLARE_PLATFORM(xw_plus_b, ENGINE_CPU); + + DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU); + } } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp new file mode 100644 index 000000000..01a003c2c --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp @@ -0,0 +1,426 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleg Semeniv + // + // + +#include +#include +#include +#include +#include "mkldnnUtils.h" + +using namespace dnnl; + +namespace sd { + namespace ops { + namespace platforms { + + ////////////////////////////////////////////////////////////////////// + static void xwPlusBiasMKLDNN(const NDArray* x, const NDArray* weights, const NDArray* bias, NDArray* z, const bool bShouldTransp) { + + // mkl works with following + // [M,K] x [N,K]^T + [N] = [M,N] + const auto xRank = x->rankOf(); + + // [M,K] x [K,N] = [M,N] + const int M = x->sizeAt(0); + const int K = x->sizeAt(1); // K == wK + const int N = z->sizeAt(1); + + dnnl::memory::dims xShape = dnnl::memory::dims({ M, K }); + dnnl::memory::dims wShape = dnnl::memory::dims({ N, K }); + dnnl::memory::dims zShape = dnnl::memory::dims({ M, N }); + dnnl::memory::dims bShape = dnnl::memory::dims({ N }); + + dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; + + // x type + dnnl::memory::data_type xType = dnnl::memory::data_type::f32; + if (x->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else if (x->dataType() == DataType::INT8) + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = (weights->dataType() == DataType::FLOAT32) ? + wType = dnnl::memory::data_type::f32 : wType = dnnl::memory::data_type::s8; + + // bias type need add description for bias + dnnl::memory::data_type bType = dnnl::memory::data_type::f32; + if (bias->dataType() == DataType::INT32) + bType = dnnl::memory::data_type::s32; + else if (bias->dataType() == DataType::UINT8) + bType = dnnl::memory::data_type::u8; + else if (bias->dataType() == DataType::INT8) + bType = dnnl::memory::data_type::s8; + + // z type + dnnl::memory::data_type zType = dnnl::memory::data_type::f32; + if (z->dataType() == DataType::INT32) + zType = dnnl::memory::data_type::s32; + else if (z->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (z->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + + // memory descriptors for arrays + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); + mkldnnUtils::setBlockStrides(x, x_user_md); + + // weights + dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, format); + if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { + + weights_user_md.data.format_kind = dnnl_blocked; // overrides format + if (bShouldTransp) { + weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1); + weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0); + } + else { + weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0); + weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1); + } + } + // bias + dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); + dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); + mkldnnUtils::setBlockStrides(bias, bias_user_md); + + // z + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); + mkldnnUtils::setBlockStrides(z, z_user_md); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::inner_product_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, z_mkl_md); + + dnnl::inner_product_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + + // bias + auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, bias->getBuffer()); + args[DNNL_ARG_BIAS] = bias_mkl_mem; + + // z + auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::inner_product_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + } + + ////////////////////////////////////////////////////////////////////// + static void xwPlusBiasBp(const NDArray* x, const NDArray* weights, const NDArray* bias, const NDArray* dLdz, + NDArray* dLdx, NDArray* dLdw, NDArray* dLdb, const bool bShouldTransp) { + + // mkl works with following + // [M,K] x [N,K]^T + [N] = [M,N] + const auto xRank = x->rankOf(); + + // [M,K] x [K,N] = [M,N] + const int M = x->sizeAt(0); + const int K = x->sizeAt(1); // K == wK + const int N = dLdz->sizeAt(1); + // input dims + dnnl::memory::dims xShape = dnnl::memory::dims({ M, K }); + dnnl::memory::dims wShape = dnnl::memory::dims({ N, K }); + dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N }); + + dnnl::memory::dims bShape = dnnl::memory::dims({ N }); + // output dims + dnnl::memory::dims dLdxShape = xShape; + dnnl::memory::dims dLdwShape = wShape; + + dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; + dnnl::memory::data_type dataType = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format); + mkldnnUtils::setBlockStrides(x, x_user_md); + + // weights + dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, format); + if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { + + weights_user_md.data.format_kind = dnnl_blocked; // overrides format + if (bShouldTransp) { + weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1); + weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0); + } + else { + weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0); + weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1); + } + } + // bias + dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + mkldnnUtils::setBlockStrides(bias, bias_user_md); + + // dLdz + dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, format); + mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + + // dLdw + dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format); + dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, format); + if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) { + + dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format + if (bShouldTransp) { + dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1); + dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0); + } + else { + dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0); + dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1); + } + } + + // dLdb + dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md); + + // dLdx + dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, format); + mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + // forward + // operation primitive description + dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md); + dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backprob + // dLdw + auto op_bpdw_desc = inner_product_backward_weights::desc(x_mkl_md, dLdw_mkl_md, dLdb_mkl_md, dLdz_mkl_md); + auto op_bpdw_prim_desc = inner_product_backward_weights::primitive_desc(op_bpdw_desc, engine, op_ff_prim_desc); + + // backprob + // dLdx + auto op_bpdx_desc = inner_product_backward_data::desc(dLdx_mkl_md, weights_mkl_md, dLdz_mkl_md); + auto op_bpdx_prim_desc = inner_product_backward_data::primitive_desc(op_bpdx_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map argsDw, argsDx; + + dnnl::stream stream(engine); + + // dLdz dw + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]); + + // dLdz - dx + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]); + + // input x for dw + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]); + + // weights - dx + mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]); + + // dLdw + auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->getBuffer()); + const bool dLdwReorder = op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc(); + auto dLdw_mkl_mem = dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) : dLdw_user_mem; + argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem; + + // dLdx + auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); + const bool dLdxReorder = op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc(); + auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) : dLdx_user_mem; + argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; + + // dLdb + auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->getBuffer()); + const bool dLdbReorder = op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc(); + auto dLdb_mkl_mem = dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) : dLdb_user_mem; + argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem; + + // run calculations dw + dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw); + // run calculations dx + dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx); + + // reorder outputs if necessary + if (dLdxReorder) + dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); + + if (dLdwReorder) + dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem).execute(stream, dLdw_mkl_mem, dLdw_user_mem); + + if (dLdbReorder) + dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem).execute(stream, dLdb_mkl_mem, dLdb_user_mem); + + stream.wait(); + } + + PLATFORM_IMPL(xw_plus_b, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); + + if (x->isEmpty() || w->isEmpty() || b->isEmpty()) + return Status::OK(); + + const int xRank = x->rankOf(); + const int wRank = w->rankOf(); + const int zRank = z->rankOf(); + + const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + + REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank); + REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank); + REQUIRE_TRUE(zRank == 2, 0, "xw_plus_b MKL: Output array should have rank equal 2, but got instead %i!", zRank); + + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b MKL: Input bias vector should be 1D and have proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); + + // mkldnnInerPorductss + xwPlusBiasMKLDNN(x, w, b, z, bShouldTransp); + + return Status::OK(); + } + + PLATFORM_CHECK(xw_plus_b, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType wType = w->dataType(); + const DataType bType = b->dataType(); + const DataType zType = z->dataType(); + + /* + Source Weights Destination Bias + f32 f32 f32 f32 + u8, s8 s8 u8, s8, s32, f32 u8, s8, s32, f32 + */ + return block.isUseMKLDNN() && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + ( // x + (xType == DataType::UINT8 || xType == DataType::INT8) && + // w + (wType == DataType::UINT8 || wType == DataType::INT8) && + // b + (bType == DataType::UINT8 || bType == DataType::INT8 || bType == DataType::INT32 || bType == DataType::FLOAT32) && + // z + (zType == DataType::UINT8 || zType == DataType::INT8 || zType == DataType::INT32 || zType == DataType::FLOAT32) + )); + } + + PLATFORM_IMPL(xw_plus_b_bp, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdw = OUTPUT_VARIABLE(1); + auto dLdb = OUTPUT_VARIABLE(2); + + if (x->isEmpty() || w->isEmpty() || b->isEmpty() || dLdz->isEmpty()) + return Status::OK(); + + const int xRank = x->rankOf(); + const int wRank = w->rankOf(); + const int dLdzRank = dLdz->rankOf(); + + const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + + REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); + REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP MKL: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf()); + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(1), 0, "xw_plus_b BP MKL: Input bias vector should be 1D and have proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", dLdz->sizeAt(1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(1)); + + xwPlusBiasBp(x, w, b, dLdz, dLdx, dLdw, dLdb, bShouldTransp); + + return Status::OK(); + } + + PLATFORM_CHECK(xw_plus_b_bp, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdw = OUTPUT_VARIABLE(1); + auto dLdb = OUTPUT_VARIABLE(2); + + const DataType xType = x->dataType(); + const DataType wType = w->dataType(); + const DataType bType = b->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); + const DataType dLdwType = dLdw->dataType(); + const DataType dLdbType = dLdb->dataType(); + + /* + Source Weights Destination Bias + f32 f32 f32 f32 + */ + return block.isUseMKLDNN() && + (xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && + dLdbType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32 && + dLdwType == DataType::FLOAT32); + } + + } + } +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index b1cafa073..1f36a8f2c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -15,9 +15,9 @@ ******************************************************************************/ -// -// @author raver119@gmail.com -// + // + // @author raver119@gmail.com + // #include "testlayers.h" #include @@ -45,19 +45,19 @@ TEST_F(DeclarableOpsTests18, test_bitcast_1) { auto e = NDArrayFactory::create(4597464930322771456L); sd::ops::bitcast op; - auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong) sd::DataType::INT64}, {}); + auto status = op.execute({ &x }, { &z }, {}, { (Nd4jLong)sd::DataType::INT64 }, {}); ASSERT_EQ(Status::OK(), status); ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests18, test_tanh_1) { - auto x = NDArrayFactory::create('c', {8}, {0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f}); + auto x = NDArrayFactory::create('c', { 8 }, { 0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f }); auto z = x.ulike(); - auto e = NDArrayFactory::create('c', {8}, {0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f}); + auto e = NDArrayFactory::create('c', { 8 }, { 0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f }); sd::ops::tanh op; - op.execute({&x}, {&z}); + op.execute({ &x }, { &z }); ASSERT_EQ(e, z); } @@ -187,6 +187,197 @@ TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) { ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output.equalsTo(exp)); } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_1) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + NDArray dLdz('c', { 2, 2 }, DataType::FLOAT32); + dLdz.linspace(1); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 17.f, 14.f, 10.f, 45.f, 32.f, 26.f }); + auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 43.f, 58.f, 26.f, 42.f, 21.f, 30.f }); + auto edLdb = NDArrayFactory::create('c', { 2 }, { 4.f, 6.f }); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_2) { + + auto x = NDArrayFactory::create('c', { 6,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('c', { 3,4 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f, 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create('c', { 4 }, { 100.f, 200.f, 100.f, 200.f }); + + NDArray dLdz('c', { 6, 4 }, DataType::FLOAT32); + dLdz.linspace(.1, .5); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 6,3 }, { 15.3f, 18.700001f, 13.2f, 61.299995f, 62.699997f, 47.200001f, 107.299995f, 106.699997f, 81.199997f, 153.299988f, 150.699997f, 115.199997f, 199.300018f, 194.700012f, 149.199997f, 245.300018f, 238.700012f, 183.199997f }); + auto edLdw = NDArrayFactory::create('c', { 3,4 }, { 268.5f, 291.f, 313.5f, 336.f, 226.800003f, 250.800003f, 274.799988f, 298.799988f, 146.699997f, 160.199997f, 173.700012f, 187.200012f }); + auto edLdb = NDArrayFactory::create('c', { 4 }, { 30.6f, 33.599998f, 36.599998f, 39.599998f }); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_3) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto w = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); + + auto dLdz = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 3937.f, 3096.f }); + auto edLdw = NDArrayFactory::create('c', { 2,3 }, { 166.f, 269.f, 326.f, 1826.f, 2959.f, 3586.f }); + auto edLdb = NDArrayFactory::create('c', { 3 }, { 166.f, 269.f, 326.f }); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_4) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto w = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); + + auto dLdz = NDArrayFactory::create('c', { 1,1 }, { 244.f }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 2684.f, 732.f }); + auto edLdw = NDArrayFactory::create('c', { 2,1 }, { 244.f, 2684.f }); + auto edLdb = NDArrayFactory::create('c', { 1 }, { 244.f }); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_5) { + + auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto dLdz = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdxC = NDArrayFactory::create('c', { 2,3 }, { 2705.f, 1818.f, 1026.f, 4912.f, 2967.f, 1850.f }); + auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 3297.f, 4094.f, 4438.f, 5613.f, 2422.f, 3271.f }); + auto edLdbC = NDArrayFactory::create('c', { 2 }, { 427.f, 584.f }); + + auto edLdx = NDArrayFactory::create('f', { 2,3 }); + auto edLdw = NDArrayFactory::create('f', { 3,2 }); + auto edLdb = NDArrayFactory::create('f', { 2 }); + + edLdx.assign(edLdxC); + edLdw.assign(edLdwC); + edLdb.assign(edLdbC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_6) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto dLdz = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + + // mkl-format + w.permutei({ 1,0 }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 2695.f, 2012.f, 1566.f, 4247.f, 2635.f, 2418.f }); + auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); + auto edLdb = NDArrayFactory::create('c', { 2 }, { 483.f, 543.f }); + auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); + edLdw.permutei({ 1,0 }); + edLdw.assign(edLdwC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); +} ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 8958f9023..6ac9d34cd 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -2432,18 +2432,36 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_3) { } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_1) { - auto x = NDArrayFactory::create('c', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); - auto y = NDArrayFactory::create('c', {3,2}, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); - auto b = NDArrayFactory::create({100.f, 200.f}); + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); - auto exp = NDArrayFactory::create('c', {2,2}, {173.f, 264.f, 310.f, 279.f}); + auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); sd::ops::xw_plus_b op; - auto result = op.evaluate({&x, &y, &b}, {}, {}); + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_2) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto y = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); + + auto exp = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); @@ -2452,9 +2470,107 @@ TEST_F(DeclarableOpsTests5, XWPlusB_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_3) { + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); + + auto exp = NDArrayFactory::create('c', { 1,1 }, { 244.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_4) { + + auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto exp = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_5) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + + y = y.transpose(); + + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_6) { + + auto x = NDArrayFactory::create('c', { 3, 2 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + + auto b = NDArrayFactory::create('c', { 1 }, { 100.f }); + + auto exp = NDArrayFactory::create('c', { 3, 1 }, { 144.f, 175.f, 173.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_7) { + + auto x = NDArrayFactory::create('c', { 3, 4 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 4, 5 }, { 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 3.f, 11.f, 3.f, 11.f }); + + auto b = NDArrayFactory::create('c', { 5 }, { 100.f, 200.f, 300.f, 400.f, 500.f }); + + auto exp = NDArrayFactory::create('c', { 3, 5 }, { 219.f, 375.f, 531.f, 575.f, 731.f, 217.f, 317.f, 505.f, 517.f, 705.f, 248.f, 396.f, 496.f, 596.f, 696.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, StopGradient_1) { diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index dcbfa29b0..bb3934994 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -76,8 +76,13 @@ TEST_F(MklDnnTests, helpers_includer) { sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp; + + sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b; - printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp }); + sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp; + + printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, &xw_plus_b_bp }); + #endif -} \ No newline at end of file +} From 0a27e9f41d7ceb66f515cb42c9d26a62e0217c03 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 1 Apr 2020 04:00:38 +0300 Subject: [PATCH 04/19] Fix incompatibilities with generated code (#303) * Cholesky fixed * Constructors added * MatMul wrapper * Constructor added * Missing wrappers added * Generate Linalg namespace added * Output data types * Unit tests * Added mmul * Code generation * Code generated * Build fixed * Fixing signatures * Tests fixed * Tests fixed * Added enum * Fix tests * Some fixes * Eye test fixed * SameDiff: small fix for renameVariable - also replace variable name in lossVariable list if necessary Signed-off-by: Alex Black * Some fixes * Tests fixed * Revert wrong fix * Some fixes * Some fixes * Extending base test class * Added pad * Fixed for generated signatures * Fixes due to nd4j codegen * Backwards compatibility fixes * Fixed errors in tests, reverted wrong changes * Test fixed * Added missing operations used for nd4s operators * Compilation fixed * Added meshgrid * Fixed constructors * fixes Signed-off-by: Alex Black * Fix bad commit (incorrectly reverted change from master) Signed-off-by: Alex Black * Fixed test Co-authored-by: Alex Black --- .../samediff/testlayers/SameDiffConv.java | 16 +- .../nn/conf/layers/CapsuleLayer.java | 17 +- .../nn/conf/layers/LocallyConnected1D.java | 6 +- .../nn/conf/layers/LocallyConnected2D.java | 6 +- .../org/deeplearning4j/util/CapsuleUtils.java | 11 - .../remote/JsonModelServerTest.java | 2 +- .../org/deeplearning4j/ui/TestSameDiffUI.java | 2 +- .../samediff/SameDiffMLPTestCases.java | 1 + .../DifferentialFunctionFactory.java | 9 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 18 + .../nd4j/autodiff/samediff/ops/SDBitwise.java | 573 +- .../org/nd4j/autodiff/samediff/ops/SDCNN.java | 1774 +++--- .../nd4j/autodiff/samediff/ops/SDImage.java | 585 +- .../nd4j/autodiff/samediff/ops/SDLinalg.java | 561 ++ .../nd4j/autodiff/samediff/ops/SDLoss.java | 1493 +++-- .../nd4j/autodiff/samediff/ops/SDMath.java | 5470 +++++++++-------- .../org/nd4j/autodiff/samediff/ops/SDNN.java | 2175 +++---- .../org/nd4j/autodiff/samediff/ops/SDOps.java | 24 +- .../org/nd4j/autodiff/samediff/ops/SDRNN.java | 376 +- .../nd4j/autodiff/samediff/ops/SDRandom.java | 529 +- .../autodiff/samediff/ops/SDValidation.java | 33 + .../factory => }/enums/DataFormat.java | 2 +- .../converters/ImportClassMapping.java | 4 +- .../linalg/api/ops/BaseIndexAccumulation.java | 6 + .../linalg/api/ops/custom/AdjustContrast.java | 9 +- .../nd4j/linalg/api/ops/custom/AdjustHue.java | 5 + .../api/ops/custom/AdjustSaturation.java | 5 + .../nd4j/linalg/api/ops/custom/Logdet.java | 52 + .../org/nd4j/linalg/api/ops/custom/Lstsq.java | 21 + .../linalg/api/ops/custom/MatrixBandPart.java | 4 +- .../api/ops/impl/image/CropAndResize.java | 12 +- .../ops/impl/image/ExtractImagePatches.java | 18 +- .../api/ops/impl/image/NonMaxSuppression.java | 7 + .../api/ops/impl/indexaccum/FirstIndex.java | 9 +- .../linalg/api/ops/impl/indexaccum/IMax.java | 5 + .../linalg/api/ops/impl/indexaccum/IMin.java | 4 + .../api/ops/impl/indexaccum/LastIndex.java | 15 +- .../impl/layers/convolution/AvgPooling3D.java | 4 - .../impl/layers/convolution/BatchNorm.java | 13 + .../ops/impl/layers/convolution/Conv1D.java | 12 +- .../ops/impl/layers/convolution/Conv2D.java | 13 +- .../ops/impl/layers/convolution/Conv3D.java | 13 +- .../ops/impl/layers/convolution/DeConv2D.java | 14 +- .../ops/impl/layers/convolution/DeConv3D.java | 16 +- .../impl/layers/convolution/DepthToSpace.java | 26 +- .../layers/convolution/DepthwiseConv2D.java | 22 +- .../LocalResponseNormalization.java | 4 + .../impl/layers/convolution/MaxPooling2D.java | 7 +- .../impl/layers/convolution/MaxPooling3D.java | 8 +- .../ops/impl/layers/convolution/SConv2D.java | 13 +- .../impl/layers/convolution/SpaceToDepth.java | 28 +- .../impl/layers/convolution/Upsampling2d.java | 8 + .../ops/impl/loss/AbsoluteDifferenceLoss.java | 5 + .../linalg/api/ops/impl/loss/BaseLoss.java | 10 +- .../api/ops/impl/loss/CosineDistanceLoss.java | 5 + .../linalg/api/ops/impl/loss/HingeLoss.java | 5 + .../linalg/api/ops/impl/loss/HuberLoss.java | 5 + .../linalg/api/ops/impl/loss/LogLoss.java | 5 + .../api/ops/impl/loss/LogPoissonLoss.java | 5 + .../loss/MeanPairwiseSquaredErrorLoss.java | 5 + .../ops/impl/loss/MeanSquaredErrorLoss.java | 5 + .../impl/loss/SigmoidCrossEntropyLoss.java | 5 + .../impl/loss/SoftmaxCrossEntropyLoss.java | 5 + .../nd4j/linalg/api/ops/impl/reduce/Mmul.java | 18 + .../api/ops/impl/reduce/TensorMmul.java | 19 + .../linalg/api/ops/impl/reduce/bool/Any.java | 4 + .../api/ops/impl/reduce/custom/LogSumExp.java | 4 + .../ops/impl/reduce/floating/SquaredNorm.java | 4 + .../impl/reduce/longer/MatchCondition.java | 8 + .../linalg/api/ops/impl/reduce/same/Sum.java | 4 + .../linalg/api/ops/impl/scalar/LeakyReLU.java | 4 + .../nd4j/linalg/api/ops/impl/scalar/Pow.java | 4 + .../api/ops/impl/scalar/RectifiedLinear.java | 4 + .../linalg/api/ops/impl/scalar/Relu6.java | 4 + .../nd4j/linalg/api/ops/impl/scalar/Step.java | 4 + .../scalar/comparison/ScalarLessThan.java | 3 + .../api/ops/impl/scatter/ScatterAdd.java | 5 + .../api/ops/impl/scatter/ScatterDiv.java | 5 + .../api/ops/impl/scatter/ScatterMax.java | 5 + .../api/ops/impl/scatter/ScatterMin.java | 5 + .../api/ops/impl/scatter/ScatterMul.java | 5 + .../api/ops/impl/scatter/ScatterSub.java | 5 + .../api/ops/impl/scatter/ScatterUpdate.java | 5 + .../linalg/api/ops/impl/shape/Concat.java | 8 + .../api/ops/impl/shape/ConfusionMatrix.java | 11 + .../nd4j/linalg/api/ops/impl/shape/Cross.java | 9 +- .../nd4j/linalg/api/ops/impl/shape/Diag.java | 9 +- .../linalg/api/ops/impl/shape/DiagPart.java | 4 + .../linalg/api/ops/impl/shape/ExpandDims.java | 9 + .../nd4j/linalg/api/ops/impl/shape/Eye.java | 7 + .../linalg/api/ops/impl/shape/Gather.java | 23 + .../linalg/api/ops/impl/shape/GatherNd.java | 11 + .../linalg/api/ops/impl/shape/Linspace.java | 14 + .../linalg/api/ops/impl/shape/MeshGrid.java | 4 + .../linalg/api/ops/impl/shape/OneHot.java | 11 + .../linalg/api/ops/impl/shape/OnesLike.java | 8 + .../linalg/api/ops/impl/shape/Permute.java | 5 + .../nd4j/linalg/api/ops/impl/shape/Rank.java | 9 + .../linalg/api/ops/impl/shape/Reshape.java | 4 + .../api/ops/impl/shape/SequenceMask.java | 8 +- .../nd4j/linalg/api/ops/impl/shape/Shape.java | 8 + .../nd4j/linalg/api/ops/impl/shape/Size.java | 6 + .../nd4j/linalg/api/ops/impl/shape/Slice.java | 6 + .../linalg/api/ops/impl/shape/Squeeze.java | 10 + .../nd4j/linalg/api/ops/impl/shape/Stack.java | 10 + .../api/ops/impl/shape/StridedSlice.java | 15 + .../nd4j/linalg/api/ops/impl/shape/Tile.java | 8 + .../linalg/api/ops/impl/shape/Transpose.java | 4 + .../linalg/api/ops/impl/shape/ZerosLike.java | 8 + .../api/ops/impl/summarystats/Variance.java | 4 + .../api/ops/impl/transforms/Cholesky.java | 12 + .../linalg/api/ops/impl/transforms/Pad.java | 4 + .../ops/impl/transforms/bool/IsFinite.java | 6 +- .../api/ops/impl/transforms/bool/IsInf.java | 6 +- .../api/ops/impl/transforms/bool/IsNaN.java | 6 +- .../impl/transforms/custom/BatchToSpace.java | 21 +- .../transforms/custom/BatchToSpaceND.java | 2 +- .../ops/impl/transforms/custom/CumProd.java | 6 +- .../ops/impl/transforms/custom/CumSum.java | 5 +- .../impl/transforms/custom/Dilation2D.java | 33 +- .../transforms/custom/DynamicPartition.java | 13 + .../impl/transforms/custom/DynamicStitch.java | 11 + .../ops/impl/transforms/custom/EqualTo.java | 8 + .../api/ops/impl/transforms/custom/Fill.java | 5 + .../impl/transforms/custom/GreaterThan.java | 8 + .../transforms/custom/GreaterThanOrEqual.java | 9 + .../transforms/custom/InvertPermutation.java | 9 + .../transforms/custom/IsNonDecreasing.java | 11 +- .../transforms/custom/IsNumericTensor.java | 7 + .../custom/IsStrictlyIncreasing.java | 4 + .../ops/impl/transforms/custom/LessThan.java | 8 + .../transforms/custom/LessThanOrEqual.java | 8 + .../transforms/custom/MatrixDeterminant.java | 4 + .../impl/transforms/custom/MatrixInverse.java | 3 + .../impl/transforms/custom/MatrixSetDiag.java | 4 + .../api/ops/impl/transforms/custom/Max.java | 4 + .../api/ops/impl/transforms/custom/Min.java | 3 + .../impl/transforms/custom/NotEqualTo.java | 8 + .../api/ops/impl/transforms/custom/Qr.java | 56 + .../transforms/custom/ReverseSequence.java | 6 + .../ops/impl/transforms/custom/SoftMax.java | 4 + .../impl/transforms/custom/SpaceToBatch.java | 19 +- .../transforms/custom/SpaceToBatchND.java | 2 +- .../api/ops/impl/transforms/custom/Svd.java | 6 + .../api/ops/impl/transforms/custom/Trace.java | 2 +- .../transforms/custom/segment/SegmentMax.java | 5 + .../custom/segment/SegmentMean.java | 5 + .../transforms/custom/segment/SegmentMin.java | 5 + .../custom/segment/SegmentProd.java | 5 + .../transforms/custom/segment/SegmentSum.java | 5 + .../ops/impl/transforms/floating/RSqrt.java | 9 +- .../ops/impl/transforms/floating/Sqrt.java | 4 + .../gradient/HardTanhDerivative.java | 6 +- .../gradient/LeakyReLUDerivative.java | 8 +- .../gradient/SoftSignDerivative.java | 5 +- .../pairwise/arithmetic/MergeAddOp.java | 8 +- .../api/ops/impl/transforms/same/Abs.java | 9 +- .../api/ops/impl/transforms/same/Cube.java | 10 +- .../api/ops/impl/transforms/same/Floor.java | 5 +- .../ops/impl/transforms/same/Identity.java | 4 + .../ops/impl/transforms/same/Negative.java | 6 +- .../ops/impl/transforms/same/Reciprocal.java | 9 +- .../api/ops/impl/transforms/same/Round.java | 9 +- .../api/ops/impl/transforms/same/Sign.java | 4 + .../api/ops/impl/transforms/same/Square.java | 4 + .../segment/UnsortedSegmentMax.java | 6 + .../segment/UnsortedSegmentMean.java | 6 + .../segment/UnsortedSegmentMin.java | 6 + .../segment/UnsortedSegmentProd.java | 6 + .../segment/UnsortedSegmentSum.java | 6 + .../api/ops/impl/transforms/strict/ACos.java | 9 +- .../api/ops/impl/transforms/strict/ACosh.java | 5 +- .../api/ops/impl/transforms/strict/ASin.java | 6 +- .../api/ops/impl/transforms/strict/ATan.java | 5 +- .../api/ops/impl/transforms/strict/Cos.java | 9 +- .../api/ops/impl/transforms/strict/Cosh.java | 9 +- .../api/ops/impl/transforms/strict/Erf.java | 10 +- .../api/ops/impl/transforms/strict/Erfc.java | 10 +- .../api/ops/impl/transforms/strict/Exp.java | 10 +- .../api/ops/impl/transforms/strict/Expm1.java | 9 +- .../api/ops/impl/transforms/strict/GELU.java | 9 +- .../impl/transforms/strict/HardSigmoid.java | 7 +- .../ops/impl/transforms/strict/HardTanh.java | 6 +- .../api/ops/impl/transforms/strict/Log.java | 4 + .../api/ops/impl/transforms/strict/Log1p.java | 6 +- .../impl/transforms/strict/LogSigmoid.java | 5 +- .../api/ops/impl/transforms/strict/SELU.java | 4 + .../ops/impl/transforms/strict/Sigmoid.java | 4 + .../api/ops/impl/transforms/strict/Sin.java | 4 + .../api/ops/impl/transforms/strict/Sinh.java | 4 + .../ops/impl/transforms/strict/SoftPlus.java | 4 + .../ops/impl/transforms/strict/SoftSign.java | 4 + .../api/ops/impl/transforms/strict/Swish.java | 5 +- .../api/ops/impl/transforms/strict/Tan.java | 4 + .../api/ops/impl/transforms/strict/Tanh.java | 4 + .../linalg/api/ops/random/BaseRandomOp.java | 3 +- .../ops/random/custom/RandomExponential.java | 10 + .../random/impl/BernoulliDistribution.java | 9 +- .../ops/random/impl/BinomialDistribution.java | 4 + .../ops/random/impl/GaussianDistribution.java | 14 +- .../random/impl/LogNormalDistribution.java | 11 +- .../impl/TruncatedNormalDistribution.java | 7 + .../ops/random/impl/UniformDistribution.java | 7 +- .../org/nd4j/linalg/factory/NDValidation.java | 22 + .../org/nd4j/linalg/factory/ops/NDBase.java | 2056 +++++++ .../nd4j/linalg/factory/ops/NDBitwise.java | 4 +- .../org/nd4j/linalg/factory/ops/NDCNN.java | 2 +- .../org/nd4j/linalg/factory/ops/NDLinalg.java | 274 + .../org/nd4j/linalg/factory/ops/NDLoss.java | 2 +- .../org/nd4j/linalg/factory/ops/NDMath.java | 77 +- .../org/nd4j/linalg/factory/ops/NDNN.java | 10 +- .../org/nd4j/linalg/factory/ops/NDRNN.java | 0 .../org/nd4j/linalg/factory/ops/NDRandom.java | 11 +- .../opvalidation/LayerOpValidation.java | 44 +- .../opvalidation/LossOpValidation.java | 6 +- .../opvalidation/MiscOpValidation.java | 2 +- .../opvalidation/RandomOpValidation.java | 44 +- .../opvalidation/ReductionOpValidation.java | 22 +- .../opvalidation/ShapeOpValidation.java | 7 +- .../opvalidation/TransformOpValidation.java | 32 +- .../samediff/FailingSameDiffTests.java | 2 +- .../samediff/FlatBufferSerdeTest.java | 8 +- .../autodiff/samediff/NameScopeTests.java | 4 +- .../nd4j/autodiff/samediff/SameDiffTests.java | 15 +- .../samediff/SameDiffTrainingTest.java | 2 +- .../listeners/ExecDebuggingListenerTest.java | 1 + .../samediff/listeners/ListenerTest.java | 2 +- .../nd4j/linalg/custom/CustomOpsTests.java | 40 + .../nd4j/linalg/factory/ops/NDLossTest.java | 17 +- .../nd4j/linalg/generated/SDLinalgTest.java | 285 + 230 files changed, 12090 insertions(+), 6072 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java rename nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/{linalg/factory => }/enums/DataFormat.java (95%) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Logdet.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Qr.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java mode change 100755 => 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java mode change 100755 => 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java index 1be09182c..7b78c14fc 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java @@ -130,14 +130,6 @@ public class SameDiffConv extends SameDiffLayer { SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); - SDVariable[] vars; - if(hasBias){ - SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); - vars = new SDVariable[]{layerInput, w, b}; - } else { - vars = new SDVariable[]{layerInput, w}; - } - Conv2DConfig c = Conv2DConfig.builder() .kH(kernel[0]).kW(kernel[1]) .pH(padding[0]).pW(padding[1]) @@ -146,7 +138,13 @@ public class SameDiffConv extends SameDiffLayer { .isSameMode(this.cm == ConvolutionMode.Same) .build(); - SDVariable conv = sameDiff.cnn().conv2d(vars, c); //TODO can't set name + SDVariable conv = null; + if(hasBias){ + SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); + conv = sameDiff.cnn().conv2d(layerInput, w, b, c); + } else { + conv = sameDiff.cnn().conv2d(layerInput, w, c); + } return activation.asSameDiff("out", sameDiff, conv); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java index 4b4a69159..98bda1f3a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java @@ -31,6 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.util.ArrayUtil; import java.util.Map; @@ -99,15 +100,15 @@ public class CapsuleLayer extends SameDiffLayer { } @Override - public SDVariable defineLayer(SameDiff SD, SDVariable input, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sd, SDVariable input, Map paramTable, SDVariable mask) { // input: [mb, inputCapsules, inputCapsuleDimensions] // [mb, inputCapsules, 1, inputCapsuleDimensions, 1] - SDVariable expanded = SD.expandDims(SD.expandDims(input, 2), 4); + SDVariable expanded = sd.expandDims(sd.expandDims(input, 2), 4); // [mb, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1] - SDVariable tiled = SD.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1); + SDVariable tiled = sd.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1); // [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions] SDVariable weights = paramTable.get(WEIGHT_PARAM); @@ -119,13 +120,13 @@ public class CapsuleLayer extends SameDiffLayer { // b is the logits of the routing procedure // [mb, inputCapsules, capsules, 1, 1] - SDVariable b = SD.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1)); + SDVariable b = sd.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1)); for(int i = 0 ; i < routings ; i++){ // c is the coupling coefficient, i.e. the edge weight between the 2 capsules // [mb, inputCapsules, capsules, 1, 1] - SDVariable c = CapsuleUtils.softmax(SD, b, 2, 5); + SDVariable c = sd.nn.softmax(b, 2); // [mb, 1, capsules, capsuleDimensions, 1] SDVariable s = c.times(uHat).sum(true, 1); @@ -135,14 +136,14 @@ public class CapsuleLayer extends SameDiffLayer { // v is the per capsule activations. On the last routing iteration, this is output // [mb, 1, capsules, capsuleDimensions, 1] - SDVariable v = CapsuleUtils.squash(SD, s, 3); + SDVariable v = CapsuleUtils.squash(sd, s, 3); if(i == routings - 1){ - return SD.squeeze(SD.squeeze(v, 1), 3); + return sd.squeeze(sd.squeeze(v, 1), 3); } // [mb, inputCapsules, capsules, capsuleDimensions, 1] - SDVariable vTiled = SD.tile(v, 1, (int) inputCapsules, 1, 1, 1); + SDVariable vTiled = sd.tile(v, 1, (int) inputCapsules, 1, 1, 1); // [mb, inputCapsules, capsules, 1, 1] b = b.plus(uHat.times(vTiled).sum(true, 3)); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index fc805f0ca..60ecbf057 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -178,9 +178,11 @@ public class LocallyConnected1D extends SameDiffLayer { //Note: for same mode, bottom/right padding can be 1 more than top/left padding //NCW format. if(cm == ConvolutionMode.Same) { - layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, paddingR}}, 0); + layerInput = sameDiff.nn().pad(layerInput, + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), 0); } else { - layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, padding}}, 0); + layerInput = sameDiff.nn().pad(layerInput, + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), 0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index ef07c9dc5..5044017a0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -184,9 +184,11 @@ public class LocallyConnected2D extends SameDiffLayer { //Note: for same mode, bottom/right padding can be 1 more than top/left padding //NCHW format if(cm == ConvolutionMode.Same){ - layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}}, 0); + layerInput = sameDiff.nn().pad(layerInput, + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), 0.0); } else { - layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}}, 0); + layerInput = sameDiff.nn().pad(layerInput, + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), 0.0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/CapsuleUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/CapsuleUtils.java index ff605d028..66d732907 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/CapsuleUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/CapsuleUtils.java @@ -45,15 +45,4 @@ public class CapsuleUtils { return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale)); } - /** - * Compute softmax along a given dimension - */ - public static SDVariable softmax(SameDiff SD, SDVariable x, int dimension, int rank){ - int[] permutation = ArrayUtil.range(0, rank); - permutation[0] = dimension; - permutation[dimension] = 0; - - return SD.nn.softmax(x.permute(permutation)).permute(ArrayUtil.invertPermutation(permutation)); - } - } diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java index 8b060a77c..e94ffbb40 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java @@ -495,7 +495,7 @@ public class JsonModelServerTest extends BaseDL4JTest { SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28); SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10)); SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10)); - SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b)); + SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b), -1); val server = new JsonModelServer.Builder(sd) .outputSerializer( new IntSerde()) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java index 4ba24eafa..7401874d3 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java @@ -58,7 +58,7 @@ public class TestSameDiffUI extends BaseDL4JTest { SDVariable b = sd.var("b", DataType.FLOAT, 1, 4); SDVariable z = in.mmul(w).add(b); - SDVariable a = sd.nn().tanh(z); + SDVariable a = sd.math().tanh(z); LogFileWriter lfw = new LogFileWriter(f); lfw.writeGraphStructure(sd); diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java index a3fdc0c3f..ced461089 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java @@ -20,6 +20,7 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.integration.ModelType; import org.deeplearning4j.integration.TestCase; +import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 8a9bd8edc..fcb63ea0a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -28,6 +28,7 @@ import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.enums.DataFormat; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; @@ -1489,7 +1490,7 @@ public class DifferentialFunctionFactory { } public SDVariable reciprocal(SDVariable a) { - return new Reciprocal(sameDiff(), a, false).outputVariable(); + return new Reciprocal(sameDiff(), a).outputVariable(); } @@ -1990,13 +1991,13 @@ public class DifferentialFunctionFactory { .outputVariable(); } - public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, String dataFormat) { + public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { validateDifferentialFunctionsameDiff(differentialFunction); return new DepthToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) .outputVariable(); } - public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, String dataFormat) { + public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { validateDifferentialFunctionsameDiff(differentialFunction); return new SpaceToDepth(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) .outputVariable(); @@ -2635,7 +2636,7 @@ public class DifferentialFunctionFactory { return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable(); } - public SDVariable[] maxPoolWithArgmaxs(SDVariable x, Pooling2DConfig pooling2DConfig) { + public SDVariable[] maxPoolWithArgmax(SDVariable x, Pooling2DConfig pooling2DConfig) { return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 3411e2007..ab3279fd0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -181,6 +181,11 @@ public class SameDiff extends SDBaseOps { */ public final SDBitwise bitwise = new SDBitwise(this); + /** + * Op creator object for linalg operations + */ + public final SDLinalg linalg = new SDLinalg(this); + /** * Op creator object for math operations */ @@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps { return bitwise; } + /** + * Op creator object for linalg operations + */ + public SDLinalg linalg(){ + return linalg; + } + private Map sameDiffFunctionInstances; private Table fieldVariableResolutionMapping; @@ -3448,6 +3460,12 @@ public class SameDiff extends SDBaseOps { sd.renameVariable(from, to); } } + + //Check losses: + if(lossVariables.contains(from)){ + int idx = lossVariables.indexOf(from); + lossVariables.set(idx, to); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java index a255afbc3..956444ffe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java @@ -1,217 +1,416 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; +public class SDBitwise extends SDOps { + public SDBitwise(SameDiff sameDiff) { + super(sameDiff); + } -/** - * - */ -public class SDBitwise extends SDOps { - public SDBitwise(SameDiff sameDiff) { - super(sameDiff); - } + /** + * Bitwise AND operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y Second input array (INT type) + * @return output Bitwise AND array (INT type) + */ + public SDVariable and(SDVariable x, SDVariable y) { + SDValidation.validateInteger("and", "x", x); + SDValidation.validateInteger("and", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); + } - /** - * See {@link #leftShift(String, SDVariable, SDVariable)} - */ - public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){ - return leftShift(null, x, y); - } + /** + * Bitwise AND operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param name name May be null. Name for the output variable + * @param x First input array (INT type) + * @param y Second input array (INT type) + * @return output Bitwise AND array (INT type) + */ + public SDVariable and(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("and", "x", x); + SDValidation.validateInteger("and", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Bitwise left shift operation. Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x Input to be bit shifted (must be an integer type) - * @param y Amount to shift elements of x array (must be an integer type) - * @return Bitwise shifted input x - */ - public SDVariable leftShift(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise left shift", x); - validateInteger("bitwise left shift", y); + /** + * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitRotl(SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitRotl", "x", x); + SDValidation.validateInteger("bitRotl", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + } - SDVariable ret = f().shift(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitRotl", "x", x); + SDValidation.validateInteger("bitRotl", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #rightShift(String, SDVariable, SDVariable)} - */ - public SDVariable rightShift(SDVariable x, SDVariable y){ - return rightShift(null, x, y); - } + /** + * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitRotr(SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitRotr", "x", x); + SDValidation.validateInteger("bitRotr", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + } - /** - * Bitwise right shift operation. Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x Input to be bit shifted (must be an integer type) - * @param y Amount to shift elements of x array (must be an integer type) - * @return Bitwise shifted input x - */ - public SDVariable rightShift(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise right shift", x); - validateInteger("bitwise right shift", y); + /** + * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitRotr", "x", x); + SDValidation.validateInteger("bitRotr", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - SDVariable ret = f().rshift(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Shift integer bits to the left, i.e. var << 4
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitShift(SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitShift", "x", x); + SDValidation.validateInteger("bitShift", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + } - /** - * See {@link #leftShiftCyclic(String, SDVariable, SDVariable)} - */ - public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){ - return leftShiftCyclic(null, x, y); - } + /** + * Shift integer bits to the left, i.e. var << 4
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitShift", "x", x); + SDValidation.validateInteger("bitShift", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Bitwise left cyclical shift operation. Supports broadcasting. - * Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around": - * {@code leftShiftCyclic(01110000, 2) -> 11000001} - * - * @param name Name of the output variable. May be null. - * @param x Input to be bit shifted (must be an integer type) - * @param y Amount to shift elements of x array (must be an integer type) - * @return Bitwise cyclic shifted input x - */ - public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise left shift (cyclic)", x); - validateInteger("bitwise left shift (cyclic)", y); + /** + * Shift integer bits to the right, i.e. var >> 4
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitShiftRight(SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitShiftRight", "x", x); + SDValidation.validateInteger("bitShiftRight", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + } - SDVariable ret = f().rotl(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Shift integer bits to the right, i.e. var >> 4
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitShiftRight", "x", x); + SDValidation.validateInteger("bitShiftRight", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #rightShiftCyclic(String, SDVariable, SDVariable)} - */ - public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){ - return rightShiftCyclic(null, x, y); - } + /** + * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * + * @param x First input array. (INT type) + * @param y Second input array. (INT type) + * @return output bitwise Hamming distance (INT type) + */ + public SDVariable bitsHammingDistance(SDVariable x, SDVariable y) { + SDValidation.validateInteger("bitsHammingDistance", "x", x); + SDValidation.validateInteger("bitsHammingDistance", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); + } - /** - * Bitwise right cyclical shift operation. Supports broadcasting. - * Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around": - * {@code rightShiftCyclic(00001110, 2) -> 10000011} - * - * @param name Name of the output variable. May be null. - * @param x Input to be bit shifted (must be an integer type) - * @param y Amount to shift elements of x array (must be an integer type) - * @return Bitwise cyclic shifted input x - */ - public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise right shift (cyclic)", x); - validateInteger("bitwise right shift (cyclic)", y); + /** + * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * + * @param name name May be null. Name for the output variable + * @param x First input array. (INT type) + * @param y Second input array. (INT type) + * @return output bitwise Hamming distance (INT type) + */ + public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("bitsHammingDistance", "x", x); + SDValidation.validateInteger("bitsHammingDistance", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - SDVariable ret = f().rotr(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Bitwise left shift operation. Supports broadcasting.
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public SDVariable leftShift(SDVariable x, SDVariable y) { + SDValidation.validateInteger("leftShift", "x", x); + SDValidation.validateInteger("leftShift", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); + } - /** - * See {@link #bitsHammingDistance(String, SDVariable, SDVariable)} - */ - public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){ - return bitsHammingDistance(null, x, y); - } + /** + * Bitwise left shift operation. Supports broadcasting.
+ * + * @param name name May be null. Name for the output variable + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public SDVariable leftShift(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("leftShift", "x", x); + SDValidation.validateInteger("leftShift", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Bitwise Hamming distance reduction over all elements of both input arrays.
- * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1) - * - * @param name Name of the output variable. May be null. - * @param x First input array. Must be integer type. - * @param y First input array. Must be integer type, same type as x - * @return - */ - public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise hamming distance", x); - validateInteger("bitwise hamming distance", y); + /** + * Bitwise left cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code leftShiftCyclic(01110000, 2) -> 11000001}
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) { + SDValidation.validateInteger("leftShiftCyclic", "x", x); + SDValidation.validateInteger("leftShiftCyclic", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); + } - SDVariable ret = f().bitwiseHammingDist(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Bitwise left cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code leftShiftCyclic(01110000, 2) -> 11000001}
+ * + * @param name name May be null. Name for the output variable + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("leftShiftCyclic", "x", x); + SDValidation.validateInteger("leftShiftCyclic", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #and(String, SDVariable, SDVariable)} - */ - public SDVariable and(SDVariable x, SDVariable y){ - return and(null, x, y); - } + /** + * Bitwise OR operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise OR array (INT type) + */ + public SDVariable or(SDVariable x, SDVariable y) { + SDValidation.validateInteger("or", "x", x); + SDValidation.validateInteger("or", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); + } - /** - * Bitwise AND operation. Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x First input array. Must be integer type. - * @param y First input array. Must be integer type, same type as x - * @return Bitwise AND array - */ - public SDVariable and(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise AND", x); - validateInteger("bitwise AND", y); + /** + * Bitwise OR operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param name name May be null. Name for the output variable + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise OR array (INT type) + */ + public SDVariable or(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("or", "x", x); + SDValidation.validateInteger("or", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - SDVariable ret = f().bitwiseAnd(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Bitwise right shift operation. Supports broadcasting.
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public SDVariable rightShift(SDVariable x, SDVariable y) { + SDValidation.validateInteger("rightShift", "x", x); + SDValidation.validateInteger("rightShift", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); + } - /** - * See {@link #or(String, SDVariable, SDVariable)} - */ - public SDVariable or(SDVariable x, SDVariable y){ - return or(null, x, y); - } + /** + * Bitwise right shift operation. Supports broadcasting.
+ * + * @param name name May be null. Name for the output variable + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public SDVariable rightShift(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("rightShift", "x", x); + SDValidation.validateInteger("rightShift", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Bitwise OR operation. Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x First input array. Must be integer type. - * @param y First input array. Must be integer type, same type as x - * @return Bitwise OR array - */ - public SDVariable or(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise OR", x); - validateInteger("bitwise OR", y); + /** + * Bitwise right cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code rightShiftCyclic(00001110, 2) -> 10000011}
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) { + SDValidation.validateInteger("rightShiftCyclic", "x", x); + SDValidation.validateInteger("rightShiftCyclic", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); + } - SDVariable ret = f().bitwiseOr(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Bitwise right cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code rightShiftCyclic(00001110, 2) -> 10000011}
+ * + * @param name name May be null. Name for the output variable + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("rightShiftCyclic", "x", x); + SDValidation.validateInteger("rightShiftCyclic", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #xor(String, SDVariable, SDVariable)} - */ - public SDVariable xor(SDVariable x, SDVariable y){ - return xor(null, x, y); - } + /** + * Bitwise XOR operation (exclusive OR). Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise XOR array (INT type) + */ + public SDVariable xor(SDVariable x, SDVariable y) { + SDValidation.validateInteger("xor", "x", x); + SDValidation.validateInteger("xor", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); + } - /** - * Bitwise XOR operation (exclusive OR). Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x First input array. Must be integer type. - * @param y First input array. Must be integer type, same type as x - * @return Bitwise XOR array - */ - public SDVariable xor(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise XOR", x); - validateInteger("bitwise XOR", y); - - SDVariable ret = f().bitwiseXor(x, y); - return updateVariableNameAndReference(ret, name); - } - - /** - * Flip bits - * - * @param name Name of the output variable - * @param x input array - * @return array after flipping each input bit - */ - public SDVariable toggleBits(String name, SDVariable x) { - SDVariable res = f().toggleBits(x); - return updateVariableNameAndReference(res, name); - } + /** + * Bitwise XOR operation (exclusive OR). Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param name name May be null. Name for the output variable + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise XOR array (INT type) + */ + public SDVariable xor(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("xor", "x", x); + SDValidation.validateInteger("xor", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java index 7b56ca266..d367e3d4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,777 +14,1015 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; +import org.nd4j.enums.DataFormat; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateFloatingPoint; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateNumerical; - -/** - * SameDiff Convolutional Neural Network operations - CNN1d, 2d and 3d ops - as well as related functions.
- * Accessible via {@link SameDiff#cnn()}
- * See also {@link SDNN} (accessible via {@link SameDiff#nn()} for general neural network ops.
- * See also {@link SDRNN} (accessible via {@link SameDiff#rnn()} for recurrent neural network ops.
- * - * @author Alex Black - */ public class SDCNN extends SDOps { - - public SDCNN(SameDiff sameDiff) { - super(sameDiff); - } - - /** - * See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}. - */ - public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { - return avgPooling2d(null, input, pooling2DConfig); - } - - /** - * 2D Convolution layer operation - average pooling 2d - * - * @param name name of the operation in SameDiff - * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param pooling2DConfig the configuration - * @return Result after applying average pooling on the input - */ - public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { - validateFloatingPoint("avgPooling2d", input); - SDVariable ret = f().avgPooling2d(input, pooling2DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}. - */ - public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { - return avgPooling3d(null, input, pooling3DConfig); - } - - /** - * 3D convolution layer operation - average pooling 3d - * - * @param name name of the operation in SameDiff - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param pooling3DConfig the configuration - * @return Result after applying average pooling on the input - */ - public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { - validateFloatingPoint("avgPooling3d", input); - SDVariable ret = f().avgPooling3d(input, pooling3DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #batchToSpace(String, SDVariable, int[], int[][]) - */ - public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) { - return batchToSpace(null, x, blocks, crops); - } - - /** - * Convolution 2d layer batch to space operation on 4d input. - * Reduces input batch dimension by rearranging data into a larger spatial dimensions - * - * @param name Output variable name - * @param x Input variable. 4d input - * @param blocks Block size, in the height/width dimension - * @param crops Optional 2d int[] array: values [[crop top, crop bottom], [crop left, crop right]] - * @return Output variable - * @see #spaceToBatch(String, SDVariable, int[], int[][]) - */ - public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) { - validateNumerical("batchToSpace", x); - SDVariable ret = f().batchToSpace(x, blocks, crops); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #col2Im(String, SDVariable, Conv2DConfig)}. - */ - public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) { - return col2Im(null, in, config); - } - - /** - * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape - * [minibatch, inputChannels, height, width] - * - * @param name Name of the output variable - * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] - * @param config Convolution configuration for the col2im operation - * @return Col2Im output variable - */ - public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) { - SDVariable ret = f().col2Im(in, config); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias. - */ - public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) { - return conv1d((String) null, input, weights, conv1DConfig); - } - - /** - * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias. - */ - public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) { - validateFloatingPoint("conv1d", input); - validateFloatingPoint("conv1d", weights); - SDVariable ret = f().conv1d(input, weights, conv1DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}. - */ - public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { - return conv1d(null, input, weights, bias, conv1DConfig); - } - - /** - * Conv1d operation. - * - * @param name name of the operation in SameDiff - * @param input the inputs to conv1d - * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] - * @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. - * @param conv1DConfig the configuration - * @return - */ - public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { - validateFloatingPoint("conv1d", input); - validateFloatingPoint("conv1d", weights); - validateFloatingPoint("conv1d", bias); - SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) { - return conv2d(layerInput, weights, null, config); - } - - /** - * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) { - return conv2d(name, layerInput, weights, null, config); - } - - /** - * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. - */ - public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) { - return conv2d(null, layerInput, weights, bias, config); - } - - /** - * 2D Convolution operation with optional bias - * - * @param name name of the operation in SameDiff - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param config Conv2DConfig configuration - * @return result of conv2d op - */ - public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) { - validateFloatingPoint("conv2d", "input", layerInput); - validateFloatingPoint("conv2d", "weights", weights); - validateFloatingPoint("conv2d", "bias", bias); - SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; - arr[0] = layerInput; - arr[1] = weights; - if (bias != null) - arr[2] = bias; - return conv2d(name, arr, config); - } - - /** - * See {@link #conv2d(String, SDVariable[], Conv2DConfig)}. - */ - public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) { - return conv2d(null, inputs, config); - } - - /** - * 2D Convolution operation with optional bias - * - * @param name Name of the output SDVariable - * @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as - * described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param config Conv2DConfig configuration - * @return result of convolution 2d operation - */ - public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) { - for(SDVariable v : inputs) - validateNumerical("conv2d", v); - SDVariable ret = f().conv2d(inputs, config); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias. - */ - public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) { - return conv3d(null, input, weights, null, conv3DConfig); - } - - /** - * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias. - */ - public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) { - return conv3d(name, input, weights, null, conv3DConfig); - } - - /** - * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}. - */ - public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) { - return conv3d(null, input, weights, bias, conv3DConfig); - } - - /** - * Convolution 3D operation with optional bias - * - * @param name Name of the output variable - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param conv3DConfig the configuration - * @return Conv3d output variable - */ - public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) { - validateFloatingPoint("conv3d", "input", input); - validateFloatingPoint("conv3d", "weights", weights); - validateFloatingPoint("conv3d", "bias", bias); - SDVariable[] args; - if (bias == null) { - args = new SDVariable[]{input, weights}; - } else { - args = new SDVariable[]{input, weights, bias}; - } - SDVariable ret = f().conv3d(args, conv3DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias. - */ - public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) { - return deconv2d(layerInput, weights, null, deconv2DConfig); - } - - /** - * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias. - */ - public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) { - return deconv2d(name, layerInput, weights, null, deconv2DConfig); - } - - /** - * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}. - */ - public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) { - return deconv2d(null, layerInput, weights, bias, deconv2DConfig); - } - - /** - * 2D deconvolution operation with optional bias - * - * @param name name of the operation in SameDiff - * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth]. - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param deconv2DConfig DeConv2DConfig configuration - * @return result of deconv2d op - */ - public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) { - validateFloatingPoint("deconv2d", "input", layerInput); - validateFloatingPoint("deconv2d", "weights", weights); - validateFloatingPoint("deconv2d", "bias", bias); - SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; - arr[0] = layerInput; - arr[1] = weights; - if (bias != null) - arr[2] = bias; - return deconv2d(name, arr, deconv2DConfig); - } - - /** - * See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}. - */ - public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) { - return deconv2d(null, inputs, deconv2DConfig); - } - - /** - * 2D deconvolution operation with or without optional bias - * - * @param name Name of the output variable - * @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights) - * or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)} - * @param deconv2DConfig the configuration - * @return result of deconv2d op - */ - public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) { - for(SDVariable v : inputs) - validateNumerical("deconv2d", v); - SDVariable ret = f().deconv2d(inputs, deconv2DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias. - */ - public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { - return deconv3d(input, weights, null, config); - } - - /** - * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias. - */ - public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { - return deconv3d(name, input, weights, null, config); - } - - /** - * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}. - */ - public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { - return deconv3d(null, input, weights, bias, config); - } - - /** - * 3D CNN deconvolution operation with or without optional bias - * - * @param name Name of the output variable - * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - * @param weights Weights array - shape [kD, kH, kW, oC, iC] - * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] - * @param config Configuration - */ - public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { - validateFloatingPoint("conv3d", input); - validateFloatingPoint("conv3d", weights); - validateFloatingPoint("conv3d", bias); - SDVariable ret = f().deconv3d(input, weights, bias, config); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #depthToSpace(String, SDVariable, int, String)}. - */ - public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) { - return depthToSpace(null, x, blockSize, dataFormat); - } - - /** - * Convolution 2d layer batch to space operation on 4d input.
- * Reduces input channels dimension by rearranging data into a larger spatial dimensions
- * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2] - * = [mb, 2, 4, 4] - * - * @param name Output variable name - * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param blockSize Block size, in the height/width dimension - * @param dataFormat Data format: "NCHW" or "NHWC" - * @return Output variable - * @see #depthToSpace(String, SDVariable, int, String) - */ - public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) { - SDVariable ret = f().depthToSpace(x, blockSize, dataFormat); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) { - return depthWiseConv2d(layerInput, depthWeights, null, config); - } - - /** - * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) { - return depthWiseConv2d(name, layerInput, depthWeights, null, config); - } - - /** - * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. - */ - public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) { - return depthWiseConv2d(null, layerInput, depthWeights, bias, config); - } - - /** - * Depth-wise 2D convolution operation with optional bias - * - * @param name name of the operation in SameDiff - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param config Conv2DConfig configuration - * @return result of depthwise conv2d op - */ - public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) { - validateFloatingPoint("depthwiseConv2d", "input", layerInput); - validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights); - validateFloatingPoint("depthwiseConv2d", "bias", bias); - SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; - arr[0] = layerInput; - arr[1] = depthWeights; - if (bias != null) - arr[2] = bias; - return depthWiseConv2d(name, arr, config); - } - - /** - * See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}. - */ - public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) { - return depthWiseConv2d(null, inputs, depthConv2DConfig); - } - - /** - * Depth-wise convolution 2D operation. - * - * @param name name of the output variable - * @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights) - * or 3 elements (layerInput, depthWeights, bias) as described in - * {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param depthConv2DConfig the configuration - * @return result of depthwise conv2d op - */ - public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) { - for(SDVariable v : inputs) - validateFloatingPoint("depthWiseConv2d", v); - SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}. - */ - public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides, - @NonNull int[] rates, @NonNull boolean isSameMode) { - return dilation2D(null, df, weights, strides, rates, isSameMode); - } - - /** - * TODO doc string - * - * @param name - * @param df - * @param weights - * @param strides - * @param rates - * @param isSameMode - * @return - */ - public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides, - @NonNull int[] rates, @NonNull boolean isSameMode) { - SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode); - return updateVariableNameAndReference(ret, name); - } - - - /** - * Extract image patches - * - * @param name Name of the output variable - * @param input Input array. Must be rank 4, with shape [minibatch, height, width, channels] - * @param kH Kernel height - * @param kW Kernel width - * @param sH Stride height - * @param sW Stride width - * @param rH Rate height - * @param rW Rate width - * @param sameMode If true: use same mode padding. If false - * @return - */ - public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { - SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #im2Col(String, SDVariable, Conv2DConfig)}. - */ - public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) { - return im2Col(null, in, config); - } - - /** - * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape - * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] - * - * @param name Name of the output variable - * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] - * @param config Convolution configuration for the im2col operation - * @return Im2Col output variable - */ - public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) { - SDVariable ret = f().im2Col(in, config); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}. - */ - public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) { - return localResponseNormalization(null, inputs, lrnConfig); - } - - /** - * 2D convolution layer operation - local response normalization - * - * @param name name of the operation in SameDiff - * @param input the inputs to lrn - * @param lrnConfig the configuration - * @return - */ - public SDVariable localResponseNormalization(String name, @NonNull SDVariable input, - @NonNull LocalResponseNormalizationConfig lrnConfig) { - validateFloatingPoint("local response normalization", input); - SDVariable ret = f().localResponseNormalization(input, lrnConfig); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}. - */ - public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { - return maxPooling2d(null, input, pooling2DConfig); - } - - /** - * 2D Convolution layer operation - max pooling 2d - * - * @param name name of the operation in SameDiff - * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param pooling2DConfig the configuration - * @return Result after applying max pooling on the input - */ - public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { - validateNumerical("maxPooling2d", input); - SDVariable ret = f().maxPooling2d(input, pooling2DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}. - */ - public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { - return maxPooling3d(null, input, pooling3DConfig); - } - - /** - * 3D convolution layer operation - max pooling 3d operation. - * - * @param name name of the operation in SameDiff - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param pooling3DConfig the configuration - * @return Result after applying max pooling on the input - */ - public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { - validateNumerical("maxPooling3d", input); - SDVariable ret = f().maxPooling3d(input, pooling3DConfig); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, - @NonNull Conv2DConfig config) { - return separableConv2d(layerInput, depthWeights, pointWeights, null, config); - } - - - /** - * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, - @NonNull Conv2DConfig config) { - return separableConv2d(layerInput, depthWeights, pointWeights, null, config); - } - - /** - * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. - */ - public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, - SDVariable bias, @NonNull Conv2DConfig config) { - return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config); - } - - /** - * Separable 2D convolution operation with optional bias - * - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] - * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels] - * May be null - * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. - * @param config Conv2DConfig configuration - * @return result of separable convolution 2d operation - */ - public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, - SDVariable bias, @NonNull Conv2DConfig config) { - validateFloatingPoint("separableConv2d", "input", layerInput); - validateFloatingPoint("separableConv2d", "depthWeights", depthWeights); - validateFloatingPoint("separableConv2d", "pointWeights", pointWeights); - validateFloatingPoint("separableConv2d", "bias", bias); - SDVariable[] arr = new SDVariable[bias == null ? 3 : 4]; - arr[0] = layerInput; - arr[1] = depthWeights; - arr[2] = pointWeights; - if (bias != null) - arr[3] = bias; - return sconv2d(name, arr, config); - } - - /** - * See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}. - */ - public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) { - return sconv2d(null, inputs, conv2DConfig); - } - - /** - * Separable 2D convolution operation with/without optional bias - * - * @param name name of the output variable - * @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights) - * or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param conv2DConfig the configuration - * @return result of separable convolution 2d operation - */ - public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) { - for(SDVariable v : inputs) - validateFloatingPoint("sconv2d", v); - SDVariable ret = f().sconv2d(inputs, conv2DConfig); - return updateVariableNameAndReference(ret, name); - } - - - /** - * @see #spaceToBatch(String, SDVariable, int[], int[][]) - */ - public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) { - return spaceToBatch(null, x, blocks, padding); - } - - /** - * Convolution 2d layer space to batch operation on 4d input. - * Increases input batch dimension by rearranging data from spatial dimensions into batch dimension - * - * @param name Output variable name - * @param x Input variable. 4d input - * @param blocks Block size, in the height/width dimension - * @param padding Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] - * @return Output variable - * @see #batchToSpace(String, SDVariable, int[], int[][]) - */ - public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) { - SDVariable ret = f().spaceToBatch(x, blocks, padding); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #spaceToDepth(String, SDVariable, int, String) - */ - public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) { - return spaceToDepth(null, x, blockSize, dataFormat); - } - - /** - * Convolution 2d layer space to depth operation on 4d input.
- * Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension
- * Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2] - * = [mb, 2, 4, 4] - * - * @param name Output variable name - * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param blockSize Block size, in the height/width dimension - * @param dataFormat Data format: "NCHW" or "NHWC" - * @return Output variable - * @see #depthToSpace(String, SDVariable, int, String) - */ - public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) { - SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}, - * scale is used for both height and width dimensions. - * - * @param scale The scale for both height and width dimensions. - */ - public SDVariable upsampling2d(@NonNull SDVariable input, int scale) { - return upsampling2d(null, input, true, scale, scale); - } - - /** - * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}, - * scale is used for both height and width dimensions. - * - * @param scale The scale for both height and width dimensions. - */ - public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) { - return upsampling2d(name, input, true, scale, scale); - } - - /** - * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}. - */ - public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) { - return upsampling2d(null, input, nchw, scaleH, scaleW); - } - - /** - * 2D Convolution layer operation - Upsampling 2d - * - * @param input Input, in NCHW format - * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format - * @param scaleH Scale to upsample in height dimension - * @param scaleW Scale to upsample in width dimension - * @return Upsampled input - */ - public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) { - SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW); - return updateVariableNameAndReference(ret, name); - } + public SDCNN(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * 2D Convolution layer operation - average pooling 2d
+ * + * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying average pooling on the input (NUMERIC type) + */ + public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("avgPooling2d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D(sd,input, Pooling2DConfig).outputVariable(); + } + + /** + * 2D Convolution layer operation - average pooling 2d
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying average pooling on the input (NUMERIC type) + */ + public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("avgPooling2d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D(sd,input, Pooling2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 3D convolution layer operation - average pooling 3d
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output after applying average pooling on the input (NUMERIC type) + */ + public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig Pooling3DConfig) { + SDValidation.validateNumerical("avgPooling3d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D(sd,input, Pooling3DConfig).outputVariable(); + } + + /** + * 3D convolution layer operation - average pooling 3d
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output after applying average pooling on the input (NUMERIC type) + */ + public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig Pooling3DConfig) { + SDValidation.validateNumerical("avgPooling3d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D(sd,input, Pooling3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input batch dimension by rearranging data into a larger spatial dimensions
+ * + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param croppingTop (Size: Exactly(count=2)) + * @param croppingBottom (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable batchToSpace(SDVariable x, int[] blocks, int[] croppingTop, + int... croppingBottom) { + SDValidation.validateNumerical("batchToSpace", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(croppingTop.length == 2, "croppingTop has incorrect size/length. Expected: croppingTop.length == 2, got %s", croppingTop.length); + Preconditions.checkArgument(croppingBottom.length == 2, "croppingBottom has incorrect size/length. Expected: croppingBottom.length == 2, got %s", croppingBottom.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace(sd,x, blocks, croppingTop, croppingBottom).outputVariable(); + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input batch dimension by rearranging data into a larger spatial dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param croppingTop (Size: Exactly(count=2)) + * @param croppingBottom (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[] croppingTop, + int... croppingBottom) { + SDValidation.validateNumerical("batchToSpace", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(croppingTop.length == 2, "croppingTop has incorrect size/length. Expected: croppingTop.length == 2, got %s", croppingTop.length); + Preconditions.checkArgument(croppingBottom.length == 2, "croppingBottom has incorrect size/length. Expected: croppingBottom.length == 2, got %s", croppingBottom.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace(sd,x, blocks, croppingTop, croppingBottom).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
+ * [minibatch, inputChannels, height, width]
+ * + * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Col2Im output variable (NUMERIC type) + */ + public SDVariable col2Im(SDVariable in, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("col2Im", "in", in); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im(sd,in, Conv2DConfig).outputVariable(); + } + + /** + * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
+ * [minibatch, inputChannels, height, width]
+ * + * @param name name May be null. Name for the output variable + * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Col2Im output variable (NUMERIC type) + */ + public SDVariable col2Im(String name, SDVariable in, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("col2Im", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im(sd,in, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Conv1d operation.
+ * + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, + Conv1DConfig Conv1DConfig) { + SDValidation.validateNumerical("conv1d", "input", input); + SDValidation.validateNumerical("conv1d", "weights", weights); + SDValidation.validateNumerical("conv1d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(sd,input, weights, bias, Conv1DConfig).outputVariable(); + } + + /** + * Conv1d operation.
+ * + * @param name name May be null. Name for the output variable + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public SDVariable conv1d(String name, SDVariable input, SDVariable weights, SDVariable bias, + Conv1DConfig Conv1DConfig) { + SDValidation.validateNumerical("conv1d", "input", input); + SDValidation.validateNumerical("conv1d", "weights", weights); + SDValidation.validateNumerical("conv1d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(sd,input, weights, bias, Conv1DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Conv1d operation.
+ * + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig Conv1DConfig) { + SDValidation.validateNumerical("conv1d", "input", input); + SDValidation.validateNumerical("conv1d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(sd,input, weights, null, Conv1DConfig).outputVariable(); + } + + /** + * Conv1d operation.
+ * + * @param name name May be null. Name for the output variable + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public SDVariable conv1d(String name, SDVariable input, SDVariable weights, + Conv1DConfig Conv1DConfig) { + SDValidation.validateNumerical("conv1d", "input", input); + SDValidation.validateNumerical("conv1d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(sd,input, weights, null, Conv1DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("conv2d", "layerInput", layerInput); + SDValidation.validateNumerical("conv2d", "weights", weights); + SDValidation.validateNumerical("conv2d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(sd,layerInput, weights, bias, Conv2DConfig).outputVariable(); + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public SDVariable conv2d(String name, SDVariable layerInput, SDVariable weights, SDVariable bias, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("conv2d", "layerInput", layerInput); + SDValidation.validateNumerical("conv2d", "weights", weights); + SDValidation.validateNumerical("conv2d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(sd,layerInput, weights, bias, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("conv2d", "layerInput", layerInput); + SDValidation.validateNumerical("conv2d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(sd,layerInput, weights, null, Conv2DConfig).outputVariable(); + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public SDVariable conv2d(String name, SDVariable layerInput, SDVariable weights, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("conv2d", "layerInput", layerInput); + SDValidation.validateNumerical("conv2d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(sd,layerInput, weights, null, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, + Conv3DConfig Conv3DConfig) { + SDValidation.validateNumerical("conv3d", "input", input); + SDValidation.validateNumerical("conv3d", "weights", weights); + SDValidation.validateNumerical("conv3d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(sd,input, weights, bias, Conv3DConfig).outputVariable(); + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, + Conv3DConfig Conv3DConfig) { + SDValidation.validateNumerical("conv3d", "input", input); + SDValidation.validateNumerical("conv3d", "weights", weights); + SDValidation.validateNumerical("conv3d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(sd,input, weights, bias, Conv3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig Conv3DConfig) { + SDValidation.validateNumerical("conv3d", "input", input); + SDValidation.validateNumerical("conv3d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(sd,input, weights, null, Conv3DConfig).outputVariable(); + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public SDVariable conv3d(String name, SDVariable input, SDVariable weights, + Conv3DConfig Conv3DConfig) { + SDValidation.validateNumerical("conv3d", "input", input); + SDValidation.validateNumerical("conv3d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(sd,input, weights, null, Conv3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, + DeConv2DConfig DeConv2DConfig) { + SDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + SDValidation.validateNumerical("deconv2d", "weights", weights); + SDValidation.validateNumerical("deconv2d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(sd,layerInput, weights, bias, DeConv2DConfig).outputVariable(); + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public SDVariable deconv2d(String name, SDVariable layerInput, SDVariable weights, + SDVariable bias, DeConv2DConfig DeConv2DConfig) { + SDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + SDValidation.validateNumerical("deconv2d", "weights", weights); + SDValidation.validateNumerical("deconv2d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(sd,layerInput, weights, bias, DeConv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, + DeConv2DConfig DeConv2DConfig) { + SDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + SDValidation.validateNumerical("deconv2d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(sd,layerInput, weights, null, DeConv2DConfig).outputVariable(); + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public SDVariable deconv2d(String name, SDVariable layerInput, SDVariable weights, + DeConv2DConfig DeConv2DConfig) { + SDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + SDValidation.validateNumerical("deconv2d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(sd,layerInput, weights, null, DeConv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, + DeConv3DConfig DeConv3DConfig) { + SDValidation.validateNumerical("deconv3d", "input", input); + SDValidation.validateNumerical("deconv3d", "weights", weights); + SDValidation.validateNumerical("deconv3d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(sd,input, weights, bias, DeConv3DConfig).outputVariable(); + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param name name May be null. Name for the output variable + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, + DeConv3DConfig DeConv3DConfig) { + SDValidation.validateNumerical("deconv3d", "input", input); + SDValidation.validateNumerical("deconv3d", "weights", weights); + SDValidation.validateNumerical("deconv3d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(sd,input, weights, bias, DeConv3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig DeConv3DConfig) { + SDValidation.validateNumerical("deconv3d", "input", input); + SDValidation.validateNumerical("deconv3d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(sd,input, weights, null, DeConv3DConfig).outputVariable(); + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param name name May be null. Name for the output variable + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, + DeConv3DConfig DeConv3DConfig) { + SDValidation.validateNumerical("deconv3d", "input", input); + SDValidation.validateNumerical("deconv3d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(sd,input, weights, null, DeConv3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input channels dimension by rearranging data into a larger spatial dimensions
+ * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public SDVariable depthToSpace(SDVariable x, int blockSize, DataFormat dataFormat) { + SDValidation.validateNumerical("depthToSpace", "x", x); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace(sd,x, blockSize, dataFormat).outputVariable(); + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input channels dimension by rearranging data into a larger spatial dimensions
+ * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param name name May be null. Name for the output variable + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public SDVariable depthToSpace(String name, SDVariable x, int blockSize, DataFormat dataFormat) { + SDValidation.validateNumerical("depthToSpace", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace(sd,x, blockSize, dataFormat).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("depthWiseConv2d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(sd,layerInput, depthWeights, bias, Conv2DConfig).outputVariable(); + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public SDVariable depthWiseConv2d(String name, SDVariable layerInput, SDVariable depthWeights, + SDVariable bias, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("depthWiseConv2d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(sd,layerInput, depthWeights, bias, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(sd,layerInput, depthWeights, null, Conv2DConfig).outputVariable(); + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public SDVariable depthWiseConv2d(String name, SDVariable layerInput, SDVariable depthWeights, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(sd,layerInput, depthWeights, null, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * TODO doc string
+ * + * @param df (NUMERIC type) + * @param weights df (NUMERIC type) + * @param strides weights (Size: Exactly(count=2)) + * @param rates strides (Size: Exactly(count=2)) + * @param isSameMode isSameMode + * @return output Computed the grayscale dilation of 4-D input and 3-D filters tensors. (NUMERIC type) + */ + public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, int[] rates, + boolean isSameMode) { + SDValidation.validateNumerical("dilation2D", "df", df); + SDValidation.validateNumerical("dilation2D", "weights", weights); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length == 2, "rates has incorrect size/length. Expected: rates.length == 2, got %s", rates.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D(sd,df, weights, strides, rates, isSameMode).outputVariable(); + } + + /** + * TODO doc string
+ * + * @param name name May be null. Name for the output variable + * @param df (NUMERIC type) + * @param weights df (NUMERIC type) + * @param strides weights (Size: Exactly(count=2)) + * @param rates strides (Size: Exactly(count=2)) + * @param isSameMode isSameMode + * @return output Computed the grayscale dilation of 4-D input and 3-D filters tensors. (NUMERIC type) + */ + public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides, + int[] rates, boolean isSameMode) { + SDValidation.validateNumerical("dilation2D", "df", df); + SDValidation.validateNumerical("dilation2D", "weights", weights); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length == 2, "rates has incorrect size/length. Expected: rates.length == 2, got %s", rates.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D(sd,df, weights, strides, rates, isSameMode).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Extract image patches
+ * + * @param input Input array. Must be rank 4, with shape [minibatch, height, width, channels] (NUMERIC type) + * @param kH Kernel height + * @param kW Kernel width + * @param sH Stride height + * @param sW Stride width + * @param rH Rate height + * @param rW Rate width + * @param sameMode If true: use same mode padding. If false + * @return output The result is a 4D tensor which is indexed by batch, row, and column. (NUMERIC type) + */ + public SDVariable extractImagePatches(SDVariable input, int kH, int kW, int sH, int sW, int rH, + int rW, boolean sameMode) { + SDValidation.validateNumerical("extractImagePatches", "input", input); + return new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,input, kH, kW, sH, sW, rH, rW, sameMode).outputVariable(); + } + + /** + * Extract image patches
+ * + * @param name name May be null. Name for the output variable + * @param input Input array. Must be rank 4, with shape [minibatch, height, width, channels] (NUMERIC type) + * @param kH Kernel height + * @param kW Kernel width + * @param sH Stride height + * @param sW Stride width + * @param rH Rate height + * @param rW Rate width + * @param sameMode If true: use same mode padding. If false + * @return output The result is a 4D tensor which is indexed by batch, row, and column. (NUMERIC type) + */ + public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, + int sW, int rH, int rW, boolean sameMode) { + SDValidation.validateNumerical("extractImagePatches", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,input, kH, kW, sH, sW, rH, rW, sameMode).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
+ * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
+ * + * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Im2Col output variable (NUMERIC type) + */ + public SDVariable im2Col(SDVariable in, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("im2Col", "in", in); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col(sd,in, Conv2DConfig).outputVariable(); + } + + /** + * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
+ * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
+ * + * @param name name May be null. Name for the output variable + * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Im2Col output variable (NUMERIC type) + */ + public SDVariable im2Col(String name, SDVariable in, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("im2Col", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col(sd,in, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D convolution layer operation - local response normalization
+ * + * @param input the inputs to lrn (NUMERIC type) + * @param LocalResponseNormalizationConfig Configuration Object + * @return output Result after Local Response Normalization (NUMERIC type) + */ + public SDVariable localResponseNormalization(SDVariable input, + LocalResponseNormalizationConfig LocalResponseNormalizationConfig) { + SDValidation.validateNumerical("localResponseNormalization", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(sd,input, LocalResponseNormalizationConfig).outputVariable(); + } + + /** + * 2D convolution layer operation - local response normalization
+ * + * @param name name May be null. Name for the output variable + * @param input the inputs to lrn (NUMERIC type) + * @param LocalResponseNormalizationConfig Configuration Object + * @return output Result after Local Response Normalization (NUMERIC type) + */ + public SDVariable localResponseNormalization(String name, SDVariable input, + LocalResponseNormalizationConfig LocalResponseNormalizationConfig) { + SDValidation.validateNumerical("localResponseNormalization", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(sd,input, LocalResponseNormalizationConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D Convolution layer operation - max pooling 2d
+ * + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("maxPooling2d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D(sd,input, Pooling2DConfig).outputVariable(); + } + + /** + * 2D Convolution layer operation - max pooling 2d
+ * + * @param name name May be null. Name for the output variable + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("maxPooling2d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D(sd,input, Pooling2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 3D convolution layer operation - max pooling 3d operation.
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig Pooling3DConfig) { + SDValidation.validateNumerical("maxPooling3d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D(sd,input, Pooling3DConfig).outputVariable(); + } + + /** + * 3D convolution layer operation - max pooling 3d operation.
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig Pooling3DConfig) { + SDValidation.validateNumerical("maxPooling3d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D(sd,input, Pooling3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, + SDVariable pointWeights, SDVariable bias, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + SDValidation.validateNumerical("separableConv2d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(sd,layerInput, depthWeights, pointWeights, bias, Conv2DConfig).outputVariable(); + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public SDVariable separableConv2d(String name, SDVariable layerInput, SDVariable depthWeights, + SDVariable pointWeights, SDVariable bias, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + SDValidation.validateNumerical("separableConv2d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(sd,layerInput, depthWeights, pointWeights, bias, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, + SDVariable pointWeights, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(sd,layerInput, depthWeights, pointWeights, null, Conv2DConfig).outputVariable(); + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public SDVariable separableConv2d(String name, SDVariable layerInput, SDVariable depthWeights, + SDVariable pointWeights, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(sd,layerInput, depthWeights, pointWeights, null, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 2d layer space to batch operation on 4d input.
+ * Increases input batch dimension by rearranging data from spatial dimensions into batch dimension
+ * + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param paddingTop Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @param paddingBottom Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[] paddingTop, + int... paddingBottom) { + SDValidation.validateNumerical("spaceToBatch", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(paddingTop.length == 2, "paddingTop has incorrect size/length. Expected: paddingTop.length == 2, got %s", paddingTop.length); + Preconditions.checkArgument(paddingBottom.length == 2, "paddingBottom has incorrect size/length. Expected: paddingBottom.length == 2, got %s", paddingBottom.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch(sd,x, blocks, paddingTop, paddingBottom).outputVariable(); + } + + /** + * Convolution 2d layer space to batch operation on 4d input.
+ * Increases input batch dimension by rearranging data from spatial dimensions into batch dimension
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param paddingTop Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @param paddingBottom Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[] paddingTop, + int... paddingBottom) { + SDValidation.validateNumerical("spaceToBatch", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(paddingTop.length == 2, "paddingTop has incorrect size/length. Expected: paddingTop.length == 2, got %s", paddingTop.length); + Preconditions.checkArgument(paddingBottom.length == 2, "paddingBottom has incorrect size/length. Expected: paddingBottom.length == 2, got %s", paddingBottom.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch(sd,x, blocks, paddingTop, paddingBottom).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 2d layer space to depth operation on 4d input.
+ * Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension
+ * Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public SDVariable spaceToDepth(SDVariable x, int blockSize, DataFormat dataFormat) { + SDValidation.validateNumerical("spaceToDepth", "x", x); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth(sd,x, blockSize, dataFormat).outputVariable(); + } + + /** + * Convolution 2d layer space to depth operation on 4d input.
+ * Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension
+ * Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param name name May be null. Name for the output variable + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, DataFormat dataFormat) { + SDValidation.validateNumerical("spaceToDepth", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth(sd,x, blockSize, dataFormat).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Upsampling layer for 2D inputs.
+ * scale is used for both height and width dimensions.
+ * + * @param input Input in NCHW format (NUMERIC type) + * @param scale The scale for both height and width dimensions. + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling2d(SDVariable input, int scale) { + SDValidation.validateNumerical("upsampling2d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scale).outputVariable(); + } + + /** + * Upsampling layer for 2D inputs.
+ * scale is used for both height and width dimensions.
+ * + * @param name name May be null. Name for the output variable + * @param input Input in NCHW format (NUMERIC type) + * @param scale The scale for both height and width dimensions. + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling2d(String name, SDVariable input, int scale) { + SDValidation.validateNumerical("upsampling2d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scale).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D Convolution layer operation - Upsampling 2d
+ * + * @param input Input in NCHW format (NUMERIC type) + * @param scaleH Scale to upsample in height dimension + * @param scaleW Scale to upsample in width dimension + * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling2d(SDVariable input, int scaleH, int scaleW, boolean nchw) { + SDValidation.validateNumerical("upsampling2d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scaleH, scaleW, nchw).outputVariable(); + } + + /** + * 2D Convolution layer operation - Upsampling 2d
+ * + * @param name name May be null. Name for the output variable + * @param input Input in NCHW format (NUMERIC type) + * @param scaleH Scale to upsample in height dimension + * @param scaleW Scale to upsample in width dimension + * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling2d(String name, SDVariable input, int scaleH, int scaleW, + boolean nchw) { + SDValidation.validateNumerical("upsampling2d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scaleH, scaleW, nchw).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index 7b662b960..70940863a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -1,185 +1,440 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.custom.*; -import org.nd4j.linalg.api.ops.impl.image.CropAndResize; -import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; -import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; +import org.nd4j.base.Preconditions; -/** - * @author Alex Black - */ public class SDImage extends SDOps { - public SDImage(SameDiff sameDiff) { - super(sameDiff); - } + public SDImage(SameDiff sameDiff) { + super(sameDiff); + } - /** - * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size. - * - * @param name May be null. Name for the output variable. - * @param image Input image, with shape [batch, height, width, channels] - * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 - * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] - * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] - * @param method Image resize method - * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default - * @return Cropped and resized images - */ - public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, SDVariable cropOutSize, - CropAndResize.Method method, double extrapolationValue) { - SDVariable out = new CropAndResize(sd, image, cropBoxes, boxIndices, cropOutSize, method, extrapolationValue).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default + * @return output Cropped and resized images (NUMERIC type) + */ + public SDVariable cropAndResize(SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, + SDVariable cropOutSize, double extrapolationValue) { + SDValidation.validateNumerical("CropAndResize", "image", image); + SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + return new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, extrapolationValue).outputVariable(); + } - /** - * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension. - * - * @param name Map be null. Name for the output variable - * @param image Input image to extract image patches from - shape [batch, height, width, channels] - * @param kSizes Kernel size - size of the image patches, [height, width] - * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] - * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels - * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken - * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension - * @param sameMode Padding algorithm. If true: use Same padding - * @return The extracted image patches - */ - public SDVariable extractImagePatches(String name, SDVariable image, @NonNull int[] kSizes, - @NonNull int[] strides, @NonNull int[] rates, boolean sameMode) { - SDVariable out = new ExtractImagePatches(sd, image, kSizes, strides, rates, sameMode).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param name name May be null. Name for the output variable + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default + * @return output Cropped and resized images (NUMERIC type) + */ + public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, + SDVariable boxIndices, SDVariable cropOutSize, double extrapolationValue) { + SDValidation.validateNumerical("CropAndResize", "image", image); + SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, extrapolationValue).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Greedily selects a subset of bounding boxes in descending order of score - * @param name Might be null. Name for the output variable - * @param boxes 2D array of shape [num_boxes,4] - * @param scores vector of shape [num_boxes] - * @param maxOutSize scalar representing the maximum number of boxes to be selected - * @param iouThreshold float - threshold for deciding whether boxes overlap too much with respect to IOU - * @param scoreThreshold float - threshold for deciding when to remove boxes based on score - * @return vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size - */ - public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize, - @NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){ - SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @return output Cropped and resized images (NUMERIC type) + */ + public SDVariable cropAndResize(SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, + SDVariable cropOutSize) { + SDValidation.validateNumerical("CropAndResize", "image", image); + SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + return new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, 0.0).outputVariable(); + } - /** - * Adjusts contrast of RGB or grayscale images. - * @param name name for the output variable - * @param in images to adjust. 3D shape or higher. - * @param factor float multiplier for adjusting contrast. - * @return Contrast-adjusted image - */ - public SDVariable adjustContrast(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { - SDVariable out = new AdjustContrast(sd, in, factor).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param name name May be null. Name for the output variable + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @return output Cropped and resized images (NUMERIC type) + */ + public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, + SDVariable boxIndices, SDVariable cropOutSize) { + SDValidation.validateNumerical("CropAndResize", "image", image); + SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, 0.0).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Adjust saturation of RGB images - * @param name name for the output variable - * @param in RGB image as 3D array - * @param factor factor for saturation - * @return adjusted image - */ - public SDVariable adjustSaturation(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { - SDVariable out = new AdjustSaturation(sd, in, factor).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjusts contrast of RGB or grayscale images.
+ * + * @param in images to adjust. 3D shape or higher (NUMERIC type) + * @param factor multiplier for adjusting contrast + * @return output Contrast-adjusted image (NUMERIC type) + */ + public SDVariable adjustContrast(SDVariable in, double factor) { + SDValidation.validateNumerical("adjustContrast", "in", in); + return new org.nd4j.linalg.api.ops.custom.AdjustContrast(sd,in, factor).outputVariable(); + } - /** - * Adjust hue of RGB image - * @param name name for the output variable - * @param in RGB image as 3D array - * @param delta value to add to hue channel - * @return adjusted image - */ - public SDVariable adjustHue(String name, @NonNull SDVariable in, @NonNull SDVariable delta) { - SDVariable out = new AdjustHue(sd, in, delta).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjusts contrast of RGB or grayscale images.
+ * + * @param name name May be null. Name for the output variable + * @param in images to adjust. 3D shape or higher (NUMERIC type) + * @param factor multiplier for adjusting contrast + * @return output Contrast-adjusted image (NUMERIC type) + */ + public SDVariable adjustContrast(String name, SDVariable in, double factor) { + SDValidation.validateNumerical("adjustContrast", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustContrast(sd,in, factor).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Randomly crops image - * @param name name for the output variable - * @param input input array - * @param shape shape for crop - * @return cropped array - */ - public SDVariable randomCrop(String name, @NonNull SDVariable input, @NonNull SDVariable shape) { - SDVariable out = new RandomCrop(sd, input, shape).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjust hue of RGB image
+ * + * @param in image as 3D array (NUMERIC type) + * @param delta value to add to hue channel + * @return output adjusted image (NUMERIC type) + */ + public SDVariable adjustHue(SDVariable in, double delta) { + SDValidation.validateNumerical("adjustHue", "in", in); + return new org.nd4j.linalg.api.ops.custom.AdjustHue(sd,in, delta).outputVariable(); + } - /** - * Converting array from HSV to RGB format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable rgbToHsv(String name, @NonNull SDVariable input) { - SDVariable out = new RgbToHsv(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjust hue of RGB image
+ * + * @param name name May be null. Name for the output variable + * @param in image as 3D array (NUMERIC type) + * @param delta value to add to hue channel + * @return output adjusted image (NUMERIC type) + */ + public SDVariable adjustHue(String name, SDVariable in, double delta) { + SDValidation.validateNumerical("adjustHue", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustHue(sd,in, delta).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Converting image from HSV to RGB format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable hsvToRgb(String name, @NonNull SDVariable input) { - SDVariable out = new HsvToRgb(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjust saturation of RGB images
+ * + * @param in RGB image as 3D array (NUMERIC type) + * @param factor factor for saturation + * @return output adjusted image (NUMERIC type) + */ + public SDVariable adjustSaturation(SDVariable in, double factor) { + SDValidation.validateNumerical("adjustSaturation", "in", in); + return new org.nd4j.linalg.api.ops.custom.AdjustSaturation(sd,in, factor).outputVariable(); + } - /** - * Converting array from RGB to YIQ format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable rgbToYiq(String name, @NonNull SDVariable input) { - SDVariable out = new RgbToYiq(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjust saturation of RGB images
+ * + * @param name name May be null. Name for the output variable + * @param in RGB image as 3D array (NUMERIC type) + * @param factor factor for saturation + * @return output adjusted image (NUMERIC type) + */ + public SDVariable adjustSaturation(String name, SDVariable in, double factor) { + SDValidation.validateNumerical("adjustSaturation", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustSaturation(sd,in, factor).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Converting image from YIQ to RGB format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable yiqToRgb(String name, @NonNull SDVariable input) { - SDVariable out = new YiqToRgb(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension.
+ * + * @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type) + * @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2)) + * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2)) + * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels + * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken + * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0)) + * @param sameMode Padding algorithm. If true: use Same padding + * @return output The extracted image patches (NUMERIC type) + */ + public SDVariable extractImagePatches(SDVariable image, int[] kSizes, int[] strides, int[] rates, + boolean sameMode) { + SDValidation.validateNumerical("extractImagePatches", "image", image); + Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length); + return new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,image, kSizes, strides, rates, sameMode).outputVariable(); + } - /** - * Converting array from RGB to YUV format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable rgbToYuv(String name, @NonNull SDVariable input) { - SDVariable out = new RgbToYuv(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension.
+ * + * @param name name May be null. Name for the output variable + * @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type) + * @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2)) + * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2)) + * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels + * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken + * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0)) + * @param sameMode Padding algorithm. If true: use Same padding + * @return output The extracted image patches (NUMERIC type) + */ + public SDVariable extractImagePatches(String name, SDVariable image, int[] kSizes, int[] strides, + int[] rates, boolean sameMode) { + SDValidation.validateNumerical("extractImagePatches", "image", image); + Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,image, kSizes, strides, rates, sameMode).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Converting image from YUV to RGB format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable yuvToRgb(String name, @NonNull SDVariable input) { - SDVariable out = new YuvToRgb(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Converting image from HSV to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable hsvToRgb(SDVariable input) { + SDValidation.validateNumerical("hsvToRgb", "input", input); + return new org.nd4j.linalg.api.ops.custom.HsvToRgb(sd,input).outputVariable(); + } + + /** + * Converting image from HSV to RGB format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable hsvToRgb(String name, SDVariable input) { + SDValidation.validateNumerical("hsvToRgb", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.HsvToRgb(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greedily selects a subset of bounding boxes in descending order of score
+ * + * @param boxes Might be null. Name for the output variable (NUMERIC type) + * @param scores vector of shape [num_boxes] (NUMERIC type) + * @param maxOutSize scalar representing the maximum number of boxes to be selected + * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU + * @param scoreThreshold threshold for deciding when to remove boxes based on score + * @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) + */ + public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize, + double iouThreshold, double scoreThreshold) { + SDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes); + SDValidation.validateNumerical("nonMaxSuppression", "scores", scores); + return new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(sd,boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); + } + + /** + * Greedily selects a subset of bounding boxes in descending order of score
+ * + * @param name name May be null. Name for the output variable + * @param boxes Might be null. Name for the output variable (NUMERIC type) + * @param scores vector of shape [num_boxes] (NUMERIC type) + * @param maxOutSize scalar representing the maximum number of boxes to be selected + * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU + * @param scoreThreshold threshold for deciding when to remove boxes based on score + * @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) + */ + public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores, + int maxOutSize, double iouThreshold, double scoreThreshold) { + SDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes); + SDValidation.validateNumerical("nonMaxSuppression", "scores", scores); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(sd,boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Randomly crops image
+ * + * @param input input array (NUMERIC type) + * @param shape shape for crop (INT type) + * @return output cropped array (NUMERIC type) + */ + public SDVariable randomCrop(SDVariable input, SDVariable shape) { + SDValidation.validateNumerical("randomCrop", "input", input); + SDValidation.validateInteger("randomCrop", "shape", shape); + return new org.nd4j.linalg.api.ops.custom.RandomCrop(sd,input, shape).outputVariable(); + } + + /** + * Randomly crops image
+ * + * @param name name May be null. Name for the output variable + * @param input input array (NUMERIC type) + * @param shape shape for crop (INT type) + * @return output cropped array (NUMERIC type) + */ + public SDVariable randomCrop(String name, SDVariable input, SDVariable shape) { + SDValidation.validateNumerical("randomCrop", "input", input); + SDValidation.validateInteger("randomCrop", "shape", shape); + SDVariable out = new org.nd4j.linalg.api.ops.custom.RandomCrop(sd,input, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting array from HSV to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToHsv(SDVariable input) { + SDValidation.validateNumerical("rgbToHsv", "input", input); + return new org.nd4j.linalg.api.ops.custom.RgbToHsv(sd,input).outputVariable(); + } + + /** + * Converting array from HSV to RGB format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToHsv(String name, SDVariable input) { + SDValidation.validateNumerical("rgbToHsv", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToHsv(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting array from RGB to YIQ format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToYiq(SDVariable input) { + SDValidation.validateNumerical("rgbToYiq", "input", input); + return new org.nd4j.linalg.api.ops.custom.RgbToYiq(sd,input).outputVariable(); + } + + /** + * Converting array from RGB to YIQ format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToYiq(String name, SDVariable input) { + SDValidation.validateNumerical("rgbToYiq", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToYiq(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting array from RGB to YUV format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToYuv(SDVariable input) { + SDValidation.validateNumerical("rgbToYuv", "input", input); + return new org.nd4j.linalg.api.ops.custom.RgbToYuv(sd,input).outputVariable(); + } + + /** + * Converting array from RGB to YUV format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToYuv(String name, SDVariable input) { + SDValidation.validateNumerical("rgbToYuv", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToYuv(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting image from YIQ to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable yiqToRgb(SDVariable input) { + SDValidation.validateNumerical("yiqToRgb", "input", input); + return new org.nd4j.linalg.api.ops.custom.YiqToRgb(sd,input).outputVariable(); + } + + /** + * Converting image from YIQ to RGB format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable yiqToRgb(String name, SDVariable input) { + SDValidation.validateNumerical("yiqToRgb", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.YiqToRgb(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting image from YUV to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable yuvToRgb(SDVariable input) { + SDValidation.validateNumerical("yuvToRgb", "input", input); + return new org.nd4j.linalg.api.ops.custom.YuvToRgb(sd,input).outputVariable(); + } + + /** + * Converting image from YUV to RGB format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable yuvToRgb(String name, SDVariable input) { + SDValidation.validateNumerical("yuvToRgb", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.YuvToRgb(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java new file mode 100644 index 000000000..8dbb9d3b3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java @@ -0,0 +1,561 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.autodiff.samediff.ops; + +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +public class SDLinalg extends SDOps { + public SDLinalg(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * Computes the Cholesky decomposition of one or more square matrices.
+ * + * @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type) + * @return output Transformed tensor (NUMERIC type) + */ + public SDVariable cholesky(SDVariable input) { + SDValidation.validateNumerical("Cholesky", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(sd,input).outputVariable(); + } + + /** + * Computes the Cholesky decomposition of one or more square matrices.
+ * + * @param name name May be null. Name for the output variable + * @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type) + * @return output Transformed tensor (NUMERIC type) + */ + public SDVariable cholesky(String name, SDVariable input) { + SDValidation.validateNumerical("Cholesky", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Solver for linear squares problems.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @param fast fast mode, defaults to True + * @return output Transformed tensor (FLOATING_POINT type) + */ + public SDVariable lstsq(SDVariable matrix, SDVariable rhs, double l2_reguralizer, boolean fast) { + SDValidation.validateNumerical("Lstsq", "matrix", matrix); + SDValidation.validateNumerical("Lstsq", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, fast).outputVariable(); + } + + /** + * Solver for linear squares problems.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @param fast fast mode, defaults to True + * @return output Transformed tensor (FLOATING_POINT type) + */ + public SDVariable lstsq(String name, SDVariable matrix, SDVariable rhs, double l2_reguralizer, + boolean fast) { + SDValidation.validateNumerical("Lstsq", "matrix", matrix); + SDValidation.validateNumerical("Lstsq", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, fast).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Solver for linear squares problems.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @return output Transformed tensor (FLOATING_POINT type) + */ + public SDVariable lstsq(SDVariable matrix, SDVariable rhs, double l2_reguralizer) { + SDValidation.validateNumerical("Lstsq", "matrix", matrix); + SDValidation.validateNumerical("Lstsq", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, true).outputVariable(); + } + + /** + * Solver for linear squares problems.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @return output Transformed tensor (FLOATING_POINT type) + */ + public SDVariable lstsq(String name, SDVariable matrix, SDVariable rhs, double l2_reguralizer) { + SDValidation.validateNumerical("Lstsq", "matrix", matrix); + SDValidation.validateNumerical("Lstsq", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, true).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Computes LU decomposition.
+ * + * @param input input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable lu(SDVariable input) { + SDValidation.validateNumerical("Lu", "input", input); + return new org.nd4j.linalg.api.ops.custom.Lu(sd,input).outputVariable(); + } + + /** + * Computes LU decomposition.
+ * + * @param name name May be null. Name for the output variable + * @param input input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable lu(String name, SDVariable input) { + SDValidation.validateNumerical("Lu", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Lu(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Performs matrix mutiplication on input tensors.
+ * + * @param a input tensor (NUMERIC type) + * @param b input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable matmul(SDVariable a, SDVariable b) { + SDValidation.validateNumerical("Matmul", "a", a); + SDValidation.validateNumerical("Matmul", "b", b); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,a, b).outputVariable(); + } + + /** + * Performs matrix mutiplication on input tensors.
+ * + * @param name name May be null. Name for the output variable + * @param a input tensor (NUMERIC type) + * @param b input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable matmul(String name, SDVariable a, SDVariable b) { + SDValidation.validateNumerical("Matmul", "a", a); + SDValidation.validateNumerical("Matmul", "b", b); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,a, b).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Copy a tensor setting outside a central band in each innermost matrix.
+ * + * @param input input tensor (NUMERIC type) + * @param minLower lower diagonal count + * @param maxUpper upper diagonal count + */ + public SDVariable[] matrixBandPart(SDVariable input, int minLower, int maxUpper) { + SDValidation.validateNumerical("MatrixBandPart", "input", input); + return new org.nd4j.linalg.api.ops.custom.MatrixBandPart(sd,input, minLower, maxUpper).outputVariables(); + } + + /** + * Copy a tensor setting outside a central band in each innermost matrix.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input input tensor (NUMERIC type) + * @param minLower lower diagonal count + * @param maxUpper upper diagonal count + */ + public SDVariable[] matrixBandPart(String[] names, SDVariable input, int minLower, int maxUpper) { + SDValidation.validateNumerical("MatrixBandPart", "input", input); + SDVariable[] out = new org.nd4j.linalg.api.ops.custom.MatrixBandPart(sd,input, minLower, maxUpper).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param input input tensor (NUMERIC type) + * @param full full matrices mode + */ + public SDVariable[] qr(SDVariable input, boolean full) { + SDValidation.validateNumerical("Qr", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, full).outputVariables(); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input input tensor (NUMERIC type) + * @param full full matrices mode + */ + public SDVariable[] qr(String[] names, SDVariable input, boolean full) { + SDValidation.validateNumerical("Qr", "input", input); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, full).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param input input tensor (NUMERIC type) + */ + public SDVariable[] qr(SDVariable input) { + SDValidation.validateNumerical("Qr", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, false).outputVariables(); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input input tensor (NUMERIC type) + */ + public SDVariable[] qr(String[] names, SDVariable input) { + SDValidation.validateNumerical("Qr", "input", input); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, false).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Solver for systems of linear equations.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param adjoint adjoint mode, defaults to False + * @return output Output tensor (FLOATING_POINT type) + */ + public SDVariable solve(SDVariable matrix, SDVariable rhs, boolean adjoint) { + SDValidation.validateNumerical("Solve", "matrix", matrix); + SDValidation.validateNumerical("Solve", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, adjoint).outputVariable(); + } + + /** + * Solver for systems of linear equations.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param adjoint adjoint mode, defaults to False + * @return output Output tensor (FLOATING_POINT type) + */ + public SDVariable solve(String name, SDVariable matrix, SDVariable rhs, boolean adjoint) { + SDValidation.validateNumerical("Solve", "matrix", matrix); + SDValidation.validateNumerical("Solve", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, adjoint).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Solver for systems of linear equations.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @return output Output tensor (FLOATING_POINT type) + */ + public SDVariable solve(SDVariable matrix, SDVariable rhs) { + SDValidation.validateNumerical("Solve", "matrix", matrix); + SDValidation.validateNumerical("Solve", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, false).outputVariable(); + } + + /** + * Solver for systems of linear equations.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @return output Output tensor (FLOATING_POINT type) + */ + public SDVariable solve(String name, SDVariable matrix, SDVariable rhs) { + SDValidation.validateNumerical("Solve", "matrix", matrix); + SDValidation.validateNumerical("Solve", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Solver for systems of linear questions.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param lower defines whether innermost matrices in matrix are lower or upper triangular + * @param adjoint adjoint mode + * @return output (FLOATING_POINT type) + */ + public SDVariable triangularSolve(SDVariable matrix, SDVariable rhs, boolean lower, + boolean adjoint) { + SDValidation.validateNumerical("TriangularSolve", "matrix", matrix); + SDValidation.validateNumerical("TriangularSolve", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.TriangularSolve(sd,matrix, rhs, lower, adjoint).outputVariable(); + } + + /** + * Solver for systems of linear questions.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param lower defines whether innermost matrices in matrix are lower or upper triangular + * @param adjoint adjoint mode + * @return output (FLOATING_POINT type) + */ + public SDVariable triangularSolve(String name, SDVariable matrix, SDVariable rhs, boolean lower, + boolean adjoint) { + SDValidation.validateNumerical("TriangularSolve", "matrix", matrix); + SDValidation.validateNumerical("TriangularSolve", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.TriangularSolve(sd,matrix, rhs, lower, adjoint).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Computes pairwise cross product.
+ * + * @param a (NUMERIC type) + * @param b (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable cross(SDVariable a, SDVariable b) { + SDValidation.validateNumerical("cross", "a", a); + SDValidation.validateNumerical("cross", "b", b); + return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + } + + /** + * Computes pairwise cross product.
+ * + * @param name name May be null. Name for the output variable + * @param a (NUMERIC type) + * @param b (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable cross(String name, SDVariable a, SDVariable b) { + SDValidation.validateNumerical("cross", "a", a); + SDValidation.validateNumerical("cross", "b", b); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates diagonal tensor.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable diag(SDVariable input) { + SDValidation.validateNumerical("diag", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,input).outputVariable(); + } + + /** + * Calculates diagonal tensor.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable diag(String name, SDVariable input) { + SDValidation.validateNumerical("diag", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates diagonal tensor.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable diag_part(SDVariable input) { + SDValidation.validateNumerical("diag_part", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,input).outputVariable(); + } + + /** + * Calculates diagonal tensor.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable diag_part(String name, SDVariable input) { + SDValidation.validateNumerical("diag_part", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates log of determinant.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable logdet(SDVariable input) { + SDValidation.validateNumerical("logdet", "input", input); + return new org.nd4j.linalg.api.ops.custom.Logdet(sd,input).outputVariable(); + } + + /** + * Calculates log of determinant.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable logdet(String name, SDVariable input) { + SDValidation.validateNumerical("logdet", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Logdet(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public SDVariable mmul(SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public SDVariable mmul(String name, SDVariable x, SDVariable y, boolean transposeX, + boolean transposeY, boolean transposeZ) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable mmul(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable(); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable mmul(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates singular value decomposition.
+ * + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @param switchNum + * @return output (FLOATING_POINT type) + */ + public SDVariable svd(SDVariable input, boolean fullUV, boolean computeUV, int switchNum) { + SDValidation.validateNumerical("svd", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, switchNum).outputVariable(); + } + + /** + * Calculates singular value decomposition.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @param switchNum + * @return output (FLOATING_POINT type) + */ + public SDVariable svd(String name, SDVariable input, boolean fullUV, boolean computeUV, + int switchNum) { + SDValidation.validateNumerical("svd", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, switchNum).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates singular value decomposition.
+ * + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @return output (FLOATING_POINT type) + */ + public SDVariable svd(SDVariable input, boolean fullUV, boolean computeUV) { + SDValidation.validateNumerical("svd", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable(); + } + + /** + * Calculates singular value decomposition.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @return output (FLOATING_POINT type) + */ + public SDVariable svd(String name, SDVariable input, boolean fullUV, boolean computeUV) { + SDValidation.validateNumerical("svd", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java index f0e94a4e5..9a1ef1249 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java @@ -1,5 +1,5 @@ -/* ***************************************************************************** - * Copyright (c) 2015-2019 Skymind, Inc. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,544 +14,1045 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.ops.impl.loss.LogLoss; -import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss; -import org.nd4j.linalg.factory.Nd4j; -import static org.nd4j.autodiff.samediff.ops.SDValidation.*; - -/** - * SameDiff loss functions
- * Accessible via {@link SameDiff#loss()} - * - * @author Alex Black - */ -@SuppressWarnings("unused") public class SDLoss extends SDOps { - public SDLoss(SameDiff sameDiff) { - super(sameDiff); - } + public SDLoss(SameDiff sameDiff) { + super(sameDiff); + } - /** - * helper to refactor duplicate code - */ - private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){ - String weightName = (name == null) ? null : name + "/weight"; - return (weights == null) ? sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)) : weights; - } + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output loss variable (NUMERIC type) + */ + public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + SDValidation.validateNumerical("absoluteDifference", "label", label); + SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + SDValidation.validateNumerical("absoluteDifference", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return absoluteDifference(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output loss variable (NUMERIC type) + */ + public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("absoluteDifference", "label", label); + SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + SDValidation.validateNumerical("absoluteDifference", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] ) - * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("absolute difference loss", "predictions", predictions); - validateNumerical("absolute difference loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output loss variable (NUMERIC type) + */ + public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("absoluteDifference", "label", label); + SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + SDValidation.validateNumerical("absoluteDifference", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return absoluteDifference(name, label, predictions, null, lossReduce); - } + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output loss variable (NUMERIC type) + */ + public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("absoluteDifference", "label", label); + SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + SDValidation.validateNumerical("absoluteDifference", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #cosineDistance(String, SDVariable, SDVariable, SDVariable, LossReduce, int)}. - */ - public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, int dimension) { - return cosineDistance(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension); - } + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public SDVariable cosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, int dimension) { + SDValidation.validateNumerical("cosineDistance", "label", label); + SDValidation.validateNumerical("cosineDistance", "predictions", predictions); + SDValidation.validateNumerical("cosineDistance", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, lossReduce, dimension).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is - * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm. - * If this is not the case, you should normalize them first by dividing by {@link SameDiff#norm2(String, SDVariable, boolean, int...)} - * along the cosine distance dimension (with keepDims=true). - * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param dimension Dimension to perform the cosine distance over - * @return Cosine distance loss variable - */ - public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce, int dimension) { - validateFloatingPoint("cosine distance loss", "predictions", predictions); - validateNumerical("cosine distance loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public SDVariable cosineDistance(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce, int dimension) { + SDValidation.validateNumerical("cosineDistance", "label", label); + SDValidation.validateNumerical("cosineDistance", "predictions", predictions); + SDValidation.validateNumerical("cosineDistance", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, lossReduce, dimension).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #cosineDistance(String, SDVariable, SDVariable, SDVariable, LossReduce, int)}. - */ - public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - @NonNull LossReduce lossReduce, int dimension) { - return cosineDistance(name, label, predictions, null, lossReduce, dimension); - } + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public SDVariable cosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, + int dimension) { + SDValidation.validateNumerical("cosineDistance", "label", label); + SDValidation.validateNumerical("cosineDistance", "predictions", predictions); + SDValidation.validateNumerical("cosineDistance", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #hingeLoss(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return hingeLoss(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public SDVariable cosineDistance(String name, SDVariable label, SDVariable predictions, + SDVariable weights, int dimension) { + SDValidation.validateNumerical("cosineDistance", "label", label); + SDValidation.validateNumerical("cosineDistance", "predictions", predictions); + SDValidation.validateNumerical("cosineDistance", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Hinge loss: a loss function used for training classifiers. - * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1} - * from the user specified {0,1}. Note that Labels should be provided with values {0,1}. - * - * @param name Name of the operation - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("hinge loss", "predictions", predictions); - validateNumerical("hinge loss", "labels", label); - if (weights == null) - weights = sd.scalar(null, predictions.dataType(), 1.0); - SDVariable result = f().lossHinge(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + SDValidation.validateNumerical("hingeLoss", "label", label); + SDValidation.validateNumerical("hingeLoss", "predictions", predictions); + SDValidation.validateNumerical("hingeLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #hingeLoss(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return hingeLoss(name, label, predictions, null, lossReduce); - } + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("hingeLoss", "label", label); + SDValidation.validateNumerical("hingeLoss", "predictions", predictions); + SDValidation.validateNumerical("hingeLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #huberLoss(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, double delta) { - return huberLoss(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta); - } + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights) { + SDValidation.validateNumerical("hingeLoss", "label", label); + SDValidation.validateNumerical("hingeLoss", "predictions", predictions); + SDValidation.validateNumerical("hingeLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss, - * though is less sensitive to outliers than squared error.
- * Huber loss implements: - *
-     * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta
-     *  L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise
-     *     }
-     * 
- * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param delta Loss function delta value - * @return Huber loss variable - */ - public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce, double delta) { - validateFloatingPoint("huber loss", "predictions", predictions); - validateNumerical("huber loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("hingeLoss", "label", label); + SDValidation.validateNumerical("hingeLoss", "predictions", predictions); + SDValidation.validateNumerical("hingeLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #huberLoss(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce, double delta) { - return huberLoss(name, label, predictions, null, lossReduce, delta); - } + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public SDVariable huberLoss(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, double delta) { + SDValidation.validateNumerical("huberLoss", "label", label); + SDValidation.validateNumerical("huberLoss", "predictions", predictions); + SDValidation.validateNumerical("huberLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, lossReduce, delta).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * L2 loss: 1/2 * sum(x^2) - * - * @param var Variable to calculate L2 loss of - * @return L2 loss - */ - public SDVariable l2Loss(@NonNull SDVariable var) { - return l2Loss(null, var); - } + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public SDVariable huberLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce, double delta) { + SDValidation.validateNumerical("huberLoss", "label", label); + SDValidation.validateNumerical("huberLoss", "predictions", predictions); + SDValidation.validateNumerical("huberLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, lossReduce, delta).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * L2 loss: 1/2 * sum(x^2) - * - * @param name Name of the output variable - * @param var Variable to calculate L2 loss of - * @return L2 loss - */ - public SDVariable l2Loss(String name, @NonNull SDVariable var) { - validateNumerical("l2 loss", var); - SDVariable result = f().lossL2(var); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public SDVariable huberLoss(SDVariable label, SDVariable predictions, SDVariable weights, + double delta) { + SDValidation.validateNumerical("huberLoss", "label", label); + SDValidation.validateNumerical("huberLoss", "predictions", predictions); + SDValidation.validateNumerical("huberLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #logLoss(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return logLoss(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, LogLoss.DEFAULT_EPSILON); - } + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public SDVariable huberLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights, double delta) { + SDValidation.validateNumerical("huberLoss", "label", label); + SDValidation.validateNumerical("huberLoss", "predictions", predictions); + SDValidation.validateNumerical("huberLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements: - * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))} - * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Log loss variable - */ - public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) { - validateFloatingPoint("log loss", "predictions", predictions); - validateNumerical("log loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * L2 loss: 1/2 * sum(x^2)
+ * + * @param var Variable to calculate L2 loss of (NUMERIC type) + * @return output L2 loss (NUMERIC type) + */ + public SDVariable l2Loss(SDVariable var) { + SDValidation.validateNumerical("l2Loss", "var", var); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd,var).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #logLoss(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return logLoss(name, label, predictions, null, lossReduce, LogLoss.DEFAULT_EPSILON); - } + /** + * L2 loss: 1/2 * sum(x^2)
+ * + * @param name name May be null. Name for the output variable + * @param var Variable to calculate L2 loss of (NUMERIC type) + * @return output L2 loss (NUMERIC type) + */ + public SDVariable l2Loss(String name, SDVariable var) { + SDValidation.validateNumerical("l2Loss", "var", var); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd,var).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #logPoisson(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return logPoisson(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param epsilon epsilon + * @return output Log loss (NUMERIC type) + */ + public SDVariable logLoss(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, double epsilon) { + SDValidation.validateNumerical("logLoss", "label", label); + SDValidation.validateNumerical("logLoss", "predictions", predictions); + SDValidation.validateNumerical("logLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, weights, lossReduce, epsilon).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Log poisson loss: a loss function used for training classifiers. - * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels. - * - * @param name Name of the operation - * @param label Label array. Each value should be 0.0 or 1.0 - * @param predictions Predictions array (has to be log(x) of actual predictions) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("log poisson loss", "predictions", predictions); - validateNumerical("log poisson loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param epsilon epsilon + * @return output Log loss (NUMERIC type) + */ + public SDVariable logLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce, double epsilon) { + SDValidation.validateNumerical("logLoss", "label", label); + SDValidation.validateNumerical("logLoss", "predictions", predictions); + SDValidation.validateNumerical("logLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, weights, lossReduce, epsilon).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #logPoisson(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return logPoisson(name, label, predictions, null, lossReduce); - } + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @return output Log loss (NUMERIC type) + */ + public SDVariable logLoss(SDVariable label, SDVariable predictions) { + SDValidation.validateNumerical("logLoss", "label", label); + SDValidation.validateNumerical("logLoss", "predictions", predictions); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #logPoissonFull(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return logPoissonFull(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @return output Log loss (NUMERIC type) + */ + public SDVariable logLoss(String name, SDVariable label, SDVariable predictions) { + SDValidation.validateNumerical("logLoss", "label", label); + SDValidation.validateNumerical("logLoss", "predictions", predictions); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Log poisson loss: a loss function used for training classifiers. - * Implements {@code L = exp(c) - z * c + z * log(z) - z + 0.5 * log(2 * pi * z)} - * where c is log(predictions) and z is labels. - * - * @param name Name of the operation - * @param label Label array. Each value should be 0.0 or 1.0 - * @param predictions Predictions array (has to be log(x) of actual predictions) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("log poisson (full) loss", "predictions", predictions); - validateNumerical("log poisson (full) loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public SDVariable logPoisson(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, boolean full) { + SDValidation.validateNumerical("logPoisson", "label", label); + SDValidation.validateNumerical("logPoisson", "predictions", predictions); + SDValidation.validateNumerical("logPoisson", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, lossReduce, full).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #logPoissonFull(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return logPoissonFull(name, label, predictions, null, lossReduce); - } + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public SDVariable logPoisson(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce, boolean full) { + SDValidation.validateNumerical("logPoisson", "label", label); + SDValidation.validateNumerical("logPoisson", "predictions", predictions); + SDValidation.validateNumerical("logPoisson", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, lossReduce, full).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #meanPairwiseSquaredError(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return meanPairwiseSquaredError(name, label, predictions, null, lossReduce); - } + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public SDVariable logPoisson(SDVariable label, SDVariable predictions, SDVariable weights, + boolean full) { + SDValidation.validateNumerical("logPoisson", "label", label); + SDValidation.validateNumerical("logPoisson", "predictions", predictions); + SDValidation.validateNumerical("logPoisson", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays. - * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is: - * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
- * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] - * @return Loss variable, scalar output - */ - public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("main pairwise squared error loss", "predictions", predictions); - validateNumerical("mean pairwise squared error loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public SDVariable logPoisson(String name, SDVariable label, SDVariable predictions, + SDVariable weights, boolean full) { + SDValidation.validateNumerical("logPoisson", "label", label); + SDValidation.validateNumerical("logPoisson", "predictions", predictions); + SDValidation.validateNumerical("logPoisson", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #meanSquaredError(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return meanSquaredError(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable, scalar output (NUMERIC type) + */ + public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis. - * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default)) - * this is the mean squared error loss function. - * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("mean squared error loss", "predictions", predictions); - validateNumerical("mean squared error loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable, scalar output (NUMERIC type) + */ + public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #meanSquaredError(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return meanSquaredError(name, label, predictions, null, lossReduce); - } + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @return output Loss variable, scalar output (NUMERIC type) + */ + public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #sigmoidCrossEntropy(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return sigmoidCrossEntropy(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, SigmoidCrossEntropyLoss.DEFAULT_LABEL_SMOOTHING); - } + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @return output Loss variable, scalar output (NUMERIC type) + */ + public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions") - * and implements the binary cross entropy loss function. This implementation is numerically more stable than using - * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements: - * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))} - * though this is done in a mathematically equivalent but more numerical stable form.
- *
- * When label smoothing is > 0, the following label smoothing is used:
- *
-     * {@code numClasses = labels.size(1);
-     * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
-     * 
- * - * @param name Name of the operation - * @param label Label array - * @param predictionLogits Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictionLogits, - SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { - validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits); - validateNumerical("sigmoid cross entropy loss", "labels", label); - weights = getWeights(weights, name, predictionLogits); - SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + SDValidation.validateNumerical("meanSquaredError", "label", label); + SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #sigmoidCrossEntropy(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return sigmoidCrossEntropy(name, label, predictions, null, lossReduce, SigmoidCrossEntropyLoss.DEFAULT_LABEL_SMOOTHING); - } + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("meanSquaredError", "label", label); + SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #softmaxCrossEntropy(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return softmaxCrossEntropy(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, SoftmaxCrossEntropyLoss.DEFAULT_LABEL_SMOOTHING); - } + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights) { + SDValidation.validateNumerical("meanSquaredError", "label", label); + SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels; - * otherwise, the output is a scalar.
- *

- * When label smoothing is > 0, the following label smoothing is used:
- *

-     * {@code numClasses = labels.size(1);
-     * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
-     * 
- * - * @param name Name of the operation - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) - * @param logitPredictions Predictions array (pre-softmax) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 - * @return Loss variable - */ - public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable oneHotLabels, @NonNull SDVariable logitPredictions, - SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { - validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions); - validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels); - weights = getWeights(weights, name, logitPredictions); - SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("meanSquaredError", "label", label); + SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #softmaxCrossEntropy(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return softmaxCrossEntropy(name, label, predictions, null, lossReduce, SoftmaxCrossEntropyLoss.DEFAULT_LABEL_SMOOTHING); - } + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public SDVariable sigmoidCrossEntropy(SDVariable label, SDVariable predictionLogits, + SDVariable weights, LossReduce lossReduce, double labelSmoothing) { + SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #sparseSoftmaxCrossEntropy(String, SDVariable, SDVariable)} - */ - public SDVariable sparseSoftmaxCrossEntropy(@NonNull SDVariable logits, @NonNull SDVariable labels) { - return sparseSoftmaxCrossEntropy(null, logits, labels); - } + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public SDVariable sigmoidCrossEntropy(String name, SDVariable label, SDVariable predictionLogits, + SDVariable weights, LossReduce lossReduce, double labelSmoothing) { + SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * As per {@link #softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce)} but the labels variable - * is represented as an integer array instead of the equivalent one-hot array.
- * i.e., if logits are rank N, then labels have rank N-1 - * - * @param name Name of the output variable. May be null - * @param logits Logits array ("pre-softmax activations") - * @param labels Labels array. Must be an integer type. - * @return Softmax cross entropy - */ - public SDVariable sparseSoftmaxCrossEntropy(String name, @NonNull SDVariable logits, @NonNull SDVariable labels) { - validateFloatingPoint("sparse softmax cross entropy", "logits (predictions)", logits); - validateInteger("sparse softmax cross entropy", "labels", labels); - Preconditions.checkState(labels.dataType().isIntType(), "Labels variable must be an integer type: got %s", logits); - SDVariable result = f().lossSparseSoftmaxCrossEntropy(logits, labels); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable sigmoidCrossEntropy(SDVariable label, SDVariable predictionLogits, + SDVariable weights) { + SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * TODO - * - * @param targets - * @param inputs - * @param weights - * @return - */ - public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, - SDVariable weights) { - return weightedCrossEntropyWithLogits(null, targets, inputs, weights); - } + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable sigmoidCrossEntropy(String name, SDVariable label, SDVariable predictionLogits, + SDVariable weights) { + SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * TODO - * - * @param name - * @param targets - * @param inputs - * @param weights - * @return - */ - public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, SDVariable inputs, - SDVariable weights) { - validateFloatingPoint("weighted cross entropy with logits", "inputs", inputs); - validateNumerical("weighted cross entropy with logits", "targets", targets); - SDVariable result = f().weightedCrossEntropyWithLogits(targets, inputs, weights); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public SDVariable softmaxCrossEntropy(SDVariable oneHotLabels, SDVariable logitPredictions, + SDVariable weights, LossReduce lossReduce, double labelSmoothing) { + SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); + out.markAsLoss(); + return out; + } + + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public SDVariable softmaxCrossEntropy(String name, SDVariable oneHotLabels, + SDVariable logitPredictions, SDVariable weights, LossReduce lossReduce, + double labelSmoothing) { + SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable softmaxCrossEntropy(SDVariable oneHotLabels, SDVariable logitPredictions, + SDVariable weights) { + SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return out; + } + + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable softmaxCrossEntropy(String name, SDVariable oneHotLabels, + SDVariable logitPredictions, SDVariable weights) { + SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
+ * is represented as an integer array instead of the equivalent one-hot array.
+ * i.e., if logits are rank N, then labels have rank N-1
+ * + * @param logits Logits array ("pre-softmax activations") (NUMERIC type) + * @param labels Labels array. Must be an integer type. (INT type) + * @return output Softmax cross entropy (NUMERIC type) + */ + public SDVariable sparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels) { + SDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits); + SDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(sd,logits, labels).outputVariable(); + out.markAsLoss(); + return out; + } + + /** + * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
+ * is represented as an integer array instead of the equivalent one-hot array.
+ * i.e., if logits are rank N, then labels have rank N-1
+ * + * @param name name May be null. Name for the output variable + * @param logits Logits array ("pre-softmax activations") (NUMERIC type) + * @param labels Labels array. Must be an integer type. (INT type) + * @return output Softmax cross entropy (NUMERIC type) + */ + public SDVariable sparseSoftmaxCrossEntropy(String name, SDVariable logits, SDVariable labels) { + SDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits); + SDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(sd,logits, labels).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Weighted cross entropy loss with logits
+ * + * @param targets targets array (NUMERIC type) + * @param inputs input array (NUMERIC type) + * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, + SDVariable weights) { + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets); + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs); + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd,targets, inputs, weights).outputVariable(); + out.markAsLoss(); + return out; + } + + /** + * Weighted cross entropy loss with logits
+ * + * @param name name May be null. Name for the output variable + * @param targets targets array (NUMERIC type) + * @param inputs input array (NUMERIC type) + * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, + SDVariable inputs, SDVariable weights) { + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets); + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs); + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd,targets, inputs, weights).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 1e038e193..f4a490813 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,2539 +14,2955 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; -import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; -import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix; -import org.nd4j.linalg.api.ops.impl.shape.Eye; import org.nd4j.linalg.indexing.conditions.Condition; -import java.util.List; - -import static org.nd4j.autodiff.samediff.ops.SDValidation.*; - -/** - * SameDiff math operations
- * Accessible via {@link SameDiff#math()} - * - * @author Alex Black - */ public class SDMath extends SDOps { - public SDMath(SameDiff sameDiff) { - super(sameDiff); - } - - /** - * Elementwise absolute value operation: out = abs(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable abs(SDVariable x) { - return abs(null, x); - } - - /** - * Elementwise absolute value operation: out = abs(x) - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable abs(String name, SDVariable x) { - validateNumerical("abs", x); - SDVariable result = f().abs(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable acos(SDVariable x) { - return acos(null, x); - } - - /** - * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable acos(String name, SDVariable x) { - validateNumerical("acos", x); - SDVariable result = f().acos(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable acosh(SDVariable x) { - return acosh(null, x); - } - - /** - * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable acosh(String name, SDVariable x) { - validateNumerical("acosh", x); - SDVariable result = f().acosh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amax(SDVariable in, int... dimensions) { - return amax(null, in, dimensions); - } - - /** - * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amax(String name, SDVariable in, int... dimensions) { - validateNumerical("amax", in); - SDVariable ret = f().amax(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amean(SDVariable in, int... dimensions) { - validateNumerical("amean", in); - return amean(null, in, dimensions); - } - - /** - * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amean(String name, SDVariable in, int... dimensions) { - validateNumerical("amean", in); - SDVariable ret = f().amean(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amin(SDVariable in, int... dimensions) { - return amin(null, in, dimensions); - } - - /** - * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amin(String name, SDVariable in, int... dimensions) { - validateNumerical("amin", in); - SDVariable ret = f().amin(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Boolean AND operation: elementwise (x != 0) && (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable and(SDVariable x, SDVariable y) { - return and(null, x, y); - } - - /** - * Boolean AND operation: elementwise (x != 0) && (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable and(String name, SDVariable x, SDVariable y) { - validateBool("boolean and", x, y); - SDVariable result = f().and(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable asin(SDVariable x) { - return asin(null, x); - } - - /** - * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable asin(String name, SDVariable x) { - validateNumerical("asin", x); - SDVariable result = f().asin(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable asinh(SDVariable x) { - return asinh(null, x); - } - - /** - * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable asinh(String name, SDVariable x) { - validateNumerical("asinh", x); - SDVariable result = f().asinh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable asum(SDVariable in, int... dimensions) { - return asum(null, in, dimensions); - } - - /** - * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable asum(String name, SDVariable in, int... dimensions) { - validateNumerical("asum", in); - SDVariable ret = f().asum(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable atan(SDVariable x) { - return atan(null, x); - } - - /** - * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable atan(String name, SDVariable x) { - validateNumerical("atan", x); - SDVariable result = f().atan(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y). - * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result - * - * @param y Input Y variable - * @param x Input X variable - * @return Output variable - */ - public SDVariable atan2(SDVariable y, SDVariable x) { - return atan2(null, y, x); - } - - /** - * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y). - * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result - * - * @param name Name of the output variable - * @param y Input Y variable - * @param x Input X variable - * @return Output variable - */ - public SDVariable atan2(String name, SDVariable y, SDVariable x) { - validateNumerical("atan2", y, x); - SDVariable ret = f().atan2(y, x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable atanh(SDVariable x) { - return atanh(null, x); - } - - /** - * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable atanh(String name, SDVariable x) { - validateNumerical("atanh", x); - SDVariable result = f().atanh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise ceiling function: out = ceil(x). - * Rounds each value up to the nearest integer value (if not already an integer) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable ceil(SDVariable x) { - return ceil(null, x); - } - - /** - * Element-wise ceiling function: out = ceil(x). - * Rounds each value up to the nearest integer value (if not already an integer) - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable ceil(String name, SDVariable x) { - validateFloatingPoint("ceil", x); - SDVariable ret = f().ceil(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Clipping by L2 norm
- * if l2Norm(x) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in) - * - * @param x Input variable - * @param clipValue Clipping value (maximum l2 norm) - * @return Output variable - */ - public SDVariable clipByNorm(SDVariable x, double clipValue) { - return clipByNorm(null, x, clipValue); - } - - /** - * Clipping by L2 norm
- * if l2Norm(x) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in) - * - * @param name Name of the output variable - * @param x Input variable - * @param clipValue Clipping value (maximum l2 norm) - * @return Output variable - */ - public SDVariable clipByNorm(String name, SDVariable x, double clipValue) { - validateFloatingPoint("clip by norm", x); - SDVariable ret = f().clipByNorm(x, clipValue); - return updateVariableNameAndReference(ret, name); - } - - /** - * Clipping by L2 norm, optionally along dimension(s)
- * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according - * to the corresponding l2Norm along the specified dimensions - * - * @param x Input variable - * @param clipValue Clipping value (maximum l2 norm) - * @param dimensions If not specified, all dimensions are used - * @return Output variable - */ - public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { - return clipByNorm(null, x, clipValue, dimensions); - } - - /** - * Clipping by L2 norm, optionally along dimension(s)
- * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according - * to the corresponding l2Norm along the specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param clipValue Clipping value (maximum l2 norm) - * @param dimensions If not specified, all dimensions are used - * @return Output variable - */ - public SDVariable clipByNorm(String name, SDVariable x, double clipValue, int... dimensions) { - validateFloatingPoint("clip by norm", x); - SDVariable ret = f().clipByNorm(x, clipValue, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise clipping function:
- * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
- * out[i] = clipValueMin if in[i] < clipValueMin
- * out[i] = clipValueMax if in[i] > clipValueMax
- * - * @param x Input variable - * @param clipValueMin Minimum value for clipping - * @param clipValueMax Maximum value for clipping - * @return Output variable - */ - public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { - return clipByValue(null, x, clipValueMin, clipValueMax); - } - - /** - * Element-wise clipping function:
- * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
- * out[i] = clipValueMin if in[i] < clipValueMin
- * out[i] = clipValueMax if in[i] > clipValueMax
- * - * @param name Name of the output variable - * @param x Input variable - * @param clipValueMin Minimum value for clipping - * @param clipValueMax Maximum value for clipping - * @return Output variable - */ - public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, double clipValueMax) { - validateNumerical("clip by value", x); - SDVariable ret = f().clipByValue(x, clipValueMin, clipValueMax); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #confusionMatrix(String, SDVariable, SDVariable) - */ - public SDVariable confusionMatrix(SDVariable labels, SDVariable predictions) { - return confusionMatrix((String) null, labels, predictions); - } - - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred) { - return confusionMatrix(name, labels, pred, ConfusionMatrix.DEFAULT_DTYPE); - } - - /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of - * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
- * [1, 0, 0]
- * [0, 1, 1]
- * [0, 0, 0]
- * - * @param name Name of the output variable - * @param labels Labels - 1D array of integer values representing label values - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels - * @return Output variable (2D, shape [numClasses, numClasses}) - */ - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, DataType dataType) { - validateInteger("confusionMatrix", "labels", labels); - validateInteger("confusionMatrix", "prediction", pred); - SDVariable result = f().confusionMatrix(labels, pred, dataType); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #confusionMatrix(String, SDVariable, SDVariable, Integer) - */ - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses) { - return confusionMatrix(null, labels, pred, numClasses); - } - - /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of - * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
- * [1, 0, 0, 0]
- * [0, 1, 1, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
- * - * @param name Name of the output variable - * @param labels Labels - 1D array of integer values representing label values - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels - * @param numClasses Number of classes - * @return Output variable (2D, shape [numClasses, numClasses}) - */ - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses) { - validateInteger("confusionMatrix", "labels", labels); - validateInteger("confusionMatrix", "prediction", pred); - SDVariable result = f().confusionMatrix(labels, pred, numClasses); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #confusionMatrix(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { - return confusionMatrix(null, labels, pred, weights); - } - - /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of - * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3] - * [1, 0, 0]
- * [0, 3, 2]
- * [0, 0, 0]
- * - * @param name Name of the output variable - * @param labels Labels - 1D array of integer values representing label values - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of - * each prediction. Must be same length as both labels and predictions arrays - * @return Output variable (2D, shape [numClasses, numClasses}) - */ - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, SDVariable weights) { - validateInteger("confusionMatrix", "labels", labels); - validateInteger("confusionMatrix", "prediction", pred); - validateNumerical("confusionMatrix", "weights", weights); - SDVariable result = f().confusionMatrix(labels, pred, weights); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #confusionMatrix(String, SDVariable, SDVariable, Integer, SDVariable) - */ - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { - return confusionMatrix(null, labels, pred, numClasses, weights); - } - - /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of - * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3] - * [1, 0, 0, 0]
- * [0, 3, 2, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
- * - * @param name Name of the output variable - * @param labels Labels - 1D array of integer values representing label values - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of - * each prediction. Must be same length as both labels and predictions arrays - * @return Output variable (2D, shape [numClasses, numClasses}) - */ - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { - validateInteger("confusionMatrix", "labels", labels); - validateInteger("confusionMatrix", "prediction", pred); - validateNumerical("confusionMatrix", "weights", weights); - SDVariable result = f().confusionMatrix(labels, pred, numClasses, weights); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise cosine operation: out = cos(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable cos(SDVariable x) { - return cos(null, x); - } - - /** - * Elementwise cosine operation: out = cos(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable cos(String name, SDVariable x) { - validateNumerical("cos", x); - SDVariable result = f().cos(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable cosh(SDVariable x) { - return cosh(null, x); - } - - /** - * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable cosh(String name, SDVariable x) { - validateNumerical("cosh", x); - SDVariable result = f().cosh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #cosineDistance(String, SDVariable, SDVariable, int...) - */ - public SDVariable cosineDistance(SDVariable x, SDVariable y, int... dimensions) { - return cosineDistance(null, x, y, dimensions); - } - - /** - * Cosine distance reduction operation. The output contains the cosine distance for each - * tensor/subset along the specified dimensions:
- * out = 1.0 - cosineSimilarity(x,y)
- * See {@link #cosineSimilarity(String, SDVariable, SDVariable, int...)} - * - * @param name Name of the output variable - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable cosineDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("cosine distance", x, y); - SDVariable result = f().cosineDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #cosineSimilarity(String, SDVariable, SDVariable, int...) - */ - public SDVariable cosineSimilarity(SDVariable x, SDVariable y, int... dimensions) { - return cosineSimilarity(sd.generateNewVarName(CosineSimilarity.OP_NAME, 0), x, y, dimensions); - } - - /** - * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset - * along the specified dimensions:
- * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2) - * - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable cosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("cosine similarity", x, y); - SDVariable cosim = f().cosineSimilarity(x, y, dimensions); - return updateVariableNameAndReference(cosim, name); - } - - /** - * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0) - * - * @param input Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable countNonZero(SDVariable input, int... dimensions) { - return countNonZero(null, input, dimensions); - } - - /** - * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0) - * - * @param name Name of the output variable - * @param input Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable countNonZero(String name, SDVariable input, int... dimensions) { - validateNumerical("countNonZero", input); - SDVariable res = f().countNonZero(input, dimensions); - return updateVariableNameAndReference(res, name); - } - - /** - * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0) - * - * @param input Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable countZero(SDVariable input, int... dimensions) { - return countZero(null, input, dimensions); - } - - /** - * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0) - * - * @param name Name of the output variable - * @param input Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable countZero(String name, SDVariable input, int... dimensions) { - validateNumerical("countNonZero", input); - SDVariable res = f().countZero(input, dimensions); - return updateVariableNameAndReference(res, name); - } - - /** - * @see #cross(String, SDVariable, SDVariable) - */ - public SDVariable cross(SDVariable a, SDVariable b) { - return cross(null, a, b); - } - - /** - * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta). - * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3 - * - * @param a First input - * @param b Second input - * @return Element-wise cross product - */ - public SDVariable cross(String name, SDVariable a, SDVariable b) { - validateNumerical("cross", a, b); - SDVariable ret = f().cross(a, b); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise cube function: out = x^3 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable cube(SDVariable x) { - return cube(null, x); - } - - /** - * Element-wise cube function: out = x^3 - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable cube(String name, SDVariable x) { - validateNumerical("cube", x); - SDVariable result = f().cube(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #diag(String, SDVariable) - */ - public SDVariable diag(SDVariable x) { - return diag(null, x); - } - - /** - * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
- * For example, if input = [1,2,3], then output is given by:
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
- *
- * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k]. - * i.e., for input rank R, output has rank 2R - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable diag(String name, SDVariable x) { - SDVariable ret = f().diag(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #diagPart(String, SDVariable) - */ - public SDVariable diagPart(SDVariable x) { - return diagPart(null, x); - } - - /** - * Extract the diagonal part from the input array.
- * If input is
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
- * then output is [1, 2, 3].
- * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k] - * - * @param x Input variable - * @return Diagonal part of the input - * @see #diag(String, SDVariable) - */ - public SDVariable diagPart(String name, SDVariable x) { - SDVariable ret = f().diagPart(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Entropy reduction: -sum(x * log(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce on (null/empty for full array) - * @return Output variable - */ - public SDVariable entropy(SDVariable in, int... dimensions) { - return entropy(null, in, dimensions); - } - - /** - * Entropy reduction: -sum(x * log(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce on (null/empty for full array) - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable entropy(String name, SDVariable in, int... dimensions) { - validateNumerical("entropy reduction", in); - SDVariable ret = f().entropy(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise Gaussian error function - out = erf(in) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable erf(SDVariable x) { - return erf(null, x); - } - - /** - * Element-wise Gaussian error function - out = erf(in) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable erf(String name, SDVariable x) { - validateNumerical("erf (error function)", x); - SDVariable ret = f().erf(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable erfc(SDVariable x) { - return erfc(null, x); - } - - /** - * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in) - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable erfc(String name, SDVariable x) { - validateNumerical("erfc", x); - SDVariable ret = f().erfc(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #euclideanDistance(String, SDVariable, SDVariable, int...) - */ - public SDVariable euclideanDistance(SDVariable x, SDVariable y, int... dimensions) { - return euclideanDistance(sd.generateNewVarName(EuclideanDistance.OP_NAME, 0), x, y, dimensions); - } - - /** - * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each - * tensor/subset along the specified dimensions:
- * out = sqrt( sum_i (x[i] - y[i])^2 ) - * - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable euclideanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("euclidean distance", x, y); - SDVariable result = f().euclideanDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise exponent function: out = exp(x) = 2.71828...^x - * - * @param x Input variable - * @return Output variable - */ - public SDVariable exp(SDVariable x) { - return exp(null, x); - } - - /** - * Elementwise exponent function: out = exp(x) = 2.71828...^x - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable exp(String name, SDVariable x) { - validateNumerical("exp", x); - SDVariable result = f().exp(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x - * - * @param x Input variable - * @return Output variable - */ - public SDVariable expm1(SDVariable x) { - return expm1(null, x); - } - - /** - * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable expm1(String name, SDVariable x) { - validateNumerical("expm1", x); - SDVariable result = f().expm1(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Generate a square identity matrix with the specified number of rows. - * - * @param rows Number of rows (and columns) - * @return SDVariable with an identity matrix array - */ - public SDVariable eye(int rows) { - return eye(rows, rows); - } - - /** - * Generate an identity matrix with the specified number of rows and columns. - * - * @param rows Number of rows - */ - public SDVariable eye(String name, int rows) { - return eye(name, rows, rows); - } - - /** - * @see #eye(String, int, int) - */ - public SDVariable eye(int rows, int cols) { - return eye(null, rows, cols); - } - - /** - * As per {@link #eye(String, int, int, DataType)} but with the default datatype, {@link Eye#DEFAULT_DTYPE} - */ - public SDVariable eye(String name, int rows, int cols) { - return eye(name, rows, cols, Eye.DEFAULT_DTYPE); - } - - /** - * Generate an identity matrix with the specified number of rows and columns - * Example:
- *
-     * {@code SDVariable eye = eye(3,2)
-     * eye:
-     * [ 1, 0]
-     * [ 0, 1]
-     * [ 0, 0]}
-     * 
- * - * @param name Name of the new SDVariable - * @param rows Number of rows - * @param cols Number of columns - * @return SDVaribable identity matrix - */ - public SDVariable eye(String name, int rows, int cols, DataType dataType) { - return eye(name, rows, cols, dataType); - } - - /** - * see {@link #eye(String, int, int, DataType, int...)} - */ - public SDVariable eye(int rows, int cols, DataType dataType, int... batchDimension) { - return eye(null, rows, cols, dataType, batchDimension); - } - - /** - * Generate an identity matrix with the specified number of rows and columns, with optional leading dims
- * Example:
- * batchShape: [3,3]
- * numRows: 2
- * numCols: 4
- * returns a tensor of shape (3, 3, 2, 4) that consists of 3 * 3 batches of (2,4)-shaped identity matrices:
- * 1 0 0 0
- * 0 1 0 0
- * - * @param rows Number of rows - * @param cols Number of columns - * @param batchDimension Batch dimensions. May be null - */ - public SDVariable eye(String name, int rows, int cols, DataType dataType, int... batchDimension) { - SDVariable eye = new Eye(sd, rows, cols, dataType, batchDimension).outputVariables()[0]; - return updateVariableNameAndReference(eye, name); - } - - /** - * As per {@link #eye(int, int, DataType, int...)} bit with the number of rows/columns specified as scalar SDVariables, - * and the batch dimension specified as a 1D SDVariable - */ - public SDVariable eye(SDVariable rows, SDVariable cols, SDVariable batchDimension) { - return eye(null, rows, cols, batchDimension); - } - - /** - * As per {@link #eye(String, int, int, int...)} bit with the number of rows/columns specified as scalar SDVariables, - * and the batch dimension specified as a 1D SDVariable - */ - public SDVariable eye(String name, SDVariable rows, SDVariable cols, SDVariable batchDimension) { - SDVariable eye = new Eye(sd, rows, cols, batchDimension).outputVariable(); - return updateVariableNameAndReference(eye, name); - } - - /** - * As per {@link #eye(String, int, int)} bit with the number of rows/columns specified as scalar SDVariables - */ - public SDVariable eye(String name, SDVariable rows, SDVariable cols) { - SDVariable eye = new Eye(sd, rows, cols).outputVariables()[0]; - return updateVariableNameAndReference(eye, name); - } - - /** - * As per {@link #eye(int, int)} bit with the number of rows/columns specified as scalar SDVariables - */ - public SDVariable eye(SDVariable rows, SDVariable cols) { - SDVariable eye = new Eye(sd, rows, cols).outputVariables()[0]; - return updateVariableNameAndReference(eye, null); - } - - /** - * As per {@link #eye(String, int)} but with the number of rows specified as a scalar SDVariable - */ - public SDVariable eye(String name, SDVariable rows) { - SDVariable eye = new Eye(sd, rows).outputVariables()[0]; - return updateVariableNameAndReference(eye, name); - } - - /** - * As per {@link #eye(int)} but with the number of rows specified as a scalar SDVariable - */ - public SDVariable eye(SDVariable rows) { - SDVariable eye = new Eye(sd, rows).outputVariables()[0]; - return updateVariableNameAndReference(eye, null); - } - - /** - * @see #firstIndex(String, SDVariable, Condition, int...) - */ - public SDVariable firstIndex(SDVariable in, Condition condition, int... dimensions) { - return firstIndex(null, in, condition, dimensions); - } - - /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each - * slice along the specified dimensions) - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable firstIndex(String name, SDVariable in, Condition condition, int... dimensions) { - return firstIndex(name, in, condition, false, dimensions); - } - - /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each - * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable firstIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - SDVariable ret = f().firstIndex(in, condition, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #firstIndex(String, SDVariable, Condition, boolean, int...) - */ - public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return firstIndex(null, in, condition, keepDims, dimensions); - } - - /** - * Element-wise floor function: out = floor(x). - * Rounds each value down to the nearest integer value (if not already an integer) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable floor(SDVariable x) { - return floor(null, x); - } - - /** - * Element-wise floor function: out = floor(x). - * Rounds each value down to the nearest integer value (if not already an integer) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable floor(String name, SDVariable x) { - validateFloatingPoint("floor", x); - SDVariable result = f().floor(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #hammingDistance(String, SDVariable, SDVariable, int...) - */ - public SDVariable hammingDistance(SDVariable x, SDVariable y, int... dimensions) { - return hammingDistance(null, x, y, dimensions); - } - - /** - * Hamming distance reduction operation. The output contains the cosine distance for each - * tensor/subset along the specified dimensions:
- * out = count( x[i] != y[i] ) - * - * @param name Name of the output variable - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable hammingDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("hamming distance reduction", x, y); - SDVariable result = f().hammingDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Index of the max absolute value: argmax(abs(in)) - * - * @see SameDiff#argmax(SDVariable, int...) - */ - public SDVariable iamax(SDVariable in, int... dimensions) { - return iamax(null, in, dimensions); - } - - /** - * Index of the max absolute value: argmax(abs(in)) - * - * @see SameDiff#argmax(String, SDVariable, boolean, int...) - */ - public SDVariable iamax(String name, SDVariable in, int... dimensions) { - return iamax(name, in, false, dimensions); - } - - /** - * Index of the max absolute value: argmax(abs(in)) - * - * @see SameDiff#argmax(String, SDVariable, boolean, int...) - */ - public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { - validateNumerical("iamax", in); - SDVariable ret = f().iamax(in, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Index of the max absolute value: argmax(abs(in)) - * - * @see SameDiff#argmax(String, SDVariable, boolean, int...) - */ - public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { - return iamax(null, in, keepDims, dimensions); - } - - /** - * Index of the min absolute value: argmin(abs(in)) - * - * @see SameDiff#argmin(String, SDVariable, boolean, int...) - */ - public SDVariable iamin(SDVariable in, int... dimensions) { - return iamin(null, in, dimensions); - } - - /** - * Index of the min absolute value: argmin(abs(in)) - * - * @see SameDiff#argmin(String, SDVariable, boolean, int...) - */ - public SDVariable iamin(String name, SDVariable in, int... dimensions) { - return iamin(name, in, false, dimensions); - } - - /** - * Index of the min absolute value: argmin(abs(in)) - * - * @see SameDiff#argmin(String, SDVariable, boolean, int...) - */ - public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { - validateNumerical("iamin", in); - SDVariable ret = f().iamin(in, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Index of the min absolute value: argmin(abs(in)) - * - * @see SameDiff#argmin(String, SDVariable, boolean, int...) - */ - public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { - return iamin(null, in, keepDims, dimensions); - } - - /** - * Is finite operation: elementwise isFinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isFinite(SDVariable x) { - return isFinite(null, x); - } - - /** - * Is finite operation: elementwise isFinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Output variable name - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isFinite(String name, SDVariable x) { - validateFloatingPoint("isFinite", x); - SDVariable result = f().isFinite(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Is infinite operation: elementwise isInfinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isInfinite(SDVariable x) { - return isInfinite(null, x); - } - - /** - * Is infinite operation: elementwise isInfinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Output variable name - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isInfinite(String name, SDVariable x) { - validateFloatingPoint("isInfinite", x); - SDVariable result = f().isInfinite(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Is maximum operation: elementwise x == max(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isMax(SDVariable x) { - return isMax(null, x); - } - - /** - * Is maximum operation: elementwise x == max(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isMax(String name, SDVariable x) { - validateNumerical("isMax", x); - SDVariable ret = f().isMax(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Is Not a Number operation: elementwise isNaN(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isNaN(SDVariable x) { - return isNaN(null, x); - } - - /** - * Is Not a Number operation: elementwise isNaN(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Output variable name - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isNaN(String name, SDVariable x) { - validateFloatingPoint("isNaN", x); - SDVariable result = f().isNaN(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Is the array non decreasing?
- * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared - * in 'c' (row major) order - * - * @param x Input variable - * @return Scalar variable with value 1 if non-decreasing, or 0 otherwise - */ - public SDVariable isNonDecreasing(SDVariable x) { - return isNonDecreasing(null, x); - } - - /** - * Is the array non decreasing?
- * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared - * in 'c' (row major) order - * - * @param name Output name - * @param x Input variable - * @return Scalar variable with value 1 if non-decreasing, or 0 otherwise - */ - public SDVariable isNonDecreasing(String name, SDVariable x) { - validateNumerical("isNonDecreasing", x); - SDVariable result = f().isNonDecreasing(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Is the array strictly increasing?
- * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared - * in 'c' (row major) order - * - * @param x Input variable - * @return Scalar variable with value 1 if strictly increasing, or 0 otherwise - */ - public SDVariable isStrictlyIncreasing(SDVariable x) { - return isStrictlyIncreasing(null, x); - - } - - /** - * Is the array strictly increasing?
- * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared - * in 'c' (row major) order - * - * @param name Output variable name - * @param x Input variable - * @return Scalar variable with value 1 if strictly increasing, or 0 otherwise - */ - public SDVariable isStrictlyIncreasing(String name, SDVariable x) { - validateNumerical("isStrictlyIncreasing", x); - SDVariable result = f().isStrictlyIncreasing(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Jaccard similarity reduction operation. The output contains the Jaccard distance for each - * tensor along the specified dimensions. - * - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate Jaccard similarity over - * @return Output variable - */ - public SDVariable jaccardDistance(SDVariable x, SDVariable y, int... dimensions) { - return jaccardDistance(null, x, y, dimensions); - } - - /** - * Jaccard similarity reduction operation. The output contains the Jaccard distance for each - * tensor along the specified dimensions. - * - * @param name Name of the output variable - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate Jaccard similarity over - * @return Output variable - */ - public SDVariable jaccardDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("Jaccard distance reduction", x, y); - SDVariable result = f().jaccardDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #lastIndex(String, SDVariable, Condition, int...) - */ - public SDVariable lastIndex(SDVariable in, Condition condition, int... dimensions) { - return lastIndex(null, in, condition, dimensions); - } - - /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each - * slice along the specified dimensions) - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable lastIndex(String name, SDVariable in, Condition condition, int... dimensions) { - return lastIndex(name, in, condition, false, dimensions); - } - - /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each - * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable lastIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - SDVariable ret = f().lastIndex(in, condition, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #lastIndex(String, SDVariable, Condition, boolean, int...) - */ - public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return lastIndex(null, in, condition, keepDims, dimensions); - } - - /** - * List diff operation computes the difference between two 1d arrays, and also returns the indices - i.e., the positions - * where the output appears in the input X.
- * For inputs X and Y, listDiff returns everything in X but not in Y.
- * For example, if {@code X=[1,10,3,7,6]} and {@code Y=[10, 6]), then: - * output 0 (difference) = {@code [1,3,7]}
- * output 1 (indices) = {@code [0, 2, 3]}
- * @param x Input 1 - input values - * @param y Input 2 - values to remove - * @return 2 outputs - difference, and indices - */ - public SDVariable[] listDiff(SDVariable x, SDVariable y){ - return f().listdiff(x, y); - } - - /** - * Element-wise logarithm function (base e - natural logarithm): out = log(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable log(SDVariable x) { - return log(null, x); - } - - /** - * Element-wise logarithm function (base e - natural logarithm): out = log(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable log(String name, SDVariable x) { - validateNumerical("log", x); - SDVariable result = f().log(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise logarithm function (with specified base): out = log_{base}(x) - * - * @param in Input variable - * @param base Logarithm base - * @return Output variable - */ - public SDVariable log(SDVariable in, double base) { - return log(null, in, base); - } - - /** - * Element-wise logarithm function (with specified base): out = log_{base}(x) - * - * @param name Name of the output variable - * @param in Input variable - * @param base Logarithm base - * @return Output variable - */ - public SDVariable log(String name, SDVariable in, double base) { - validateNumerical("log", in); - SDVariable ret = f().log(in, base); - return updateVariableNameAndReference(ret, name); - } - - /** - * Elementwise natural logarithm function: out = log_e (1 + x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable log1p(SDVariable x) { - return log1p(null, x); - } - - /** - * Elementwise natural logarithm function: out = log_e (1 + x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable log1p(String name, SDVariable x) { - validateNumerical("log1p", x); - SDVariable result = f().log1p(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Log entropy reduction: log(-sum(x * log(x))) - * - * @param in Input variable - * @param dimensions Dimensions to reduce on (null for full array) - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable logEntropy(SDVariable in, int... dimensions) { - return logEntropy(null, in, dimensions); - } - - /** - * Log entropy reduction: log(-sum(x * log(x))) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce on (null for full array) - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable logEntropy(String name, SDVariable in, int... dimensions) { - validateNumerical("logEntropy reduction", in); - SDVariable ret = f().logEntropy(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Log-sum-exp reduction (optionally along dimension). - * Computes log(sum(exp(x)) - * - * @param input Input variable - * @param dimensions Optional dimensions to reduce along - * @return Output variable - */ - public SDVariable logSumExp(SDVariable input, int... dimensions) { - return logSumExp(null, input, dimensions); - } - - /** - * Log-sum-exp reduction (optionally along dimension). - * Computes log(sum(exp(x)) - * - * @param name Name of the output variable - * @param input Input variable - * @param dimensions Optional dimensions to reduce along - * @return Output variable - */ - public SDVariable logSumExp(String name, SDVariable input, int... dimensions) { - return logSumExp(name, input, false, dimensions); - } - - public SDVariable logSumExp(String name, SDVariable input, boolean keepDims, int... dimensions) { - validateNumerical("logSumExp reduction", input); - SDVariable ret = f().logSumExp(input, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #manhattanDistance(String, SDVariable, SDVariable, int...) - */ - public SDVariable manhattanDistance(SDVariable x, SDVariable y, int... dimensions) { - return manhattanDistance(sd.generateNewVarName(ManhattanDistance.OP_NAME, 0), x, y, dimensions); - } - - /** - * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each - * tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i]-y[i]) - * - * @param name Name of the output variable - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable manhattanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("manhattan distance", x, y); - SDVariable result = f().manhattanDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #matrixDeterminant(String, SDVariable) - */ - public SDVariable matrixDeterminant(SDVariable in) { - return matrixDeterminant(null, in); - } - - /** - * Matrix determinant op. For 2D input, this returns the standard matrix determinant. - * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each - * shape [m,m] sub-matrix. - * - * @param name Name of the output variable - * @param in Input - * @return Matrix determinant variable - */ - public SDVariable matrixDeterminant(String name, SDVariable in) { - validateNumerical("matrix determinant", in); - SDVariable ret = f().matrixDeterminant(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #matrixInverse(String, SDVariable) - */ - public SDVariable matrixInverse(SDVariable in) { - return matrixInverse(null, in); - } - - /** - * Matrix inverse op. For 2D input, this returns the standard matrix inverse. - * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each - * shape [m,m] sub-matrix. - * - * @param name Name of the output variable - * @param in Input - * @return Matrix inverse variable - */ - public SDVariable matrixInverse(String name, SDVariable in) { - validateFloatingPoint("matrix inverse", in); - SDVariable ret = f().matrixInverse(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * Merge add function: merges an arbitrary number of equal shaped arrays using elementwise addition: - * out = sum_i in[i] - * - * @param x Input variables - * @return Output variable - */ - public SDVariable mergeAdd(SDVariable... x) { - return mergeAdd(null, x); - } - - /** - * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition: - * out = sum_i in[i] - * - * @param name Name of the output variable - * @param inputs Input variables - * @return Output variable - */ - public SDVariable mergeAdd(String name, SDVariable... inputs) { - validateSameType("mergeAdd", true, inputs); - SDVariable ret = f().mergeAdd(inputs); - return updateVariableNameAndReference(ret, name); - } - - /** - * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation: - * out = mean_i in[i] - * - * @param inputs Input variables - * @return Output variable - */ - public SDVariable mergeAvg(SDVariable... inputs) { - return mergeAvg(null, inputs); - } - - /** - * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation: - * out = mean_i in[i] - * - * @param name Name of the output variable - * @param inputs Input variables - * @return Output variable - */ - public SDVariable mergeAvg(String name, SDVariable... inputs) { - validateSameType("mergeAvg", true, inputs); - SDVariable ret = f().mergeAvg(inputs); - return updateVariableNameAndReference(ret, name); - } - - /** - * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation: - * out = max_i in[i] - * - * @param x Input variables - * @return Output variable - */ - public SDVariable mergeMax(SDVariable... x) { - return mergeMax(null, x); - } - - /** - * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation: - * out = max_i in[i] - * - * @param inputs Input variables - * @return Output variable - */ - public SDVariable mergeMax(String name, SDVariable... inputs) { - validateSameType("mergeMax", true, inputs); - SDVariable ret = f().mergeMax(inputs); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #meshgrid(List, SDVariable...) - */ - public SDVariable[] meshgrid(SDVariable... inputs) { - return meshgrid(null, inputs); - } - - /** - * Broadcast the 1D input variables onto an n-dimensional grid.
- * The resulting variable can be used for example for evaluating functions at all locations on a grid.
- * Example:
- *
-     * {@code input1 = [1, 2, 3]
-     * input2 = [4, 5, 6]
-     * SDVariable[] out = meshgrid(input1, input2)
-     * out[0]:
-     * [ 1, 2, 3]
-     * [ 1, 2, 3]
-     * [ 1, 2, 3]
-     *
-     * out[1]:
-     * [ 4, 4, 4]
-     * [ 5, 5, 5]
-     * [ 6, 6, 6]}
-     * 
- *
- * - * @param names List of names for the output variables. Must have exactly N names for N input arrays - * @param inputs N x 1D input variables - * @return an array of exactly N SDVariables (for N inputs), of rank N - */ - public SDVariable[] meshgrid(List names, SDVariable... inputs) { - return meshgrid(names, true, inputs); - } - - /** - * @see #meshgrid(List, SDVariable...) - */ - public SDVariable[] meshgrid(List names, boolean cartesian, SDVariable... inputs) { - Preconditions.checkState(names == null || names.size() == inputs.length, - "Got %s names but %s inputs", (names == null ? 0 : names.size()), inputs.length); - validateSameType("meshgrid", false, inputs); - SDVariable[] ret = f().meshgrid(cartesian, inputs); - for (int i = 0; i < ret.length; i++) { - ret[i] = updateVariableNameAndReference(ret[i], names == null ? null : names.get(i)); - } - return ret; - } - - /** - * @see #moments(String[], SDVariable, int...) - */ - public SDVariable[] moments(SDVariable input, int... axes) { - return moments(null, input, axes); - } - - /** - * Calculate the mean and (population) variance for the input variable, for the specified axis - * - * @param name Name of the output variables. Can be null; if non-null, must be length 2 - * @param input Input to calculate moments for - * @param axes Dimensions to perform calculation over - * @return Mean and variance variables - */ - public SDVariable[] moments(String[] name, SDVariable input, int... axes) { - validateNumerical("moments", input); - SDVariable[] res = f().moments(input, axes); - return sd.updateVariableNamesAndReferences(res, name); - } - - /** - * Elementwise negative operation: out = -x - * - * @param x Input variable - * @return Output variable - */ - public SDVariable neg(SDVariable x) { - return neg(null, x); - } - - /** - * Elementwise negative operation: out = -x - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable neg(String name, SDVariable x) { - validateNumerical("neg", x); - SDVariable result = f().neg(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #normalizeMoments(String[], SDVariable, SDVariable, SDVariable, double) - */ - public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) { - return normalizeMoments(null, counts, means, variances, shift); - } - - /** - * Calculate the mean and variance from the sufficient statistics - * - * @param name Name of the output variables. Can be null; if non-null, must be length 2 - * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics - * @param means Mean-value sufficient statistics: this is the SUM of all data values - * @param variances Variaance sufficient statistics: this is the squared sum of all data values - * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) - * @return Output variables: mean and population variance - */ - public SDVariable[] normalizeMoments(String[] name, SDVariable counts, SDVariable means, SDVariable variances, - double shift) { - SDVariable[] res = f().normalizeMoments(counts, means, variances, shift); - return sd.updateVariableNamesAndReferences(res, name); - } - - /** - * Boolean OR operation: elementwise (x != 0) || (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable or(SDVariable x, SDVariable y) { - return or(null, x, y); - } - - /** - * Boolean OR operation: elementwise (x != 0) || (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable or(String name, SDVariable x, SDVariable y) { - validateBool("or", x, y); - SDVariable result = f().or(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise power function: out = x^value - * - * @param x Input variable - * @param value Power to raise each element to - * @return Output variable - */ - public SDVariable pow(SDVariable x, double value) { - return pow(null, x, value); - } - - /** - * Element-wise power function: out = x^value - * - * @param name Output variable name - * @param x Input variable - * @param value Power to raise each element to - * @return Output variable - */ - public SDVariable pow(String name, SDVariable x, double value) { - validateNumerical("pow", x); - SDVariable result = f().pow(x, value); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise (broadcastable) power function: out = x[i]^y[i] - * - * @param x Input variable - * @param y Power - * @return Output variable - */ - public SDVariable pow(SDVariable x, SDVariable y) { - return pow(null, x, y); - } - - /** - * Element-wise (broadcastable) power function: out = x[i]^y[i] - * - * @param name Output variable name - * @param x Input variable - * @param y Power - * @return Output variable - */ - public SDVariable pow(String name, SDVariable x, SDVariable y) { - validateNumerical("pow", x, y); - SDVariable result = f().pow(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i] - * - * @param a Input variable - * @return Output variable - */ - public SDVariable reciprocal(SDVariable a) { - return reciprocal(null, a); - } - - /** - * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i] - * - * @param name Name of the output variable - * @param a Input variable - * @return Output variable - */ - public SDVariable reciprocal(String name, SDVariable a) { - validateNumerical("reciprocal", a); - SDVariable ret = f().reciprocal(a); - return updateVariableNameAndReference(ret, name); - } - - /** - * Elementwise round function: out = round(x). - * Rounds (up or down depending on value) to the nearest integer value. - * - * @param x Input variable - * @return Output variable - */ - public SDVariable round(SDVariable x) { - return round(null, x); - } - - /** - * Element-wise round function: out = round(x). - * Rounds (up or down depending on value) to the nearest integer value. - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable round(String name, SDVariable x) { - validateFloatingPoint("round", x); - SDVariable result = f().round(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable rsqrt(SDVariable x) { - return rsqrt(null, x); - } - - /** - * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable rsqrt(String name, SDVariable x) { - validateNumerical("rsqrt", x); - SDVariable result = f().rsqrt(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #setDiag(String, SDVariable, SDVariable) - */ - public SDVariable setDiag(SDVariable in, SDVariable diag) { - return setDiag(null, in, diag); - } - - /** - * Set the diagonal value to the specified values
- * If input is
- * [ a, b, c]
- * [ d, e, f]
- * [ g, h, i]
- * and diag = [ 1, 2, 3] then output is
- * [ 1, b, c]
- * [ d, 2, f]
- * [ g, h, 3]
- * - * @param name Name of the output variable - * @param in Input variable - * @param diag Diagonal - * @return Output variable - */ - public SDVariable setDiag(String name, SDVariable in, SDVariable diag) { - SDVariable ret = f().setDiag(in, diag); - return updateVariableNameAndReference(ret, name); - } - - /** - * Shannon Entropy reduction: -sum(x * log2(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce on (null/empty for full array) - * @return Output variable - */ - public SDVariable shannonEntropy(SDVariable in, int... dimensions) { - return shannonEntropy(null, in, dimensions); - } - - /** - * Shannon Entropy reduction: -sum(x * log2(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce on (null/empty for full array) - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable shannonEntropy(String name, SDVariable in, int... dimensions) { - validateNumerical("shannon entropy reduction", in); - SDVariable ret = f().shannonEntropy(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise sign (signum) function:
- * out = -1 if in < 0
- * out = 0 if in = 0
- * out = 1 if in > 0 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable sign(SDVariable x) { - return sign(null, x); - } - - /** - * Element-wise sign (signum) function:
- * out = -1 if in < 0
- * out = 0 if in = 0
- * out = 1 if in > 0 - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable sign(String name, SDVariable x) { - validateNumerical("sign", x); - SDVariable result = f().sign(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise sine operation: out = sin(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable sin(SDVariable x) { - return sin(null, x); - } - - /** - * Elementwise sine operation: out = sin(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable sin(String name, SDVariable x) { - validateNumerical("sin", x); - SDVariable result = f().sin(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise sinh (hyperbolic sine) operation: out = sinh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable sinh(SDVariable x) { - return sinh(null, x); - } - - /** - * Elementwise sinh (hyperbolic sine) operation: out = sinh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable sinh(String name, SDVariable x) { - validateNumerical("sinh", x); - SDVariable result = f().sinh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise square root function: out = sqrt(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable sqrt(SDVariable x) { - return sqrt(null, x); - } - - /** - * Element-wise square root function: out = sqrt(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable sqrt(String name, SDVariable x) { - validateNumerical("sqrt", x); - SDVariable result = f().sqrt(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise square function: out = x^2 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable square(SDVariable x) { - return square(null, x); - } - - /** - * Element-wise square function: out = x^2 - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable square(String name, SDVariable x) { - validateNumerical("square", x); - SDVariable result = f().square(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise step function:
- * out(x) = 1 if x >= cutoff
- * out(x) = 0 otherwise
- * - * @param in Input variable - * @param cutoff Cutoff value for step function - * @return Output variable - */ - public SDVariable step(SDVariable in, double cutoff) { - return step(null, in, cutoff); - } - - /** - * Elementwise step function:
- * out(x) = 1 if x >= cutoff
- * out(x) = 0 otherwise
- * - * @param name Name of the output variable - * @param in Input variable - * @param cutoff Cutoff value for step function - * @return Output variable - */ - public SDVariable step(String name, SDVariable in, double cutoff) { - validateNumerical("step", in); - SDVariable ret = f().step(in, cutoff); - return updateVariableNameAndReference(ret, name); - } - - /** - * Standardize input variable along given axis - * - * @see #standardize(String, SDVariable, int...) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable standardize(SDVariable x, int... dimensions) { - return standardize(null, x, dimensions); - } - - /** - * Standardize input variable along given axis - *

- * out = (x - mean) / stdev - *

- * with mean and stdev being calculated along the given dimension. - * - *

- * For example: given x as a mini batch of the shape [numExamples, exampleLength]: - *

    - *
  • use dimension 1 too use the statistics (mean, stdev) for each example
  • - *
  • use dimension 0 if you want to use the statistics for each column across all examples
  • - *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples
  • - *
- * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable standardize(String name, SDVariable x, int... dimensions) { - validateNumerical("standardize", x); - SDVariable result = f().standardize(x, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise tangent operation: out = tan(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable tan(SDVariable x) { - return tan(null, x); - } - - /** - * Elementwise tangent operation: out = tan(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable tan(String name, SDVariable x) { - validateNumerical("tan", x); - SDVariable result = f().tan(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable tanh(SDVariable x) { - return tanh(null, x); - } - - /** - * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable tanh(String name, SDVariable x) { - validateNumerical("tanh", x); - SDVariable result = f().tanh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #trace(String, SDVariable) - */ - public SDVariable trace(SDVariable in) { - return trace(null, in); - } - - /** - * Matrix trace operation - * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
- * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:]) - * - * @param name Name of the output variable. May be null. - * @param in Input variable - * @return Trace - */ - public SDVariable trace(String name, SDVariable in) { - validateNumerical("trace", in); - SDVariable ret = f().trace(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable xor(SDVariable x, SDVariable y) { - return xor(null, x, y); - } - - /** - * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable xor(String name, SDVariable x, SDVariable y) { - validateBool("xor", x, y); - SDVariable result = f().xor(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Shift integer bits to the left, i.e. var << 4 - * - * @param name Name of the output variable - * @param x Input 1 - * @return Output SDVariable with shifted bits - */ - public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { - validateInteger("shift_bits", x); - SDVariable result = f().shift(x, shift); - return updateVariableNameAndReference(result, name); - } - - /** - * Shift integer bits to the right, i.e. var >> 4 - * - * @param name Name of the output variable - * @param x Input 1 - * @return Output SDVariable with shifted bits - */ - public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { - validateInteger("rshift_bits", x); - SDVariable result = f().rshift(x, shift); - return updateVariableNameAndReference(result, name); - } - - /** - * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4) - * - * @param name Name of the output variable - * @param x Input 1 - * @return Output SDVariable with shifted bits - */ - public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) { - validateInteger("cyclic_shift_bits", x); - SDVariable result = f().rotl(x, shift); - return updateVariableNameAndReference(result, name); - } - - /** - * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4) - * - * @param name Name of the output variable - * @param x Input 1 - * @return Output SDVariable with shifted bits - */ - public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) { - validateInteger("cyclic_rshift_bits", x); - SDVariable result = f().rotr(x, shift); - return updateVariableNameAndReference(result, name); - } - - /** - * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x)) - * - * @param input Input variable - * @return Reduced array of rank 0 (scalar) - */ - public SDVariable zeroFraction(SDVariable input) { - return zeroFraction(null, input); - } - - /** - * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x)) - * - * @param name Name of the output variable - * @param input Input variable - * @return Reduced array of rank 0 (scalar) - */ - public SDVariable zeroFraction(String name, SDVariable input) { - validateNumerical("zeroFraction", input); - SDVariable res = f().zeroFraction(input); - return updateVariableNameAndReference(res, name); - } - - /** - * Compute the regularized incomplete beta integral - * - * @param name Name of the output variable - * @param a input array - * @param b input array - * @param x input array - * @return array - */ - public SDVariable betainc(String name,SDVariable a,SDVariable b,SDVariable x) { - SDVariable res = f().betainc(a,b,x); - return updateVariableNameAndReference(res, name); - } - - /** - * Copy a tensor setting everything outside a central band in each innermost matrix. - * - * @param name Name of the output variable - * @param input Rank k array - * @param minLower Number of subdiagonals to keep. - * @param maxUpper Number of superdiagonals to keep. - * @return Rank k array of the same shape as input. - */ - public SDVariable matrixBandPart(String name, SDVariable input, SDVariable minLower, SDVariable maxUpper) { - SDVariable res = f().matrixBandPart(input,minLower,maxUpper); - return updateVariableNameAndReference(res, name); - } - - /** - * Polygamma function - * - * @param name Name of the output variable - * @param n array - * @param x array - * @return array - */ - public SDVariable polygamma(String name, SDVariable n, SDVariable x) { - SDVariable res = f().polygamma(n,x); - return updateVariableNameAndReference(res, name); - } - - /** - * Rolls the elements of input - * - * @param name Name of the output variable - * @param input array - * @param shift number of places to shift elements - * @return array - */ - public SDVariable roll(String name, SDVariable input, int shift) { - SDVariable res = f().roll(input,shift); - return updateVariableNameAndReference(res, name); - } + public SDMath(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * Elementwise absolute value operation: out = abs(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable abs(SDVariable x) { + SDValidation.validateNumerical("abs", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd,x).outputVariable(); + } + + /** + * Elementwise absolute value operation: out = abs(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable abs(String name, SDVariable x) { + SDValidation.validateNumerical("abs", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable acos(SDVariable x) { + SDValidation.validateNumerical("acos", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd,x).outputVariable(); + } + + /** + * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable acos(String name, SDVariable x) { + SDValidation.validateNumerical("acos", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable acosh(SDVariable x) { + SDValidation.validateNumerical("acosh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd,x).outputVariable(); + } + + /** + * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable acosh(String name, SDVariable x) { + SDValidation.validateNumerical("acosh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amax(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd,in, dimensions).outputVariable(); + } + + /** + * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amax(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amean(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amean", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd,in, dimensions).outputVariable(); + } + + /** + * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amean(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amean", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amin(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd,in, dimensions).outputVariable(); + } + + /** + * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amin(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Boolean AND operation: elementwise (x != 0) && (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable and(SDVariable x, SDVariable y) { + SDValidation.validateBool("and", "x", x); + SDValidation.validateBool("and", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd,x, y).outputVariable(); + } + + /** + * Boolean AND operation: elementwise (x != 0) && (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable and(String name, SDVariable x, SDVariable y) { + SDValidation.validateBool("and", "x", x); + SDValidation.validateBool("and", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable asin(SDVariable x) { + SDValidation.validateNumerical("asin", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd,x).outputVariable(); + } + + /** + * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable asin(String name, SDVariable x) { + SDValidation.validateNumerical("asin", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable asinh(SDVariable x) { + SDValidation.validateNumerical("asinh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd,x).outputVariable(); + } + + /** + * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable asinh(String name, SDVariable x) { + SDValidation.validateNumerical("asinh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable asum(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("asum", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd,in, dimensions).outputVariable(); + } + + /** + * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable asum(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("asum", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atan(SDVariable x) { + SDValidation.validateNumerical("atan", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd,x).outputVariable(); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atan(String name, SDVariable x) { + SDValidation.validateNumerical("atan", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
+ * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * + * @param y Input Y variable (NUMERIC type) + * @param x Input X variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atan2(SDVariable y, SDVariable x) { + SDValidation.validateNumerical("atan2", "y", y); + SDValidation.validateNumerical("atan2", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd,y, x).outputVariable(); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
+ * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * + * @param name name May be null. Name for the output variable + * @param y Input Y variable (NUMERIC type) + * @param x Input X variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atan2(String name, SDVariable y, SDVariable x) { + SDValidation.validateNumerical("atan2", "y", y); + SDValidation.validateNumerical("atan2", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd,y, x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atanh(SDVariable x) { + SDValidation.validateNumerical("atanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd,x).outputVariable(); + } + + /** + * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atanh(String name, SDVariable x) { + SDValidation.validateNumerical("atanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Bit shift operation
+ * + * @param x input (NUMERIC type) + * @param shift shift value (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShift(SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShift", "x", x); + SDValidation.validateNumerical("bitShift", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + } + + /** + * Bit shift operation
+ * + * @param name name May be null. Name for the output variable + * @param x input (NUMERIC type) + * @param shift shift value (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShift", "x", x); + SDValidation.validateNumerical("bitShift", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Right bit shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift shift argument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShiftRight(SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRight", "x", x); + SDValidation.validateNumerical("bitShiftRight", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + } + + /** + * Right bit shift operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param shift shift argument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRight", "x", x); + SDValidation.validateNumerical("bitShiftRight", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cyclic bit shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift shift argy=ument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShiftRotl(SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRotl", "x", x); + SDValidation.validateNumerical("bitShiftRotl", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + } + + /** + * Cyclic bit shift operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param shift shift argy=ument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShiftRotl(String name, SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRotl", "x", x); + SDValidation.validateNumerical("bitShiftRotl", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cyclic right shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift Shift argument (NUMERIC type) + * @return output Shifted output (NUMERIC type) + */ + public SDVariable bitShiftRotr(SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRotr", "x", x); + SDValidation.validateNumerical("bitShiftRotr", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + } + + /** + * Cyclic right shift operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param shift Shift argument (NUMERIC type) + * @return output Shifted output (NUMERIC type) + */ + public SDVariable bitShiftRotr(String name, SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRotr", "x", x); + SDValidation.validateNumerical("bitShiftRotr", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise ceiling function: out = ceil(x).
+ * Rounds each value up to the nearest integer value (if not already an integer)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable ceil(SDVariable x) { + SDValidation.validateNumerical("ceil", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd,x).outputVariable(); + } + + /** + * Element-wise ceiling function: out = ceil(x).
+ * Rounds each value up to the nearest integer value (if not already an integer)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable ceil(String name, SDVariable x) { + SDValidation.validateNumerical("ceil", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Clipping by L2 norm, optionally along dimension(s)
+ * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
+ * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
+ * to the corresponding l2Norm along the specified dimensions
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("clipByNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd,x, clipValue, dimensions).outputVariable(); + } + + /** + * Clipping by L2 norm, optionally along dimension(s)
+ * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
+ * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
+ * to the corresponding l2Norm along the specified dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByNorm(String name, SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("clipByNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd,x, clipValue, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise clipping function:
+ * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
+ * out[i] = clipValueMin if in[i] < clipValueMin
+ * out[i] = clipValueMax if in[i] > clipValueMax
+ * + * @param x Input variable (NUMERIC type) + * @param clipValueMin Minimum value for clipping + * @param clipValueMax Maximum value for clipping + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { + SDValidation.validateNumerical("clipByValue", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd,x, clipValueMin, clipValueMax).outputVariable(); + } + + /** + * Element-wise clipping function:
+ * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
+ * out[i] = clipValueMin if in[i] < clipValueMin
+ * out[i] = clipValueMax if in[i] > clipValueMax
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValueMin Minimum value for clipping + * @param clipValueMax Maximum value for clipping + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, + double clipValueMax) { + SDValidation.validateNumerical("clipByValue", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd,x, clipValueMin, clipValueMax).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
+ * [1, 0, 0]
+ * [0, 1, 1]
+ * [0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param dataType Data type + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, DataType dataType) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, dataType).outputVariable(); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
+ * [1, 0, 0]
+ * [0, 1, 1]
+ * [0, 0, 0]
+ * + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param dataType Data type + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, + DataType dataType) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
+ * [1, 0, 0, 0]
+ * [0, 1, 1, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param numClasses Number of classes + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, int numClasses) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, numClasses).outputVariable(); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
+ * [1, 0, 0, 0]
+ * [0, 1, 1, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param numClasses Number of classes + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, + int numClasses) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, numClasses).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
+ * [1, 0, 0]
+ * [0, 3, 2]
+ * [0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDValidation.validateNumerical("confusionMatrix", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights).outputVariable(); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
+ * [1, 0, 0]
+ * [0, 3, 2]
+ * [0, 0, 0]
+ * + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, + SDVariable weights) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDValidation.validateNumerical("confusionMatrix", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
+ * [1, 0, 0, 0]
+ * [0, 3, 2, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param numClasses + * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights, + int numClasses) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDValidation.validateNumerical("confusionMatrix", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights, numClasses).outputVariable(); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
+ * [1, 0, 0, 0]
+ * [0, 3, 2, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param numClasses + * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, + SDVariable weights, int numClasses) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDValidation.validateNumerical("confusionMatrix", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights, numClasses).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise cosine operation: out = cos(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cos(SDVariable x) { + SDValidation.validateNumerical("cos", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd,x).outputVariable(); + } + + /** + * Elementwise cosine operation: out = cos(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cos(String name, SDVariable x) { + SDValidation.validateNumerical("cos", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosh(SDVariable x) { + SDValidation.validateNumerical("cosh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd,x).outputVariable(); + } + + /** + * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosh(String name, SDVariable x) { + SDValidation.validateNumerical("cosh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cosine distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = 1.0 - cosineSimilarity(x,y)
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosineDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("cosineDistance", "x", x); + SDValidation.validateNumerical("cosineDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Cosine distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = 1.0 - cosineSimilarity(x,y)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosineDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("cosineDistance", "x", x); + SDValidation.validateNumerical("cosineDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
+ * along the specified dimensions:
+ * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosineSimilarity(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("cosineSimilarity", "x", x); + SDValidation.validateNumerical("cosineSimilarity", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd,x, y, dimensions).outputVariable(); + } + + /** + * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
+ * along the specified dimensions:
+ * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("cosineSimilarity", "x", x); + SDValidation.validateNumerical("cosineSimilarity", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable countNonZero(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("countNonZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd,in, dimensions).outputVariable(); + } + + /** + * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable countNonZero(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("countNonZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable countZero(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("countZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd,in, dimensions).outputVariable(); + } + + /** + * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable countZero(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("countZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
+ * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * + * @param a First input (NUMERIC type) + * @param b Second input (NUMERIC type) + * @return output Element-wise cross product (NUMERIC type) + */ + public SDVariable cross(SDVariable a, SDVariable b) { + SDValidation.validateNumerical("cross", "a", a); + SDValidation.validateNumerical("cross", "b", b); + return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + } + + /** + * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
+ * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * + * @param name name May be null. Name for the output variable + * @param a First input (NUMERIC type) + * @param b Second input (NUMERIC type) + * @return output Element-wise cross product (NUMERIC type) + */ + public SDVariable cross(String name, SDVariable a, SDVariable b) { + SDValidation.validateNumerical("cross", "a", a); + SDValidation.validateNumerical("cross", "b", b); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise cube function: out = x^3
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cube(SDVariable x) { + SDValidation.validateNumerical("cube", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd,x).outputVariable(); + } + + /** + * Element-wise cube function: out = x^3
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cube(String name, SDVariable x) { + SDValidation.validateNumerical("cube", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
+ * For example, if input = [1,2,3], then output is given by:
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ *
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
+ * i.e., for input rank R, output has rank 2R
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable diag(SDVariable x) { + SDValidation.validateNumerical("diag", "x", x); + return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,x).outputVariable(); + } + + /** + * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
+ * For example, if input = [1,2,3], then output is given by:
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ *
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
+ * i.e., for input rank R, output has rank 2R
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable diag(String name, SDVariable x) { + SDValidation.validateNumerical("diag", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Extract the diagonal part from the input array.
+ * If input is
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ * then output is [1, 2, 3].
+ * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * + * @param x Input variable (NUMERIC type) + * @return output Diagonal part of the input (NUMERIC type) + */ + public SDVariable diagPart(SDVariable x) { + SDValidation.validateNumerical("diagPart", "x", x); + return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,x).outputVariable(); + } + + /** + * Extract the diagonal part from the input array.
+ * If input is
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ * then output is [1, 2, 3].
+ * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Diagonal part of the input (NUMERIC type) + */ + public SDVariable diagPart(String name, SDVariable x) { + SDValidation.validateNumerical("diagPart", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Entropy reduction: -sum(x * log(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable entropy(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("entropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd,in, dimensions).outputVariable(); + } + + /** + * Entropy reduction: -sum(x * log(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable entropy(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("entropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise Gaussian error function - out = erf(in)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable erf(SDVariable x) { + SDValidation.validateNumerical("erf", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd,x).outputVariable(); + } + + /** + * Element-wise Gaussian error function - out = erf(in)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable erf(String name, SDVariable x) { + SDValidation.validateNumerical("erf", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable erfc(SDVariable x) { + SDValidation.validateNumerical("erfc", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd,x).outputVariable(); + } + + /** + * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable erfc(String name, SDVariable x) { + SDValidation.validateNumerical("erfc", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable euclideanDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("euclideanDistance", "x", x); + SDValidation.validateNumerical("euclideanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable euclideanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("euclideanDistance", "x", x); + SDValidation.validateNumerical("euclideanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise exponent function: out = exp(x) = 2.71828...^x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable exp(SDVariable x) { + SDValidation.validateNumerical("exp", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd,x).outputVariable(); + } + + /** + * Elementwise exponent function: out = exp(x) = 2.71828...^x
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable exp(String name, SDVariable x) { + SDValidation.validateNumerical("exp", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable expm1(SDVariable x) { + SDValidation.validateNumerical("expm1", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd,x).outputVariable(); + } + + /** + * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable expm1(String name, SDVariable x) { + SDValidation.validateNumerical("expm1", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate an identity matrix with the specified number of rows and columns.
+ * + * @param rows Number of rows + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(int rows) { + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + } + + /** + * Generate an identity matrix with the specified number of rows and columns.
+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(String name, int rows) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per eye(String, int, int, DataType) but with the default datatype, Eye.DEFAULT_DTYPE
+ * + * @param rows Number of rows + * @param cols Number of columns + * @return output (NUMERIC type) + */ + public SDVariable eye(int rows, int cols) { + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + } + + /** + * As per eye(String, int, int, DataType) but with the default datatype, Eye.DEFAULT_DTYPE
+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows + * @param cols Number of columns + * @return output (NUMERIC type) + */ + public SDVariable eye(String name, int rows, int cols) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate an identity matrix with the specified number of rows and columns
+ * Example:
+ *

+ * {@code INDArray eye = eye(3,2)
+ * eye:
+ * [ 1, 0]
+ * [ 0, 1]
+ * [ 0, 0]}
+ *

+ * + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(int rows, int cols, DataType dataType, int... dimensions) { + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols, dataType, dimensions).outputVariable(); + } + + /** + * Generate an identity matrix with the specified number of rows and columns
+ * Example:
+ *

+ * {@code INDArray eye = eye(3,2)
+ * eye:
+ * [ 1, 0]
+ * [ 0, 1]
+ * [ 0, 0]}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(String name, int rows, int cols, DataType dataType, int... dimensions) { + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols, dataType, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per eye(int, int) bit with the number of rows/columns specified as scalar INDArrays
+ * + * @param rows Number of rows (INT type) + * @param cols Number of columns (INT type) + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(SDVariable rows, SDVariable cols) { + SDValidation.validateInteger("eye", "rows", rows); + SDValidation.validateInteger("eye", "cols", cols); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + } + + /** + * As per eye(int, int) bit with the number of rows/columns specified as scalar INDArrays
+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows (INT type) + * @param cols Number of columns (INT type) + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(String name, SDVariable rows, SDVariable cols) { + SDValidation.validateInteger("eye", "rows", rows); + SDValidation.validateInteger("eye", "cols", cols); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per eye(String, int) but with the number of rows specified as a scalar INDArray
+ * + * @param rows Number of rows (INT type) + * @return output SDVaribable identity matrix (NUMERIC type) + */ + public SDVariable eye(SDVariable rows) { + SDValidation.validateInteger("eye", "rows", rows); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + } + + /** + * As per eye(String, int) but with the number of rows specified as a scalar INDArray
+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows (INT type) + * @return output SDVaribable identity matrix (NUMERIC type) + */ + public SDVariable eye(String name, SDVariable rows) { + SDValidation.validateInteger("eye", "rows", rows); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable firstIndex(SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, false, condition, dimensions).outputVariable(); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable firstIndex(String name, SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, false, condition, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable firstIndex(String name, SDVariable in, Condition condition, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise floor function: out = floor(x).
+ * Rounds each value down to the nearest integer value (if not already an integer)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floor(SDVariable x) { + SDValidation.validateNumerical("floor", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd,x).outputVariable(); + } + + /** + * Element-wise floor function: out = floor(x).
+ * Rounds each value down to the nearest integer value (if not already an integer)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floor(String name, SDVariable x) { + SDValidation.validateNumerical("floor", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Hamming distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = count( x[i] != y[i] )
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hammingDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("hammingDistance", "x", x); + SDValidation.validateNumerical("hammingDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Hamming distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = count( x[i] != y[i] )
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hammingDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("hammingDistance", "x", x); + SDValidation.validateNumerical("hammingDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamax(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable(); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamax(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable(); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamin(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable(); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamin(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable(); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is finite operation: elementwise isFinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isFinite(SDVariable x) { + SDValidation.validateNumerical("isFinite", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd,x).outputVariable(); + } + + /** + * Is finite operation: elementwise isFinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isFinite(String name, SDVariable x) { + SDValidation.validateNumerical("isFinite", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is infinite operation: elementwise isInfinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isInfinite(SDVariable x) { + SDValidation.validateNumerical("isInfinite", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd,x).outputVariable(); + } + + /** + * Is infinite operation: elementwise isInfinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isInfinite(String name, SDVariable x) { + SDValidation.validateNumerical("isInfinite", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is maximum operation: elementwise x == max(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isMax(SDVariable x) { + SDValidation.validateNumerical("isMax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd,x).outputVariable(); + } + + /** + * Is maximum operation: elementwise x == max(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isMax(String name, SDVariable x) { + SDValidation.validateNumerical("isMax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is Not a Number operation: elementwise isNaN(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isNaN(SDVariable x) { + SDValidation.validateNumerical("isNaN", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd,x).outputVariable(); + } + + /** + * Is Not a Number operation: elementwise isNaN(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isNaN(String name, SDVariable x) { + SDValidation.validateNumerical("isNaN", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is the array non decreasing?
+ * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) + */ + public SDVariable isNonDecreasing(SDVariable x) { + SDValidation.validateNumerical("isNonDecreasing", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd,x).outputVariable(); + } + + /** + * Is the array non decreasing?
+ * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) + */ + public SDVariable isNonDecreasing(String name, SDVariable x) { + SDValidation.validateNumerical("isNonDecreasing", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is the array strictly increasing?
+ * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + */ + public SDVariable isStrictlyIncreasing(SDVariable x) { + SDValidation.validateNumerical("isStrictlyIncreasing", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd,x).outputVariable(); + } + + /** + * Is the array strictly increasing?
+ * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + */ + public SDVariable isStrictlyIncreasing(String name, SDVariable x) { + SDValidation.validateNumerical("isStrictlyIncreasing", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
+ * tensor along the specified dimensions.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable jaccardDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("jaccardDistance", "x", x); + SDValidation.validateNumerical("jaccardDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
+ * tensor along the specified dimensions.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable jaccardDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("jaccardDistance", "x", x); + SDValidation.validateNumerical("jaccardDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable lastIndex(SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, false, condition, dimensions).outputVariable(); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable lastIndex(String name, SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, false, condition, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable lastIndex(String name, SDVariable in, Condition condition, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates difference between inputs X and Y.
+ * + * @param x Input variable X (NUMERIC type) + * @param y Input variable Y (NUMERIC type) + */ + public SDVariable[] listDiff(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("listDiff", "x", x); + SDValidation.validateNumerical("listDiff", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd,x, y).outputVariables(); + } + + /** + * Calculates difference between inputs X and Y.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input variable X (NUMERIC type) + * @param y Input variable Y (NUMERIC type) + */ + public SDVariable[] listDiff(String[] names, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("listDiff", "x", x); + SDValidation.validateNumerical("listDiff", "y", y); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd,x, y).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Element-wise logarithm function (base e - natural logarithm): out = log(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable log(SDVariable x) { + SDValidation.validateNumerical("log", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd,x).outputVariable(); + } + + /** + * Element-wise logarithm function (base e - natural logarithm): out = log(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable log(String name, SDVariable x) { + SDValidation.validateNumerical("log", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise logarithm function (with specified base): out = log_{base}(x)
+ * + * @param x Input variable (NUMERIC type) + * @param base Logarithm base + * @return output Output variable (NUMERIC type) + */ + public SDVariable log(SDVariable x, double base) { + SDValidation.validateNumerical("log", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd,x, base).outputVariable(); + } + + /** + * Element-wise logarithm function (with specified base): out = log_{base}(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param base Logarithm base + * @return output Output variable (NUMERIC type) + */ + public SDVariable log(String name, SDVariable x, double base) { + SDValidation.validateNumerical("log", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd,x, base).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise natural logarithm function: out = log_e (1 + x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable log1p(SDVariable x) { + SDValidation.validateNumerical("log1p", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd,x).outputVariable(); + } + + /** + * Elementwise natural logarithm function: out = log_e (1 + x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable log1p(String name, SDVariable x) { + SDValidation.validateNumerical("log1p", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Log entropy reduction: log(-sum(x * log(x)))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable logEntropy(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("logEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd,in, dimensions).outputVariable(); + } + + /** + * Log entropy reduction: log(-sum(x * log(x)))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable logEntropy(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("logEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Log-sum-exp reduction (optionally along dimension).
+ * Computes log(sum(exp(x))
+ * + * @param input Input variable (NUMERIC type) + * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable logSumExp(SDVariable input, int... dimensions) { + SDValidation.validateNumerical("logSumExp", "input", input); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd,input, dimensions).outputVariable(); + } + + /** + * Log-sum-exp reduction (optionally along dimension).
+ * Computes log(sum(exp(x))
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable logSumExp(String name, SDVariable input, int... dimensions) { + SDValidation.validateNumerical("logSumExp", "input", input); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd,input, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i]-y[i])
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable manhattanDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("manhattanDistance", "x", x); + SDValidation.validateNumerical("manhattanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i]-y[i])
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable manhattanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("manhattanDistance", "x", x); + SDValidation.validateNumerical("manhattanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
+ * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param in Input (NUMERIC type) + * @return output Matrix determinant variable (NUMERIC type) + */ + public SDVariable matrixDeterminant(SDVariable in) { + SDValidation.validateNumerical("matrixDeterminant", "in", in); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd,in).outputVariable(); + } + + /** + * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
+ * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param name name May be null. Name for the output variable + * @param in Input (NUMERIC type) + * @return output Matrix determinant variable (NUMERIC type) + */ + public SDVariable matrixDeterminant(String name, SDVariable in) { + SDValidation.validateNumerical("matrixDeterminant", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
+ * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param in Input (NUMERIC type) + * @return output Matrix inverse variable (NUMERIC type) + */ + public SDVariable matrixInverse(SDVariable in) { + SDValidation.validateNumerical("matrixInverse", "in", in); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd,in).outputVariable(); + } + + /** + * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
+ * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param name name May be null. Name for the output variable + * @param in Input (NUMERIC type) + * @return output Matrix inverse variable (NUMERIC type) + */ + public SDVariable matrixInverse(String name, SDVariable in) { + SDValidation.validateNumerical("matrixInverse", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
+ * out = sum_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeAdd(SDVariable[] inputs) { + SDValidation.validateNumerical("mergeAdd", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); + } + + /** + * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
+ * out = sum_i in[i]
+ * + * @param name name May be null. Name for the output variable + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeAdd(String name, SDVariable[] inputs) { + SDValidation.validateNumerical("mergeAdd", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
+ * out = mean_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeAvg(SDVariable[] inputs) { + SDValidation.validateNumerical("mergeAvg", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); + } + + /** + * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
+ * out = mean_i in[i]
+ * + * @param name name May be null. Name for the output variable + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeAvg(String name, SDVariable[] inputs) { + SDValidation.validateNumerical("mergeAvg", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
+ * out = max_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeMax(SDVariable[] inputs) { + SDValidation.validateNumerical("mergeMax", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); + } + + /** + * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
+ * out = max_i in[i]
+ * + * @param name name May be null. Name for the output variable + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeMax(String name, SDVariable[] inputs) { + SDValidation.validateNumerical("mergeMax", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Broadcasts parameters for evaluation on an N-D grid.
+ * + * @param inputs (NUMERIC type) + * @param cartesian + */ + public SDVariable[] meshgrid(SDVariable[] inputs, boolean cartesian) { + SDValidation.validateNumerical("meshgrid", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd,inputs, cartesian).outputVariables(); + } + + /** + * Broadcasts parameters for evaluation on an N-D grid.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param inputs (NUMERIC type) + * @param cartesian + */ + public SDVariable[] meshgrid(String[] names, SDVariable[] inputs, boolean cartesian) { + SDValidation.validateNumerical("meshgrid", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd,inputs, cartesian).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * + * @param input Input to calculate moments for (NUMERIC type) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + */ + public SDVariable[] moments(SDVariable input, int... axes) { + SDValidation.validateNumerical("moments", "input", input); + Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + return new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd,input, axes).outputVariables(); + } + + /** + * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input Input to calculate moments for (NUMERIC type) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + */ + public SDVariable[] moments(String[] names, SDVariable input, int... axes) { + SDValidation.validateNumerical("moments", "input", input); + Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd,input, axes).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Elementwise negative operation: out = -x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable neg(SDVariable x) { + SDValidation.validateNumerical("neg", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd,x).outputVariable(); + } + + /** + * Elementwise negative operation: out = -x
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable neg(String name, SDVariable x) { + SDValidation.validateNumerical("neg", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculate the mean and variance from the sufficient statistics
+ * + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + */ + public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, + double shift) { + SDValidation.validateNumerical("normalizeMoments", "counts", counts); + SDValidation.validateNumerical("normalizeMoments", "means", means); + SDValidation.validateNumerical("normalizeMoments", "variances", variances); + return new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd,counts, means, variances, shift).outputVariables(); + } + + /** + * Calculate the mean and variance from the sufficient statistics
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + */ + public SDVariable[] normalizeMoments(String[] names, SDVariable counts, SDVariable means, + SDVariable variances, double shift) { + SDValidation.validateNumerical("normalizeMoments", "counts", counts); + SDValidation.validateNumerical("normalizeMoments", "means", means); + SDValidation.validateNumerical("normalizeMoments", "variances", variances); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd,counts, means, variances, shift).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Boolean OR operation: elementwise (x != 0) || (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable or(SDVariable x, SDVariable y) { + SDValidation.validateBool("or", "x", x); + SDValidation.validateBool("or", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd,x, y).outputVariable(); + } + + /** + * Boolean OR operation: elementwise (x != 0) || (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable or(String name, SDVariable x, SDVariable y) { + SDValidation.validateBool("or", "x", x); + SDValidation.validateBool("or", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise power function: out = x^value
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable pow(SDVariable x, double value) { + SDValidation.validateNumerical("pow", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd,x, value).outputVariable(); + } + + /** + * Element-wise power function: out = x^value
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable pow(String name, SDVariable x, double value) { + SDValidation.validateNumerical("pow", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise (broadcastable) power function: out = x[i]^y[i]
+ * + * @param x Input variable (NUMERIC type) + * @param y Power (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable pow(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("pow", "x", x); + SDValidation.validateNumerical("pow", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd,x, y).outputVariable(); + } + + /** + * Element-wise (broadcastable) power function: out = x[i]^y[i]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Power (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable pow(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("pow", "x", x); + SDValidation.validateNumerical("pow", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reciprocal(SDVariable x) { + SDValidation.validateNumerical("reciprocal", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd,x).outputVariable(); + } + + /** + * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reciprocal(String name, SDVariable x) { + SDValidation.validateNumerical("reciprocal", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise round function: out = round(x).
+ * Rounds (up or down depending on value) to the nearest integer value.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable round(SDVariable x) { + SDValidation.validateNumerical("round", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd,x).outputVariable(); + } + + /** + * Element-wise round function: out = round(x).
+ * Rounds (up or down depending on value) to the nearest integer value.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable round(String name, SDVariable x) { + SDValidation.validateNumerical("round", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsqrt(SDVariable x) { + SDValidation.validateNumerical("rsqrt", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd,x).outputVariable(); + } + + /** + * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsqrt(String name, SDVariable x) { + SDValidation.validateNumerical("rsqrt", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Set the diagonal value to the specified values
+ * If input is
+ * [ a, b, c]
+ * [ d, e, f]
+ * [ g, h, i]
+ * and diag = [ 1, 2, 3] then output is
+ * [ 1, b, c]
+ * [ d, 2, f]
+ * [ g, h, 3]
+ * + * @param in Input variable (NUMERIC type) + * @param diag Diagonal (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable setDiag(SDVariable in, SDVariable diag) { + SDValidation.validateNumerical("setDiag", "in", in); + SDValidation.validateNumerical("setDiag", "diag", diag); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd,in, diag).outputVariable(); + } + + /** + * Set the diagonal value to the specified values
+ * If input is
+ * [ a, b, c]
+ * [ d, e, f]
+ * [ g, h, i]
+ * and diag = [ 1, 2, 3] then output is
+ * [ 1, b, c]
+ * [ d, 2, f]
+ * [ g, h, 3]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param diag Diagonal (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable setDiag(String name, SDVariable in, SDVariable diag) { + SDValidation.validateNumerical("setDiag", "in", in); + SDValidation.validateNumerical("setDiag", "diag", diag); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd,in, diag).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Shannon Entropy reduction: -sum(x * log2(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable shannonEntropy(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("shannonEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd,in, dimensions).outputVariable(); + } + + /** + * Shannon Entropy reduction: -sum(x * log2(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable shannonEntropy(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("shannonEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise sign (signum) function:
+ * out = -1 if in < 0
+ * out = 0 if in = 0
+ * out = 1 if in > 0
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sign(SDVariable x) { + SDValidation.validateNumerical("sign", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd,x).outputVariable(); + } + + /** + * Element-wise sign (signum) function:
+ * out = -1 if in < 0
+ * out = 0 if in = 0
+ * out = 1 if in > 0
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sign(String name, SDVariable x) { + SDValidation.validateNumerical("sign", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise sine operation: out = sin(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sin(SDVariable x) { + SDValidation.validateNumerical("sin", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd,x).outputVariable(); + } + + /** + * Elementwise sine operation: out = sin(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sin(String name, SDVariable x) { + SDValidation.validateNumerical("sin", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sinh(SDVariable x) { + SDValidation.validateNumerical("sinh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd,x).outputVariable(); + } + + /** + * Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sinh(String name, SDVariable x) { + SDValidation.validateNumerical("sinh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise square root function: out = sqrt(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sqrt(SDVariable x) { + SDValidation.validateNumerical("sqrt", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd,x).outputVariable(); + } + + /** + * Element-wise square root function: out = sqrt(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sqrt(String name, SDVariable x) { + SDValidation.validateNumerical("sqrt", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise square function: out = x^2
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable square(SDVariable x) { + SDValidation.validateNumerical("square", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd,x).outputVariable(); + } + + /** + * Element-wise square function: out = x^2
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable square(String name, SDVariable x) { + SDValidation.validateNumerical("square", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Standardize input variable along given axis
+ *


+ * out = (x - mean) / stdev
+ *


+ * with mean and stdev being calculated along the given dimension.
+ *


+ * For example: given x as a mini batch of the shape [numExamples, exampleLength]:
+ *


    + *
  • use dimension 1 too use the statistics (mean, stdev) for each example

  • + *
  • use dimension 0 if you want to use the statistics for each column across all examples

  • + *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples

  • + *

+ * + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable standardize(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("standardize", "x", x); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd,x, dimensions).outputVariable(); + } + + /** + * Standardize input variable along given axis
+ *


+ * out = (x - mean) / stdev
+ *


+ * with mean and stdev being calculated along the given dimension.
+ *


+ * For example: given x as a mini batch of the shape [numExamples, exampleLength]:
+ *


    + *
  • use dimension 1 too use the statistics (mean, stdev) for each example

  • + *
  • use dimension 0 if you want to use the statistics for each column across all examples

  • + *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples

  • + *

+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable standardize(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("standardize", "x", x); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise step function:
+ * out(x) = 1 if x >= cutoff
+ * out(x) = 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable step(SDVariable x, double value) { + SDValidation.validateNumerical("step", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.Step(sd,x, value).outputVariable(); + } + + /** + * Elementwise step function:
+ * out(x) = 1 if x >= cutoff
+ * out(x) = 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable step(String name, SDVariable x, double value) { + SDValidation.validateNumerical("step", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Step(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise tangent operation: out = tan(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tan(SDVariable x) { + SDValidation.validateNumerical("tan", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd,x).outputVariable(); + } + + /** + * Elementwise tangent operation: out = tan(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tan(String name, SDVariable x) { + SDValidation.validateNumerical("tan", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tanh(SDVariable x) { + SDValidation.validateNumerical("tanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tanh(String name, SDVariable x) { + SDValidation.validateNumerical("tanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix trace operation
+ * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
+ * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
+ * + * @param in Input variable (NUMERIC type) + * @return output Trace (NUMERIC type) + */ + public SDVariable trace(SDVariable in) { + SDValidation.validateNumerical("trace", "in", in); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd,in).outputVariable(); + } + + /** + * Matrix trace operation
+ * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
+ * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @return output Trace (NUMERIC type) + */ + public SDVariable trace(String name, SDVariable in) { + SDValidation.validateNumerical("trace", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable xor(SDVariable x, SDVariable y) { + SDValidation.validateBool("xor", "x", x); + SDValidation.validateBool("xor", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd,x, y).outputVariable(); + } + + /** + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable xor(String name, SDVariable x, SDVariable y) { + SDValidation.validateBool("xor", "x", x); + SDValidation.validateBool("xor", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
+ * + * @param input Input variable (NUMERIC type) + * @return output Reduced array of rank 0 (scalar) (NUMERIC type) + */ + public SDVariable zeroFraction(SDVariable input) { + SDValidation.validateNumerical("zeroFraction", "input", input); + return new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd,input).outputVariable(); + } + + /** + * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @return output Reduced array of rank 0 (scalar) (NUMERIC type) + */ + public SDVariable zeroFraction(String name, SDVariable input) { + SDValidation.validateNumerical("zeroFraction", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 63aab3f33..7b18c3614 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,1054 +14,1139 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; -import org.nd4j.linalg.api.ops.impl.transforms.Pad; -import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.base.Preconditions; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateFloatingPoint; - -/** - * SameDiff general neural network operations
- * Accessible via {@link SameDiff#math()}
- * See also {@link SDCNN} (accessible via {@link SameDiff#cnn()} for convolutional neural network ops.
- * See also {@link SDRNN} (accessible via {@link SameDiff#rnn()} for recurrent neural network ops.
- * - * @author Alex Black - */ public class SDNN extends SDOps { - public SDNN(SameDiff sameDiff) { - super(sameDiff); - } - - /** - * Batch norm operation. - * - * @see #batchNorm(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, double, int...) - */ - public SDVariable batchNorm(SDVariable input, SDVariable mean, - SDVariable variance, SDVariable gamma, - SDVariable beta, double epsilon, int... axis) { - return batchNorm(null, input, mean, variance, gamma, beta, true, true, epsilon, axis); - } - - /** - * Batch normalization with optional application of gamma/beta args. - * See {@link #batchNorm(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, double, int...)} - */ - public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, - SDVariable variance, SDVariable gamma, - SDVariable beta, boolean applyGamma, boolean applyBeta, double epsilon, int... axis) { - validateFloatingPoint("batchNorm", "input", input); - validateFloatingPoint("batchNorm", "mean", mean); - validateFloatingPoint("batchNorm", "variance", variance); - validateFloatingPoint("batchNorm", "gamma", gamma); - validateFloatingPoint("batchNorm", "beta", beta); - SDVariable res = f().batchNorm(input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon, axis); - return updateVariableNameAndReference(res, name); - } - - /** - * Neural network batch normalization operation.
- * For details, see https://arxiv.org/abs/1502.03167 - * - * @param name Name of the output variable - * @param input Input variable. - * @param mean Mean value. For 1d axis, this should match input.size(axis) - * @param variance Variance value. For 1d axis, this should match input.size(axis) - * @param gamma Gamma value. For 1d axis, this should match input.size(axis) - * @param beta Beta value. For 1d axis, this should match input.size(axis) - * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) - * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations.
- * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC
- * For 1d/RNN activations: 1 for NCW format, 2 for NWC - * @return Output variable for batch normalization - */ - public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, - SDVariable variance, SDVariable gamma, - SDVariable beta, double epsilon, int... axis) { - return batchNorm(name, input, mean, variance, gamma, beta, true, true, epsilon, axis); - } - - /** - * @see #biasAdd(String, SDVariable, SDVariable, boolean) - */ - public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { - return biasAdd(null, input, bias, nchw); - } - - /** - * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector - * - * @param name Name of the output variable - * @param input 4d input variable - * @param bias 1d bias - * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. - * Unused for 2d inputs - * @return Output variable - */ - public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) { - validateFloatingPoint("biasAdd", "input", input); - validateFloatingPoint("biasAdd", "bias", bias); - SDVariable ret = f().biasAdd(input, bias, nchw); - return updateVariableNameAndReference(ret, name); - } - - /** - * @param input Input - * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) - * @return - */ - public SDVariable dropout(SDVariable input, double inputRetainProbability) { - return dropout(null, input, inputRetainProbability); - } - - /** - * @param input Input - * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) - * @return - */ - public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) { - validateFloatingPoint("dropout", input); - SDVariable res = f().dropout(input, inputRetainProbability); - return updateVariableNameAndReference(res, name); - } - - /** - * Element-wise exponential linear unit (ELU) function:
- * out = x if x > 0
- * out = a * (exp(x) - 1) if x <= 0
- * with constant a = 1.0 - *

- * See: https://arxiv.org/abs/1511.07289 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable elu(SDVariable x) { - return elu(null, x); - } - - /** - * Element-wise exponential linear unit (ELU) function:
- * out = x if x > 0
- * out = a * (exp(x) - 1) if x <= 0
- * with constant a = 1.0 - *

- * See: https://arxiv.org/abs/1511.07289 - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable elu(String name, SDVariable x) { - validateFloatingPoint("elu", x); - SDVariable result = f().elu(x); - return updateVariableNameAndReference(result, name); - } - - /** - * GELU activation function - Gaussian Error Linear Units
- * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415 - * This method uses the sigmoid approximation - * - * @param x Input - * @return Output variable - GELU applied to the input - */ - public SDVariable gelu(SDVariable x) { - return gelu(null, x); - } - - /** - * GELU activation function - Gaussian Error Linear Units
- * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415 - * This method uses the sigmoid approximation - * - * @param name Name of the output variable. May be null. - * @param x Input - * @return Output variable - GELU applied to the input - */ - public SDVariable gelu(String name, SDVariable x) { - validateFloatingPoint("gelu", x); - SDVariable ret = f().gelu(x, false); //Defaults to si - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise hard sigmoid function:
- * out[i] = 0 if in[i] <= -2.5
- * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
- * out[i] = 1 if in[i] >= 2.5
- * - * @param in Input variable - * @return Output variable - */ - public SDVariable hardSigmoid(SDVariable in) { - return hardSigmoid(null, in); - } - - /** - * Element-wise hard sigmoid function:
- * out[i] = 0 if in[i] <= -2.5
- * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
- * out[i] = 1 if in[i] >= 2.5
- * - * @param name Name of the output variable - * @param in Input variable - * @return Output variable - */ - public SDVariable hardSigmoid(String name, SDVariable in) { - validateFloatingPoint("hard sigmoid", in); - SDVariable ret = f().hardSigmoid(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise hard tanh function:
- * out[i] = -1 if in[i] <= -1
- * out[1] = in[i] if -1 < in[i] < 1
- * out[i] = 1 if in[i] >= 1
- * - * @param in Input variable - * @return Output variable - */ - public SDVariable hardTanh(SDVariable in) { - return hardTanh(null, in); - } - - /** - * Element-wise hard tanh function:
- * out[i] = -1 if in[i] <= -1
- * out[1] = in[i] if -1 < in[i] < 1
- * out[i] = 1 if in[i] >= 1
- * - * @param name Output variable name - * @param in Input variable - * @return Output variable - */ - public SDVariable hardTanh(String name, SDVariable in) { - validateFloatingPoint("hard Tanh", in); - SDVariable result = f().hardTanh(in); - return updateVariableNameAndReference(result, name); - } - - /** - * Derivative (dOut/dIn) of the element-wise hard Tanh function - {@link #hardTanh(SDVariable)} - * - * @param x Input - * @return Output variable - */ - public SDVariable hardTanhDerivative(SDVariable x) { - return hardTanhDerivative(null, x); - } - - /** - * Derivative (dOut/dIn) of the element-wise hard Tanh function - {@link #hardTanh(SDVariable)} - * - * @param name Output variable name - * @param x Input - * @return Output variable - */ - public SDVariable hardTanhDerivative(String name, SDVariable x) { - validateFloatingPoint("hard Tanh derivative", x); - SDVariable result = f().hardTanhDerivative(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise leaky ReLU function:
- * out = x if x >= 0.0
- * out = alpha * x if x < cutoff
- * Alpha value is most commonly set to 0.01 - * - * @param x Input variable - * @param alpha Cutoff - usually 0.0 - * @return Output variable - */ - public SDVariable leakyRelu(SDVariable x, double alpha) { - return leakyRelu(null, x, alpha); - } - - /** - * Element-wise leaky ReLU function:
- * out = x if x >= 0.0
- * out = alpha * x if x < cutoff
- * Alpha value is most commonly set to 0.01 - * - * @param x Input variable - * @param alpha Cutoff - usually 0.0 - * @return Output variable - */ - public SDVariable leakyRelu(String name, SDVariable x, double alpha) { - validateFloatingPoint("leaky ReLU", x); - SDVariable result = f().leakyRelu(x, alpha); - return updateVariableNameAndReference(result, name); - } - - /** - * Leaky ReLU derivative: dOut/dIn given input.
- * See {@link #leakyRelu(String, SDVariable, double)} - * - * @param x Input variable - * @param alpha Alpha value - * @return Output variable - */ - public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) { - validateFloatingPoint("leaky ReLU derivative", x); - SDVariable result = f().leakyReluDerivative(x, alpha); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #linear(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) { - return linear(null, input, weights, bias); - } - - /** - * Linear layer operation: out = mmul(in,w) + bias
- * Note that bias array is optional - * - * @param name Name of the output variable - * @param input Input data - * @param weights Weights variable - * @param bias Optional bias variable (may be null) - * @return Output variable - */ - public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) { - validateFloatingPoint("linear", "input", input); - validateFloatingPoint("linear", "weights", weights); - validateFloatingPoint("linear", "bias", bias); - SDVariable res = f().xwPlusB(input, weights, bias); - return updateVariableNameAndReference(res, name); - } - - /** - * Element-wise sigmoid function: out[i] = log(sigmoid(in[i])) - * - * @param x Input Variable - * @return Output variable - */ - public SDVariable logSigmoid(SDVariable x) { - return logSigmoid(null, x); - } - - /** - * Element-wise sigmoid function: out[i] = log(sigmoid(in[i])) - * - * @param name Name of the output variable - * @param x Input Variable - * @return Output variable - */ - public SDVariable logSigmoid(String name, SDVariable x) { - validateFloatingPoint("log sigmoid", x); - SDVariable ret = f().logSigmoid(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Log softmax activation - * - * @param x Input variable - * @return Output variable - */ - public SDVariable logSoftmax(SDVariable x) { - return logSoftmax(null, x); - } - - /** - * Log softmax activation - * - * @param name Variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable logSoftmax(String name, SDVariable x) { - validateFloatingPoint("log softmax", x); - SDVariable ret = f().logSoftmax(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Log softmax activation - * - * @param x Input variable - * @return Output variable - */ - public SDVariable logSoftmax(SDVariable x, int dimension) { - return logSoftmax(null, x, dimension); - } - - /** - * Log softmax activation - * - * @param name Variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable logSoftmax(String name, SDVariable x, int dimension) { - validateFloatingPoint("log softmax", x); - SDVariable ret = f().logSoftmax(x, dimension); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise rectified linear function with specified cutoff:
- * out[i] = in[i] if in[i] >= cutoff - * out[i] = 0 otherwise - * - * @param x Input variable - * @param cutoff Cutoff value. Usually 0 - * @return Output variable - */ - public SDVariable relu(SDVariable x, double cutoff) { - return relu(null, x, cutoff); - } - - /** - * Element-wise rectified linear function with specified cutoff:
- * out[i] = in[i] if in[i] >= cutoff - * out[i] = 0 otherwise - * - * @param name Output variable name - * @param x Input variable - * @param cutoff Cutoff value. Usually 0 - * @return Output variable - */ - public SDVariable relu(String name, SDVariable x, double cutoff) { - validateFloatingPoint("ReLU", x); - SDVariable result = f().relu(x, cutoff); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise "rectified linear 6" function with specified cutoff:
- * out[i] = min(max(in, cutoff), 6) - * - * @param x Input variable - * @param cutoff Cutoff value. Usually 0 - * @return Output variable - */ - public SDVariable relu6(SDVariable x, double cutoff) { - return relu6(null, x, cutoff); - } - - /** - * Element-wise "rectified linear 6" function with specified cutoff:
- * out[i] = min(max(in, cutoff), 6) - * - * @param name Output variable name - * @param x Input variable - * @param cutoff Cutoff value. Usually 0 - * @return Output variable - */ - public SDVariable relu6(String name, SDVariable x, double cutoff) { - validateFloatingPoint("ReLU6", x); - SDVariable result = f().relu6(x, cutoff); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #reluLayer(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { - return reluLayer(null, input, weights, bias); - } - - /** - * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
- * Note that bias array is optional - * - * @param name Name of the output variable - * @param input Input data - * @param weights Weights variable - * @param bias Optional bias variable (may be null) - * @return Output variable - */ - public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) { - validateFloatingPoint("reluLayer", "input", input); - validateFloatingPoint("reluLayer", "weights", weights); - validateFloatingPoint("reluLayer", "bias", bias); - SDVariable res = f().reluLayer(input, weights, bias); - return updateVariableNameAndReference(res, name); - } - - /** - * See {@link #prelu(String, SDVariable, SDVariable, int...)}. - */ - public SDVariable prelu(@NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){ - return f().prelu(input, alpha, sharedAxes); - } - - /** - * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
- * out[i] = in[i] if in[i] >= 0
- * out[i] = in[i] * alpha[i] otherwise
- * - * sharedAxes allows you to share learnable parameters along axes. - * For example, if the input has shape [batchSize, channels, height, width] - * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an - * alpha with shape [channels]. - * - * @param name Name of the output variable - * @param input Input data - * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. - * @param sharedAxes Which axes to share cutoff parameters along. - * @return Output variable - */ - public SDVariable prelu(String name, @NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){ - SDVariable res = f().prelu(input, alpha, sharedAxes); - return updateVariableNameAndReference(res, name); - } - - /** - * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks - *
- * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
- * Uses default lcale and alpha values. - * - * @param x Input variable - * @return Output variable - */ - public SDVariable selu(SDVariable x) { - return selu(null, x); - } - - /** - * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks - *
- * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
- * Uses default lcale and alpha values. - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable selu(String name, SDVariable x) { - validateFloatingPoint("selu", x); - SDVariable ret = f().selu(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i])) - * - * @param x Input Variable - * @return Output variable - */ - public SDVariable sigmoid(SDVariable x) { - return sigmoid(null, x); - } - - /** - * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i])) - * - * @param name Output variable name - * @param x Input Variable - * @return Output variable - */ - public SDVariable sigmoid(String name, SDVariable x) { - validateFloatingPoint("sigmoid", x); - SDVariable result = f().sigmoid(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut - * - * @param x Input Variable - * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input - * @return Output variable - */ - public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) { - return sigmoidDerivative(null, x, wrt); - } - - /** - * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut - * - * @param name Output variable name - * @param x Input Variable - * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input - * @return Output variable - */ - public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) { - validateFloatingPoint("sigmoidDerivative", x); - SDVariable result = f().sigmoidDerivative(x, wrt); - return updateVariableNameAndReference(result, name); - } - - /** - * Softmax activation on dimension 1. - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softmax(SDVariable x) { - return softmax(null, x); - } - - /** - * Softmax activation on dimension 1. - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softmax(String name, SDVariable x) { - validateFloatingPoint("softmax", x); - SDVariable result = f().softmax(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Softmax activation - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softmax(SDVariable x, int dimension) { - return softmax(null, x, dimension); - } - - /** - * Softmax activation - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softmax(String name, SDVariable x, int dimension) { - validateFloatingPoint("softmax", x); - SDVariable result = f().softmax(x, dimension); - return updateVariableNameAndReference(result, name); - } - - /** - * @param x - * @return - */ - public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt) { - return softmaxDerivative(name, x, wrt, null); - } - - public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, Integer dimension) { - validateFloatingPoint("softmaxDerivative", x); - SDVariable result = f().softmaxDerivative(x, wrt, dimension); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise softplus function: out = log(exp(x) + 1) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softplus(SDVariable x) { - return softplus(null, x); - } - - /** - * Element-wise softplus function: out = log(exp(x) + 1) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable softplus(String name, SDVariable x) { - validateFloatingPoint("softplus", x); - SDVariable result = f().softplus(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise softsign function: out = x / (abs(x) + 1) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softsign(SDVariable x) { - return softsign(null, x); - } - - /** - * Element-wise softsign function: out = x / (abs(x) + 1) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable softsign(String name, SDVariable x) { - validateFloatingPoint("softsign", x); - SDVariable result = f().softsign(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise derivative (dOut/dIn) of the softsign function {@link #softsign(SDVariable)} - * - * @param x Input variable - * @return Output varible - */ - public SDVariable softsignDerivative(SDVariable x) { - return softsignDerivative(null, x); - } - - /** - * Element-wise derivative (dOut/dIn) of the softsign function {@link #softsign(SDVariable)} - * - * @param name Output variable name - * @param x Input variable - * @return Output varible - */ - public SDVariable softsignDerivative(String name, SDVariable x) { - validateFloatingPoint("softsignDerivative", x); - SDVariable result = f().softsignDerivative(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
- * See: https://arxiv.org/abs/1710.05941 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable swish(SDVariable x) { - return swish(null, x); - } - - /** - * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
- * See: https://arxiv.org/abs/1710.05941 - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable swish(String name, SDVariable x) { - validateFloatingPoint("swish", x); - SDVariable ret = f().swish(x); - return updateVariableNameAndReference(ret, name); - } - - public SDVariable tanh(String name, SDVariable x) { - return sd.math().tanh(name, x); - } - - public SDVariable tanh(SDVariable x) { - return sd.math().tanh(x); - } - - /** - * Apply Layer Normalization - * - * y = gain * standardize(x) + bias - * - * @return Output variable - */ - public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { - return layerNorm(null, input, gain, bias, channelsFirst, dimensions); - } - - /** - * Apply Layer Normalization - * - * y = gain * standardize(x) + bias - * - * @param name Name of the output variable - * @param input Input variable - * @param gain gain - * @param bias bias - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs - * @return Output variable - */ - public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { - validateFloatingPoint("layerNorm", "input", input); - validateFloatingPoint("layerNorm", "gain", gain); - validateFloatingPoint("layerNorm", "bias", bias); - SDVariable result = f().layerNorm(input, gain, bias, channelsFirst, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Apply Layer Normalization without bias - * - * y = gain * standardize(x) - * - * @return Output variable - */ - public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { - return layerNorm((String)null, input, gain, channelsFirst, dimensions); - } - - /** - * Apply Layer Normalization - * - * y = gain * standardize(x) - * - * @param name Name of the output variable - * @param input Input variable - * @param gain gain - * @return Output variable - */ - public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { - validateFloatingPoint("layerNorm", "input", input); - validateFloatingPoint("layerNorm", "gain", gain); - SDVariable result = f().layerNorm(input, gain, channelsFirst, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * See {@link #pad(SDVariable, SDVariable, double)} - */ - public SDVariable pad(SDVariable input, int[][] padding, double constant){ - return pad(input, sd.constant(Nd4j.createFromArray(padding)), constant); - } - - /** - * Perform padding on the given array, where padded values are the specified constant.
- * Example:
- * Input array:
- * [1, 2]
- * [3, 4]
- * Padding array:
- * [2, 0]
- * [1, 1]
- * Contant = 0
- * Result:
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
- * [0, 1, 2, 0]
- * [0, 3, 4, 0]
- *
- * - * - * @param input Input array to pad - * @param padding Padding array - * @param constant Constant to use for padded values - * @return Padded array - */ - public SDVariable pad(SDVariable input, SDVariable padding, double constant){ - return pad(null, input, padding, Pad.Mode.CONSTANT, constant); - } - - /** - * As per {@link #pad(SDVariable, SDVariable, double)} but also supports multiple {@link Pad.Mode} modes.
- * Example: - * Input array:
- * [1, 2]
- * [3, 4]
- * [5, 6]
- * Padding array:
- * [2, 0]
- * [1, 1]
- * Contant = 0
- * Result: CONSTANT mode
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
- * [0, 1, 2, 0]
- * [0, 3, 4, 0]
- * [0, 5, 6, 0]
- *
- * Result: SYMMETRIC mode
- * [3, 3, 4, 4]
- * [1, 1, 2, 2]
- * [1, 1, 2, 2]
- * [3, 3, 4, 4]
- * [5, 5, 6, 6]
- *
- * Result: REFLECT:
- * [6, 5, 6, 0]
- * [2, 3, 4, 3]
- * [2, 1, 2, 1]
- * [4, 3, 4, 3]
- * [6, 5, 6, 5]
- *
- * @param outputName - * @param input - * @param padding - * @param mode - * @param constant - * @return - */ - public SDVariable pad(String outputName, SDVariable input, SDVariable padding, Pad.Mode mode, double constant){ - SDVariable out = f().pad(input, padding, mode, constant); - return updateVariableNameAndReference(out, outputName); - } - - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * @see #dotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled){ - return dotProductAttention(null, queries, keys, values, mask, scaled); - } - - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * @see #dotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public SDVariable dotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled){ - final SDVariable result = f().dotProductAttention(queries, keys, values, mask, scaled); - return updateVariableNameAndReference(result, name); - } - - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * @see #dotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public List dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights){ - return dotProductAttention(null, queries, keys, values, mask, scaled, withWeights); - } - - - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * out = sum(similarity(k_i, q) * v_i) - * - * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q - * - * Optionally with normalization step: - * similarity(k, q) = softmax(k * q / sqrt(size(q)) - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1) - * - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to - * be 3D but can have queryCount = 1 - * - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for - * both. - * - * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The - * output rank will depend on the input rank. - * - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] - * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @param withWeights return attention weights as well, false -> only one output, true -> two outputs - * - * Output Arrays: - * @return [ Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], - * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount]] - */ - public List dotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights){ - List result = f().dotProductAttention(queries, keys, values, mask, scaled, withWeights); - if(withWeights){ - return Collections.singletonList(updateVariableNameAndReference(result.get(0), name)); - }else{ - return Arrays.asList( - updateVariableNameAndReference(result.get(0), name), - updateVariableNameAndReference(result.get(1), name+":weights") - ); - } - } - - /** - * This performs multi-headed dot product attention on the given timeseries input - * @see #multiHeadDotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled){ - return multiHeadDotProductAttention(null, queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled); - } - - /** - * This performs multi-headed dot product attention on the given timeseries input - * @see #multiHeadDotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public SDVariable multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled){ - final SDVariable result = f().multiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled); - return updateVariableNameAndReference(result, name); - } - - /** - * This performs multi-headed dot product attention on the given timeseries input - * @see #multiHeadDotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public List multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights){ - return multiHeadDotProductAttention(null, queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights); - } - - - /** - * This performs multi-headed dot product attention on the given timeseries input - * out = concat(head_1, head_2, ..., head_n) * Wo - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) - * - * Optionally with normalization when calculating the attention for each head. - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention") - * - * This makes use of dot_product_attention OP support for rank 4 inputs. - * @see #dotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - * - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] - * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] - * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] - * @param Wv: input value projection weights of shape [numHeads, projectedValues, featureValues] - * @param Wo: output projection weights of shape [numHeads * projectedValues, outSize] - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @param withWeights return attention weights as well, false -> only one output, true -> two outputs - * - * Output Arrays: - * @return [ Attention result arrays of shape [batchSize, outSize, queryCount] - * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] - */ - public List multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights){ - List result = f().multiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights); - if(withWeights){ - return Collections.singletonList(updateVariableNameAndReference(result.get(0), name)); - }else{ - return Arrays.asList( - updateVariableNameAndReference(result.get(0), name), - updateVariableNameAndReference(result.get(1), name+":weights") - ); - } - } - - /** - * Max pooling on the input and outputs both max values and indices - * - * @param name Name of the output variable - * @param x input array - * @return output array and argmax array - */ - public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable x, Pooling2DConfig pooling2DConfig) { - SDVariable[] res = f().maxPoolWithArgmaxs(x, pooling2DConfig); - return sd.updateVariableNamesAndReferences(res, names); - } - - /** - * Batch normalization - * - * @param name Name of the output variable - * @param x 4D array - * @param scale vector for scaling factor of normalized x - * @param offset vector to shift to the normalized x - * @param dataFormat integer scalar - data format - * @param isTraining boolean scalar - is training mode - * @return y: 4D array - * batch_mean: vector - * batch_var: vector - */ - public SDVariable[] fusedBatchNorm(String[] names, SDVariable x, SDVariable scale, SDVariable offset, - SDVariable dataFormat, SDVariable isTraining) { - SDVariable[] res = f().fusedBatchNorm(x,scale,offset,dataFormat,isTraining); - return sd.updateVariableNamesAndReferences(res, names); - } + public SDNN(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * Neural network batch normalization operation.
+ * For details, see https://arxiv.org/abs/1502.03167
+ * + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. + * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC + * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @return output variable for batch normalization (NUMERIC type) + */ + public SDVariable batchNorm(SDVariable input, SDVariable mean, SDVariable variance, + SDVariable gamma, SDVariable beta, double epsilon, int... axis) { + SDValidation.validateNumerical("batchNorm", "input", input); + SDValidation.validateNumerical("batchNorm", "mean", mean); + SDValidation.validateNumerical("batchNorm", "variance", variance); + SDValidation.validateNumerical("batchNorm", "gamma", gamma); + SDValidation.validateNumerical("batchNorm", "beta", beta); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd,input, mean, variance, gamma, beta, epsilon, axis).outputVariable(); + } + + /** + * Neural network batch normalization operation.
+ * For details, see https://arxiv.org/abs/1502.03167
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. + * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC + * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @return output variable for batch normalization (NUMERIC type) + */ + public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, SDVariable variance, + SDVariable gamma, SDVariable beta, double epsilon, int... axis) { + SDValidation.validateNumerical("batchNorm", "input", input); + SDValidation.validateNumerical("batchNorm", "mean", mean); + SDValidation.validateNumerical("batchNorm", "variance", variance); + SDValidation.validateNumerical("batchNorm", "gamma", gamma); + SDValidation.validateNumerical("batchNorm", "beta", beta); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd,input, mean, variance, gamma, beta, epsilon, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
+ * + * @param input 4d input variable (NUMERIC type) + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. + * Unused for 2d inputs + * @return output Output variable, after applying bias add operation (NUMERIC type) + */ + public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { + SDValidation.validateNumerical("biasAdd", "input", input); + SDValidation.validateNumerical("biasAdd", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd,input, bias, nchw).outputVariable(); + } + + /** + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
+ * + * @param name name May be null. Name for the output variable + * @param input 4d input variable (NUMERIC type) + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. + * Unused for 2d inputs + * @return output Output variable, after applying bias add operation (NUMERIC type) + */ + public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) { + SDValidation.validateNumerical("biasAdd", "input", input); + SDValidation.validateNumerical("biasAdd", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd,input, bias, nchw).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * This operation performs dot product attention on the given timeseries input with the given queries
+ * out = sum(similarity(k_i, q) * v_i)
+ *
+ * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ *
+ * Optionally with normalization step:
+ * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ *
+ * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
+ * be 3D but can have queryCount = 1
+ *
+ * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
+ * both.
+ *
+ * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
+ * output rank will depend on the input rank.
+ * + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] + * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], + * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, + SDVariable mask, boolean scaled) { + SDValidation.validateNumerical("dotProductAttention", "queries", queries); + SDValidation.validateNumerical("dotProductAttention", "keys", keys); + SDValidation.validateNumerical("dotProductAttention", "values", values); + SDValidation.validateNumerical("dotProductAttention", "mask", mask); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd,queries, keys, values, mask, scaled, false).outputVariable(); + } + + /** + * This operation performs dot product attention on the given timeseries input with the given queries
+ * out = sum(similarity(k_i, q) * v_i)
+ *
+ * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ *
+ * Optionally with normalization step:
+ * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ *
+ * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
+ * be 3D but can have queryCount = 1
+ *
+ * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
+ * both.
+ *
+ * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
+ * output rank will depend on the input rank.
+ * + * @param name name May be null. Name for the output variable + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] + * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], + * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public SDVariable dotProductAttention(String name, SDVariable queries, SDVariable keys, + SDVariable values, SDVariable mask, boolean scaled) { + SDValidation.validateNumerical("dotProductAttention", "queries", queries); + SDValidation.validateNumerical("dotProductAttention", "keys", keys); + SDValidation.validateNumerical("dotProductAttention", "values", values); + SDValidation.validateNumerical("dotProductAttention", "mask", mask); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd,queries, keys, values, mask, scaled, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Dropout operation
+ * + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @return output Output (NUMERIC type) + */ + public SDVariable dropout(SDVariable input, double inputRetainProbability) { + SDValidation.validateNumerical("dropout", "input", input); + return new org.nd4j.linalg.api.ops.random.impl.DropOut(sd,input, inputRetainProbability).outputVariable(); + } + + /** + * Dropout operation
+ * + * @param name name May be null. Name for the output variable + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @return output Output (NUMERIC type) + */ + public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) { + SDValidation.validateNumerical("dropout", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.DropOut(sd,input, inputRetainProbability).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise exponential linear unit (ELU) function:
+ * out = x if x > 0
+ * out = a * (exp(x) - 1) if x <= 0
+ * with constant a = 1.0
+ *


+ * See: https://arxiv.org/abs/1511.07289
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable elu(SDVariable x) { + SDValidation.validateNumerical("elu", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd,x).outputVariable(); + } + + /** + * Element-wise exponential linear unit (ELU) function:
+ * out = x if x > 0
+ * out = a * (exp(x) - 1) if x <= 0
+ * with constant a = 1.0
+ *


+ * See: https://arxiv.org/abs/1511.07289
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable elu(String name, SDVariable x) { + SDValidation.validateNumerical("elu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the sigmoid approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable gelu(SDVariable x) { + SDValidation.validateNumerical("gelu", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd,x).outputVariable(); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the sigmoid approximation
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable gelu(String name, SDVariable x) { + SDValidation.validateNumerical("gelu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise hard sigmoid function:
+ * out[i] = 0 if in[i] <= -2.5
+ * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
+ * out[i] = 1 if in[i] >= 2.5
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardSigmoid(SDVariable x) { + SDValidation.validateNumerical("hardSigmoid", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd,x).outputVariable(); + } + + /** + * Element-wise hard sigmoid function:
+ * out[i] = 0 if in[i] <= -2.5
+ * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
+ * out[i] = 1 if in[i] >= 2.5
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardSigmoid(String name, SDVariable x) { + SDValidation.validateNumerical("hardSigmoid", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise hard tanh function:
+ * out[i] = -1 if in[i] <= -1
+ * out[1] = in[i] if -1 < in[i] < 1
+ * out[i] = 1 if in[i] >= 1
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardTanh(SDVariable x) { + SDValidation.validateNumerical("hardTanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd,x).outputVariable(); + } + + /** + * Element-wise hard tanh function:
+ * out[i] = -1 if in[i] <= -1
+ * out[1] = in[i] if -1 < in[i] < 1
+ * out[i] = 1 if in[i] >= 1
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardTanh(String name, SDVariable x) { + SDValidation.validateNumerical("hardTanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardTanhDerivative(SDVariable x) { + SDValidation.validateNumerical("hardTanhDerivative", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd,x).outputVariable(); + } + + /** + * Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardTanhDerivative(String name, SDVariable x) { + SDValidation.validateNumerical("hardTanhDerivative", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, + boolean channelsFirst, int... dimensions) { + SDValidation.validateNumerical("layerNorm", "input", input); + SDValidation.validateNumerical("layerNorm", "gain", gain); + SDValidation.validateNumerical("layerNorm", "bias", bias); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, bias, channelsFirst, dimensions).outputVariable(); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, + boolean channelsFirst, int... dimensions) { + SDValidation.validateNumerical("layerNorm", "input", input); + SDValidation.validateNumerical("layerNorm", "gain", gain); + SDValidation.validateNumerical("layerNorm", "bias", bias); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, bias, channelsFirst, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, + int... dimensions) { + SDValidation.validateNumerical("layerNorm", "input", input); + SDValidation.validateNumerical("layerNorm", "gain", gain); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, null, channelsFirst, dimensions).outputVariable(); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, + int... dimensions) { + SDValidation.validateNumerical("layerNorm", "input", input); + SDValidation.validateNumerical("layerNorm", "gain", gain); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, null, channelsFirst, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise leaky ReLU function:
+ * out = x if x >= 0.0
+ * out = alpha * x if x < cutoff
+ * Alpha value is most commonly set to 0.01
+ * + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 + * @return output Output variable (NUMERIC type) + */ + public SDVariable leakyRelu(SDVariable x, double alpha) { + SDValidation.validateNumerical("leakyRelu", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd,x, alpha).outputVariable(); + } + + /** + * Element-wise leaky ReLU function:
+ * out = x if x >= 0.0
+ * out = alpha * x if x < cutoff
+ * Alpha value is most commonly set to 0.01
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 + * @return output Output variable (NUMERIC type) + */ + public SDVariable leakyRelu(String name, SDVariable x, double alpha) { + SDValidation.validateNumerical("leakyRelu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd,x, alpha).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Leaky ReLU derivative: dOut/dIn given input.
+ * + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 + * @return output Output variable (NUMERIC type) + */ + public SDVariable leakyReluDerivative(SDVariable x, double alpha) { + SDValidation.validateNumerical("leakyReluDerivative", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd,x, alpha).outputVariable(); + } + + /** + * Leaky ReLU derivative: dOut/dIn given input.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 + * @return output Output variable (NUMERIC type) + */ + public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) { + SDValidation.validateNumerical("leakyReluDerivative", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd,x, alpha).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Linear layer operation: out = mmul(in,w) + bias
+ * Note that bias array is optional
+ * + * @param input Input data (NUMERIC type) + * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) { + SDValidation.validateNumerical("linear", "input", input); + SDValidation.validateNumerical("linear", "weights", weights); + SDValidation.validateNumerical("linear", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd,input, weights, bias).outputVariable(); + } + + /** + * Linear layer operation: out = mmul(in,w) + bias
+ * Note that bias array is optional
+ * + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) + * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) { + SDValidation.validateNumerical("linear", "input", input); + SDValidation.validateNumerical("linear", "weights", weights); + SDValidation.validateNumerical("linear", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd,input, weights, bias).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable logSigmoid(SDVariable x) { + SDValidation.validateNumerical("logSigmoid", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd,x).outputVariable(); + } + + /** + * Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable logSigmoid(String name, SDVariable x) { + SDValidation.validateNumerical("logSigmoid", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Log softmax activation
+ * + * @param x (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable logSoftmax(SDVariable x) { + SDValidation.validateNumerical("logSoftmax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x).outputVariable(); + } + + /** + * Log softmax activation
+ * + * @param name name May be null. Name for the output variable + * @param x (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable logSoftmax(String name, SDVariable x) { + SDValidation.validateNumerical("logSoftmax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Log softmax activation
+ * + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply log softmax + * @return output Output - log(softmax(input)) (NUMERIC type) + */ + public SDVariable logSoftmax(SDVariable x, int dimension) { + SDValidation.validateNumerical("logSoftmax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x, dimension).outputVariable(); + } + + /** + * Log softmax activation
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply log softmax + * @return output Output - log(softmax(input)) (NUMERIC type) + */ + public SDVariable logSoftmax(String name, SDVariable x, int dimension) { + SDValidation.validateNumerical("logSoftmax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * This performs multi-headed dot product attention on the given timeseries input
+ * out = concat(head_1, head_2, ..., head_n) * Wo
+ * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
+ *
+ * Optionally with normalization when calculating the attention for each head.
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
+ *
+ * This makes use of dot_product_attention OP support for rank 4 inputs.
+ * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
+ * + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] + * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, + SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, + SDVariable mask, boolean scaled) { + SDValidation.validateNumerical("multiHeadDotProductAttention", "queries", queries); + SDValidation.validateNumerical("multiHeadDotProductAttention", "keys", keys); + SDValidation.validateNumerical("multiHeadDotProductAttention", "values", values); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", Wq); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", Wk); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); + SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd,queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); + } + + /** + * This performs multi-headed dot product attention on the given timeseries input
+ * out = concat(head_1, head_2, ..., head_n) * Wo
+ * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
+ *
+ * Optionally with normalization when calculating the attention for each head.
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
+ *
+ * This makes use of dot_product_attention OP support for rank 4 inputs.
+ * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
+ * + * @param name name May be null. Name for the output variable + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] + * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public SDVariable multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, + SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, + SDVariable mask, boolean scaled) { + SDValidation.validateNumerical("multiHeadDotProductAttention", "queries", queries); + SDValidation.validateNumerical("multiHeadDotProductAttention", "keys", keys); + SDValidation.validateNumerical("multiHeadDotProductAttention", "values", values); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", Wq); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", Wk); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); + SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd,queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(SDVariable input, SDVariable padding, double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + } + + /** + * Padding operation
+ * + * @param name name May be null. Name for the output variable + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(String name, SDVariable input, SDVariable padding, double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
+ * out[i] = in[i] if in[i] >= 0
+ * out[i] = in[i] * alpha[i] otherwise
+ *
+ * sharedAxes allows you to share learnable parameters along axes.
+ * For example, if the input has shape [batchSize, channels, height, width]
+ * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
+ * alpha with shape [channels].
+ * + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) + * @return output Output (NUMERIC type) + */ + public SDVariable prelu(SDVariable input, SDVariable alpha, int... sharedAxes) { + SDValidation.validateNumerical("prelu", "input", input); + SDValidation.validateNumerical("prelu", "alpha", alpha); + Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); + return new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd,input, alpha, sharedAxes).outputVariable(); + } + + /** + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
+ * out[i] = in[i] if in[i] >= 0
+ * out[i] = in[i] * alpha[i] otherwise
+ *
+ * sharedAxes allows you to share learnable parameters along axes.
+ * For example, if the input has shape [batchSize, channels, height, width]
+ * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
+ * alpha with shape [channels].
+ * + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) + * @return output Output (NUMERIC type) + */ + public SDVariable prelu(String name, SDVariable input, SDVariable alpha, int... sharedAxes) { + SDValidation.validateNumerical("prelu", "input", input); + SDValidation.validateNumerical("prelu", "alpha", alpha); + Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd,input, alpha, sharedAxes).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise rectified linear function with specified cutoff:
+ * out[i] = in[i] if in[i] >= cutoff
+ * out[i] = 0 otherwise
+ * + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 + * @return output Output (NUMERIC type) + */ + public SDVariable relu(SDVariable x, double cutoff) { + SDValidation.validateNumerical("relu", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd,x, cutoff).outputVariable(); + } + + /** + * Element-wise rectified linear function with specified cutoff:
+ * out[i] = in[i] if in[i] >= cutoff
+ * out[i] = 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 + * @return output Output (NUMERIC type) + */ + public SDVariable relu(String name, SDVariable x, double cutoff) { + SDValidation.validateNumerical("relu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd,x, cutoff).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise "rectified linear 6" function with specified cutoff:
+ * out[i] = min(max(in, cutoff), 6)
+ * + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation. Usually 0 + * @return output Output (NUMERIC type) + */ + public SDVariable relu6(SDVariable x, double cutoff) { + SDValidation.validateNumerical("relu6", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd,x, cutoff).outputVariable(); + } + + /** + * Element-wise "rectified linear 6" function with specified cutoff:
+ * out[i] = min(max(in, cutoff), 6)
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation. Usually 0 + * @return output Output (NUMERIC type) + */ + public SDVariable relu6(String name, SDVariable x, double cutoff) { + SDValidation.validateNumerical("relu6", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd,x, cutoff).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
+ * Note that bias array is optional
+ * + * @param input Input data (NUMERIC type) + * @param weights Weights variable (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { + SDValidation.validateNumerical("reluLayer", "input", input); + SDValidation.validateNumerical("reluLayer", "weights", weights); + SDValidation.validateNumerical("reluLayer", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd,input, weights, bias).outputVariable(); + } + + /** + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
+ * Note that bias array is optional
+ * + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) + * @param weights Weights variable (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) { + SDValidation.validateNumerical("reluLayer", "input", input); + SDValidation.validateNumerical("reluLayer", "weights", weights); + SDValidation.validateNumerical("reluLayer", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd,input, weights, bias).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
+ *
+ * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
+ * Uses default scale and alpha values.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable selu(SDVariable x) { + SDValidation.validateNumerical("selu", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd,x).outputVariable(); + } + + /** + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
+ *
+ * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
+ * Uses default scale and alpha values.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable selu(String name, SDVariable x) { + SDValidation.validateNumerical("selu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sigmoid(SDVariable x) { + SDValidation.validateNumerical("sigmoid", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd,x).outputVariable(); + } + + /** + * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sigmoid(String name, SDVariable x) { + SDValidation.validateNumerical("sigmoid", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
+ * + * @param x Input Variable (NUMERIC type) + * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) + * @return output Output (gradient at input of sigmoid) (NUMERIC type) + */ + public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) { + SDValidation.validateNumerical("sigmoidDerivative", "x", x); + SDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd,x, wrt).outputVariable(); + } + + /** + * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
+ * + * @param name name May be null. Name for the output variable + * @param x Input Variable (NUMERIC type) + * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) + * @return output Output (gradient at input of sigmoid) (NUMERIC type) + */ + public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) { + SDValidation.validateNumerical("sigmoidDerivative", "x", x); + SDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd,x, wrt).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply softmax + * @return output Output variable (NUMERIC type) + */ + public SDVariable softmax(SDVariable x, int dimension) { + SDValidation.validateNumerical("softmax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, dimension).outputVariable(); + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply softmax + * @return output Output variable (NUMERIC type) + */ + public SDVariable softmax(String name, SDVariable x, int dimension) { + SDValidation.validateNumerical("softmax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param x Input (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softmax(SDVariable x) { + SDValidation.validateNumerical("softmax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, -1).outputVariable(); + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softmax(String name, SDVariable x) { + SDValidation.validateNumerical("softmax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, -1).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Softmax derivative function
+ * + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param dimension Softmax dimension + * @return output (NUMERIC type) + */ + public SDVariable softmaxDerivative(SDVariable x, SDVariable wrt, int dimension) { + SDValidation.validateNumerical("softmaxDerivative", "x", x); + SDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd,x, wrt, dimension).outputVariable(); + } + + /** + * Softmax derivative function
+ * + * @param name name May be null. Name for the output variable + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param dimension Softmax dimension + * @return output (NUMERIC type) + */ + public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, int dimension) { + SDValidation.validateNumerical("softmaxDerivative", "x", x); + SDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd,x, wrt, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise softplus function: out = log(exp(x) + 1)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softplus(SDVariable x) { + SDValidation.validateNumerical("softplus", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd,x).outputVariable(); + } + + /** + * Element-wise softplus function: out = log(exp(x) + 1)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softplus(String name, SDVariable x) { + SDValidation.validateNumerical("softplus", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise softsign function: out = x / (abs(x) + 1)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softsign(SDVariable x) { + SDValidation.validateNumerical("softsign", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd,x).outputVariable(); + } + + /** + * Element-wise softsign function: out = x / (abs(x) + 1)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softsign(String name, SDVariable x) { + SDValidation.validateNumerical("softsign", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable softsignDerivative(SDVariable x) { + SDValidation.validateNumerical("softsignDerivative", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd,x).outputVariable(); + } + + /** + * Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable softsignDerivative(String name, SDVariable x) { + SDValidation.validateNumerical("softsignDerivative", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
+ * See: https://arxiv.org/abs/1710.05941
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable swish(SDVariable x) { + SDValidation.validateNumerical("swish", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd,x).outputVariable(); + } + + /** + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
+ * See: https://arxiv.org/abs/1710.05941
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable swish(String name, SDVariable x) { + SDValidation.validateNumerical("swish", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tanh(SDVariable x) { + SDValidation.validateNumerical("tanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tanh(String name, SDVariable x) { + SDValidation.validateNumerical("tanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java index e5cfef684..88792bddb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java @@ -27,17 +27,21 @@ import org.nd4j.autodiff.samediff.SameDiff; */ public abstract class SDOps { - protected final SameDiff sd; + protected final SameDiff sd; - public SDOps(SameDiff sameDiff) { - this.sd = sameDiff; - } + public SDOps() { + sd = null; + } - protected DifferentialFunctionFactory f() { - return sd.f(); - } + public SDOps(SameDiff sameDiff) { + this.sd = sameDiff; + } - protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { - return sd.updateVariableNameAndReference(varToUpdate, newVarName); - } + protected DifferentialFunctionFactory f() { + return sd.f(); + } + + protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { + return sd.updateVariableNameAndReference(varToUpdate, newVarName); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java index de0114b92..6b1831de7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,198 +14,232 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; +import java.lang.String; + import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.*; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*; - -import java.util.Arrays; -import java.util.List; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs; import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; -import org.nd4j.linalg.primitives.Pair; -/** - * SameDiff Recurrent Neural Network operations
- * Accessible via {@link SameDiff#rnn()}
- * See also {@link SDNN} (accessible via {@link SameDiff#nn()} for general neural network ops.
- * See also {@link SDCNN} (accessible via {@link SameDiff#cnn()} for convolutional neural network ops.
- * - * @author Alex Black - */ public class SDRNN extends SDOps { - public SDRNN(SameDiff sameDiff) { - super(sameDiff); - } + public SDRNN(SameDiff sameDiff) { + super(sameDiff); + } + /** + * The GRU cell. Does a single time step operation
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) + * @param GRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public SDVariable gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + SDValidation.validateNumerical("gru", "x", x); + SDValidation.validateNumerical("gru", "hLast", hLast); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariable(); + } - /** - * See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}. - */ - public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { - GRUCell c = new GRUCell(sd, x, hLast, weights); - return new GRUCellOutputs(c.outputVariables()); - } + /** + * The GRU cell. Does a single time step operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) + * @param GRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public GRUCellOutputs gru(String name, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + SDValidation.validateNumerical("gru", "x", x); + SDValidation.validateNumerical("gru", "hLast", hLast); + GRUCell c = new GRUCell(sd,x, hLast, GRUWeights); + return new GRUCellOutputs(c.outputVariables(name)); + } - /** - * The GRU cell. Does a single time step operation. - * - * @param baseName The base name for the gru cell - * @param x Input, with shape [batchSize, inSize] - * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] - * @param weights The cell's weights. - * @return The cell's outputs. - */ - public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { - GRUCell c = new GRUCell(sd, x, hLast, weights); - return new GRUCellOutputs(c.outputVariables(baseName)); - } + /** + * The LSTM cell. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The cell's outputs (NUMERIC type) + */ + public LSTMCellOutputs lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast, + LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmCell", "x", x); + SDValidation.validateNumerical("lstmCell", "cLast", cLast); + SDValidation.validateNumerical("lstmCell", "yLast", yLast); + LSTMBlockCell c = new LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); + return new LSTMCellOutputs(c.outputVariables()); + } - /** - * See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}. - */ - public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - LSTMWeights weights, LSTMConfiguration config){ - LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); - return new LSTMCellOutputs(c.outputVariables()); - } + /** + * The LSTM cell. Does a single time step operation.
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The cell's outputs (NUMERIC type) + */ + public LSTMCellOutputs lstmCell(String name, SDVariable x, SDVariable cLast, SDVariable yLast, + LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmCell", "x", x); + SDValidation.validateNumerical("lstmCell", "cLast", cLast); + SDValidation.validateNumerical("lstmCell", "yLast", yLast); + LSTMBlockCell c = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); + return new LSTMCellOutputs(c.outputVariables(name)); + } - /** - * The LSTM cell. Does a single time step operation. - * - * @param baseName The base name for the lstm cell - * @param x Input, with shape [batchSize, inSize] - * @param cLast Previous cell state, with shape [batchSize, numUnits] - * @param yLast Previous cell output, with shape [batchSize, numUnits] - * @param weights The cell's weights. - * @param config The cell's config. - * @return The cell's outputs. - */ - public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); - return new LSTMCellOutputs(c.outputVariables(baseName)); - } + /** + * The LSTM layer. Does multiple time steps.
+ * + * @param maxTSLength (NUMERIC type) + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public SDVariable lstmLayer(SDVariable maxTSLength, SDVariable x, SDVariable cLast, + SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + SDValidation.validateNumerical("lstmLayer", "x", x); + SDValidation.validateNumerical("lstmLayer", "cLast", cLast); + SDValidation.validateNumerical("lstmLayer", "yLast", yLast); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + } - /** - * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} - */ - public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength, - @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config); - return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat()); - } + /** + * The LSTM layer. Does multiple time steps.
+ * + * @param name name May be null. Name for the output variable + * @param maxTSLength (NUMERIC type) + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public SDVariable lstmLayer(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast, + SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + SDValidation.validateNumerical("lstmLayer", "x", x); + SDValidation.validateNumerical("lstmLayer", "cLast", cLast); + SDValidation.validateNumerical("lstmLayer", "yLast", yLast); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} - */ - public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - return lstmLayer( - sd.scalar("lstm_max_ts_length", maxTSLength), - x, cLast, yLast, weights, config); - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public SDVariable sru(SDVariable x, SDVariable initialC, SDVariable mask, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sru", "x", x); + SDValidation.validateNumerical("sru", "initialC", initialC); + SDValidation.validateNumerical("sru", "mask", mask); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, mask, SRUWeights).outputVariable(); + } - /** - * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} - */ - public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - if(baseName != null) { - return lstmLayer(baseName, - sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength), - x, cLast, yLast, weights, config); - } else { - return lstmLayer(maxTSLength, x, cLast, yLast, weights, config); - } - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public SDVariable sru(String name, SDVariable x, SDVariable initialC, SDVariable mask, + SRUWeights SRUWeights) { + SDValidation.validateNumerical("sru", "x", x); + SDValidation.validateNumerical("sru", "initialC", initialC); + SDValidation.validateNumerical("sru", "mask", mask); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, mask, SRUWeights).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * The LSTM layer. Does multiple time steps. - * - * Input shape depends on data format (in config):
- * TNS -> [timeSteps, batchSize, inSize]
- * NST -> [batchSize, inSize, timeSteps]
- * NTS -> [batchSize, timeSteps, inSize]
- * - * @param baseName The base name for the lstm layer - * @param x Input, with shape dependent on the data format (in config). - * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] - * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] - * @param weights The layer's weights. - * @param config The layer's config. - * @return The layer's outputs. - */ - public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength, - @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config); - return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat()); - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public SDVariable sru(SDVariable x, SDVariable initialC, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sru", "x", x); + SDValidation.validateNumerical("sru", "initialC", initialC); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, null, SRUWeights).outputVariable(); + } - /** - * See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}. - */ - public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { - return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables()); - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public SDVariable sru(String name, SDVariable x, SDVariable initialC, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sru", "x", x); + SDValidation.validateNumerical("sru", "initialC", initialC); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, null, SRUWeights).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * The SRU cell. Does a single time step operation. - * - * @param baseName The base name for the sru cell - * @param x Input, with shape [batchSize, inSize] - * @param cLast Previous cell state, with shape [batchSize, inSize] - * @param weights The cell's weights. - * @return The cell's outputs. - */ - public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { - return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName)); - } - - /** - * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} - */ - public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) { - return sru(x, initialC, null, weights); - } - - /** - * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} - */ - public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) { - return sru(baseName, x, initialC, null, weights); - } - - /** - * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} - */ - public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { - return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables()); - } - - /** - * The SRU layer. Does a single time step operation. - * - * @param baseName The base name for the sru layer - * @param x Input, with shape [batchSize, inSize, timeSeriesLength] - * @param initialC Initial cell state, with shape [batchSize, inSize] - * @param mask An optional dropout mask, with shape [batchSize, inSize] - * @param weights The layer's weights. - * @return The layer's outputs. - */ - public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { - return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName)); - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public SDVariable sruCell(SDVariable x, SDVariable cLast, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sruCell", "x", x); + SDValidation.validateNumerical("sruCell", "cLast", cLast); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(sd,x, cLast, SRUWeights).outputVariable(); + } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public SDVariable sruCell(String name, SDVariable x, SDVariable cLast, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sruCell", "x", x); + SDValidation.validateNumerical("sruCell", "cLast", cLast); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(sd,x, cLast, SRUWeights).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java index cabf41103..cd986d7bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,324 +14,253 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; - -/** - * SameDiff random number generator operations
- * Accessible via {@link SameDiff#random()} - * - * @author Alex Black - */ public class SDRandom extends SDOps { + public SDRandom(SameDiff sameDiff) { + super(sameDiff); + } - public SDRandom(SameDiff sd) { - super(sd); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,
+ * with the specified probability. Array values will have value 1 with probability P and value 0 with probability
+ * 1-P.
+ * + * @param p Probability of value 1 + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable bernoulli(double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(sd,p, datatype, shape).outputVariable(); + } - /** - * @see #bernoulli(String, double, SDVariable) - */ - public SDVariable bernoulli(double p, SDVariable shape) { - return bernoulli(null, p, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,
+ * with the specified probability. Array values will have value 1 with probability P and value 0 with probability
+ * 1-P.
+ * + * @param name name May be null. Name for the output variable + * @param p Probability of value 1 + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable bernoulli(String name, double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(sd,p, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution, - * with the specified probability. Array values will have value 1 with probability P and value 0 with probability - * 1-P.
- * See {@link #bernoulli(String, double, long...)} for the equivalent function where the shape is - * specified as a long[] instead - * - * @param name Name of the new SDVariable - * @param p Probability of value 1 - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable bernoulli(String name, double p, SDVariable shape) { - validateInteger("bernoulli random", shape); - SDVariable ret = f().randomBernoulli(p, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,
+ * with the specified number of trials and probability.
+ * + * @param nTrials Number of trials parameter for the binomial distribution + * @param p Probability of success for each trial + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable binomial(int nTrials, double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(sd,nTrials, p, datatype, shape).outputVariable(); + } - /** - * @see #bernoulli(String, double, long...) - */ - public SDVariable bernoulli(double p, long... shape) { - return bernoulli(null, p, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,
+ * with the specified number of trials and probability.
+ * + * @param name name May be null. Name for the output variable + * @param nTrials Number of trials parameter for the binomial distribution + * @param p Probability of success for each trial + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable binomial(String name, int nTrials, double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(sd,nTrials, p, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution, - * with the specified probability. Array values will have value 1 with probability P and value 0 with probability - * 1-P.
- * See {@link #bernoulli(String, double, SDVariable)} for the equivalent function where the shape is - * specified as a SDVarible instead - * - * @param name Name of the new SDVariable - * @param p Probability of value 1 - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable bernoulli(String name, double p, long... shape) { - SDVariable ret = f().randomBernoulli(p, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:
+ * P(x) = lambda * exp(-lambda * x)
+ * + * Inputs must satisfy the following constraints:
+ * Must be positive: lambda > 0
+ * + * @param lambda lambda parameter + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable exponential(double lambda, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + Preconditions.checkArgument(lambda > 0, "Must be positive"); + return new org.nd4j.linalg.api.ops.random.custom.RandomExponential(sd,lambda, datatype, shape).outputVariable(); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, - * with the specified number of trials and probability. - * - * @param nTrials Number of trials parameter for the binomial distribution - * @param p Probability of success for each trial - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable binomial(int nTrials, double p, long... shape) { - return binomial(null, nTrials, p, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:
+ * P(x) = lambda * exp(-lambda * x)
+ * + * Inputs must satisfy the following constraints:
+ * Must be positive: lambda > 0
+ * + * @param name name May be null. Name for the output variable + * @param lambda lambda parameter + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable exponential(String name, double lambda, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + Preconditions.checkArgument(lambda > 0, "Must be positive"); + SDVariable out = new org.nd4j.linalg.api.ops.random.custom.RandomExponential(sd,lambda, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, - * with the specified number of trials and probability. - * - * @param name Name of the new SDVariable - * @param nTrials Number of trials parameter for the binomial distribution - * @param p Probability of success for each trial - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable binomial(String name, int nTrials, double p, long... shape) { - SDVariable ret = f().randomBinomial(nTrials, p, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,
+ * i.e., {@code log(x) ~ N(mean, stdev)}
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable logNormal(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution: - * P(x) = lambda * exp(-lambda * x) - * - * @param lambda Must be > 0 - * @param shape Shape of the output - * @return new SDVariable - */ - public SDVariable exponential(double lambda, SDVariable shape) { - return exponential(null, lambda, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,
+ * i.e., {@code log(x) ~ N(mean, stdev)}
+ * + * @param name name May be null. Name for the output variable + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable logNormal(String name, double mean, double stddev, DataType datatype, + long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution: - * P(x) = lambda * exp(-lambda * x) - * - * @param name Name of the output variable - * @param lambda Must be > 0 - * @param shape Shape of the new variable - * @return new SDVaribale - */ - public SDVariable exponential(String name, double lambda, SDVariable shape) { - validateInteger("exponential random", shape); - SDVariable ret = f().randomExponential(lambda, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev)
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable normal(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + } - /** - * @see #logNormal(String, double, double, long...) - */ - public SDVariable logNormal(double mean, double stddev, long... shape) { - return logNormal(null, mean, stddev, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev)
+ * + * @param name name May be null. Name for the output variable + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable normal(String name, double mean, double stddev, DataType datatype, + long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Log Normal distribution, - * i.e., {@code log(x) ~ N(mean, stdev)}
- * - * @param name Name of the new SDVariable - * @param mean Mean value for the random array - * @param stddev Standard deviation for the random array - * @param shape Shape of the new random SDVariable - * @return New SDVariable - */ - public SDVariable logNormal(String name, double mean, double stddev, long... shape) { - SDVariable ret = f().randomLogNormal(mean, stddev, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable normalTruncated(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + } - /** - * @see #normal(String, double, double, SDVariable) - */ - public SDVariable normal(double mean, double stddev, SDVariable shape) { - return normal(null, mean, stddev, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
+ * + * @param name name May be null. Name for the output variable + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable normalTruncated(String name, double mean, double stddev, DataType datatype, + long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, - * N(mean, stdev)
- * See {@link #normal(String, double, double, long...)} for the equivalent function where the shape is - * specified as a long[] instead - * - * @param name Name of the new SDVariable - * @param mean Mean value for the random array - * @param stddev Standard deviation for the random array - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable normal(String name, double mean, double stddev, SDVariable shape) { - validateInteger("normal (Gaussian) random", shape); - SDVariable ret = f().randomNormal(mean, stddev, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #normal(String, double, double, long...) - */ - public SDVariable normal(double mean, double stddev, long... shape) { - return normal(null, mean, stddev, shape); - } - - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, - * N(mean, stdev)
- * See {@link #normal(String, double, double, SDVariable)} for the equivalent function where the shape is - * specified as a long[] instead - * - * @param name Name of the new SDVariable - * @param mean Mean value for the random array - * @param stddev Standard deviation for the random array - * @param shape Shape of the new random SDVariable - * @return New SDVariable - */ - public SDVariable normal(String name, double mean, double stddev, long... shape) { - SDVariable ret = f().randomNormal(mean, stddev, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #normalTruncated(String, double, double, long...) - */ - public SDVariable normalTruncated(double mean, double stddev, long... shape) { - return normalTruncated(null, mean, stddev, shape); - } - - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, - * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
- * - * @param name Name of the new SDVariable - * @param mean Mean value for the random array - * @param stddev Standard deviation for the random array - * @param shape Shape of the new random SDVariable - * @return New SDVariable - */ - public SDVariable normalTruncated(String name, double mean, double stddev, long... shape) { - SDVariable ret = f().randomNormalTruncated(mean, stddev, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #uniform(String, double, double, SDVariable) - */ - public SDVariable uniform(double min, double max, SDVariable shape) { - return uniform(null, min, max, shape); - } - - /** - * @see #uniform(String, double, double, SDVariable) - */ - public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) { - return uniform(null, min, max, shape, dataType); - } - - /** - * As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output - */ - public SDVariable uniform(String name, double min, double max, SDVariable shape) { - return uniform(name, min, max, shape, null); - } - - /** - * Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution, - * U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned
- * See {@link #uniform(double, double, long...)} for the equivalent function where the shape is - * specified as a long[] instead - * - * @param name Name of the new SDVariable - * @param min Minimum value - * @param max Maximum value. Must satisfy max >= min - * @param shape Shape of the new random SDVariable, as a 1D array - * @param dataType Data type of the output array (if null: Float32 output is returned) - * @return New SDVariable, of the specified data type - */ - public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) { - validateInteger("uniform random", shape); - SDVariable ret = f().randomUniform(min, max, shape, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #uniform(String, double, double, long...) - */ - public SDVariable uniform(double min, double max, long... shape) { - return uniform(null, min, max, shape); - } - - /** - * Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution, - * U(min,max)
- * See {@link #uniform(double, double, long...)} for the equivalent function where the shape is - * specified as a SDVariable instead - * - * @param name Name of the new SDVariable - * @param min Minimum value - * @param max Maximum value. Must satisfy max >= min - * @param shape Shape of the new random SDVariable - * @return New SDVariable - */ - public SDVariable uniform(String name, double min, double max, long... shape) { - SDVariable ret = f().randomUniform(min, max, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * Generate a new random SDVariable with Gamma distribution - * - * @param name Name of the output variable - * @param alpha distribution parameter - * @param beta distribution parameter - * @param shape Shape of the new variable - * @return new SDVariable - */ - public SDVariable gamma(String name, SDVariable shape, SDVariable alpha, SDVariable beta) { - SDVariable ret = f().randomGamma(alpha, beta, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * Generate a new random SDVariable with Poission distribution - * - * @param name Name of the output variable - * @param lambda rate distribution parameter - * @param shape Shape of the new variable - * @return new SDVariable - */ - public SDVariable poisson(String name, SDVariable lambda, SDVariable shape, int... seeds) { - SDVariable ret = f().randomPoisson(shape, lambda, seeds); - return updateVariableNameAndReference(ret, name); - } - - /** - * Generate a new random SDVariable by random shuffle - * - * @param name Name of the output variable - * @param value array to shuffle - * @return new SDVariable - */ - public SDVariable shuffle(String name, SDVariable value, int... seeds) { - SDVariable ret = f().randomShuffle(value, seeds); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,
+ * U(min,max)
+ * + * @param min Minimum value + * @param max Maximum value. + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable uniform(double min, double max, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(sd,min, max, datatype, shape).outputVariable(); + } + /** + * Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,
+ * U(min,max)
+ * + * @param name name May be null. Name for the output variable + * @param min Minimum value + * @param max Maximum value. + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable uniform(String name, double min, double max, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(sd,min, max, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java index f6434a56f..93999f0fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java @@ -55,6 +55,15 @@ public class SDValidation { v.name() + "\" with non-integer data type " + v.dataType()); } + protected static void validateNumerical(String opName, String inputName, SDVariable[] vars) { + for (SDVariable v : vars) { + if (v == null) continue; + if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" + + v.name() + "\" with non-integer data type " + v.dataType()); + } + } + /** * Validate that the operation is being applied on numerical SDVariables (not boolean or utf8). * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays @@ -97,6 +106,16 @@ public class SDValidation { v.name() + "\" with non-integer data type " + v.dataType()); } + protected static void validateInteger(String opName, String inputName, SDVariable[] vars) { + for (SDVariable v : vars) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" + + v.name() + "\" with non-integer data type " + v.dataType()); + } + } + /** * Validate that the operation is being applied on an floating point type SDVariable * @@ -200,4 +219,18 @@ public class SDValidation { } } + public static boolean isSameType(SDVariable x, SDVariable y) { + return x.dataType() == y.dataType(); + } + + public static boolean isSameType(SDVariable[] x) { + DataType firstDataType = x[0].dataType(); + if (x.length > 1) { + for (int i = 1; i < x.length; ++i) { + if (firstDataType != x[i].dataType()) + return false; + } + } + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/enums/DataFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/DataFormat.java similarity index 95% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/enums/DataFormat.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/DataFormat.java index fb3fc9c67..c42795070 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/enums/DataFormat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/DataFormat.java @@ -16,7 +16,7 @@ //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== -package org.nd4j.linalg.factory.enums; +package org.nd4j.enums; /** * Data format: "NCHW" or "NHWC" */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index ebe27bd85..62edb778f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -633,7 +633,9 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.Lu.class, org.nd4j.linalg.api.ops.custom.TriangularSolve.class, org.nd4j.linalg.api.ops.custom.LinearSolve.class, - org.nd4j.linalg.api.ops.custom.Lstsq.class + org.nd4j.linalg.api.ops.custom.Lstsq.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.Qr.class, + org.nd4j.linalg.api.ops.custom.Logdet.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index 2a72fc76e..8b598242c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -85,6 +85,12 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum this(x, null, dimensions); } + public BaseIndexAccumulation(INDArray x, boolean keepDims, int[] dimensions) { + this(x, null, dimensions); + this.keepDims = keepDims; + defineDimensions(dimensions); + } + public BaseIndexAccumulation(INDArray x, INDArray z, int[] dimensions) { super(x, z); defineDimensions(dimensions); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index 55d551369..f842303ca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -29,12 +29,17 @@ public class AdjustContrast extends BaseAdjustContrast { super(in, factor, out); } + public AdjustContrast(@NonNull INDArray in, double factor) { + this(in, factor, null); + } + public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { super(sameDiff,new SDVariable[]{in,factor}); } - public AdjustContrast(@NonNull INDArray in, double factor) { - this(in, factor, null); + public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) { + super(sameDiff,new SDVariable[]{in}); + addTArgument(factor); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java index e1a5b0a7a..bd0c88792 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java @@ -50,6 +50,11 @@ public class AdjustHue extends DynamicCustomOp { super(sameDiff,new SDVariable[]{in,factor}); } + public AdjustHue(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) { + super(sameDiff,new SDVariable[]{in}); + addTArgument(factor); + } + @Override public String opName() { return "adjust_hue"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java index e9f1f90c8..3c98f2149 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java @@ -49,6 +49,11 @@ public class AdjustSaturation extends DynamicCustomOp { super(sameDiff, new SDVariable[]{in, factor}); } + public AdjustSaturation(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) { + super(sameDiff, new SDVariable[]{in}); + addTArgument(factor); + } + @Override public String opName() { return "adjust_saturation"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Logdet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Logdet.java new file mode 100644 index 000000000..81b8cbc08 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Logdet.java @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Logdet extends DynamicCustomOp { + + public Logdet(INDArray input) { + addInputArgument(input); + } + + public Logdet(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "logdet"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java index 20751164f..b7c0e4092 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java @@ -17,9 +17,17 @@ package org.nd4j.linalg.api.ops.custom; import lombok.NoArgsConstructor; import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + @NoArgsConstructor public class Lstsq extends DynamicCustomOp { @@ -33,8 +41,21 @@ public class Lstsq extends DynamicCustomOp { this(matrix, rhs, 0.0, true); } + public Lstsq(@NonNull SameDiff sameDiff, @NonNull SDVariable matrix, @NonNull SDVariable rhs, double l2_regularizer, boolean fast) { + super(sameDiff, new SDVariable[]{matrix,rhs}); + addTArgument(l2_regularizer); + addBArgument(fast); + } + @Override public String opName() { return "lstsq"; } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java index 554781958..40d50afe3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java @@ -15,6 +15,7 @@ ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -26,10 +27,9 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; import java.util.List; +@NoArgsConstructor public class MatrixBandPart extends DynamicCustomOp { - public MatrixBandPart() {} - public MatrixBandPart(@NonNull INDArray input, int minLower, int maxUpper) { Preconditions.checkArgument(input.rank() >= 2, "MatrixBandPart: Input rank should be 2 or higher"); long N = input.size(-2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java index 97b826064..da0896a46 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java @@ -1,6 +1,5 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -37,7 +36,6 @@ import java.util.*; */ @NoArgsConstructor public class CropAndResize extends DynamicCustomOp { - public enum Method {BILINEAR, NEAREST}; protected Method method = Method.BILINEAR; protected double extrapolationValue = 0.0; @@ -50,6 +48,10 @@ public class CropAndResize extends DynamicCustomOp { addArgs(); } + public CropAndResize(@NonNull SameDiff sameDiff, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, + SDVariable cropOutSize, double extrapolationValue) { + this(sameDiff, image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue); + } public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, @NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue, @@ -65,12 +67,10 @@ public class CropAndResize extends DynamicCustomOp { outputArguments.add(output); } - public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, - @NonNull INDArray cropOutSize, double extrapolationValue) { - this(image, cropBoxes, boxIndices, cropOutSize, Method.BILINEAR, extrapolationValue, null); + public CropAndResize(INDArray image, INDArray cropBoxes, INDArray boxIndices, INDArray cropOutSize, double extrapolationValue ) { + this(image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue, null); } - @Override public String opName() { return "crop_and_resize"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java index 71b8b1fb2..5e6362d67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java @@ -46,6 +46,12 @@ public class ExtractImagePatches extends DynamicCustomOp { public ExtractImagePatches(){ } + public ExtractImagePatches(@NonNull SameDiff samediff, @NonNull SDVariable input, + int kH, int kW, int sH, int sW, int rH, int rW, + boolean sameMode) { + this(samediff, input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH,rW}, sameMode); + + } public ExtractImagePatches(@NonNull SameDiff samediff, @NonNull SDVariable input, @NonNull int[] kSizes, @NonNull int[] strides, @NonNull int[] rates, boolean sameMode){ super(samediff, input); @@ -72,16 +78,8 @@ public class ExtractImagePatches extends DynamicCustomOp { addArgs(); } - public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { - super(new INDArray[]{input},null); - int[] kSises = {kH,kW}; - int[] strides = {sH,sW}; - int[] rates = {rH, rW}; - this.kSizes = kSises; - this.strides = strides; - this.rates = rates; - this.isSameMode = sameMode; - addArgs(); + public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { + this(input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH, rW}, sameMode); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java index f8763c41a..f7ab95d77 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java @@ -42,6 +42,13 @@ public class NonMaxSuppression extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false); } + public NonMaxSuppression(SameDiff sameDiff, SDVariable boxes, SDVariable scores, int maxOutSize, + double iouThreshold, double scoreThreshold) { + super(null, sameDiff, new SDVariable[]{boxes, scores}, false); + addIArgument(maxOutSize); + addTArgument(iouThreshold, scoreThreshold); + } + public NonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, double iouThreshold, double scoreThreshold) { addInputArgument(boxes,scores); addIArgument(maxOutSize); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java index 60b278ed7..dd61c03e4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java @@ -54,10 +54,18 @@ public class FirstIndex extends BaseIndexAccumulation { this.extraArgs = new Object[] {compare, eps, (double) mode}; } + public FirstIndex(SameDiff sameDiff, SDVariable i_v, boolean keepDims, Condition condition, int... dimensions) { + this(sameDiff, i_v, condition, keepDims, dimensions); + } + public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) { this(x, condition, false, dimension); } + public FirstIndex(INDArray x, boolean keepDims, @NonNull Condition condition, int... dimension) { + this(x,condition,keepDims,dimension); + } + public FirstIndex(INDArray x, @NonNull Condition condition, boolean keepDims, int... dimension) { this(x, condition, Nd4j.EPS_THRESHOLD, dimension); this.keepDims = keepDims; @@ -72,7 +80,6 @@ public class FirstIndex extends BaseIndexAccumulation { this.extraArgs = new Object[] {compare, eps, (double) mode}; } - @Override public int opNum() { return 4; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java index 7280d7adf..8b7872b49 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java @@ -45,6 +45,11 @@ public class IMax extends BaseIndexAccumulation { super(x, z, dimensions); } + public IMax(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + + } + public IMax(INDArray x, int... dimensions) { super(x, null, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java index 449ea36a0..06b3deb1c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java @@ -44,6 +44,10 @@ public class IMin extends BaseIndexAccumulation { super(x, dimensions); } + public IMin(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + public IMin(INDArray x, INDArray z, int... dimensions) { super(x, z, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java index e77d42398..1325d33c5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; import lombok.Data; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -38,12 +39,16 @@ import java.util.Map; * @author raver119@gmail.com */ @Data +@NoArgsConstructor public class LastIndex extends BaseIndexAccumulation { protected Condition condition; protected double compare; protected double eps; protected int mode; + public LastIndex(SameDiff sameDiff, SDVariable i_v, boolean keepDims, Condition condition, int... dimensions) { + this(sameDiff, i_v, condition, keepDims, dimensions); + } public LastIndex(SameDiff sameDiff, SDVariable i_v, Condition condition, boolean keepDims, int... dimensions) { super(sameDiff, i_v, keepDims, dimensions); this.condition = condition; @@ -53,13 +58,19 @@ public class LastIndex extends BaseIndexAccumulation { this.extraArgs = new Object[] {compare, eps, (double) mode}; } - public LastIndex() {} - + public LastIndex(SameDiff sameDiff, SDVariable x, @NonNull Condition condition, int... dimensions) { + super(sameDiff, x, false, dimensions); + this.condition = condition; + } public LastIndex(INDArray x, @NonNull Condition condition, int... dimensions) { this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); } + public LastIndex(INDArray in, boolean keepDim, Condition condition, int... dimensions) { + this(in, condition, keepDim, dimensions); + } + public LastIndex(INDArray x, @NonNull Condition condition, boolean keepDim, int... dimensions) { this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); this.keepDims = keepDim; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java index 79bcacab0..7c6b5186c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java @@ -47,10 +47,6 @@ public class AvgPooling3D extends Pooling3D { super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.AVG); } - public AvgPooling3D(SameDiff sameDiff,INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) { - super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.AVG); - } - public AvgPooling3D(@NonNull INDArray input, Pooling3DConfig pooling3DConfig) { super(null,null,new INDArray[]{input},null,false, pooling3DConfig, Pooling3DType.AVG); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java index e3716bc24..b4f881fac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java @@ -76,6 +76,19 @@ public class BatchNorm extends DynamicCustomOp { addArgs(); } + public BatchNorm(SameDiff sameDiff, SDVariable input, SDVariable mean, SDVariable variance, + SDVariable gamma, SDVariable beta, double epsilon, int[] axis) { + super(null,sameDiff, wrapFilterNull(input, mean, variance, gamma, beta), false); + Preconditions.checkState(axis != null && axis.length > 0, "Invalid axis argument: axis must be specified" + + "and length > 0. Got %s", axis); + this.sameDiff = sameDiff; + this.applyBeta = beta != null; + this.applyGamma = gamma != null; + this.epsilon = epsilon; + this.jaxis = axis; + addArgs(); + } + public BatchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int... axis){ super(wrapFilterNull(input, mean, variance, gamma, beta), null); this.jaxis = axis; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java index 819d1d10c..e33037738 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java @@ -46,6 +46,10 @@ public class Conv1D extends DynamicCustomOp { protected Conv1DConfig config; private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s "; + public Conv1D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { + this(sameDiff, wrapFilterNull(input, weights, bias), conv1DConfig); + } + @Builder(builderMethodName = "sameDiffBuilder") public Conv1D(SameDiff sameDiff, SDVariable[] inputFunctions, @@ -64,12 +68,8 @@ public class Conv1D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public Conv1D( @NonNull INDArray input, @NonNull INDArray weights, INDArray bias, Conv1DConfig conv1DConfig) { - this(wrapFilterNull(input, weights, bias), null, conv1DConfig); - } - - public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, Conv1DConfig conv1DConfig) { - this(new INDArray[]{input, weights}, null, conv1DConfig); + public Conv1D(INDArray input, INDArray weights, INDArray bias, Conv1DConfig config) { + this(input, weights, bias, null, config); } private void initConfig(Conv1DConfig config){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 60bdfbfcc..9635c6f36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -56,6 +56,11 @@ public class Conv2D extends DynamicCustomOp { protected Conv2DConfig config; private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s "; + public Conv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, + SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { + this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); + } + @Builder(builderMethodName = "sameDiffBuilder") public Conv2D(SameDiff sameDiff, SDVariable[] inputFunctions, @@ -75,12 +80,8 @@ public class Conv2D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, @NonNull Conv2DConfig conv2DConfig) { - this(new INDArray[]{layerInput, weights}, null, conv2DConfig); - } - - public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, INDArray bias, @NonNull Conv2DConfig conv2DConfig) { - this(wrapFilterNull(layerInput, weights,bias), null, conv2DConfig); + public Conv2D(INDArray layerInput, INDArray weights, INDArray bias, Conv2DConfig config) { + this(layerInput, weights, bias, null, config); } protected void initConfig(Conv2DConfig config){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java index 94fb897b0..bb30930d7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java @@ -55,6 +55,11 @@ public class Conv3D extends DynamicCustomOp { public Conv3D() { } + public Conv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, + SDVariable bias, @NonNull Conv3DConfig config) { + this(sameDiff, wrapFilterNull(input, weights, bias), config); + } + @Builder(builderMethodName = "sameDiffBuilder") public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) { super(sameDiff, inputFunctions); @@ -70,12 +75,12 @@ public class Conv3D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public Conv3D(@NonNull INDArray input,@NonNull INDArray weights, @NonNull Conv3DConfig conv3DConfig) { - this(new INDArray[]{input, weights}, null, conv3DConfig); + public Conv3D(INDArray input, INDArray weights, INDArray bias, Conv3DConfig config) { + this(wrapFilterNull(input, weights, bias), null, config); } - public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull Conv3DConfig conv3DConfig) { - this(wrapFilterNull(input, weights, bias) , null, conv3DConfig); + public Conv3D(INDArray input, INDArray weights, Conv3DConfig config) { + this(wrapFilterNull(input, weights), null, config); } private void initConfig(Conv3DConfig config){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index f3500bec0..74b448dc8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -52,6 +52,11 @@ public class DeConv2D extends DynamicCustomOp { protected DeConv2DConfig config; + public DeConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, + SDVariable bias, DeConv2DConfig config) { + this(sameDiff, wrapFilterNull(input, weights, bias), config); + } + @Builder(builderMethodName = "sameDiffBuilder") public DeConv2D(SameDiff sameDiff, SDVariable[] inputs, @@ -73,15 +78,10 @@ public class DeConv2D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public DeConv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, DeConv2DConfig deConv2DConfig) { - this(wrapFilterNull(layerInput, weights), null, deConv2DConfig); + public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig config) { + this(layerInput, weights, bias, null, config); } - public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig deConv2DConfig) { - this(wrapFilterNull(layerInput, weights, bias), null, deConv2DConfig); - } - - @Override public long[] iArgs() { if (iArguments.size() == 0) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index a4652850c..436659443 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -48,12 +48,18 @@ public class DeConv3D extends DynamicCustomOp { protected DeConv3DConfig config; - public DeConv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { + public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { super(sameDiff, toArr(input, weights, bias)); this.config = config; addArgs(); } + public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { + super(sameDiff, toArr(input, weights, null)); + this.config = config; + addArgs(); + } + public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){ super(inputs, outputs); @@ -65,12 +71,8 @@ public class DeConv3D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, @NonNull DeConv3DConfig deConv3DConfig) { - this(new INDArray[]{input, weights}, null, deConv3DConfig); - } - - public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull DeConv3DConfig deConv3DConfig) { - this(wrapFilterNull(input, weights, bias), null, deConv3DConfig); + public DeConv3D(INDArray input, INDArray weights, INDArray bias, DeConv3DConfig config) { + this(input, weights, bias, null, config); } private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java index 3becef510..20808dff5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java @@ -16,16 +16,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; -import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.enums.DataFormat; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.enums.DataFormat; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -46,45 +45,48 @@ import java.util.*; * @author raver119@gmail.com, Max Pumperla */ public class DepthToSpace extends DynamicCustomOp { - private String dataFormat = "NHWC"; + private DataFormat dataFormat = DataFormat.NHWC; private int blockSize; public DepthToSpace() { } - public DepthToSpace(SameDiff sameDiff, SDVariable[] args, int blockSize, String dataFormat) { + public DepthToSpace(SameDiff sameDiff, SDVariable args, int blockSize, DataFormat dataFormat) { + this(sameDiff, new SDVariable[]{args}, blockSize, dataFormat); + } + + public DepthToSpace(SameDiff sameDiff, SDVariable[] args, int blockSize, DataFormat dataFormat) { super(null, sameDiff, args, false); this.blockSize = blockSize; this.dataFormat = dataFormat; - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } - public DepthToSpace(@NonNull INDArray in, INDArray out, int blockSize, @NonNull String dataFormat) { + public DepthToSpace(INDArray in, INDArray out, int blockSize, DataFormat dataFormat) { super(null, in, out, null, null); this.blockSize = blockSize; this.dataFormat = dataFormat; - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } - public DepthToSpace(@NonNull INDArray x, int blockSize, DataFormat dataFormat) { - this(x, null, blockSize, dataFormat.toString()); + public DepthToSpace(INDArray in, int blockSize, DataFormat dataFormat) { + this(in, null, blockSize, dataFormat); } - @Override public List doDiff(List i_v) { // Gradient to DepthToSpace is just SpaceToDepth of same block size and data format. SDVariable gradient = i_v.get(0); - SDVariable ret = sameDiff.cnn().spaceToDepth(gradient, blockSize, dataFormat); + SDVariable ret = new SpaceToDepth(sameDiff, new SDVariable[]{gradient}, blockSize, dataFormat).outputVariable(); return Arrays.asList(ret); } @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index ab42c3c5a..afb51af58 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -16,8 +16,11 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; -import lombok.*; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -49,11 +52,15 @@ import java.util.*; */ @Slf4j @Getter -@NoArgsConstructor public class DepthwiseConv2D extends DynamicCustomOp { protected Conv2DConfig config; + public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { + this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); + } + @Builder(builderMethodName = "sameDiffBuilder") public DepthwiseConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, @@ -75,16 +82,11 @@ public class DepthwiseConv2D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, Conv2DConfig conv2DConfig) { - this(wrapFilterNull(layerInput, depthWeights), null, conv2DConfig); + public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig config) { + this(layerInput, depthWeights, bias, null, config); } - public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig conv2DConfig) { - this(wrapFilterNull(layerInput, depthWeights, bias), null, conv2DConfig); - } - - public DepthwiseConv2D(INDArray inputs, Conv2DConfig conv2DConfig) { - this(wrapFilterNull(inputs), null, conv2DConfig); + public DepthwiseConv2D() { } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java index cc5780a7c..10108d87c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java @@ -58,6 +58,10 @@ public class LocalResponseNormalization extends DynamicCustomOp { addArgs(); } + public LocalResponseNormalization(SameDiff sameDiff, SDVariable input, LocalResponseNormalizationConfig config) { + this(sameDiff, new SDVariable[]{input}, false, config); + } + public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){ super(new INDArray[]{input}, wrapOrNull(output)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java index 9f7c9bfb7..c54b63aa7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java @@ -60,15 +60,16 @@ public class MaxPooling2D extends DynamicCustomOp { addArgs(); } - public MaxPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){ + public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){ super(null, new INDArray[]{input}, wrapOrNull(output)); config.setType(Pooling2D.Pooling2DType.MAX); + this.config = config; addArgs(); } - public MaxPooling2D(@NonNull INDArray input, @NonNull Pooling2DConfig pooling2DConfig) { - this(input, null, pooling2DConfig); + public MaxPooling2D(INDArray input, @NonNull Pooling2DConfig config){ + this(input, null, config); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java index 6c4ccaa9a..6c7aec888 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java @@ -47,8 +47,12 @@ public class MaxPooling3D extends Pooling3D { super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.MAX); } - public MaxPooling3D(SameDiff sameDiff, INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) { - super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.MAX); + public MaxPooling3D(INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) { + addInputArgument(arrayInput); + if (arrayOutput != null) + addOutputArgument(arrayOutput); + this.config = config; + addArgs(); } public MaxPooling3D(INDArray input, Pooling3DConfig pooling3DConfig) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java index b28b9a987..cf4e87814 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java @@ -44,18 +44,23 @@ public class SConv2D extends Conv2D { super(sameDiff, inputFunctions, conv2DConfig); } + public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, + @NonNull SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { + this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig); + } + public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ super(inputs, outputs, config); } - public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, @NonNull Conv2DConfig Conv2DConfig){ - this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, Conv2DConfig); - } - public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, @NonNull Conv2DConfig Conv2DConfig){ this(wrapFilterNull(layerInput, depthWeights, pointWeights), null, Conv2DConfig); } + public SConv2D(INDArray layerInput, INDArray depthWeights, INDArray pointWeights, INDArray bias, Conv2DConfig config) { + this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, config); + } + public SConv2D() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java index 5ae281ae2..700824512 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java @@ -16,16 +16,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; -import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.enums.DataFormat; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.enums.DataFormat; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -45,47 +44,48 @@ import java.util.*; * @author raver119@gmail.com, Max Pumperla */ public class SpaceToDepth extends DynamicCustomOp { - private String dataFormat; + private DataFormat dataFormat; private int blockSize; public SpaceToDepth() { } - public SpaceToDepth(SameDiff sameDiff, SDVariable[] args, int blockSize, String dataFormat) { + public SpaceToDepth(SameDiff sameDiff, SDVariable[] args, int blockSize, DataFormat dataFormat) { super(null, sameDiff, args, false); this.blockSize = blockSize; this.dataFormat = dataFormat; - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } - public SpaceToDepth(INDArray in, INDArray out, int blockSize, String dataFormat){ + public SpaceToDepth(SameDiff sameDiff, SDVariable x, int blockSize, DataFormat dataFormat) { + this(sameDiff, new SDVariable[]{x}, blockSize, dataFormat); + } + + public SpaceToDepth(INDArray in, INDArray out, int blockSize, DataFormat dataFormat){ super(null, in, out, null, null); this.blockSize = blockSize; this.dataFormat = dataFormat; - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } - - - public SpaceToDepth(@NonNull INDArray x, int blockSize, @NonNull DataFormat dataFormat) { - this(x, null, blockSize,dataFormat.toString()); + public SpaceToDepth(INDArray x, int blockSize, DataFormat dataFormat) { + this(x, null, blockSize, dataFormat); } - @Override public List doDiff(List i_v) { // Gradient to SpaceToDepth is just DepthToSpace of same block size and data format. SDVariable gradient = i_v.get(0); - SDVariable ret = sameDiff.cnn().depthToSpace(gradient, blockSize, dataFormat); + SDVariable ret = new DepthToSpace(sameDiff, gradient, blockSize, dataFormat).outputVariable(); return Arrays.asList(ret); } @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat == null ? true : dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java index 574682a36..df345a2f3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java @@ -56,6 +56,14 @@ public class Upsampling2d extends DynamicCustomOp { addIArgument(nchw ? 1 : 0); } + public Upsampling2d(SameDiff sameDiff, SDVariable input, int scaleH, int scaleW, boolean nchw) { + this(sameDiff, input, nchw, scaleH, scaleW); + } + + public Upsampling2d(SameDiff sameDiff, SDVariable input, int scale) { + super(null,sameDiff, new SDVariable[]{input}); + addIArgument(scale); + } public Upsampling2d(INDArray input, int scale) { this(input, scale, scale, true); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java index 3a3caa787..adc59e4e0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java @@ -38,6 +38,11 @@ public class AbsoluteDifferenceLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public AbsoluteDifferenceLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + this(sameDiff, lossReduce, predictions, weights, label); + } + public AbsoluteDifferenceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ super(lossReduce, predictions, weights, labels); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java index 9794c7c8b..3f890da0b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java @@ -33,9 +33,9 @@ public abstract class BaseLoss extends DynamicCustomOp { protected LossReduce lossReduce; - public BaseLoss(@NonNull SameDiff sameDiff, @NonNull LossReduce lossReduce, @NonNull SDVariable predictions, @NonNull SDVariable weights, + public BaseLoss(@NonNull SameDiff sameDiff, @NonNull LossReduce lossReduce, @NonNull SDVariable predictions, SDVariable weights, @NonNull SDVariable labels){ - super(null, sameDiff, new SDVariable[]{predictions, weights, labels}); + super(null, sameDiff, new SDVariable[]{predictions, getWeights(sameDiff, weights, predictions), labels}); this.lossReduce = lossReduce; addArgs(); } @@ -50,6 +50,10 @@ public abstract class BaseLoss extends DynamicCustomOp { return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0); } + protected static SDVariable getWeights(SameDiff sd, SDVariable weights, SDVariable predictions){ + return weights != null ? weights : sd.constant(Nd4j.scalar(predictions.dataType(), 1.0)); + } + protected BaseLoss(){ } protected void addArgs(){ @@ -62,7 +66,7 @@ public abstract class BaseLoss extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() >= 2, "Expected exactly 2 or more input datatypes for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java index 241404492..7faa5f6b0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java @@ -39,6 +39,11 @@ public class CosineDistanceLoss extends BaseLoss { this.addIArgument(dimension); } + public CosineDistanceLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, int dimension) { + this(sameDiff, lossReduce, predictions, weights, labels, dimension); + } + public CosineDistanceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension){ super(lossReduce, predictions, weights, labels); this.dimension = dimension; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java index f2998064f..5d85e4933 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java @@ -36,6 +36,11 @@ public class HingeLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public HingeLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + this(sameDiff, lossReduce, predictions, weights, labels); + } + public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ super(lossReduce, predictions, weights, labels); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java index 18803cd9f..f08d90566 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java @@ -41,6 +41,11 @@ public class HuberLoss extends BaseLoss { tArguments.add(delta); } + public HuberLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, double delta) { + this(sameDiff, lossReduce, predictions, weights, labels, delta); + } + public HuberLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta){ super(lossReduce, predictions, weights, labels); this.delta = delta; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java index 01aa283ed..a7a15f1b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java @@ -41,6 +41,11 @@ public class LogLoss extends BaseLoss { addTArgument(epsilon); } + public LogLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, double epsilon) { + this(sameDiff, lossReduce, predictions, weights, labels, epsilon); + } + public LogLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon){ super(lossReduce, predictions, weights, labels); this.epsilon = epsilon; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java index 0e0d4f7dd..a893e3f4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java @@ -38,6 +38,11 @@ public class LogPoissonLoss extends BaseLoss { this(sameDiff, lossReduce, predictions, weights, labels, false); } + public LogPoissonLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, boolean full) { + this(sameDiff, lossReduce, predictions, weights, labels, full); + } + public LogPoissonLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels, boolean full){ super(sameDiff, lossReduce, predictions, weights, labels); this.full = full; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java index 8e7bb9276..6c3c5d01b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java @@ -34,6 +34,11 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public MeanPairwiseSquaredErrorLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + this(sameDiff, lossReduce, predictions, weights, labels); + } + public MeanPairwiseSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ super(lossReduce, predictions, weights, labels); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java index c38faf29a..a9cf27584 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java @@ -36,6 +36,11 @@ public class MeanSquaredErrorLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public MeanSquaredErrorLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + this(sameDiff, lossReduce, predictions, weights, labels); + } + public MeanSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ super(lossReduce, predictions, weights, labels); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java index 32b176cfd..214380a8c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java @@ -44,6 +44,11 @@ public class SigmoidCrossEntropyLoss extends BaseLoss { public static final double DEFAULT_LABEL_SMOOTHING = 0.0; private double labelSmoothing = 0.0; + public SigmoidCrossEntropyLoss(SameDiff sameDiff, SDVariable labels, SDVariable logits, SDVariable weights, + LossReduce lossReduce, double labelSmoothing) { + this(sameDiff, lossReduce, logits, weights, labels, labelSmoothing); + } + public SigmoidCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights, SDVariable labels, double labelSmoothing) { super(sameDiff, lossReduce, logits, weights, labels); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index c8a40b805..57576b78f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -45,6 +45,11 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss { private double labelSmoothing = 0.0; + public SoftmaxCrossEntropyLoss(SameDiff sameDiff, SDVariable labels, SDVariable logits, + SDVariable weights, LossReduce lossReduce, double labelSmoothing) { + this(sameDiff, lossReduce, logits, weights, labels, labelSmoothing); + } + public SoftmaxCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights, SDVariable labels, double labelSmoothing) { super(sameDiff, lossReduce, logits, weights, labels); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index ba3d53e45..d22478e71 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -93,6 +93,24 @@ public class Mmul extends DynamicCustomOp { } } + public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) { + addInputArgument(x, y); + addIArgument(ArrayUtil.fromBoolean(transposeX), + ArrayUtil.fromBoolean(transposeY), + ArrayUtil.fromBoolean(transposeZ)); + } + + public Mmul(INDArray x, INDArray y) { + this(x,y,null,null); + } + + public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + super(null,sameDiff,new SDVariable[]{x,y}); + addIArgument(ArrayUtil.fromBoolean(transposeX), + ArrayUtil.fromBoolean(transposeY), + ArrayUtil.fromBoolean(transposeZ)); + } public Mmul() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index eca14e9f4..c613f107f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -77,6 +77,18 @@ public class TensorMmul extends DynamicCustomOp { addIArgument(dimensions[1]); } + public TensorMmul(SameDiff sameDiff, SDVariable x, SDVariable y, int[] dimensionsX, + int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) { + super(null, sameDiff, new SDVariable[]{x,y}); + this.sameDiff = sameDiff; + this.axes = new int[][]{dimensionsX, dimensionsY}; + addIArgument(dimensionsX.length); + addIArgument(dimensionsX[0]); + addIArgument(dimensionsY.length); + addIArgument(dimensionsY[0]); + addBArgument(transposeX, transposeY, transposeZ); + } + @Override public List calculateOutputShape() { List ret = new ArrayList<>(1); @@ -242,6 +254,13 @@ public class TensorMmul extends DynamicCustomOp { this.axes = axes; } + public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, + boolean transposeX, boolean transposeY, boolean transposeZ) { + super(null,new INDArray[]{x, y},null); + this.axes = new int[][]{dimensionsX, dimensionsY}; + addBArgument(transposeX, transposeY, transposeZ); + } + @Override public String opName() { return "tensordot"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java index 7daebd4cf..d4522ca69 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java @@ -41,6 +41,10 @@ public class Any extends BaseReduceBoolOp { super(x); } + public Any(INDArray x, int... dimensions) { + super(x, dimensions); + } + @Override public int opNum() { return 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java index 44cef710f..26eabf0ff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java @@ -45,6 +45,10 @@ public class LogSumExp extends DynamicCustomOp { this.keepDims = keepDims; } + public LogSumExp(SameDiff sameDiff, SDVariable i_v, int[] dimensions) { + this(sameDiff, i_v, false, dimensions); + } + public LogSumExp() {} public LogSumExp(INDArray x, int... dimensions) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java index e6b2b064d..b11fe5b1f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java @@ -41,6 +41,10 @@ public class SquaredNorm extends BaseReduceFloatOp { super(input, output, keepDims, dimensions); } + public SquaredNorm(INDArray input, boolean keepDims, int... dimensions){ + this(input, null, keepDims, dimensions); + } + public SquaredNorm(){} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java index f366dc0cd..0fb4db830 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java @@ -38,6 +38,10 @@ public class MatchCondition extends BaseReduceLongOp { private double eps; private int mode; + public MatchCondition(SameDiff sameDiff, SDVariable in, Condition condition) { + this(sameDiff, in, condition, false, null); + } + public MatchCondition(SameDiff sameDiff, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { super(sameDiff, in, dimensions, keepDims); this.compare = condition.getValue(); @@ -64,6 +68,10 @@ public class MatchCondition extends BaseReduceLongOp { defineDimensions(dimensions); } + public MatchCondition(INDArray in, Condition condition, boolean keepDim, int... dimensions) { + this(in, condition, dimensions); + } + @Override public int opNum() { return 2; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java index ae70e44d9..859b89dac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java @@ -56,6 +56,10 @@ public class Sum extends BaseReduceSameOp { super(x, z, keepDims, dimensions); } + public Sum(INDArray x, boolean keepDims, int... dimensions) { + this(x, null, keepDims, dimensions); + } + @Override public int opNum() { return 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java index b9a98dc6e..000b0414c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java @@ -50,6 +50,10 @@ public class LeakyReLU extends BaseScalarOp { } + public LeakyReLU(SameDiff sameDiff, SDVariable i_v, double alpha) { + this(sameDiff, i_v, false, alpha); + } + public LeakyReLU(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double alpha) { super(sameDiff, i_v, alpha, extraArgs); this.alpha = alpha; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java index 572e22087..5cfab3768 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java @@ -42,6 +42,10 @@ public class Pow extends BaseScalarOp { this.extraArgs = new Object[]{pow}; } + public Pow(SameDiff sameDiff, SDVariable i_v, double pow) { + this(sameDiff, i_v, false, pow); + } + public Pow(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double pow) { super(sameDiff, i_v, pow, extraArgs); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java index ca8cee2f1..944d4d095 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java @@ -35,6 +35,10 @@ public class RectifiedLinear extends BaseScalarOp { super(sameDiff, i_v, cutoff, inPlace); } + public RectifiedLinear(SameDiff sameDiff, SDVariable i_v, double cutoff) { + this(sameDiff, i_v, false, cutoff); + } + public RectifiedLinear() { super(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java index f593dc663..c80d3c8f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java @@ -42,6 +42,10 @@ public class Relu6 extends BaseScalarOp { super(sameDiff, i_v, cutoff, inPlace); } + public Relu6(SameDiff sameDiff, SDVariable i_v, double cutoff) { + this(sameDiff, i_v, false, cutoff); + } + public Relu6() { // } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java index 98e08b010..65f653d64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java @@ -41,6 +41,10 @@ public class Step extends BaseScalarOp { this.extraArgs = new Object[] {cutoff}; } + public Step(SameDiff sameDiff, SDVariable i_v, double cutoff) { + this(sameDiff, i_v, false, cutoff); + } + public Step() { cutoff = 0.0; this.extraArgs = new Object[] {cutoff}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java index 412b024f1..6f72490a1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java @@ -46,6 +46,9 @@ public class ScalarLessThan extends BaseScalarBoolOp { super(sameDiff, i_v, scalar, inPlace); } + public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, double scalar) { + super(sameDiff, i_v, scalar, false); + } @Override public int opNum() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java index 73f74665a..160556867 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -43,6 +44,10 @@ public class ScatterAdd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterAdd(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterAdd(){} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java index 4e7563e4a..5d6b60c88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -43,6 +44,10 @@ public class ScatterDiv extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterDiv(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterDiv() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java index 65162aad3..7f814d928 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -41,6 +42,10 @@ public class ScatterMax extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterMax(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterMax() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java index 8d8fe4e33..2539a3d56 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -41,6 +42,10 @@ public class ScatterMin extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterMin(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterMin() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java index 2790667cd..411c59188 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -43,6 +44,10 @@ public class ScatterMul extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterMul(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterMul() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java index 382806779..83c4cc222 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -43,6 +44,10 @@ public class ScatterSub extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterSub(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterSub() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java index dd9c52891..93e1e5995 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -53,6 +54,10 @@ public class ScatterUpdate extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterUpdate(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterUpdate(){} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index 69a62e493..bddcef970 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -49,6 +49,14 @@ public class Concat extends DynamicCustomOp { addIArgument(concatDimension); } + public Concat(INDArray[] arrays, int concatDimension) { + this(concatDimension, arrays); + } + + public Concat(SameDiff sameDiff, SDVariable[] inputs, int concatDimension){ + this(sameDiff, concatDimension, inputs); + } + public Concat(SameDiff sameDiff, int concatDimension, SDVariable... inputs){ super(null, sameDiff, inputs); addIArgument(concatDimension); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java index 337c1a936..2bf94021a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java @@ -68,6 +68,12 @@ public class ConfusionMatrix extends DynamicCustomOp { } } + + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights, DataType dataType){ + this(sameDiff, labels, pred, weights); + this.outputType = dataType; + } + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){ super(null, sameDiff, new SDVariable[]{labels, pred}); this.outputType = dataType; @@ -82,6 +88,11 @@ public class ConfusionMatrix extends DynamicCustomOp { addIArgument(numClasses); } + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights, Integer numClasses){ + super(null, sameDiff, new SDVariable[]{labels, pred, weights}); + addIArgument(numClasses); + } + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){ super(null, sameDiff, new SDVariable[]{labels, pred, weights}); if(numClasses != null) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java index 3e94cb126..616d4d1aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -39,15 +40,17 @@ import java.util.List; * * @author Max Pumperla */ +@NoArgsConstructor public class Cross extends DynamicCustomOp { - public Cross() { - } - public Cross(SameDiff sameDiff, SDVariable[] args) { super(null, sameDiff, args, false); } + public Cross(SameDiff sameDiff, SDVariable a, SDVariable b) { + this(sameDiff, new SDVariable[]{a,b}); + } + public Cross(INDArray a, INDArray b){ this(a,b,null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java index 54cefce73..94516ec54 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NoArgsConstructor; import lombok.NonNull; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; @@ -39,11 +40,9 @@ import java.util.Map; * * @author Max Pumperla */ +@NoArgsConstructor public class Diag extends DynamicCustomOp { - public Diag() { - } - public Diag(@NonNull INDArray input) { this(input, null); } @@ -52,6 +51,10 @@ public class Diag extends DynamicCustomOp { super(null, new INDArray[]{input}, wrapOrNull(output)); } + public Diag(SameDiff sameDiff, SDVariable input) { + this(sameDiff, new SDVariable[]{input}, false); + } + public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java index 9162b8935..b498157b6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java @@ -50,6 +50,10 @@ public class DiagPart extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } + public DiagPart(SameDiff sameDiff, SDVariable in) { + this(sameDiff, new SDVariable[]{in}, false); + } + public DiagPart(INDArray in){ this(in, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java index a13a03184..f0a6f436a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java @@ -46,6 +46,10 @@ public class ExpandDims extends DynamicCustomOp { public ExpandDims() { } + public ExpandDims(SameDiff sameDiff, SDVariable args, int axis) { + this(sameDiff, new SDVariable[]{args}, axis); + } + public ExpandDims(SameDiff sameDiff, SDVariable[] args, int axis) { super(null, sameDiff, args); if (axis == Integer.MAX_VALUE) { @@ -63,6 +67,11 @@ public class ExpandDims extends DynamicCustomOp { super(null, inputs, outputs); } + public ExpandDims(INDArray input, int axis) { + addInputArgument(input); + addIArgument(axis); + } + public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java index 3a8bb8f15..1e8409244 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java @@ -122,6 +122,13 @@ public class Eye extends DynamicCustomOp { addArgs(); } + public Eye(SameDiff sameDiff, SDVariable numRows, SDVariable numCols, DataType dataType, int[] batchDimension) { + super(null, sameDiff, new SDVariable[] {numRows, numCols}, false); + this.batchDimension = batchDimension; + this.dataType = dataType; + addArgs(); + } + protected void addArgs() { iArguments.clear(); tArguments.clear(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java index fd6ec5240..b4f690ba5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java @@ -24,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -40,6 +41,13 @@ public class Gather extends DynamicCustomOp { protected int[] indices; protected int jaxis = 0; + public Gather(SameDiff sameDiff, SDVariable df, SDVariable indices, int axis) { + this(sameDiff, df, indices, axis, false); + } + + public Gather(SameDiff sameDiff, SDVariable df, int[] indices, int axis) { + this(sameDiff, df, indices, axis, false); + } public Gather(SameDiff sameDiff, SDVariable input, int[] indices, int axis, boolean inPlace) { super(null, sameDiff, new SDVariable[] {input}, inPlace); @@ -56,6 +64,21 @@ public class Gather extends DynamicCustomOp { this.jaxis = axis; } + public Gather(INDArray df, int[] indexes, int axis) { + addInputArgument(df); + addIArgument(axis); + addIArgument(indexes); + this.jaxis = axis; + this.indices = indices; + } + + public Gather(INDArray df, INDArray indexes, int axis) { + addInputArgument(df, indexes); + addIArgument(axis); + this.jaxis = axis; + this.indices = indices; + } + @Override public String onnxName() { return "Gather"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index b8ef51d57..a239bd9ec 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -17,10 +17,13 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; +import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.util.ArrayUtil; import java.util.Collections; import java.util.List; @@ -31,11 +34,19 @@ import java.util.List; @NoArgsConstructor public class GatherNd extends DynamicCustomOp { + public GatherNd(SameDiff sameDiff, SDVariable[] inputs, SDVariable[] indices) { + super(null, sameDiff, ArrayUtils.addAll(inputs, indices), false); + } public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices, boolean inPlace) { super(null, sameDiff, new SDVariable[] {input, indices}, inPlace); } + public GatherNd(INDArray[] df, INDArray[] indices) { + addInputArgument(df); + addInputArgument(indices); + } + @Override public String opName() { return "gather_nd"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index fab4a0066..6fca99eae 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import org.apache.commons.lang3.NotImplementedException; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -39,11 +40,24 @@ public class Linspace extends DynamicCustomOp { private DataType dataType; + public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) { + super(sameDiff, new SDVariable[0]); + addTArgument(start,stop); + addIArgument(number); + addDArgument(dataType); + } + public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){ super(sameDiff, new SDVariable[]{from, to, length}); this.dataType = dataType; } + public Linspace(DataType dataType, double start, double stop, long number) { + addDArgument(dataType); + addTArgument(start, stop); + addIArgument(number); + } + public Linspace(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java index 84eb47fc8..f2c11f1ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java @@ -37,6 +37,10 @@ public class MeshGrid extends DynamicCustomOp { addIArgument(cartesian ? 1 : 0); } + public MeshGrid(SameDiff sd, SDVariable[] inputs, boolean cartesian) { + this(sd, cartesian, inputs); + } + public MeshGrid(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java index beb9d09b9..affc603e9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java @@ -66,6 +66,11 @@ public class OneHot extends DynamicCustomOp { this(indices, output, depth, -1, 1, 0); } + public OneHot(INDArray indices, int depth) { + addInputArgument(indices); + addIArgument(depth); + } + public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) { super(null, indices, output, null, null); this.depth = depth; @@ -75,6 +80,12 @@ public class OneHot extends DynamicCustomOp { addArgs(); } + public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { + addInputArgument(indices); + addIArgument(depth, axis); + addTArgument(on, off); + addDArgument(dataType); + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java index 4b4b3e578..8da18be06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java @@ -48,10 +48,18 @@ public class OnesLike extends DynamicCustomOp { public OnesLike() { } + public OnesLike(SameDiff sameDiff, SDVariable input) { + this(null, sameDiff, input); + } + public OnesLike(String name, SameDiff sameDiff, SDVariable input) { this(name, sameDiff, input, input.dataType()); } + public OnesLike(SameDiff sameDiff, SDVariable input, DataType dataType) { + this(null, sameDiff, input, dataType); + } + public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) { super(name, sameDiff, new SDVariable[]{input}, false); this.outputType = dataType; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java index cd78e5d12..cfd0bd7ed 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java @@ -55,6 +55,11 @@ public class Permute extends Transpose { addIArgument(permuteDims); } + public Permute(INDArray input, int... permuteDims){ + addInputArgument(input); + addIArgument(permuteDims); + } + public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){ super(sd, input, permuteDims); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java index 8201e6075..5f1448d06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java @@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.*; @@ -39,10 +40,18 @@ public class Rank extends DynamicCustomOp { public Rank() { } + public Rank(SameDiff sameDiff, SDVariable input) { + this(sameDiff, input, false); + } + public Rank(SameDiff sameDiff, SDVariable input, boolean inPlace) { super(null, sameDiff, new SDVariable[] {input}, inPlace); } + public Rank(INDArray indArray) { + addInputArgument(indArray); + } + @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 44d9b79fe..ddf0224db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -59,6 +59,10 @@ public class Reshape extends DynamicCustomOp { super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List)null); } + public Reshape(INDArray in, INDArray shape) { + addInputArgument(in, shape); + } + public Reshape() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index cc5b28bba..a2ca91c65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -69,7 +69,13 @@ public class SequenceMask extends DynamicCustomOp { addIArgument(maxLen); this.dataType = dataType; addDArgument(dataType); - } + } + + public SequenceMask(INDArray input, DataType dataType) { + addInputArgument(input); + this.dataType = dataType; + addDArgument(dataType); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java index 6cd2eec06..62bc5714e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java @@ -48,6 +48,10 @@ public class Shape extends DynamicCustomOp { public Shape() {} + public Shape(SameDiff sameDiff, SDVariable input) { + this(sameDiff, input, false); + } + public Shape(SameDiff sameDiff, SDVariable input, boolean inPlace) { super(null, sameDiff, new SDVariable[] {input}, inPlace); } @@ -56,6 +60,10 @@ public class Shape extends DynamicCustomOp { super(null, in, out, null, null); } + public Shape(INDArray in){ + this(in, null); + } + @Override public String onnxName() { throw new NoOpNameFoundException("No onnx name found for shape " + opName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java index 27989e878..ce3ce9cae 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.tensorflow.framework.AttrValue; @@ -47,6 +48,11 @@ public class Size extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {input}, false); } + public Size(INDArray in) { + addInputArgument(in); + } + + @Override public String onnxName() { throw new NoOpNameFoundException("No onnx name found for shape " + opName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java index 379e6515e..c5f7cdd70 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java @@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.*; @@ -52,6 +53,11 @@ public class Slice extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{input, begin, end}); } + public Slice(INDArray in, int[] begin, int... size) { + addInputArgument(in); + addIArgument(begin); + addIArgument(size); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java index 1ffd0820b..2734d68b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -36,12 +37,21 @@ public class Squeeze extends DynamicCustomOp { public Squeeze() { } + public Squeeze(SameDiff sameDiff, SDVariable arg, int squeezeDims) { + this(sameDiff, arg, new int[] {squeezeDims}); + } + public Squeeze(SameDiff sameDiff, SDVariable arg, int[] squeezeDims) { super(null, sameDiff, new SDVariable[]{arg}); this.squeezeDims = squeezeDims; addIArgument(squeezeDims); } + public Squeeze(INDArray x, int axis) { + addInputArgument(x); + addIArgument(axis); + } + @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { nodeDef.getAttrMap().get("squeeze_dims"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index d2bf9d71b..89c459be3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -50,6 +50,16 @@ public class Stack extends DynamicCustomOp { addArgs(); } + public Stack(INDArray input, int axis) { + addInputArgument(input); + this.jaxis = axis; + addArgs(); + } + + public Stack(SameDiff sameDiff, SDVariable values, int axis) { + this(sameDiff, new SDVariable[]{values}, axis); + } + public Stack(SameDiff sameDiff, SDVariable[] values, int axis) { super(null, sameDiff, values, false); this.jaxis = axis; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java index 2208d3a36..a053403af 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java @@ -25,6 +25,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.util.ArrayUtil; @@ -95,6 +96,20 @@ public class StridedSlice extends DynamicCustomOp { } + public StridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, + int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { + addInputArgument(in); + this.begin = ArrayUtil.toLongArray(begin); + this.end = ArrayUtil.toLongArray(end); + this.strides = ArrayUtil.toLongArray(strides); + this.beginMask = beginMask; + this.endMask = endMask; + this.ellipsisMask = ellipsisMask; + this.newAxisMask = newAxisMask; + this.shrinkAxisMask = shrinkAxisMask; + addArguments(); + } + private void addArguments(){ addIArgument(beginMask); addIArgument(ellipsisMask); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java index e1fb02be9..c2e476f60 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java @@ -66,6 +66,14 @@ public class Tile extends DynamicCustomOp { this(inputs,outputs,axis,false); } + public Tile(INDArray x, INDArray repeat) { + addInputArgument(x, repeat); + } + + public Tile(INDArray x, int... repeat) { + addInputArgument(x); + addIArgument(repeat); + } public Tile() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 9ab0ad58c..95215b686 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -60,6 +60,10 @@ public class Transpose extends DynamicCustomOp { super(null, new INDArray[]{input}, result == null ? null : new INDArray[]{result}, null, (List) null); } + public Transpose(INDArray input) { + addInputArgument(input); + } + public Transpose() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java index 7225ac355..d71200e4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java @@ -45,6 +45,10 @@ public class ZerosLike extends DynamicCustomOp { protected DataType outputType; //Allow customizing dtype for TF import + public ZerosLike(SameDiff sameDiff, SDVariable input) { + this(null, sameDiff, input, false, input.dataType()); + } + public ZerosLike(String name, SameDiff sameDiff, SDVariable input) { this(name, sameDiff, input, false, input.dataType()); } @@ -66,6 +70,10 @@ public class ZerosLike extends DynamicCustomOp { this(in, out, in.dataType()); } + public ZerosLike(INDArray in){ + addInputArgument(in); + } + public ZerosLike(INDArray in, INDArray out, DataType dataType) { super(null, in, out, null, null); if (dataType != null) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index b44b11cf6..adc92549b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -67,6 +67,10 @@ public class Variance extends BaseReduceOp { this.biasCorrected = biasCorrected; } + public Variance(INDArray x, boolean biasCorrected, boolean keepDims, int... dimensions) { + this(x, null, biasCorrected, keepDims, dimensions); + } + public Variance(INDArray x, boolean biasCorrected, int... dimensions) { super(x); this.biasCorrected = biasCorrected; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java index 75a250049..dd45af037 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java @@ -16,11 +16,14 @@ package org.nd4j.linalg.api.ops.impl.transforms; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.ops.SDValidation; import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -34,8 +37,17 @@ import java.util.Map; * Cholesky op wrapper * @author raver119@gmail.com */ +@NoArgsConstructor public class Cholesky extends DynamicCustomOp { + public Cholesky(INDArray input) { + addInputArgument(input); + } + + public Cholesky(SameDiff sameDiff, SDVariable sdInput) { + super(sameDiff, new SDVariable[]{sdInput}); + } + @Override public String opName() { return "cholesky"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index b07f52ce1..8d0a9d0d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -54,6 +54,10 @@ public class Pad extends DynamicCustomOp { addTArgument(padValue); } + public Pad(SameDiff sd, SDVariable in, SDVariable padding, double padValue) { + this(sd, in, padding, Mode.CONSTANT, padValue); + } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){ super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out}); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java index 03256a81a..8df844943 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.bool; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,12 +31,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class IsFinite extends BaseTransformBoolOp { public IsFinite(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public IsFinite() {} + public IsFinite(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public IsFinite(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java index efefaa1d9..44cb362a4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.bool; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,12 +31,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class IsInf extends BaseTransformBoolOp { public IsInf(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public IsInf() {} + public IsInf(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public IsInf(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java index 206bc32b3..daf9b0ea3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.bool; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,12 +32,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class IsNaN extends BaseTransformBoolOp { public IsNaN(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public IsNaN() {} + public IsNaN(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public IsNaN(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java index 89a9ddc64..78b995669 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java @@ -53,6 +53,14 @@ public class BatchToSpace extends DynamicCustomOp { public BatchToSpace() { } + public BatchToSpace(SameDiff sameDiff, SDVariable x, int[] blocks, int[] croppingTop, int... croppingBottom) { + this(sameDiff, x, blocks, new int[][]{croppingTop, croppingBottom}, false); + } + + public BatchToSpace(SameDiff sameDiff, SDVariable x, int[] blocks, int[][] crops, boolean inPlace) { + this(sameDiff, new SDVariable[]{x}, blocks, crops, inPlace); + } + public BatchToSpace(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) { super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(crops))}, inPlace); @@ -63,15 +71,14 @@ public class BatchToSpace extends DynamicCustomOp { addIArgument(b); } - public BatchToSpace(INDArray x, int[] blocks, int[] croppingTop, int[] croppingBottom) { - super(null,x,null,null,null); + public BatchToSpace(INDArray x, int[] blocks, int[] croppingTop, int... croppingBottom) { + addInputArgument(x); + int[][] crops = new int[][]{croppingTop, croppingBottom}; this.blocks = blocks; - this.crops = new int[][]{croppingTop,croppingBottom}; + this.crops = crops; + for (val b : blocks) addIArgument(b); - - for (int e = 0; e < crops.length; e++) - addIArgument(crops[e][0], crops[e][1]); } @@ -94,7 +101,7 @@ public class BatchToSpace extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of batch to space is space to batch with same blocks and padding as crops SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops)); + return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops[0], crops[1])); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java index ef07c7cc6..6622b134e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java @@ -83,7 +83,7 @@ public class BatchToSpaceND extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of batch to space is space to batch with same blocks and padding as crops SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops)); + return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops[0], crops[1])); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java index 3874c040b..0be0b08ad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java @@ -59,7 +59,7 @@ public class CumProd extends DynamicCustomOp { } public CumProd(INDArray in, INDArray result, boolean exclusive, boolean reverse, int... axis) { - super(null, new INDArray[]{in}, new INDArray[]{result}, null, (List)null); + super(null, new INDArray[]{in}, result != null ? new INDArray[]{result} : null, null, (List)null); this.exclusive = exclusive; this.reverse = reverse; this.jaxis = axis; @@ -69,6 +69,10 @@ public class CumProd extends DynamicCustomOp { addArgs(); } + public CumProd(INDArray in, boolean exclusive, boolean reverse, int... axis) { + this(in, null, exclusive, reverse, axis); + } + @Override public String opName() { return "cumprod"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java index 6720b5a75..c24693b01 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java @@ -64,13 +64,16 @@ public class CumSum extends DynamicCustomOp { } public CumSum(INDArray in, INDArray result, boolean exclusive, boolean reverse, int... axis) { - super(null, new INDArray[]{in}, new INDArray[]{result}, null, (List)null); + super(null, new INDArray[]{in}, wrapOrNull(result), null, (List)null); this.exclusive = exclusive; this.reverse = reverse; this.jaxis = axis; addArgs(); } + public CumSum(INDArray in, boolean exclusive, boolean reverse, int... axis) { + this(in, null, exclusive, reverse, axis); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java index c4de19cfa..3bc812596 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java @@ -16,8 +16,6 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; -import lombok.NoArgsConstructor; -import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -42,7 +40,6 @@ import java.util.*; * * @author raver119@gmail.com */ -@NoArgsConstructor public class Dilation2D extends DynamicCustomOp { protected boolean isSameMode; @@ -52,11 +49,21 @@ public class Dilation2D extends DynamicCustomOp { // strides protected int s0, s1, s2, s3; + + public Dilation2D() { + } + + public Dilation2D(SameDiff sameDiff, SDVariable df, SDVariable weights, int[] strides, int[] rates, boolean isSameMode) { + this(sameDiff, new SDVariable[]{df, weights}, strides, rates, isSameMode, false); + } + public Dilation2D(SameDiff sameDiff, SDVariable[] inputAndWeights, int[] strides, int[] rates, boolean isSameMode, boolean inPlace ) { super(null, sameDiff, inputAndWeights, inPlace); - Preconditions.checkArgument(rates.length == 4, "Dilation rate length must be 4, got an array with length %s with values %s", rates.length, rates); - Preconditions.checkArgument(strides.length == 4, "Dilation strides length must be 4, got an array with length %s with values %s", strides.length, strides); + Preconditions.checkArgument(rates.length == 4, + "Dilation rate length must be 4, got an array with length %s with values %s", rates.length, rates); + Preconditions.checkArgument(strides.length == 4, + "Dilation strides length must be 4, got an array with length %s with values %s", strides.length, strides); r0 = rates[0]; r1 = rates[1]; @@ -69,18 +76,21 @@ public class Dilation2D extends DynamicCustomOp { this.isSameMode = isSameMode; addArgs(); + } public Dilation2D(INDArray[] inputArrays, INDArray[] outputs) { super(null, inputArrays, outputs); + } - public Dilation2D(@NonNull INDArray df, @NonNull INDArray weights, @NonNull int[] strides, @NonNull int[] rates, boolean isSameMode) { - super(null, new INDArray[]{df, weights},null); - Preconditions.checkArgument(rates.length == 4, "Dilation rate length must be 4, got an array with length %s with values %s", rates.length, rates); - Preconditions.checkArgument(strides.length == 4, "Dilation strides length must be 4, got an array with length %s with values %s", strides.length, strides); + public Dilation2D(INDArray df, INDArray weights, int[] strides, int[] rates, boolean isSameMode) { + addInputArgument(df, weights); - this.isSameMode = isSameMode; + if (rates.length < 4) + throw new IllegalArgumentException("Dilation rate length must be 4."); + if (strides.length < 4) + throw new IllegalArgumentException("Strides length must be 4."); r0 = rates[0]; r1 = rates[1]; @@ -90,10 +100,11 @@ public class Dilation2D extends DynamicCustomOp { s1 = strides[1]; s2 = strides[2]; s3 = strides[3]; + this.isSameMode = isSameMode; + addArgs(); } - protected void addArgs() { addIArgument(isSameMode ? 1 : 0, r0, r1, r2, r3, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index 0e5232896..b64581b49 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -53,6 +54,10 @@ public class DynamicPartition extends DynamicCustomOp { public DynamicPartition() { } + public DynamicPartition(SameDiff sameDiff, SDVariable input, SDVariable[] partitions, int numPartitions) { + this(sameDiff, input, partitions[0], numPartitions); + } + public DynamicPartition(SameDiff sameDiff, SDVariable input, SDVariable partitions, int numPartitions) { super(null, sameDiff, new SDVariable[] {input, partitions}, false); @@ -61,6 +66,14 @@ public class DynamicPartition extends DynamicCustomOp { addArgs(); } + public DynamicPartition(INDArray input, INDArray[] partitions, int numPartitions) { + addInputArgument(input); + for (INDArray part : partitions) + addInputArgument(part); + + addIArgument(numPartitions); + } + @Override public List doDiff(List i_v) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java index 72aebe1e2..94c34d108 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -60,6 +61,16 @@ public class DynamicStitch extends DynamicCustomOp { this.numPartitions = inputs.length; } + public DynamicStitch(INDArray[] inputs, INDArray[] indices) { + for (INDArray input : inputs) { + addInputArgument(input); + } + + for (INDArray index : indices) { + addInputArgument(index); + } + } + @Override public List doDiff(List i_v) { // DynamicPartition and DynamicStitch are mutually inverse diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java index c4d7b2469..0d1214c9a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java @@ -36,6 +36,10 @@ import java.util.List; public class EqualTo extends BaseDynamicTransformOp { public EqualTo() {} + public EqualTo( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public EqualTo( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -44,6 +48,10 @@ public class EqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } + public EqualTo( INDArray x, INDArray y) { + addInputArgument(x, y); + } + public EqualTo(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java index a5ffbced5..73f221f35 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java @@ -63,6 +63,11 @@ public class Fill extends DynamicCustomOp { this.value = value; } + public Fill(INDArray shape, DataType dataType, double value) { + super(null, shape, null, Collections.singletonList(value), null); + this.value = value; + } + public Fill(INDArray shape, INDArray value, INDArray result) { super(null, new INDArray[]{shape, value}, new INDArray[]{result}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java index 4c7fce72c..6a1ecc2cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java @@ -37,6 +37,10 @@ import java.util.List; public class GreaterThan extends BaseDynamicTransformOp { public GreaterThan() {} + public GreaterThan( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y},false); + } + public GreaterThan( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -45,6 +49,10 @@ public class GreaterThan extends BaseDynamicTransformOp { super(inputs, outputs); } + public GreaterThan( INDArray x, INDArray y) { + addInputArgument(x,y); + } + public GreaterThan(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java index 6326870ec..dfb7fe8dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java @@ -36,6 +36,10 @@ import java.util.List; public class GreaterThanOrEqual extends BaseDynamicTransformOp { public GreaterThanOrEqual() {} + public GreaterThanOrEqual( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public GreaterThanOrEqual( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -48,6 +52,11 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp { this(new INDArray[]{x, y}, new INDArray[]{z}); } + public GreaterThanOrEqual(INDArray x, INDArray y) { + + this(new INDArray[]{x,y}, null); + } + @Override public int opNum() { return 11; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java index 387f484f2..6048c9dff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import java.util.Arrays; @@ -35,10 +36,18 @@ import java.util.List; @NoArgsConstructor public class InvertPermutation extends BaseDynamicTransformOp { + public InvertPermutation(SameDiff sameDiff, SDVariable input) { + this(sameDiff, input, false); + } + public InvertPermutation(SameDiff sameDiff, SDVariable input, boolean inPlace) { super( sameDiff, new SDVariable[] {input}, inPlace); } + public InvertPermutation(INDArray input) { + addInputArgument(input); + } + @Override public String opName() { return "invert_permutation"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java index 96ad104af..95640fead 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -32,13 +33,21 @@ import java.util.List; * and returns true if for every adjacent pair we have x[i] <= x[i+1]. * */ +@NoArgsConstructor public class IsNonDecreasing extends DynamicCustomOp { - public IsNonDecreasing() {} public IsNonDecreasing( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); } + public IsNonDecreasing( SameDiff sameDiff, SDVariable[] args) { + super(null, sameDiff, args, false); + } + + public IsNonDecreasing( SameDiff sameDiff, SDVariable input) { + super(null, sameDiff, new SDVariable[]{input}, false); + } + public IsNonDecreasing(@NonNull INDArray input){ this(input, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java index f25372b58..88c0a84ba 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java @@ -33,6 +33,10 @@ import java.util.List; public class IsNumericTensor extends DynamicCustomOp { public IsNumericTensor() {} + public IsNumericTensor( SameDiff sameDiff, SDVariable args) { + this(sameDiff, new SDVariable[]{args}, false); + } + public IsNumericTensor( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); } @@ -41,6 +45,9 @@ public class IsNumericTensor extends DynamicCustomOp { super(null, inputs, outputs); } + public IsNumericTensor(INDArray input) { + addInputArgument(input); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java index 55b866cad..f6701c4ca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java @@ -39,6 +39,10 @@ public class IsStrictlyIncreasing extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } + public IsStrictlyIncreasing( SameDiff sameDiff, SDVariable input) { + super(null, sameDiff, new SDVariable[]{input}); + } + public IsStrictlyIncreasing(@NonNull INDArray input){ this(input, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java index 61fbe2bee..b1a38e0ff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java @@ -37,6 +37,10 @@ import java.util.List; public class LessThan extends BaseDynamicTransformOp { public LessThan() {} + public LessThan( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public LessThan( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -45,6 +49,10 @@ public class LessThan extends BaseDynamicTransformOp { super(inputs, outputs); } + public LessThan( INDArray x, INDArray y) { + addInputArgument(x,y); + } + public LessThan(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java index 9f471f8dc..0ca6bf7e6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java @@ -36,6 +36,10 @@ import java.util.List; public class LessThanOrEqual extends BaseDynamicTransformOp { public LessThanOrEqual() {} + public LessThanOrEqual( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public LessThanOrEqual( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -44,6 +48,10 @@ public class LessThanOrEqual extends BaseDynamicTransformOp { super(inputs, outputs); } + public LessThanOrEqual( INDArray x, INDArray y) { + addInputArgument(x,y); + } + public LessThanOrEqual(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java index 7fd707507..37fe652d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java @@ -48,6 +48,10 @@ public class MatrixDeterminant extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{in}, inPlace); } + public MatrixDeterminant(SameDiff sameDiff, SDVariable in) { + this(sameDiff, in, false); + } + @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java index 4ff0f942b..475f3c6a8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java @@ -46,6 +46,9 @@ public class MatrixInverse extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{in}, inPlace); } + public MatrixInverse(SameDiff sameDiff, SDVariable in) { + this(sameDiff, in, false); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java index 9bbf6c50f..19d139cbb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java @@ -34,6 +34,10 @@ public class MatrixSetDiag extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{in, diag}, inPlace); } + public MatrixSetDiag(SameDiff sameDiff, SDVariable in, SDVariable diag) { + this(sameDiff, in, diag, false); + } + public MatrixSetDiag(@NonNull INDArray in, @NonNull INDArray diag){ super(new INDArray[]{in, diag}, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java index 6c877f96d..e8653d4c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java @@ -52,6 +52,10 @@ public class Max extends BaseDynamicTransformOp { super(inputs, outputs); } + public Max( INDArray x, INDArray y) { + addInputArgument(x,y); + } + @Override public String opName() { return "maximum"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java index 73bfbacc7..c195178c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java @@ -52,6 +52,9 @@ public class Min extends BaseDynamicTransformOp { super(inputs, outputs); } + public Min( INDArray x, INDArray y) { + addInputArgument(x,y); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java index c2c245979..69d724a7e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java @@ -36,6 +36,10 @@ import java.util.List; public class NotEqualTo extends BaseDynamicTransformOp { public NotEqualTo() {} + public NotEqualTo( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public NotEqualTo( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -44,6 +48,10 @@ public class NotEqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } + public NotEqualTo( INDArray x, INDArray y) { + addInputArgument(x,y); + } + public NotEqualTo(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Qr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Qr.java new file mode 100644 index 000000000..409b0bd8e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Qr.java @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.List; + +@NoArgsConstructor +public class Qr extends DynamicCustomOp { + + public Qr(INDArray input) { + this(input, false); + } + public Qr(INDArray input, boolean fullMatrices) { + addInputArgument(input); + addBArgument(fullMatrices); + } + + public Qr(SameDiff sameDiff, SDVariable input, boolean fullMatrices) { + super(sameDiff, new SDVariable[]{input}); + addBArgument(fullMatrices); + } + + @Override + public String opName() { + return "qr"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java index 3e7276def..11897fef8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java @@ -23,6 +23,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -66,6 +67,11 @@ public class ReverseSequence extends DynamicCustomOp { public ReverseSequence() { } + public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) { + addInputArgument(x, seq_lengths); + addIArgument(seqDim, batchDim); + } + @Override public String opName() { return "reverse_sequence"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java index 712df46fc..24c2353c1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java @@ -50,6 +50,10 @@ public class SoftMax extends BaseDynamicTransformOp { super(sameDiff, args, false); } + public SoftMax(SameDiff sameDiff, SDVariable x, int dimension) { + this(sameDiff, new SDVariable[]{x}, dimension); + } + public SoftMax(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java index 1ce8a7889..c7a8c0cda 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java @@ -54,6 +54,10 @@ public class SpaceToBatch extends DynamicCustomOp { public SpaceToBatch() { } + public SpaceToBatch(SameDiff sameDiff, SDVariable x, int[] blocks, int[] paddingTop, int... paddingBottom) { + this(sameDiff, new SDVariable[]{x}, blocks, new int[][]{paddingBottom, paddingBottom}, false); + } + public SpaceToBatch(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) { super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(padding))}, inPlace); @@ -63,19 +67,14 @@ public class SpaceToBatch extends DynamicCustomOp { addIArgument(blocks[0]); } - public SpaceToBatch(INDArray x, int[] blocks, int[] paddingTop, int[] paddingBottom) { - super(null,x,null,null,null); + public SpaceToBatch(INDArray x, int[] blocks, int[] paddingTop, int... paddingBottom) { + addInputArgument(x); this.blocks = blocks; - this.padding = new int[][]{paddingTop,paddingBottom}; + this.padding = padding; - for (val b : blocks) - addIArgument(b); - - for (int e = 0; e < padding.length; e++) - addIArgument(padding[e][0], padding[e][1]); + addIArgument(blocks[0]); } - @Override public String opName() { return "space_to_batch"; @@ -95,7 +94,7 @@ public class SpaceToBatch extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of space to batch is batch to space with same blocks and crops as padding SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding)); + return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding[0], padding[1])); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java index 9eb72e54f..12009d955 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java @@ -84,7 +84,7 @@ public class SpaceToBatchND extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of space to batch is batch to space with same blocks and crops as padding SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding)); + return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding[0], padding[1])); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java index c66285dc2..60de8a665 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java @@ -71,6 +71,12 @@ public class Svd extends DynamicCustomOp { addIArgument(ArrayUtil.fromBoolean(fullUV), ArrayUtil.fromBoolean(computeUv), switchNum); } + public Svd(INDArray input, boolean fullUV, boolean computeUV, int switchNum) { + addInputArgument(input); + addBArgument(fullUV, computeUV); + addIArgument(switchNum); + } + @Override public String opName(){ return "svd"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java index 06ebbb5ef..24d79f234 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java @@ -53,7 +53,7 @@ public class Trace extends DynamicCustomOp { public List doDiff(List gradAtOutput){ SDVariable rows = f().reshape(f().sizeAt(arg(), -2), new long[]{1}); SDVariable cols = f().reshape(f().sizeAt(arg(), -1), new long[]{1}); - SDVariable eye = sameDiff.math().eye(f().shape(gradAtOutput.get(0)), rows, cols); + SDVariable eye = sameDiff.math().eye(/*f().shape(gradAtOutput.get(0)),*/ rows, cols); //Reshape gradient from [x,y,z] to [x,y,z,1,1] SDVariable reshapedGrad = f().expandDims(gradAtOutput.get(0), -1); reshapedGrad = f().expandDims(reshapedGrad, -1); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java index b0c007a29..5b6cd2517 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentMax extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentMax(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentMax(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java index 0b881ecbd..d0a9a6784 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentMean extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentMean(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentMean(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java index 7417ccb1d..2bc369f2a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentMin extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentMin(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentMin(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java index 4345b27ec..3be3625e7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentProd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentProd(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentProd(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java index 236a74041..5de847162 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentSum extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentSum(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentSum(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java index 87e5281ea..df5cdbcc7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.floating; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,13 +32,17 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class RSqrt extends BaseTransformFloatOp { + + public RSqrt(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public RSqrt(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public RSqrt() {} - public RSqrt(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Sqrt.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Sqrt.java index 454b27342..34d74beb8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Sqrt.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Sqrt.java @@ -36,6 +36,10 @@ public class Sqrt extends BaseTransformFloatOp { super(sameDiff, i_v, inPlace); } + public Sqrt(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sqrt() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java index e5322c02f..2ca198506 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -36,12 +37,15 @@ import java.util.List; * @author Adam Gibson */ @Deprecated +@NoArgsConstructor public class HardTanhDerivative extends BaseTransformStrictOp { public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public HardTanhDerivative() {} + public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public HardTanhDerivative(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java index 202f7e291..259180f5c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,6 +32,7 @@ import java.util.List; /**Leaky ReLU derivative. Default alpha = 0.01. Cutoff = 0 */ +@NoArgsConstructor public class LeakyReLUDerivative extends BaseScalarOp { private double alpha = 0.01; @@ -40,14 +42,16 @@ public class LeakyReLUDerivative extends BaseScalarOp { this.extraArgs = new Object[] {alpha}; } + public LeakyReLUDerivative(SameDiff sameDiff, SDVariable i_v, double alpha) { + this(sameDiff, i_v, false, alpha); + } + public LeakyReLUDerivative(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double alpha) { super(sameDiff, i_v, alpha, extraArgs); this.alpha = alpha; this.extraArgs = new Object[] {alpha}; } - public LeakyReLUDerivative() {} - public LeakyReLUDerivative(INDArray x, INDArray z) { this(x, z, 0.01); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java index 4ae26e585..cd189c82d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -33,12 +34,14 @@ import java.util.List; * @deprecated Use {@link SoftSignBp} */ @Deprecated +@NoArgsConstructor public class SoftSignDerivative extends BaseTransformStrictOp { public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public SoftSignDerivative() { + public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public SoftSignDerivative(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index 07bea9ae7..fc89333f4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -33,14 +34,17 @@ import java.util.List; * * @author Max Pumperla */ +@NoArgsConstructor public class MergeAddOp extends BaseDynamicTransformOp { - public MergeAddOp() {} - public MergeAddOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } + public MergeAddOp(SameDiff sameDiff, SDVariable[] args) { + this(sameDiff, args, false); + } + public MergeAddOp(@NonNull INDArray... inputs){ this(inputs, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java index 2a8cf1111..4c6bf0ad9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -32,13 +33,17 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Abs extends BaseTransformSameOp { + + public Abs(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Abs(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Abs() { - } public Abs(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java index d58ad8f3f..6422e8df8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java @@ -17,6 +17,8 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; import java.util.Collections; + +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,13 +34,17 @@ import java.util.List; * * @author Paul Dubs */ +@NoArgsConstructor public class Cube extends BaseTransformSameOp { + + public Cube(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Cube(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Cube() {} - public Cube(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java index 842c78929..ba6dd7171 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,12 +31,14 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Floor extends BaseTransformSameOp { public Floor(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Floor() { + public Floor(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public Floor(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java index bf3d28d71..dcee02131 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java @@ -41,6 +41,10 @@ public class Identity extends BaseDynamicTransformOp { super(new INDArray[]{x}, new INDArray[]{z}); } + public Identity(INDArray x){ + addInputArgument(x); + } + public Identity(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java index 9a1664cb6..37b370fe9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,12 +32,15 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Negative extends BaseTransformSameOp { public Negative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Negative() {} + public Negative(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public Negative(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java index 764aca29f..1e11fa34d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -30,13 +31,11 @@ import java.util.List; /** * Created by susaneraly on 3/28/18. */ +@NoArgsConstructor public class Reciprocal extends BaseTransformSameOp { - public Reciprocal(SameDiff sameDiff, SDVariable in, boolean inPlace) { - super(sameDiff, in, inPlace); - } - - public Reciprocal() { + public Reciprocal(SameDiff sameDiff, SDVariable in) { + super(sameDiff, in, false); } public Reciprocal(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java index 25f3120dc..375a8acb5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -31,13 +32,17 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Round extends BaseTransformSameOp { + + public Round(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Round(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Round() {} - public Round(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java index 8ab85ac18..58c5d9c20 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java @@ -36,6 +36,10 @@ public class Sign extends BaseTransformSameOp { super(sameDiff, i_v, inPlace); } + public Sign(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sign() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java index 9dbc77bac..c63e00114 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java @@ -35,6 +35,10 @@ public class Square extends BaseTransformSameOp { super(sameDiff, i_v, inPlace); } + public Square(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Square() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index 0e5426c3c..1506ac5f3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.*; @@ -39,6 +40,11 @@ public class UnsortedSegmentMax extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + public UnsortedSegmentMax(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index b0b7f4457..4338cf33d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -43,6 +44,11 @@ public class UnsortedSegmentMean extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_mean"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 5b7e1c7e0..2f8aab0b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -43,6 +44,11 @@ public class UnsortedSegmentMin extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_min"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index bca9e1788..7afd75fac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -43,6 +44,11 @@ public class UnsortedSegmentProd extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_prod"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java index b3a507435..336c756ac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -44,6 +45,11 @@ public class UnsortedSegmentSum extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_sum"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java index 21a7e5b38..3e0c60bb0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,12 +33,14 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class ACos extends BaseTransformStrictOp { - public ACos(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + public ACos(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } - public ACos() { + public ACos(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public ACos(INDArray x) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java index 2e51ea351..a8d9f12ad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,11 +33,9 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class ACosh extends BaseTransformStrictOp { - public ACosh() { - } - public ACosh(INDArray x) { super(x); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java index 8716a8f7d..fc514a415 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -33,12 +34,15 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class ASin extends BaseTransformStrictOp { public ASin(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public ASin() { + public ASin(SameDiff sameDiff, SDVariable i_v) { + + this(sameDiff, i_v, false); } public ASin(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java index 458d9fad1..483896dfd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,12 +33,14 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class ATan extends BaseTransformStrictOp { public ATan(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public ATan() { + public ATan(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public ATan(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java index 3beb90343..21076ad6e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,15 +33,17 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Cos extends BaseTransformStrictOp { + public Cos(SameDiff sameDiff, SDVariable i_v){ + this(sameDiff, i_v, false); + } + public Cos(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Cos() { - } - public Cos(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java index b9ada31d0..dc08ead5f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,15 +33,17 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Cosh extends BaseTransformStrictOp { + public Cosh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Cosh(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Cosh() { - } - public Cosh(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erf.java index 3d49194ab..2769c95f8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erf.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -34,12 +35,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Erf extends BaseTransformStrictOp { - public Erf(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + + public Erf(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } - public Erf() { + public Erf(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public Erf(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erfc.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erfc.java index 857d87141..f31e71ee8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erfc.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erfc.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -36,12 +37,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Erfc extends BaseTransformStrictOp { - public Erfc(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + + public Erfc(SameDiff sameDiff, SDVariable i_v){ + this(sameDiff, i_v, false); } - public Erfc() { + public Erfc(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public Erfc(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java index 05dc708a8..21aa49522 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,12 +32,15 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Exp extends BaseTransformStrictOp { - public Exp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + + public Exp(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } - public Exp() { + public Exp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public Exp(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java index 5aad7aebd..538f6a003 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -33,12 +34,14 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Expm1 extends BaseTransformStrictOp { - public Expm1(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + public Expm1(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } - public Expm1() { + public Expm1(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public Expm1(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java index b33ea8b8f..b784ddde0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -33,12 +34,18 @@ import java.util.List; * use precise=false; otherwise, use precise = true for the slower but marginally more accurate tanh version. * @author raver119@gmail.com */ +@NoArgsConstructor public class GELU extends BaseTransformStrictOp { public GELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace, boolean precise) { super(sameDiff, i_v, inPlace); } - public GELU() { + public GELU(SameDiff sameDiff, SDVariable i_v, boolean precise) { + this(sameDiff, i_v, false, precise); + } + + public GELU(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false, false); } public GELU(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java index ddca48d4c..ddaa8631f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,8 +33,8 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class HardSigmoid extends BaseTransformStrictOp { - public HardSigmoid() {} public HardSigmoid(INDArray x, INDArray z) { super(x, z); @@ -47,6 +48,10 @@ public class HardSigmoid extends BaseTransformStrictOp { super(sameDiff, in, inPlace); } + public HardSigmoid(SameDiff sameDiff, SDVariable in){ + this(sameDiff, in, false); + } + @Override public int opNum() { return 36; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java index 4237e72de..fa80bf880 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java @@ -17,6 +17,8 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; import java.util.Collections; + +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,12 +34,14 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class HardTanh extends BaseTransformStrictOp { public HardTanh(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public HardTanh() { + public HardTanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public HardTanh(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java index 0295b8e52..a937e1d63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java @@ -36,6 +36,10 @@ public class Log extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Log(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Log() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java index 96892a9f0..131986d15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,11 +33,14 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Log1p extends BaseTransformStrictOp { public Log1p(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Log1p() {} + public Log1p(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public Log1p(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java index 0f4c7abcc..353ced004 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,12 +33,14 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class LogSigmoid extends BaseTransformStrictOp { public LogSigmoid(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public LogSigmoid() { + public LogSigmoid(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public LogSigmoid(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java index f72676f86..00592f0e2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java @@ -43,6 +43,10 @@ public class SELU extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public SELU(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public SELU() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java index 22d5b6302..37ef4b743 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java @@ -36,6 +36,10 @@ public class Sigmoid extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Sigmoid(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sigmoid() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java index 22357d386..0fa918c11 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java @@ -37,6 +37,10 @@ public class Sin extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Sin(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sin(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java index 84a8f522a..d5e3be988 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java @@ -37,6 +37,10 @@ public class Sinh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Sinh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sinh(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java index abd1ce904..11ffb2ef8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java @@ -34,6 +34,10 @@ public class SoftPlus extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public SoftPlus(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public SoftPlus(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java index c7c90b201..8be5ea2d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java @@ -40,6 +40,10 @@ public class SoftSign extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public SoftSign(SameDiff sameDiff, SDVariable i_v) { + super(sameDiff, i_v, false); + } + public SoftSign() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java index 029c7c5b4..0794e0b57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,12 +33,14 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Swish extends BaseTransformStrictOp { public Swish(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Swish() { + public Swish(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public Swish(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java index 77954deec..3244925b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java @@ -38,6 +38,10 @@ public class Tan extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Tan(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Tan() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java index 667bf6a93..136d0bbea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java @@ -36,6 +36,10 @@ public class Tanh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Tanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v,false); + } + public Tanh() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java index 752881c6e..0672f5d15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java @@ -38,6 +38,7 @@ import java.util.List; @NoArgsConstructor public abstract class BaseRandomOp extends BaseOp implements RandomOp { protected long[] shape; + protected DataType dataType = Nd4j.defaultFloatingPointType(); public BaseRandomOp(SameDiff sameDiff, SDVariable i_v) { Preconditions.checkNotNull(i_v, "Input variable can't be null with this constructor"); @@ -72,7 +73,7 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { @Override public List calculateOutputShape(OpContext opContext) { if(shape != null){ - return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType())); + return Collections.singletonList(LongShapeDescriptor.fromShape(shape, dataType)); } else { return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType()))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java index 5b9faa005..dfedd6dcd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java @@ -36,6 +36,7 @@ import java.util.List; @Slf4j public class RandomExponential extends DynamicCustomOp { private double lambda = 0.0; + private DataType dataType = DataType.DOUBLE; public RandomExponential() { // @@ -48,6 +49,15 @@ public class RandomExponential extends DynamicCustomOp { addTArgument(lambda); } + public RandomExponential(SameDiff sd, double lambda, DataType dataType, long... shape){ + super(null, sd, new SDVariable[]{sd.constant(Nd4j.createFromArray(shape))}); + this.lambda = lambda; + addTArgument(lambda); + this.dataType = dataType; + addDArgument(dataType); + addIArgument(shape); + } + public RandomExponential(double lambda, DataType datatype, long... shape){ this(Nd4j.createFromArray(shape), Nd4j.createUninitialized(datatype, shape), lambda); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java index 3f08d1619..ec4bf96a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java @@ -44,6 +44,13 @@ public class BernoulliDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.prob}; } + public BernoulliDistribution(SameDiff sd, double prob, DataType dataType, long[] shape){ + this(sd, prob, shape); + this.prob = prob; + this.extraArgs = new Object[] {this.prob}; + super.dataType = dataType; + } + public BernoulliDistribution() { super(); } @@ -113,6 +120,6 @@ public class BernoulliDistribution extends BaseRandomOp { Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); //Input data type specifies the shape; output data type should be any float //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 - return Collections.singletonList(DataType.DOUBLE); + return Collections.singletonList(dataType); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java index b08f56be3..93e4e3c66 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java @@ -45,6 +45,10 @@ public class BinomialDistribution extends BaseRandomOp { this.extraArgs = new Object[] {(double) this.trials, this.probability}; } + public BinomialDistribution(SameDiff sd, int trials, double probability, DataType dataType, long[] shape){ + this(sd, trials, probability, shape); + } + public BinomialDistribution(int trials, double probability, DataType dt, long[] shape){ this(Nd4j.createUninitialized(dt, shape), trials, probability); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java index 1081e141b..1aa031ec0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java @@ -36,7 +36,7 @@ import java.util.List; */ public class GaussianDistribution extends BaseRandomOp { private double mean; - private double stddev; + private double stddev; public GaussianDistribution(SameDiff sd, double mean, double stddev, long[] shape){ super(sd, shape); @@ -45,6 +45,14 @@ public class GaussianDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public GaussianDistribution(SameDiff sd, double mean, double stddev, DataType dataType, long[] shape){ + super(sd, shape); + this.mean = mean; + this.stddev = stddev; + this.dataType = dataType; + this.extraArgs = new Object[] {this.mean, this.stddev}; + } + public GaussianDistribution() { super(); } @@ -134,9 +142,7 @@ public class GaussianDistribution extends BaseRandomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); - //Input data type specifies the shape; output data type should be any float - //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 - return Collections.singletonList(DataType.DOUBLE); + return Collections.singletonList(dataType); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java index c007d4e92..44545f8ab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java @@ -36,7 +36,7 @@ import java.util.List; */ public class LogNormalDistribution extends BaseRandomOp { private double mean; - private double stddev; + private double stddev; public LogNormalDistribution() { super(); @@ -49,6 +49,11 @@ public class LogNormalDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public LogNormalDistribution(SameDiff sd, double mean, double stdev, DataType dataType, long... shape){ + this(sd, mean, stdev,shape); + this.dataType = dataType; + } + public LogNormalDistribution(double mean, double stddev, DataType datatype, long... shape){ this(Nd4j.createUninitialized(datatype, shape), mean, stddev); } @@ -131,9 +136,7 @@ public class LogNormalDistribution extends BaseRandomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); - //Input data type specifies the shape; output data type should be any float - //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 - return Collections.singletonList(DataType.DOUBLE); + return Collections.singletonList(dataType); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java index ba09a2d29..a95169d78 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java @@ -49,6 +49,13 @@ public class TruncatedNormalDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public TruncatedNormalDistribution(SameDiff sd, double mean, double stddev, DataType dataType, long[] shape) { + super(sd, shape); + this.mean = mean; + this.stddev = stddev; + this.extraArgs = new Object[] {this.mean, this.stddev}; + } + public TruncatedNormalDistribution(double mean, double stddev, DataType datatype, long... shape){ this(Nd4j.createUninitialized(datatype, shape), mean, stddev); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java index 408af9ce2..e1b40a382 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java @@ -47,6 +47,11 @@ public class UniformDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.from, this.to}; } + public UniformDistribution(SameDiff sd, double from, double to, DataType dataType, long[] shape) { + this(sd, from, to, shape); + this.dataType = dataType; + } + public UniformDistribution(double min, double max, DataType datatype, long... shape){ this(Nd4j.createUninitialized(datatype, shape), min, max); } @@ -111,6 +116,6 @@ public class UniformDistribution extends BaseRandomOp { Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); //Input data type specifies the shape; output data type should be any float //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 - return Collections.singletonList(DataType.DOUBLE); + return Collections.singletonList(dataType); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java index f60726c36..4986b8277 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.factory; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -129,6 +130,16 @@ public class NDValidation { " type; got array with non-integer data type " + v.dataType()); } + public static void validateInteger(String opName, String inputName, INDArray[] vars) { + for (INDArray v : vars) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" + + " type; got array with non-integer data type " + v.dataType()); + } + } + /** * Validate that the operation is being applied on an floating point type INDArray * @@ -233,4 +244,15 @@ public class NDValidation { public static boolean isSameType(INDArray x, INDArray y) { return x.dataType() == y.dataType(); } + + public static boolean isSameType(INDArray[] x) { + DataType firstDataType = x[0].dataType(); + if (x.length > 1) { + for (int i = 1; i < x.length; ++i) { + if (firstDataType != x[i].dataType()) + return false; + } + } + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java new file mode 100644 index 000000000..cfaf00d18 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -0,0 +1,2056 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.conditions.Condition; + +public class NDBase { + public NDBase() { + } + + /** + * Boolean and array reduction operation, optionally along specified dimensions
+ * + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public INDArray all(INDArray x, int... dimensions) { + NDValidation.validateBool("all", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.All(x, dimensions)); + } + + /** + * Boolean or array reduction operation, optionally along specified dimensions
+ * + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public INDArray any(INDArray x, int... dimensions) { + NDValidation.validateBool("any", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(x, dimensions)); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray argmax(INDArray in, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, keepDims, dimensions)); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray argmax(INDArray in, int... dimensions) { + NDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, false, dimensions)); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray argmin(INDArray in, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, keepDims, dimensions)); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray argmin(INDArray in, int... dimensions) { + NDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions)); + } + + /** + * Concatenate a set of inputs along the specified dimension.
+ * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
+ * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c]
+ * + * Inputs must satisfy the following constraints:
+ * Input arrays must all be the same datatype: isSameType(inputs)
+ * + * @param inputs Input variables (NUMERIC type) + * @param dimension Dimension to concatenate on + * @return output (NUMERIC type) + */ + public INDArray concat(INDArray[] inputs, int dimension) { + NDValidation.validateNumerical("concat", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Concat(inputs, dimension))[0]; + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray cumprod(INDArray in, boolean exclusive, boolean reverse, int... axis) { + NDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, axis))[0]; + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray cumprod(INDArray in, int... axis) { + NDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis))[0]; + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public INDArray cumsum(INDArray in, boolean exclusive, boolean reverse, int... axis) { + NDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis))[0]; + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public INDArray cumsum(INDArray in, int... axis) { + NDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis))[0]; + } + + /** + * Pairwise dot product reduction along dimension
+ * output = sum(i=0 ... size(dim)-1) x[i] * y[i]
+ * + * @param x first input (NUMERIC type) + * @param y second input (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output output variable (NUMERIC type) + */ + public INDArray dot(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("dot", "x", x); + NDValidation.validateNumerical("dot", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.Dot(x, y, dimensions)); + } + + /** + * Dynamically partition the input variable values into the specified number of paritions, using the indices.
+ * Example:
+ *


+ * input = [1,2,3,4,5]
+ * numPartitions = 2
+ * partitions = [1,0,0,1,0]
+ * out[0] = [2,3,5]
+ * out[1] = [1,4] }
+ *

+ * + * @param x Input variable (NUMERIC type) + * @param partitions 1D input with values 0 to numPartitions-1 (INT type) + * @param numPartitions Number of partitions, >= 1 + * @return output Output variables (equal in number to numPartitions) (NUMERIC type) + */ + public INDArray dynamicPartition(INDArray x, INDArray[] partitions, int numPartitions) { + NDValidation.validateNumerical("dynamicPartition", "x", x); + NDValidation.validateInteger("dynamicPartition", "partitions", partitions); + Preconditions.checkArgument(partitions.length >= 1, "partitions has incorrect size/length. Expected: partitions.length >= 1, got %s", partitions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions))[0]; + } + + /** + * Dynamically merge the specified input arrays into a single array, using the specified indices
+ * + * @param x Input variables. (NUMERIC type) + * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @return output Merged output variable (NUMERIC type) + */ + public INDArray dynamicStitch(INDArray[] x, INDArray[] indices) { + NDValidation.validateNumerical("dynamicStitch", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + NDValidation.validateInteger("dynamicStitch", "indices", indices); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(x, indices))[0]; + } + + /** + * Equals operation: elementwise x == y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray eq(INDArray x, double y) { + NDValidation.validateNumerical("eq", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals(x, y)); + } + + /** + * Equal to operation: elementwise x == y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray eq(INDArray x, INDArray y) { + NDValidation.validateNumerical("eq", "x", x); + NDValidation.validateNumerical("eq", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(x, y))[0]; + } + + /** + * Reshape the input by adding a 1 at the specified location.
+ * For example, if input has shape [a, b], then output shape is:
+ * axis = 0: [1, a, b]
+ * axis = 1: [a, 1, b]
+ * axis = 2: [a, b, 1]
+ * + * @param x Input variable (NDARRAY type) + * @param axis Axis to expand + * @return output Output variable (NUMERIC type) + */ + public INDArray expandDims(INDArray x, int axis) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(x, axis))[0]; + } + + /** + * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
+ * + * @param shape Shape: must be a 1D array/variable (INT type) + * @param dataType Datatype of the output array + * @param value Value to set all elements to + * @return output Output variable (NUMERIC type) + */ + public INDArray fill(INDArray shape, DataType dataType, double value) { + NDValidation.validateInteger("fill", "shape", shape); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value))[0]; + } + + /** + * Gather slices from the input variable where the indices are specified as fixed int[] values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param df Input variable (NUMERIC type) + * @param indices Indices to get (Size: AtLeast(min=1)) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public INDArray gather(INDArray df, int[] indices, int axis) { + NDValidation.validateNumerical("gather", "df", df); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0]; + } + + /** + * Gather slices from the input variable where the indices are specified as dynamic array values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param df Input variable (NUMERIC type) + * @param indices Indices to get slices for. Rank 0 or 1 input (INT type) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public INDArray gather(INDArray df, INDArray indices, int axis) { + NDValidation.validateNumerical("gather", "df", df); + NDValidation.validateInteger("gather", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0]; + } + + /** + * Gather slices from df with shape specified by indices.
+ * + * @param df (NUMERIC type) + * @param indices (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray gatherNd(INDArray[] df, INDArray[] indices) { + NDValidation.validateNumerical("gatherNd", "df", df); + Preconditions.checkArgument(df.length >= 1, "df has incorrect size/length. Expected: df.length >= 1, got %s", df.length); + NDValidation.validateNumerical("gatherNd", "indices", indices); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0]; + } + + /** + * Greater than operation: elementwise x > y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray gt(INDArray x, double y) { + NDValidation.validateNumerical("gt", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan(x, y)); + } + + /** + * Greater than operation: elementwise x > y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray gt(INDArray x, INDArray y) { + NDValidation.validateNumerical("gt", "x", x); + NDValidation.validateNumerical("gt", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(x, y))[0]; + } + + /** + * Greater than or equals operation: elementwise x >= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray gte(INDArray x, double y) { + NDValidation.validateNumerical("gte", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(x, y)); + } + + /** + * Greater than or equal to operation: elementwise x >= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray gte(INDArray x, INDArray y) { + NDValidation.validateNumerical("gte", "x", x); + NDValidation.validateNumerical("gte", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y))[0]; + } + + /** + * Elementwise identity operation: out = x
+ * + * @param input Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray identity(INDArray input) { + NDValidation.validateNumerical("identity", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(input))[0]; + } + + /** + * Compute the inverse permutation indices for a permutation operation
+ * Example: if input is [2, 0, 1] then output is [1, 2, 0]
+ * The idea is that x.permute(input).permute(invertPermutation(input)) == x
+ * + * @param input 1D indices for permutation (INT type) + * @return output 1D inverted permutation (INT type) + */ + public INDArray invertPermutation(INDArray input) { + NDValidation.validateInteger("invertPermutation", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input))[0]; + } + + /** + * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
+ * + * @param x Input variable (NUMERIC type) + * @return output scalar boolean with value true or false (NDARRAY type) + */ + public INDArray isNumericTensor(INDArray x) { + NDValidation.validateNumerical("isNumericTensor", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(x))[0]; + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param dataType Data type of the output array + * @param start Start value + * @param stop Stop value + * @param number Number of values to generate + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public INDArray linspace(DataType dataType, double start, double stop, long number) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; + } + + /** + * Less than operation: elementwise x < y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray lt(INDArray x, double y) { + NDValidation.validateNumerical("lt", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan(x, y)); + } + + /** + * Less than operation: elementwise x < y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray lt(INDArray x, INDArray y) { + NDValidation.validateNumerical("lt", "x", x); + NDValidation.validateNumerical("lt", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(x, y))[0]; + } + + /** + * Less than or equals operation: elementwise x <= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray lte(INDArray x, double y) { + NDValidation.validateNumerical("lte", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(x, y)); + } + + /** + * Less than or equal to operation: elementwise x <= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray lte(INDArray x, INDArray y) { + NDValidation.validateNumerical("lte", "x", x); + NDValidation.validateNumerical("lte", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(x, y))[0]; + } + + /** + * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
+ * + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Boolean mask (NUMERIC type) + */ + public INDArray matchCondition(INDArray in, Condition condition) { + NDValidation.validateNumerical("matchCondition", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(in, condition)); + } + + /** + * Returns a count of the number of elements that satisfy the condition
+ * + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public INDArray matchConditionCount(INDArray in, Condition condition) { + NDValidation.validateNumerical("matchConditionCount", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition)); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public INDArray matchConditionCount(INDArray in, Condition condition, boolean keepDim, + int... dimensions) { + NDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, keepDim, dimensions)); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public INDArray matchConditionCount(INDArray in, Condition condition, int... dimensions) { + NDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, false, dimensions)); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray max(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Max(x, keepDims, dimensions)); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray max(INDArray x, int... dimensions) { + NDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Max(x, false, dimensions)); + } + + /** + * Element-wise maximum operation: out[i] = max(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray max(INDArray first, INDArray second) { + NDValidation.validateNumerical("max", "first", first); + NDValidation.validateNumerical("max", "second", second); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second))[0]; + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray mean(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, keepDims, dimensions)); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray mean(INDArray x, int... dimensions) { + NDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, false, dimensions)); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray min(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Min(x, keepDims, dimensions)); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray min(INDArray x, int... dimensions) { + NDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Min(x, false, dimensions)); + } + + /** + * Element-wise minimum operation: out[i] = min(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Second input array (NUMERIC type) + */ + public INDArray min(INDArray first, INDArray second) { + NDValidation.validateNumerical("min", "first", first); + NDValidation.validateNumerical("min", "second", second); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second))[0]; + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public INDArray mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + NDValidation.validateNumerical("mmul", "x", x); + NDValidation.validateNumerical("mmul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0]; + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray mmul(INDArray x, INDArray y) { + NDValidation.validateNumerical("mmul", "x", x); + NDValidation.validateNumerical("mmul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, false, false, false))[0]; + } + + /** + * Not equals operation: elementwise x != y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray neq(INDArray x, double y) { + NDValidation.validateNumerical("neq", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals(x, y)); + } + + /** + * Not equal to operation: elementwise x != y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray neq(INDArray x, INDArray y) { + NDValidation.validateNumerical("neq", "x", x); + NDValidation.validateNumerical("neq", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(x, y))[0]; + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray norm1(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, keepDims, dimensions)); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray norm1(INDArray x, int... dimensions) { + NDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, false, dimensions)); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray norm2(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, keepDims, dimensions)); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray norm2(INDArray x, int... dimensions) { + NDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, false, dimensions)); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray normmax(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, keepDims, dimensions)); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray normmax(INDArray x, int... dimensions) { + NDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, false, dimensions)); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @param dataType Output data type + * @return output Output variable (NUMERIC type) + */ + public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off, + DataType dataType) { + NDValidation.validateNumerical("oneHot", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType))[0]; + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @return output Output variable (NUMERIC type) + */ + public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off) { + NDValidation.validateNumerical("oneHot", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, DataType.FLOAT))[0]; + } + + /** + * Convert the array to a one-hot array with walues 0 and 1 for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0
+ * see oneHot(SDVariable, int, int, double, double)
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @return output Output variable (NUMERIC type) + */ + public INDArray oneHot(INDArray indices, int depth) { + NDValidation.validateNumerical("oneHot", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth))[0]; + } + + /** + * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param input Input INDArray (NUMERIC type) + * @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type) + */ + public INDArray onesLike(INDArray input) { + NDValidation.validateNumerical("onesLike", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input))[0]; + } + + /** + * As per onesLike(String, SDVariable) but the output datatype may be specified
+ * + * @param input (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public INDArray onesLike(INDArray input, DataType dataType) { + NDValidation.validateNumerical("onesLike", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0]; + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public INDArray permute(INDArray x, int... dimensions) { + NDValidation.validateNumerical("permute", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0]; + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public INDArray prod(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(x, keepDims, dimensions)); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public INDArray prod(INDArray x, int... dimensions) { + NDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(x, false, dimensions)); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param from Initial/smallest value + * @param to Largest value (exclusive) + * @param step Step size + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public INDArray range(double from, double to, double step, DataType dataType) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0]; + } + + /** + * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param in Input variable (NUMERIC type) + * @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type) + */ + public INDArray rank(INDArray in) { + NDValidation.validateNumerical("rank", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Rank(in))[0]; + } + + /** + * Element-wise replace where condition:
+ * out[i] = from[i] if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param update Source array (NUMERIC type) + * @param from Replacement values array (used conditionally). Must be same shape as 'update' array (NUMERIC type) + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public INDArray replaceWhere(INDArray update, INDArray from, Condition condition) { + NDValidation.validateNumerical("replaceWhere", "update", update); + NDValidation.validateNumerical("replaceWhere", "from", from); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition)); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray reshape(INDArray x, INDArray shape) { + NDValidation.validateNumerical("reshape", "x", x); + NDValidation.validateNumerical("reshape", "shape", shape); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; + } + + /** + * Reverse the values of an array for the specified dimensions
+ * If input is:
+ * [ 1, 2, 3]
+ * [ 4, 5, 6]
+ * then
+ * reverse(in, 0):
+ * [3, 2, 1]
+ * [6, 5, 4]
+ * reverse(in, 1):
+ * [4, 5, 6]
+ * [1, 2 3]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Input variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray reverse(INDArray x, int... dimensions) { + NDValidation.validateNumerical("reverse", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(x, dimensions))[0]; + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @param seqDim Sequence dimension + * @param batchDim Batch dimension + * @return output Reversed sequences (NUMERIC type) + */ + public INDArray reverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) { + NDValidation.validateNumerical("reverseSequence", "x", x); + NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, batchDim))[0]; + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @return output Reversed sequences (NUMERIC type) + */ + public INDArray reverseSequence(INDArray x, INDArray seq_lengths) { + NDValidation.validateNumerical("reverseSequence", "x", x); + NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, 0))[0]; + } + + /** + * Element-wise scalar floor modulus operation: out = floorMod(in, value).
+ * i.e., returns the remainder after division by 'value'
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public INDArray scalarFloorMod(INDArray in, double value) { + NDValidation.validateNumerical("scalarFloorMod", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(in, value)); + } + + /** + * Element-wise scalar maximum operation: out = max(in, value)
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Scalar value to compare (NUMERIC type) + */ + public INDArray scalarMax(INDArray in, double value) { + NDValidation.validateNumerical("scalarMax", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarMax(in, value)); + } + + /** + * Element-wise scalar minimum operation: out = min(in, value)
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public INDArray scalarMin(INDArray in, double value) { + NDValidation.validateNumerical("scalarMin", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarMin(in, value)); + } + + /** + * Return a variable with equal shape to the input, but all elements set to value 'set'
+ * + * @param in Input variable (NUMERIC type) + * @param set Value to set + * @return output Output variable (NUMERIC type) + */ + public INDArray scalarSet(INDArray in, double set) { + NDValidation.validateNumerical("scalarSet", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarSet(in, set)); + } + + /** + * Scatter addition operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterAdd(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterAdd", "ref", ref); + NDValidation.validateNumerical("scatterAdd", "indices", indices); + NDValidation.validateNumerical("scatterAdd", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(ref, indices, updates))[0]; + } + + /** + * Scatter division operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterDiv(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterDiv", "ref", ref); + NDValidation.validateNumerical("scatterDiv", "indices", indices); + NDValidation.validateNumerical("scatterDiv", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(ref, indices, updates))[0]; + } + + /** + * Scatter max operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterMax(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterMax", "ref", ref); + NDValidation.validateNumerical("scatterMax", "indices", indices); + NDValidation.validateNumerical("scatterMax", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(ref, indices, updates))[0]; + } + + /** + * Scatter min operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterMin(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterMin", "ref", ref); + NDValidation.validateNumerical("scatterMin", "indices", indices); + NDValidation.validateNumerical("scatterMin", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(ref, indices, updates))[0]; + } + + /** + * Scatter multiplication operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterMul(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterMul", "ref", ref); + NDValidation.validateNumerical("scatterMul", "indices", indices); + NDValidation.validateNumerical("scatterMul", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(ref, indices, updates))[0]; + } + + /** + * Scatter subtraction operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterSub(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterSub", "ref", ref); + NDValidation.validateNumerical("scatterSub", "indices", indices); + NDValidation.validateNumerical("scatterSub", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(ref, indices, updates))[0]; + } + + /** + * Scatter update operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterUpdate(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterUpdate", "ref", ref); + NDValidation.validateNumerical("scatterUpdate", "indices", indices); + NDValidation.validateNumerical("scatterUpdate", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates))[0]; + } + + /** + * Segment max operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentMax(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds))[0]; + } + + /** + * Segment mean operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentMean(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, segmentIds))[0]; + } + + /** + * Segment min operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentMin(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds))[0]; + } + + /** + * Segment product operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentProd(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, segmentIds))[0]; + } + + /** + * Segment sum operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentSum(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds))[0]; + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public INDArray sequenceMask(INDArray lengths, int maxLen, DataType dataType) { + NDValidation.validateNumerical("sequenceMask", "lengths", lengths); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + } + + /** + * see sequenceMask(String, SDVariable, SDVariable, DataType)
+ * + * @param lengths (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public INDArray sequenceMask(INDArray lengths, DataType dataType) { + NDValidation.validateNumerical("sequenceMask", "lengths", lengths); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, dataType))[0]; + } + + /** + * Returns the shape of the specified INDArray as a 1D INDArray
+ * + * @param input Input variable (NUMERIC type) + * @return output 1D output variable with contents equal to the shape of the input (NUMERIC type) + */ + public INDArray shape(INDArray input) { + NDValidation.validateNumerical("shape", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Shape(input))[0]; + } + + /** + * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param in Input variable (NUMERIC type) + * @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type) + */ + public INDArray size(INDArray in) { + NDValidation.validateNumerical("size", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Size(in))[0]; + } + + /** + * Returns a rank 0 (scalar) variable for the size of the specified dimension.
+ * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
+ * + * @param in Input variable (NUMERIC type) + * @param dimension Dimension to get size of + * @return output Scalar INDArray for size at specified variable (NUMERIC type) + */ + public INDArray sizeAt(INDArray in, int dimension) { + NDValidation.validateNumerical("sizeAt", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SizeAt(in, dimension))[0]; + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @param size Size of the output array. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @return output Subset of the input (NUMERIC type) + */ + public INDArray slice(INDArray input, int[] begin, int... size) { + NDValidation.validateNumerical("slice", "input", input); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x (NUMERIC type) + * @param keepDims + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public INDArray squaredNorm(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, keepDims, dimensions)); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public INDArray squaredNorm(INDArray x, int... dimensions) { + NDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, false, dimensions)); + } + + /** + * Remove a single dimension of size 1.
+ * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c]
+ * + * @param x Input variable (NUMERIC type) + * @param axis Size 1 dimension to remove + * @return output Output variable (NUMERIC type) + */ + public INDArray squeeze(INDArray x, int axis) { + NDValidation.validateNumerical("squeeze", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Squeeze(x, axis))[0]; + } + + /** + * Stack a set of N INDArray of rank X into one rank X+1 variable.
+ * If inputs have shape [a,b,c] then output has shape:
+ * axis = 0: [N,a,b,c]
+ * axis = 1: [a,N,b,c]
+ * axis = 2: [a,b,N,c]
+ * axis = 3: [a,b,c,N]
+ * see unstack(String[], SDVariable, int, int)
+ * + * @param values Input variables to stack. Must have the same shape for all inputs (NDARRAY type) + * @param axis Axis to stack on + * @return output Output variable (NDARRAY type) + */ + public INDArray stack(INDArray values, int axis) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0]; + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray standardDeviation(INDArray x, boolean biasCorrected, boolean keepDims, + int... dimensions) { + NDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, keepDims, dimensions)); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray standardDeviation(INDArray x, boolean biasCorrected, int... dimensions) { + NDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, false, dimensions)); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, and a value of 0 is used instead for the beginning index for that dimension + * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, and a value of size(i)-1 is used instead for the end index for that dimension + * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other dimensions are inserted as required at the specified position + * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is inserted at this point + * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions + * @return output A subset of the input array (NUMERIC type) + */ + public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, + int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { + NDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0]; + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @return output A subset of the input array (NUMERIC type) + */ + public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int... strides) { + NDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0))[0]; + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray sum(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, keepDims, dimensions)); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray sum(INDArray x, int... dimensions) { + NDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, false, dimensions)); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output Output variable (NUMERIC type) + */ + public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, + boolean transposeX, boolean transposeY, boolean transposeZ) { + NDValidation.validateNumerical("tensorMmul", "x", x); + NDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ))[0]; + } + + /** + * //TODO: Ops must be documented.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int... dimensionsY) { + NDValidation.validateNumerical("tensorMmul", "x", x); + NDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, false, false))[0]; + } + + /** + * Repeat (tile) the input tensor the specified number of times.
+ * For example, if input is
+ * [1, 2]
+ * [3, 4]
+ * and repeat is [2, 3]
+ * then output is
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * + * @param x Input variable (NDARRAY type) + * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array (INT type) + * @return output Output variable (NDARRAY type) + */ + public INDArray tile(INDArray x, INDArray repeat) { + NDValidation.validateInteger("tile", "repeat", repeat); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0]; + } + + /** + * see tile(String, SDVariable, int...)
+ * + * @param x (NDARRAY type) + * @param repeat (Size: AtLeast(min=1)) + * @return output (NDARRAY type) + */ + public INDArray tile(INDArray x, int... repeat) { + Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0]; + } + + /** + * Matrix transpose operation: If input has shape [a,b] output has shape [b,a]
+ * + * @param x Input variable (NDARRAY type) + * @return output transposed input (NDARRAY type) + */ + public INDArray transpose(INDArray x) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Transpose(x))[0]; + } + + /** + * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentMax", "data", data); + NDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentMean", "data", data); + NDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentMin", "data", data); + NDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentProd", "data", data); + NDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); + NDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentSum", "data", data); + NDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0]; + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray variance(INDArray x, boolean biasCorrected, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, keepDims, dimensions)); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray variance(INDArray x, boolean biasCorrected, int... dimensions) { + NDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, false, dimensions)); + } + + /** + * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param input Input (NUMERIC type) + * @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type) + */ + public INDArray zerosLike(INDArray input) { + NDValidation.validateNumerical("zerosLike", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(input))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java index f77d5c823..d874b5bbf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java @@ -1,5 +1,5 @@ -/* ****************************************************************************** - * Copyright (c) 2019 Konduit K.K. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java old mode 100755 new mode 100644 index 7bee44ace..cb00a28c2 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java @@ -32,7 +32,7 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.enums.DataFormat; +import org.nd4j.enums.DataFormat; public class NDCNN { public NDCNN() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java new file mode 100644 index 000000000..cb80c8092 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java @@ -0,0 +1,274 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; + +public class NDLinalg { + public NDLinalg() { + } + + /** + * Computes the Cholesky decomposition of one or more square matrices.
+ * + * @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type) + * @return output Transformed tensor (NUMERIC type) + */ + public INDArray cholesky(INDArray input) { + NDValidation.validateNumerical("Cholesky", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(input))[0]; + } + + /** + * Solver for linear squares problems.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @param fast fast mode, defaults to True + * @return output Transformed tensor (FLOATING_POINT type) + */ + public INDArray lstsq(INDArray matrix, INDArray rhs, double l2_reguralizer, boolean fast) { + NDValidation.validateNumerical("Lstsq", "matrix", matrix); + NDValidation.validateNumerical("Lstsq", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Lstsq(matrix, rhs, l2_reguralizer, fast))[0]; + } + + /** + * Solver for linear squares problems.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @return output Transformed tensor (FLOATING_POINT type) + */ + public INDArray lstsq(INDArray matrix, INDArray rhs, double l2_reguralizer) { + NDValidation.validateNumerical("Lstsq", "matrix", matrix); + NDValidation.validateNumerical("Lstsq", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Lstsq(matrix, rhs, l2_reguralizer, true))[0]; + } + + /** + * Computes LU decomposition.
+ * + * @param input input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray lu(INDArray input) { + NDValidation.validateNumerical("Lu", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Lu(input))[0]; + } + + /** + * Performs matrix mutiplication on input tensors.
+ * + * @param a input tensor (NUMERIC type) + * @param b input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray matmul(INDArray a, INDArray b) { + NDValidation.validateNumerical("Matmul", "a", a); + NDValidation.validateNumerical("Matmul", "b", b); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(a, b))[0]; + } + + /** + * Copy a tensor setting outside a central band in each innermost matrix.
+ * + * @param input input tensor (NUMERIC type) + * @param minLower lower diagonal count + * @param maxUpper upper diagonal count + */ + public INDArray[] matrixBandPart(INDArray input, int minLower, int maxUpper) { + NDValidation.validateNumerical("MatrixBandPart", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.MatrixBandPart(input, minLower, maxUpper)); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param input input tensor (NUMERIC type) + * @param full full matrices mode + */ + public INDArray[] qr(INDArray input, boolean full) { + NDValidation.validateNumerical("Qr", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(input, full)); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param input input tensor (NUMERIC type) + */ + public INDArray[] qr(INDArray input) { + NDValidation.validateNumerical("Qr", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(input, false)); + } + + /** + * Solver for systems of linear equations.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param adjoint adjoint mode, defaults to False + * @return output Output tensor (FLOATING_POINT type) + */ + public INDArray solve(INDArray matrix, INDArray rhs, boolean adjoint) { + NDValidation.validateNumerical("Solve", "matrix", matrix); + NDValidation.validateNumerical("Solve", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.LinearSolve(matrix, rhs, adjoint))[0]; + } + + /** + * Solver for systems of linear equations.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @return output Output tensor (FLOATING_POINT type) + */ + public INDArray solve(INDArray matrix, INDArray rhs) { + NDValidation.validateNumerical("Solve", "matrix", matrix); + NDValidation.validateNumerical("Solve", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.LinearSolve(matrix, rhs, false))[0]; + } + + /** + * Solver for systems of linear questions.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param lower defines whether innermost matrices in matrix are lower or upper triangular + * @param adjoint adjoint mode + * @return output (FLOATING_POINT type) + */ + public INDArray triangularSolve(INDArray matrix, INDArray rhs, boolean lower, boolean adjoint) { + NDValidation.validateNumerical("TriangularSolve", "matrix", matrix); + NDValidation.validateNumerical("TriangularSolve", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.TriangularSolve(matrix, rhs, lower, adjoint))[0]; + } + + /** + * Computes pairwise cross product.
+ * + * @param a (NUMERIC type) + * @param b (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray cross(INDArray a, INDArray b) { + NDValidation.validateNumerical("cross", "a", a); + NDValidation.validateNumerical("cross", "b", b); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Cross(a, b))[0]; + } + + /** + * Calculates diagonal tensor.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray diag(INDArray input) { + NDValidation.validateNumerical("diag", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Diag(input))[0]; + } + + /** + * Calculates diagonal tensor.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray diag_part(INDArray input) { + NDValidation.validateNumerical("diag_part", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.DiagPart(input))[0]; + } + + /** + * Calculates log of determinant.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray logdet(INDArray input) { + NDValidation.validateNumerical("logdet", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Logdet(input))[0]; + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public INDArray mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + NDValidation.validateNumerical("mmul", "x", x); + NDValidation.validateNumerical("mmul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0]; + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray mmul(INDArray x, INDArray y) { + NDValidation.validateNumerical("mmul", "x", x); + NDValidation.validateNumerical("mmul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, false, false, false))[0]; + } + + /** + * Calculates singular value decomposition.
+ * + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @param switchNum + * @return output (FLOATING_POINT type) + */ + public INDArray svd(INDArray input, boolean fullUV, boolean computeUV, int switchNum) { + NDValidation.validateNumerical("svd", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(input, fullUV, computeUV, switchNum))[0]; + } + + /** + * Calculates singular value decomposition.
+ * + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @return output (FLOATING_POINT type) + */ + public INDArray svd(INDArray input, boolean fullUV, boolean computeUV) { + NDValidation.validateNumerical("svd", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(input, fullUV, computeUV, 16))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java index 4c1234514..cdee59ea1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index 66f8071e2..eddbe3db7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -1,5 +1,5 @@ -/* ****************************************************************************** - * Copyright (c) 2019 Konduit K.K. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,6 +18,8 @@ package org.nd4j.linalg.factory.ops; +import static org.nd4j.linalg.factory.NDValidation.isSameType; + import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -66,12 +68,12 @@ public class NDMath { * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray amax(INDArray in, int... dimensions) { NDValidation.validateNumerical("amax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(in, dimensions)); } @@ -79,12 +81,12 @@ public class NDMath { * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray amean(INDArray in, int... dimensions) { NDValidation.validateNumerical("amean", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(in, dimensions)); } @@ -92,12 +94,12 @@ public class NDMath { * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray amin(INDArray in, int... dimensions) { NDValidation.validateNumerical("amin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(in, dimensions)); } @@ -143,12 +145,12 @@ public class NDMath { * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray asum(INDArray in, int... dimensions) { NDValidation.validateNumerical("asum", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(in, dimensions)); } @@ -375,12 +377,12 @@ public class NDMath { * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray countNonZero(INDArray in, int... dimensions) { NDValidation.validateNumerical("countNonZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(in, dimensions)); } @@ -388,12 +390,12 @@ public class NDMath { * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray countZero(INDArray in, int... dimensions) { NDValidation.validateNumerical("countZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(in, dimensions)); } @@ -461,12 +463,12 @@ public class NDMath { * Entropy reduction: -sum(x * log(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray entropy(INDArray in, int... dimensions) { NDValidation.validateNumerical("entropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(in, dimensions)); } @@ -566,10 +568,12 @@ public class NDMath { * @param rows Number of rows * @param cols Number of columns * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) * @return output Identity matrix (NUMERIC type) */ - public INDArray eye(int rows, int cols, DataType dataType) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols, dataType))[0]; + public INDArray eye(int rows, int cols, DataType dataType, int... dimensions) { + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols, dataType, dimensions))[0]; } /** @@ -615,7 +619,7 @@ public class NDMath { public INDArray firstIndex(INDArray in, Condition condition, int... dimensions) { NDValidation.validateNumerical("firstIndex", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, condition, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, false, condition, dimensions)); } /** @@ -639,7 +643,7 @@ public class NDMath { int... dimensions) { NDValidation.validateNumerical("firstIndex", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, condition, keepDims, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, keepDims, condition, dimensions)); } /** @@ -682,7 +686,7 @@ public class NDMath { public INDArray iamax(INDArray in, int... dimensions) { NDValidation.validateNumerical("iamax", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, false, dimensions)); } /** @@ -711,7 +715,7 @@ public class NDMath { public INDArray iamin(INDArray in, int... dimensions) { NDValidation.validateNumerical("iamin", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, false, dimensions)); } /** @@ -842,7 +846,7 @@ public class NDMath { public INDArray lastIndex(INDArray in, Condition condition, int... dimensions) { NDValidation.validateNumerical("lastIndex", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, condition, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, false, condition, dimensions)); } /** @@ -865,7 +869,7 @@ public class NDMath { public INDArray lastIndex(INDArray in, Condition condition, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("lastIndex", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, condition, keepDims, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, keepDims, condition, dimensions)); } /** @@ -883,13 +887,12 @@ public class NDMath { * Element-wise logarithm function (with specified base): out = log_{base}(x)
* * @param x Input variable (NUMERIC type) - * @param base Logarithm base (NUMERIC type) + * @param base Logarithm base * @return output Output variable (NUMERIC type) */ - public INDArray log(INDArray x, INDArray base) { + public INDArray log(INDArray x, double base) { NDValidation.validateNumerical("log", "x", x); - NDValidation.validateNumerical("log", "base", base); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(x, base)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LogX(x, base)); } /** @@ -907,12 +910,12 @@ public class NDMath { * Log entropy reduction: log(-sum(x * log(x)))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray logEntropy(INDArray in, int... dimensions) { NDValidation.validateNumerical("logEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(in, dimensions)); } @@ -1017,12 +1020,11 @@ public class NDMath { * * @param input Input to calculate moments for (NUMERIC type) * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) - * @return output Mean and variance variables (NUMERIC type) */ - public INDArray moments(INDArray input, int... axes) { + public INDArray[] moments(INDArray input, int... axes) { NDValidation.validateNumerical("moments", "input", input); Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Moments(input, axes))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Moments(input, axes)); } /** @@ -1043,14 +1045,13 @@ public class NDMath { * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) - * @return output Output variables: mean and population variance (NUMERIC type) */ - public INDArray normalizeMoments(INDArray counts, INDArray means, INDArray variances, + public INDArray[] normalizeMoments(INDArray counts, INDArray means, INDArray variances, double shift) { NDValidation.validateNumerical("normalizeMoments", "counts", counts); NDValidation.validateNumerical("normalizeMoments", "means", means); NDValidation.validateNumerical("normalizeMoments", "variances", variances); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(counts, means, variances, shift))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(counts, means, variances, shift)); } /** @@ -1153,12 +1154,12 @@ public class NDMath { * Shannon Entropy reduction: -sum(x * log2(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray shannonEntropy(INDArray in, int... dimensions) { NDValidation.validateNumerical("shannonEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(in, dimensions)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 815f22e5b..04a713ecf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -1,5 +1,5 @@ -/* ****************************************************************************** - * Copyright (c) 2019 Konduit K.K. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -107,7 +107,7 @@ public class NDNN { NDValidation.validateNumerical("dotProductAttention", "keys", keys); NDValidation.validateNumerical("dotProductAttention", "values", values); NDValidation.validateNumerical("dotProductAttention", "mask", mask); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled, false))[0]; } /** @@ -227,7 +227,7 @@ public class NDNN { NDValidation.validateNumerical("layerNorm", "input", input); NDValidation.validateNumerical("layerNorm", "gain", gain); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, channelsFirst, dimensions))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, null, channelsFirst, dimensions))[0]; } /** @@ -343,7 +343,7 @@ public class NDNN { NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java old mode 100755 new mode 100644 diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java index 1dfcd60ae..dc5e472e8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java @@ -1,5 +1,5 @@ -/* ****************************************************************************** - * Copyright (c) 2019 Konduit K.K. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,6 +18,8 @@ package org.nd4j.linalg.factory.ops; +import static org.nd4j.linalg.factory.NDValidation.isSameType; + import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -67,11 +69,12 @@ public class NDRandom { * @param lambda lambda parameter * @param datatype Data type of the output variable * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) */ - public INDArray[] exponential(double lambda, DataType datatype, long... shape) { + public INDArray exponential(double lambda, DataType datatype, long... shape) { Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); Preconditions.checkArgument(lambda > 0, "Must be positive"); - return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.RandomExponential(lambda, datatype, shape)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.RandomExponential(lambda, datatype, shape))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 9abe0a483..eab974821 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -257,7 +257,7 @@ public class LayerOpValidation extends BaseOpValidation { msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW); inSize = inSizeNCHW; in = sd.var("in", inSize); - out = sd.cnn().upsampling2d(in, true, 2, 2); + out = sd.cnn().upsampling2d(in, 2, 2, true); break; default: throw new RuntimeException(); @@ -588,8 +588,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().sconv2d(vars, c); - out = sd.nn().tanh("out", out); + SDVariable out = sd.cnn().separableConv2d(in, dW, b, c); + out = sd.f().tanh(out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 @@ -623,7 +623,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable pW = sd.var("pW", pointWeightArr); SDVariable b = sd.var("b", bArr); - SDVariable[] vars = new SDVariable[]{in, dW, pW, b}; + //SDVariable[] vars = new SDVariable[]{in, dW, pW, b}; Conv2DConfig c = Conv2DConfig.builder() .kH(kH).kW(kW) @@ -634,8 +634,8 @@ public class LayerOpValidation extends BaseOpValidation { .dataFormat(Conv2DConfig.NCHW) .build(); - SDVariable out = sd.cnn().sconv2d(vars, c); - out = sd.nn().tanh("out", out); + SDVariable out = sd.cnn().separableConv2d(in, dW, pW, b, c); + out = sd.nn().tanh(out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7 @@ -685,8 +685,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().deconv2d(vars, deconv); - out = sd.nn().tanh("out", out); + SDVariable out = sd.f().deconv2d(vars, deconv); + out = sd.f().tanh(out); INDArray outArr = out.eval(); //Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9 @@ -733,8 +733,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().conv2d("conv", vars, c); - out = sd.nn().tanh("out", out); + SDVariable out = sd.f().conv2d(vars, c); + out = sd.f().tanh(out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 @@ -767,7 +767,7 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(true) .build(); - SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig); + SDVariable[] results = sd.f().maxPoolWithArgmax(/*new String[]{"out","idx"},*/ in, pooling2DConfig); assertArrayEquals(inArr.shape(), results[0].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape()); } @@ -797,7 +797,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig); - SDVariable out = sd.nn().tanh("out", outPool); + SDVariable out = sd.f().tanh(/*"out",*/ outPool); INDArray outArr = out.eval(); val outShape = outArr.shape(); @@ -855,7 +855,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig); - SDVariable out = sd.nn().tanh("out", outPool); + SDVariable out = sd.f().tanh(/*"out",*/ outPool); INDArray outArr = out.eval(); val outShape = outArr.shape(); @@ -906,7 +906,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + out = sd.f().tanh(/*"loss", */out).shape().rename("out"); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L); @@ -942,7 +942,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + out = sd.math().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -976,8 +976,8 @@ public class LayerOpValidation extends BaseOpValidation { .paddingMode(PaddingMode.VALID) .build(); - SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); + out = sd.math().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1018,7 +1018,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig); - SDVariable loss = sd.nn().tanh(out).std(true).rename("loss"); + SDVariable loss = sd.f().tanh(out).std(true).rename("loss"); sd.setLossVariables("loss"); @@ -1057,7 +1057,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); SDVariable w = sd.var("w", wArr); - SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); + SDVariable res = sd.cnn.conv1d(in, w, null, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); INDArray expected = Nd4j.createFromArray( new double[][][]{ @@ -1113,7 +1113,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + out = sd.math().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1156,7 +1156,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + out = sd.math().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1335,7 +1335,7 @@ public class LayerOpValidation extends BaseOpValidation { .paddingMode(PaddingMode.VALID) .build(); - SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); + SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index 7f8da282e..ca3f10d04 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -178,7 +178,7 @@ public class LossOpValidation extends BaseOpValidation { predictionsArr = Transforms.log(Transforms.abs(predictionsArr)); labelsArr = Transforms.abs(labelsArr); expOut = Transforms.exp(predictionsArr).sub(labelsArr.mul(predictionsArr)); - loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction); + loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction,false); break; case "log_poisson_full": predictionsArr = Transforms.log(Transforms.abs(predictionsArr)); @@ -188,7 +188,7 @@ public class LossOpValidation extends BaseOpValidation { .add(labelsArr.mul(Transforms.log(labelsArr))) .sub(labelsArr) .add(Transforms.log(labelsArr.mul(Math.PI * 2)).mul(0.5)); - loss = sd.loss().logPoissonFull("loss", labels, predictions, w, reduction); + loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction,true); break; case "mse": //To match TF, this is actually sum of squares - 1/numExamples (prediction-label)^2 @@ -251,7 +251,7 @@ public class LossOpValidation extends BaseOpValidation { expOut.muli(1/((n*(n-1)) / 2)); - loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions, w, reduction); + loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions,w, reduction); break; case "sparsesoftmax": labelsArr = Nd4j.create(DataType.DOUBLE, minibatch); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 1f23e12ec..06c64445b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -1289,7 +1289,7 @@ public class MiscOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5)); - SDVariable merged = sd.math().mergeAvg("merged", var); + SDVariable merged = sd.math().mergeAvg("merged", new SDVariable[]{var}); SDVariable sum = sd.sum(merged); Map m = sd.output(Collections.emptyMap(), "merged"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 4585b4a15..053f3a70b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -71,7 +71,7 @@ public class RandomOpValidation extends BaseOpValidation { switch (i) { case 0: name = "randomUniform"; - rand = sd.random().uniform(1, 2, shapeVar); + rand = sd.random().uniform(1, 2, DataType.DOUBLE, shape); checkFn = in -> { double min = in.minNumber().doubleValue(); double max = in.maxNumber().doubleValue(); @@ -83,7 +83,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 1: name = "randomNormal"; - rand = sd.random().normal(1, 1, shapeVar); + rand = sd.random().normal(1, 1, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double stdev = in.std(true).getDouble(0); @@ -94,7 +94,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 2: name = "randomBernoulli"; - rand = sd.random().bernoulli(0.5, shapeVar); + rand = sd.random().bernoulli(0.5, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double min = in.minNumber().doubleValue(); @@ -110,7 +110,7 @@ public class RandomOpValidation extends BaseOpValidation { case 3: name = "randomExponential"; final double lambda = 2; - rand = sd.random().exponential(lambda, shapeVar); + rand = sd.random().exponential(lambda, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double min = in.minNumber().doubleValue(); @@ -168,7 +168,7 @@ public class RandomOpValidation extends BaseOpValidation { switch (i) { case 0: name = "randomBernoulli"; - rand = sd.random().bernoulli(0.5, shape); + rand = sd.random().bernoulli(0.5, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double min = in.minNumber().doubleValue(); @@ -183,7 +183,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 1: name = "normal"; - rand = sd.random().normal(1, 2, shape); + rand = sd.random().normal(1, 2, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double stdev = in.std(true).getDouble(0); @@ -194,7 +194,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 2: name = "randomBinomial"; - rand = sd.random().binomial(4, 0.5, shape); + rand = sd.random().binomial(4, 0.5, DataType.DOUBLE, shape); checkFn = in -> { NdIndexIterator iter = new NdIndexIterator(in.shape()); while(iter.hasNext()){ @@ -209,7 +209,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 3: name = "randomUniform"; - rand = sd.random().uniform(1, 2, shape); + rand = sd.random().uniform(1, 2, DataType.DOUBLE, shape); checkFn = in -> { double min = in.minNumber().doubleValue(); double max = in.maxNumber().doubleValue(); @@ -225,7 +225,7 @@ public class RandomOpValidation extends BaseOpValidation { continue; } name = "truncatednormal"; - rand = sd.random().normalTruncated(1, 2, shape); + rand = sd.random().normalTruncated(1, 2, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double stdev = in.std(true).getDouble(0); @@ -236,7 +236,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 5: name = "lognormal"; - rand = sd.random().logNormal(1, 2, shape); + rand = sd.random().logNormal(1, 2, DataType.DOUBLE, shape); //Note: lognormal parameters are mean and stdev of LOGARITHM of values checkFn = in -> { INDArray log = Transforms.log(in, true); @@ -389,15 +389,25 @@ public class RandomOpValidation extends BaseOpValidation { for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ SameDiff sd = SameDiff.create(); SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100)); - SDVariable out = sd.random.uniform(0, 10, shape, t); + SDVariable out = sd.random.uniform(0, 10, t, 1, 100); INDArray arr = out.eval(); assertEquals(t, arr.dataType()); - double min = arr.minNumber().doubleValue(); - double max = arr.maxNumber().doubleValue(); - double mean = arr.meanNumber().doubleValue(); - assertEquals(0, min, 0.5); - assertEquals(10, max, 0.5); - assertEquals(5.5, mean, 1); + if (t.equals(DataType.DOUBLE)) { + double min = arr.minNumber().doubleValue(); + double max = arr.maxNumber().doubleValue(); + double mean = arr.meanNumber().doubleValue(); + assertEquals(0, min, 0.5); + assertEquals(10, max, 0.5); + assertEquals(5.5, mean, 1); + } + else if (t.equals(DataType.FLOAT)) { + float min = arr.minNumber().floatValue(); + float max = arr.maxNumber().floatValue(); + float mean = arr.meanNumber().floatValue(); + assertEquals(0, min, 0.5); + assertEquals(10, max, 0.5); + assertEquals(5.0, mean, 1); + } } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 3027138a1..bb2287e03 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -235,39 +235,39 @@ public class ReductionOpValidation extends BaseOpValidation { tc.expectedOutput("loss", inputArr.normmax()); break; case 10: - loss = sd.math().countNonZero("loss", input); + loss = sd.math().countNonZero("loss", input, 0,1); name = "countNonZero"; tc.expectedOutput("loss", Nd4j.scalar(inputArr.length())); gradCheck = false; //Long out, not floating point break; case 11: - loss = sd.math().countZero("loss", input); + loss = sd.math().countZero("loss", input, 0,1); name = "countZero"; tc.expectedOutput("loss", Nd4j.scalar(0L)); gradCheck = false; //Long out, not floating point break; case 12: - loss = sd.math().amax("loss", input); + loss = sd.math().amax("loss", input, 0,1); name = "amax"; tc.expectedOutput("loss", inputArr.amax()); break; case 13: - loss = sd.math().amin("loss", input); + loss = sd.math().amin("loss", input, 0,1); name = "amin"; tc.expectedOutput("loss", inputArr.amin()); break; case 14: - loss = sd.math().asum("loss", input); + loss = sd.math().asum("loss", input, 0,1); name = "asum"; tc.expectedOutput("loss", Nd4j.getExecutioner().exec(new ASum(inputArr.dup()))); break; case 15: - loss = sd.math().amean("loss", input); + loss = sd.math().amean("loss", input, 0,1); name = "amean"; tc.expectedOutput("loss", Nd4j.getExecutioner().exec(new AMean(inputArr.dup()))); break; case 16: - loss = sd.math().entropy("loss", input); + loss = sd.math().entropy("loss", input, 0,1); name = "entropy"; inputArr = Nd4j.linspace(0.01, 0.99, length, DataType.DOUBLE).reshape('c', minibatch, nOut); tc.expected("loss", inputArr.mul(Transforms.log(inputArr, true)).sum(Integer.MAX_VALUE).negi()); @@ -290,14 +290,14 @@ public class ReductionOpValidation extends BaseOpValidation { case 19: inputArr = Nd4j.rand(minibatch, nOut); name = "logEntropy"; - loss = sd.math().logEntropy("loss", input); + loss = sd.math().logEntropy("loss", input, 0,1); double logEntropy = inputArr.logEntropyNumber().doubleValue(); tc.expected(loss, Nd4j.scalar(logEntropy)); break; case 20: inputArr = Nd4j.rand(minibatch, nOut); name = "shannonEntropy"; - loss = sd.math().shannonEntropy("loss", input); + loss = sd.math().shannonEntropy("loss", input, 0); double shannonEntropy = inputArr.shannonEntropyNumber().doubleValue(); tc.expected(loss, Nd4j.scalar(shannonEntropy)); if (OpValidationSuite.IGNORE_FAILING) { @@ -836,11 +836,11 @@ public class ReductionOpValidation extends BaseOpValidation { @Test public void testIndexAccum() { List failed = new ArrayList<>(); - List dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1}, new int[0]); + List dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/); INDArray in = Nd4j.rand(DataType.DOUBLE,3, 4); - for (int t = 0; t < 4; t++) { + for (int t = 0; t < 3; t++) { int[] d = dims.get(t); for (int i = 0; i < 7; i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 30d4baf5c..795cef3f1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -1406,14 +1406,13 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable[] arr = new SDVariable[rank]; - List names = new ArrayList<>(); + String[] names = new String[rank]; for( int i=0; i ph = Collections.singletonMap("in", Nd4j.rand(DataType.FLOAT, 2, 4)); List outputs = Arrays.asList("in", "z", "softmax"); @@ -3522,13 +3521,13 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testRngSanityCheck(){ Nd4j.getRandom().setSeed(12345); - for(DataType dt : DataType.values()) { + for(DataType dt : new DataType[]{DataType.FLOAT, DataType.DOUBLE,DataType.BFLOAT16}) { if (!dt.isNumerical()) continue; SameDiff sameDiff = SameDiff.create(); INDArray indaShape = Nd4j.createFromArray(3, 10); SDVariable sdShape = sameDiff.constant(indaShape); - SDVariable random = sameDiff.random().uniform("data", 0.0, 10.0, sdShape, dt); + SDVariable random = sameDiff.random().uniform("data", 0.0, 10.0, dt, 3, 10); INDArray out = random.eval(); String s = out.toString(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java index 4db765c5e..13853f246 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java @@ -80,7 +80,7 @@ public class SameDiffTrainingTest extends BaseNd4jTest { SDVariable z0 = in.mmul(w0).add(b0); SDVariable a0 = sd.math().tanh(z0); SDVariable z1 = a0.mmul(w1).add("prediction", b1); - SDVariable a1 = sd.nn().softmax(z1); + SDVariable a1 = sd.nn().softmax(z1,-1); SDVariable diff = sd.f().squaredDifference(a1, label); SDVariable lossMse = diff.mul(diff).mean(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java index cb6c70a89..d9f942793 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java @@ -2,6 +2,7 @@ package org.nd4j.autodiff.samediff.listeners; import org.junit.Test; import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener; +import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java index b2c33f386..4f105aecc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java @@ -85,7 +85,7 @@ public class ListenerTest extends BaseNd4jTest { SDVariable z1 = a0.mmul(w1).add(b1); SDVariable predictions = sd.nn().softmax("predictions", z1, 1); - SDVariable loss = sd.loss.softmaxCrossEntropy("loss", label, predictions); + SDVariable loss = sd.loss.softmaxCrossEntropy("loss", label, predictions, null); sd.setLossVariables("loss"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 2d178a210..01dc83ee4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -34,11 +34,14 @@ import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; import org.nd4j.linalg.api.ops.impl.image.ResizeArea; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.shape.OnesLike; import org.nd4j.linalg.api.ops.impl.shape.SequenceMask; +import org.nd4j.linalg.api.ops.impl.transforms.Cholesky; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Qr; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal; @@ -1775,4 +1778,41 @@ public class CustomOpsTests extends BaseNd4jTest { INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.INT32)); assertEquals(expected, ret[0]); } + + @Test + public void testCholesky() { + INDArray x = Nd4j.createFromArray(new double[] {4,12,-16, 12 ,37,-43, -16, -43, 98}).reshape(3,3); + INDArray exp = Nd4j.createFromArray(new double[] {2., 0., 0., 6., 1., 0., -8., 5., 3.}).reshape(3,3); + + INDArray[] res = Nd4j.exec(new Cholesky(x)); + assertEquals(res[0], exp); + } + + @Test + public void testQr() { + INDArray in = Nd4j.createFromArray(new double[]{ + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. + }).reshape(5,3); + Qr op = new Qr(in); + INDArray[] ret = Nd4j.exec(op); + INDArray res = Nd4j.createUninitialized(in.shape()); + DynamicCustomOp matmul = DynamicCustomOp.builder("matmul") + .addInputs(ret[0], ret[1]) + .build(); + ret = Nd4j.exec(matmul); + assertEquals(ret[0], in); + } + + @Test + public void testLogdet() { + INDArray x = Nd4j.createFromArray(new double[]{ + 4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8 + }).reshape(2,3,3); + + INDArray expected = Nd4j.createFromArray(new double[]{3.5835189, 4.159008}); + INDArray[] ret = Nd4j.exec(new Logdet(x)); + assertEquals(ret[0], expected); + + } + } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java index 40d32121d..6d0ec8f54 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java @@ -61,7 +61,7 @@ public class NDLossTest extends BaseNd4jTest { SDVariable loss = sd.loss().absoluteDifference("loss", labels, predictions, w, reduction); - SDVariable loss2 = sd.loss().absoluteDifference("loss2", labels, predictions, null, reduction); + SDVariable loss2 = sd.loss().absoluteDifference("loss2", labels, predictions,null, reduction); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -251,8 +251,8 @@ public class NDLossTest extends BaseNd4jTest { INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); - SDVariable loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction); - SDVariable loss2 = sd.loss().logPoisson("loss2", labels, predictions, null, reduction); + SDVariable loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction, false); + SDVariable loss2 = sd.loss().logPoisson("loss2", labels, predictions, null, reduction, false); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -285,7 +285,8 @@ public class NDLossTest extends BaseNd4jTest { INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); SDVariable loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions, w, reduction); - SDVariable loss2 = sd.loss().meanPairwiseSquaredError("loss2", labels, predictions, null, reduction); + SDVariable loss2 = sd.loss().meanPairwiseSquaredError("loss2", labels, predictions, + null, reduction); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -318,7 +319,8 @@ public class NDLossTest extends BaseNd4jTest { INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); SDVariable loss = sd.loss().meanSquaredError("loss", labels, predictions, w, reduction); - SDVariable loss2 = sd.loss().meanSquaredError("loss2", labels, predictions, null, reduction); + SDVariable loss2 = sd.loss().meanSquaredError("loss2", labels, predictions, + null, reduction); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -352,7 +354,8 @@ public class NDLossTest extends BaseNd4jTest { double labelSmoothing = 0.01; SDVariable loss = sd.loss().sigmoidCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing); - SDVariable loss2 = sd.loss().sigmoidCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing); + SDVariable loss2 = sd.loss().sigmoidCrossEntropy("loss2", labels, predictions, + null, reduction, labelSmoothing); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -388,7 +391,7 @@ public class NDLossTest extends BaseNd4jTest { double labelSmoothing = 0.0; - SDVariable loss = sd.loss().softmaxCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing); + SDVariable loss = sd.loss().softmaxCrossEntropy("loss", labels, predictions, null, reduction, labelSmoothing); SDVariable loss2 = sd.loss().softmaxCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java new file mode 100644 index 000000000..fbce0db6b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java @@ -0,0 +1,285 @@ +/***************************************************************************** + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.generated; + +import org.junit.Before; +import org.junit.Test; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class SDLinalgTest extends BaseNd4jTest { + public SDLinalgTest(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering(){ + return 'c'; + } + + private SameDiff sameDiff; + + @Before + public void setup() { + sameDiff = SameDiff.create(); + } + + @Test + public void testCholesky() { + INDArray input = Nd4j.createFromArray( + new float[]{ + 10.f, 14.f, + 14.f, 20.f, + 74.f, 86.f, + 86.f, 100.f + } + ).reshape(2,2,2); + + INDArray expected = Nd4j.createFromArray( + new float[]{ + 3.1622777f, 0.f, 4.427189f, 0.6324552f, + 8.602325f, 0.f, 9.997296f, 0.23252854f + } + ).reshape(2,2,2); + + SDVariable sdinput = sameDiff.var(input); + SDVariable out = sameDiff.linalg().cholesky(sdinput); + assertEquals(expected, out.eval()); + } + + @Test + public void testLstsq() { + INDArray a = Nd4j.createFromArray(new float[]{ + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f + }).reshape(2,2,2); + + INDArray b = Nd4j.createFromArray(new float[]{ + 3.f, 7.f, 11.f, 15.f + }).reshape(2,2,1); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 0.831169367f, 1.090908766f, 0.920544624f, 1.063016534f + }).reshape(2,2,1); + + SDVariable sda = sameDiff.var(a); + SDVariable sdb = sameDiff.var(b); + + SDVariable res = sameDiff.linalg().lstsq(sda,sdb,0.5,true); + assertEquals(expected, res.eval()); + } + + @Test + public void testLu() { + SDVariable sdInput = sameDiff.var(Nd4j.createFromArray(new double[]{ + 1., 2., 3., 0., 2., 3., 0., 0., 7. + }).reshape(3,3)); + + INDArray expected = Nd4j.createFromArray(new double[]{ + 1., 2., 3., 0., 2., 3., 0., 0., 7 + }).reshape(3,3); + + SDVariable out = sameDiff.linalg().lu("lu", sdInput); + assertEquals(expected, out.eval()); + } + + @Test + public void testMatrixBandPart() { + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); + INDArray expected = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); + + SDVariable sdx = sameDiff.var(x); + SDVariable[] res = sameDiff.linalg().matrixBandPart(sdx, 1, 1); + assertArrayEquals(x.shape(), res[0].eval().shape()); + } + + @Test + public void testQr() { + INDArray input = Nd4j.createFromArray(new double[]{ + 12., -51., 4., + 6., 167., -68., + -4., 24., -41., + -1., 1., 0., + 2., 0., 3. + }).reshape(5,3); + + INDArray expectedQ = Nd4j.createFromArray(new double[]{ + 0.8464147390303179, -0.3912908119746455, 0.34312406418022884, + 0.42320736951515897, 0.9040872694197354, -0.02927016186366648, + -0.2821382463434393, 0.17042054976392634, 0.9328559865183932, + -0.07053456158585983, 0.01404065236547358, -0.00109937201747271, + 0.14106912317171966, -0.01665551070074392, -0.10577161246232346 + }).reshape(5,3); + + INDArray expectedR = Nd4j.createFromArray(new double[]{ + 14.177446878757824, 20.666626544656932, -13.401566701313369, + -0.0000000000000006, 175.04253925050244, -70.0803066408638, + 0.00000000000000017, -0.00000000000000881, -35.20154302119086 + }).reshape(3,3); + + SDVariable sdInput = sameDiff.var(input); + SDVariable[] res = sameDiff.linalg().qr(sdInput); + + SDVariable mmulResult = sameDiff.mmul(res[0], res[1]); + + assertEquals(input, mmulResult.eval()); + } + + @Test + public void testSolve() { + INDArray a = Nd4j.createFromArray(new float[] { + 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f + }).reshape(3,3); + + INDArray b = Nd4j.createFromArray(new float[] { + 2.f, 4.f, 3.f + }).reshape(3,1); + + INDArray expected = Nd4j.createFromArray(new float[] { + 7.625f, 3.25f, 5.f + }).reshape(3,1); + + SDVariable sda = sameDiff.var(a); + SDVariable sdb = sameDiff.var(b); + + SDVariable res = sameDiff.linalg().solve(sda, sdb); + assertEquals(expected, res.eval()); + } + + @Test + public void testTriangularSolve() { + INDArray a = Nd4j.createFromArray(new float[] { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }).reshape(3,3); + + INDArray b = Nd4j.createFromArray(new float[] { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }).reshape(3,3); + + INDArray expected = Nd4j.createFromArray(new float[] { + 0.99088347f, 1.1917052f, 1.2642528f, + 0.35071516f, 0.50630623f, 0.42935497f, + -0.30013534f, -0.53690606f, -0.47959247f + }).reshape(3,3); + + SDVariable sda = sameDiff.var(a); + SDVariable sdb = sameDiff.var(b); + + SDVariable res = sameDiff.linalg().triangularSolve(sda, sdb, true, false); + assertEquals(expected, res.eval()); + } + + @Test + public void testCross() { + INDArray a = Nd4j.createFromArray(new double[]{1, 2, 3}); + INDArray b = Nd4j.createFromArray(new double[]{6, 7, 8}); + INDArray expected = Nd4j.createFromArray(new double[]{-5, 10, -5}); + + SDVariable sda = sameDiff.var(a); + SDVariable sdb = sameDiff.var(b); + + SDVariable res = sameDiff.linalg().cross(sda, sdb); + assertEquals(expected, res.eval()); + } + + @Test + public void testDiag() { + INDArray x = Nd4j.createFromArray(new double[]{1,2}); + INDArray expected = Nd4j.createFromArray(new double[]{1,0,0,2}).reshape(2,2); + + SDVariable sdx = sameDiff.var(x); + + SDVariable res = sameDiff.linalg().diag(sdx); + assertEquals(expected, res.eval()); + } + + @Test + public void testDiagPart() { + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4).reshape(2,2); + INDArray expected = Nd4j.createFromArray(new double[]{1,4}); + + SDVariable sdx = sameDiff.var(x); + + SDVariable res = sameDiff.linalg().diag_part(sdx); + assertEquals(expected, res.eval()); + } + + @Test + public void testLogdet() { + INDArray x = Nd4j.createFromArray(new double[]{ + 4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8 + }).reshape(2,3,3); + INDArray expected = Nd4j.createFromArray(new double[]{3.5835189, 4.159008}); + + SDVariable sdx = sameDiff.var(x); + + SDVariable res = sameDiff.linalg().logdet(sdx); + assertEquals(expected, res.eval()); + } + + @Test + public void testSvd() { + INDArray x = Nd4j.createFromArray(new double[]{ + 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f + }).reshape(3,3); + INDArray expected = Nd4j.createFromArray(new double[]{1.8967269987492157, 0.3709665595850617, 0.05524869852188223}); + + SDVariable sdx = sameDiff.var(x); + SDVariable res = sameDiff.linalg().svd(sdx, false, false); + assertEquals(expected, res.eval()); + } + + @Test + public void testLogdetName() { + INDArray x = Nd4j.createFromArray(new double[]{ + 4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8 + }).reshape(2,3,3); + + SDVariable sdx = sameDiff.var(x); + + SDVariable res = sameDiff.linalg().logdet("logdet", sdx); + assertEquals("logdet", res.name()); + } + + @Test + public void testQrNames() { + INDArray input = Nd4j.createFromArray(new double[]{ + 12., -51., 4., + 6., 167., -68., + -4., 24., -41., + -1., 1., 0., + 2., 0., 3. + }).reshape(5,3); + + SDVariable sdInput = sameDiff.var(input); + SDVariable[] res = sameDiff.linalg().qr(new String[]{"ret0", "ret1"}, sdInput); + + assertEquals("ret0", res[0].name()); + assertEquals("ret1", res[1].name()); + } +} From 81ebfeead11f91f49734e937e5747123bad45719 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 1 Apr 2020 15:11:39 +1100 Subject: [PATCH 05/19] Small fixes (#355) * #8787 DataVec test fix Signed-off-by: Alex Black * New nd4j test + fix bad datavec test Signed-off-by: Alex Black * #8745 Fix flaky arbiter test Signed-off-by: Alex Black * One more test Signed-off-by: Alex Black --- .../TestGraphLocalExecution.java | 9 +++++-- .../org/datavec/image/loader/LFWLoader.java | 8 +++--- .../org/datavec/image/loader/LoaderTests.java | 24 ++++++++++++----- .../opvalidation/MiscOpValidation.java | 22 +++++++++++++++ .../opvalidation/ShapeOpValidation.java | 27 +++++++++++++++++++ 5 files changed, 77 insertions(+), 13 deletions(-) diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java index 391139f32..5c48b4c18 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java @@ -96,6 +96,11 @@ public class TestGraphLocalExecution extends BaseDL4JTest { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); } + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @Test public void testLocalExecutionDataSources() throws Exception { @@ -204,7 +209,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest { OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() .candidateGenerator(candidateGenerator).dataProvider(dataProvider) .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) - .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + .terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS), new MaxCandidatesCondition(3)) .build(); @@ -251,7 +256,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest { .candidateGenerator(candidateGenerator) .dataProvider(new TestMdsDataProvider(1, 32)) .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) - .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + .terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS), new MaxCandidatesCondition(3)) .scoreFunction(ScoreFunctions.testSetAccuracy()) .build(); diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java index 7a8e23ac0..3ba918548 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java @@ -72,11 +72,11 @@ public class LFWLoader extends BaseImageLoader implements Serializable { protected File fullDir; protected boolean useSubset = false; - InputSplit[] inputSplit; + protected InputSplit[] inputSplit; - public static Map lfwData = new HashMap<>(); - public static Map lfwLabel = new HashMap<>(); - public static Map lfwSubsetData = new HashMap<>(); + public Map lfwData = new HashMap<>(); + public Map lfwLabel = new HashMap<>(); + public Map lfwSubsetData = new HashMap<>(); public LFWLoader() { this(false); diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java index bf6d908e2..7afdb7ac0 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java @@ -45,15 +45,23 @@ import static org.junit.Assert.assertTrue; */ public class LoaderTests { + private static void ensureDataAvailable(){ + //Ensure test resources available by initializing CifarLoader and relying on auto download + boolean preProcessCifar = false; + int numExamples = 10; + int row = 28; + int col = 28; + int channels = 1; + for( boolean train : new boolean[]{true, false}){ + CifarLoader loader = new CifarLoader(row, col, channels, train, preProcessCifar); + loader.next(numExamples); + } + new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42)).next(); + } + @Test public void testLfwReader() throws Exception { - String subDir = "lfw-a/lfw"; - File path = new File(FilenameUtils.concat(System.getProperty("user.home"), subDir)); - FileSplit fileSplit = new FileSplit(path, LFWLoader.ALLOWED_FORMATS, new Random(42)); - BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(42), LFWLoader.LABEL_PATTERN, 1, 1, 1); - InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1); - RecordReader rr = new ImageRecordReader(250, 250, 3, LFWLoader.LABEL_PATTERN); - rr.initialize(inputSplit[0]); + RecordReader rr = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42)); List exptedLabel = rr.getLabels(); RecordReader rr2 = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42)); @@ -63,6 +71,7 @@ public class LoaderTests { @Test public void testCifarLoader() { + ensureDataAvailable(); File dir = new File(FilenameUtils.concat(System.getProperty("user.home"), "cifar/cifar-10-batches-bin")); CifarLoader cifar = new CifarLoader(false, dir); assertTrue(dir.exists()); @@ -71,6 +80,7 @@ public class LoaderTests { @Test public void testCifarInputStream() throws Exception { + ensureDataAvailable(); // check train String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin"; String path = FilenameUtils.concat(System.getProperty("user.home"), subDir); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 06c64445b..3998bc184 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -2122,4 +2122,26 @@ public class MiscOpValidation extends BaseOpValidation { assertNull(err); } + + @Test + public void testSeqMask(){ + INDArray arr = Nd4j.createFromArray(1,2,3); + INDArray maxLen = Nd4j.scalar(4); + + INDArray out = Nd4j.create(DataType.INT32, 3, 4); + out.assign(Integer.MAX_VALUE); + + Nd4j.exec(DynamicCustomOp.builder("sequence_mask") + .addInputs(arr, maxLen) + .addOutputs(out) + .build() + ); + + INDArray exp = Nd4j.createFromArray(new int[][]{ + {1, 0, 0, 0}, + {1, 1, 0, 0}, + {1, 1, 1, 0}}); + + assertEquals(exp, out); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 795cef3f1..0cbe52479 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -22,6 +22,7 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.math3.linear.LUDecomposition; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; @@ -2498,4 +2499,30 @@ public class ShapeOpValidation extends BaseOpValidation { assertEquals(Nd4j.createFromArray(2, 2), out); } + + + @Test @Ignore //AB 2020/04/01 - https://github.com/eclipse/deeplearning4j/issues/8592 + public void testReshapeZeros(){ + int[][] shapes = new int[][]{{2,0}, {10,0}, {10, 0}, {2,0,0,10}, {10, 0}, {0, 0, 10}, {0,2,10}, {1,2,0}}; + int[][] reshape = new int[][]{{2,-1}, {2,0,-1}, {5,2,-1}, {2,0,-1}, {-1, 2, 0}, {2, -1, 0}, {2, 0, 0, 0, -1}, {2,0,-1}}; + int[][] expected = new int[][]{{2,0}, {2,0,5}, {5,2,0}, {2,0,10}, {5,2,0}, {2,5,0}, {2,0,0,0,10}, {2,0,1}}; + + for( int i=0; i Date: Wed, 1 Apr 2020 07:13:34 +0300 Subject: [PATCH 06/19] - correct reshape op for empty shapes (#354) * - correct reshape op for empty shape in case of -1 at the end Signed-off-by: Yurii * Fix test + new reshape op constructor Signed-off-by: Alex Black Co-authored-by: Alex Black --- .../ops/declarable/generic/shape/reshape.cpp | 104 ++++++++++++------ .../layers_tests/DeclarableOpsTests14.cpp | 7 +- .../linalg/api/ops/impl/shape/Reshape.java | 9 +- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 8 +- 4 files changed, 88 insertions(+), 40 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index 5ac7686e2..023e9bf89 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -81,43 +81,79 @@ DECLARE_SHAPE_FN(reshape) { REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); - Nd4jLong xLen = x->lengthOf(); - if(x->isEmpty()) { - xLen = 1; - for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - if(x->sizeAt(i) != 0) - xLen *= x->sizeAt(i); + // Nd4jLong xLen = x->lengthOf(); + // if(x->isEmpty()) { + // xLen = 1; + // for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes + // if(x->sizeAt(i) != 0) + // xLen *= x->sizeAt(i); + // } + + // for (uint i = 0; i < reshapeArgs.size(); ++i) { + + // if (reshapeArgs[i] == -1) { + + // uint shapeLength = 1, numOfZeros = 0; + + // for(uint j = 0; j < i; ++j) + // if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; + + // for(uint j = i + 1; j < reshapeArgs.size(); ++j) { + // REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); + // if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; + // } + + // const auto dim = xLen / shapeLength; + + // if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) + // shapeNew.push_back(0); + // else + // shapeNew.push_back(dim); + // } + // else + // shapeNew.push_back(reshapeArgs[i]); + // } + + Nd4jLong newShapeLen = 1; + int pos = -1; + bool newShapeEmpty = false; + + for (int i = 0; i < reshapeArgs.size(); ++i) { + + const int dim = reshapeArgs[i]; + + if (dim == -1) { + REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); + pos = i; + shapeNew.push_back(1); + } + else if (dim == 0) { + shapeNew.push_back(0); + newShapeEmpty = true; + } + else { + shapeNew.push_back(dim); + newShapeLen *= dim; + } } - for (uint i = 0; i < reshapeArgs.size(); ++i) { + if (pos != -1) { - if (reshapeArgs[i] == -1) { - - uint shapeLength = 1, numOfZeros = 0; - - for(uint j = 0; j < i; ++j) - if(reshapeArgs[j] != 0) - shapeLength *= reshapeArgs[j]; - else - ++numOfZeros; - - for(uint j = i + 1; j < reshapeArgs.size(); ++j) { - REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - if(reshapeArgs[j] != 0) - shapeLength *= reshapeArgs[j]; - else - ++numOfZeros; - } - - const auto dim = xLen / shapeLength; - - if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) - shapeNew.push_back(0); - else - shapeNew.push_back(dim); + Nd4jLong xLen = x->lengthOf(); + if(x->isEmpty()) { + xLen = 1; + for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes + if(x->sizeAt(i) > 0 || !newShapeEmpty) + xLen *= x->sizeAt(i); } - else - shapeNew.push_back(reshapeArgs[i]); + + shapeNew[pos] = xLen / newShapeLen; } auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); @@ -126,6 +162,8 @@ DECLARE_SHAPE_FN(reshape) { return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew)); } + + } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index db49c12f2..b4c9839ab 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -2288,7 +2288,7 @@ TEST_F(DeclarableOpsTests14, Reshape15) { auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); - auto e0 = NDArrayFactory::create('c', {2, 0, 0}); + auto e0 = NDArrayFactory::create('c', {2, 0, 1}); auto e1 = NDArrayFactory::create('c', {0, 1}); sd::ops::reshape op; @@ -2374,6 +2374,7 @@ TEST_F(DeclarableOpsTests14, Reshape20) { NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32); NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32); NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32); + NDArray x8('c', {1,2,0}, sd::DataType::FLOAT32); sd::ops::reshape op; @@ -2416,4 +2417,8 @@ TEST_F(DeclarableOpsTests14, Reshape20) { result = op.evaluate({&x7}, {}, {10,0,50,100}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100})); + + result = op.evaluate({&x7}, {}, {2,0,-1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,1})); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index ddf0224db..c6b79b2b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -55,8 +56,12 @@ public class Reshape extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{i_v, shape}); } - public Reshape(INDArray in, INDArray shape, INDArray out){ - super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List)null); + public Reshape(INDArray in, INDArray shape){ + this(in, shape, null); + } + + public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){ + super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List)null); } public Reshape(INDArray in, INDArray shape) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 5424d3c50..da91fb6cf 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8255,11 +8255,11 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2); - INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0]; - INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0]; - INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0]; + INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1)))[0]; + INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1)))[0]; + INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1)))[0]; - assertArrayEquals(new long[]{2, 0, 0}, out0.shape()); + assertArrayEquals(new long[]{2, 0, 1}, out0.shape()); assertArrayEquals(new long[]{0, 1}, out1.shape()); assertArrayEquals(new long[]{10, 0}, out2.shape()); } From 8ac89aeb190c55983835f0293161c4c04a15209f Mon Sep 17 00:00:00 2001 From: Chris Bamford Date: Wed, 1 Apr 2020 06:28:01 +0100 Subject: [PATCH 07/19] RL4J: Force shape fix (#352) * fix edge case where input to network needs to have shape > 1 Signed-off-by: Bam4d * adding test for single dimension Signed-off-by: Bam4d --- .../deeplearning4j/rl4j/helper/INDArrayHelper.java | 2 +- .../rl4j/helper/INDArrayHelperTest.java | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java index 7d93b1175..2e608db19 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java @@ -32,7 +32,7 @@ public class INDArrayHelper { * @return The source INDArray with the correct shape */ public static INDArray forceCorrectShape(INDArray source) { - return source.shape()[0] == 1 + return source.shape()[0] == 1 && source.shape().length > 1 ? source : Nd4j.expandDims(source, 0); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java index 9bfceadad..e1c5c64ed 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java @@ -35,4 +35,18 @@ public class INDArrayHelperTest { assertEquals(3, output.shape()[1]); } + @Test + public void when_inputHasOneDimension_expect_outputWithTwoDimensions() { + // Arrange + INDArray input = Nd4j.create(new double[] { 1.0 }); + + // Act + INDArray output = INDArrayHelper.forceCorrectShape(input); + + // Assert + assertEquals(2, output.shape().length); + assertEquals(1, output.shape()[0]); + assertEquals(1, output.shape()[1]); + } + } From fb1c41c51234dd050a5dea6390d432e56af87388 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 1 Apr 2020 11:09:48 +0300 Subject: [PATCH 08/19] Build fix (#357) --- .../main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index c6b79b2b1..6feace53f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -64,10 +64,6 @@ public class Reshape extends DynamicCustomOp { super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List)null); } - public Reshape(INDArray in, INDArray shape) { - addInputArgument(in, shape); - } - public Reshape() { } From 1a35ebec2ed7209e7e732246dd505c94d482a136 Mon Sep 17 00:00:00 2001 From: Chris Bamford Date: Mon, 6 Apr 2020 04:36:12 +0100 Subject: [PATCH 09/19] RL4J: Add Backwardly Compatible Builder patterns (#326) * Starting to switch configs of RL algorithms to use more fluent builder patterns. Many parameter choices in different algorithms default to SOTA and only be changed in specific cases Signed-off-by: Bam4d * remove personal gpu-build file Signed-off-by: Bam4d * refactored out configurations so they are heirarchical and re-usable, this is a step towards having a plug-and-play framework for different algorithms * backwardly compatible configurations * adding documentation to new configuration classes Signed-off-by: Bam4d * private access modifiers are better suited here Signed-off-by: Bam4d * RL4j does not compile without java 8 due to previous updates fixing null pointers when listener arrays are empty Signed-off-by: Bam4d * fixing copyright headers Signed-off-by: Bam4d * uncomment logging line Signed-off-by: Bam4d * fixing default value for learningUpdateFrequency fixing test failure due to #352 Signed-off-by: Bam4d Co-authored-by: Bam4d --- .../api/transform/split/RandomSplit.java | 1 + rl4j/rl4j-core/pom.xml | 12 ++ .../rl4j/learning/EpochStepCounter.java | 16 +++ .../rl4j/learning/ILearning.java | 16 +-- .../rl4j/learning/async/AsyncGlobal.java | 36 +++--- .../rl4j/learning/async/AsyncLearning.java | 30 +++-- .../rl4j/learning/async/AsyncThread.java | 8 +- .../learning/async/AsyncThreadDiscrete.java | 6 +- .../async/a3c/discrete/A3CDiscrete.java | 44 +++++-- .../async/a3c/discrete/A3CDiscreteConv.java | 44 +++++-- .../async/a3c/discrete/A3CDiscreteDense.java | 83 +++++++----- .../async/a3c/discrete/A3CThreadDiscrete.java | 13 +- .../discrete/AsyncNStepQLearningDiscrete.java | 56 +++++--- .../AsyncNStepQLearningDiscreteConv.java | 39 +++--- .../AsyncNStepQLearningDiscreteDense.java | 57 ++++++--- .../AsyncNStepQLearningThreadDiscrete.java | 28 ++-- .../A3CLearningConfiguration.java | 46 +++++++ .../AsyncQLearningConfiguration.java | 42 ++++++ .../IAsyncLearningConfiguration.java | 28 ++++ .../ILearningConfiguration.java} | 30 +---- .../configuration/LearningConfiguration.java | 59 +++++++++ .../configuration/QLearningConfiguration.java | 79 ++++++++++++ .../rl4j/learning/sync/SyncLearning.java | 5 +- .../learning/sync/qlearning/QLearning.java | 48 +++++-- .../qlearning/discrete/QLearningDiscrete.java | 27 ++-- .../discrete/QLearningDiscreteConv.java | 35 ++++- .../discrete/QLearningDiscreteDense.java | 28 +++- .../rl4j/network/ac/ActorCriticCompGraph.java | 11 +- .../ActorCriticFactoryCompGraphStdConv.java | 29 ++++- .../ActorCriticFactoryCompGraphStdDense.java | 38 ++---- .../ActorCriticFactorySeparateStdDense.java | 83 +++++++----- .../ActorCriticDenseNetworkConfiguration.java | 42 ++++++ .../ActorCriticNetworkConfiguration.java | 37 ++++++ .../DQNDenseNetworkConfiguration.java | 40 ++++++ .../configuration/NetworkConfiguration.java | 58 +++++++++ .../rl4j/network/dqn/DQNFactoryStdConv.java | 26 +++- .../rl4j/network/dqn/DQNFactoryStdDense.java | 63 ++++++--- .../deeplearning4j/rl4j/policy/EpsGreedy.java | 13 +- .../deeplearning4j/rl4j/util/DataManager.java | 28 +++- .../rl4j/learning/HistoryProcessorTest.java | 8 +- .../learning/async/AsyncLearningTest.java | 33 ++++- .../async/AsyncThreadDiscreteTest.java | 22 +++- .../rl4j/learning/async/AsyncThreadTest.java | 19 ++- .../a3c/discrete/A3CThreadDiscreteTest.java | 26 +++- ...AsyncNStepQLearningThreadDiscreteTest.java | 23 +++- .../rl4j/learning/sync/SyncLearningTest.java | 34 ++++- ...t.java => QLearningConfigurationTest.java} | 31 ++--- .../discrete/QLearningDiscreteTest.java | 120 +++++++++++------- .../rl4j/network/ac/ActorCriticTest.java | 49 +++---- .../rl4j/network/dqn/DQNTest.java | 20 +-- .../transform/TransformProcessTest.java | 2 +- .../rl4j/policy/PolicyTest.java | 37 ++++-- .../rl4j/support/MockAsyncConfiguration.java | 31 +++-- .../util/DataManagerTrainingListenerTest.java | 20 ++- .../malmo/MalmoObservationSpaceGrid.java | 24 ++-- 55 files changed, 1388 insertions(+), 495 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/{async/AsyncConfiguration.java => configuration/ILearningConfiguration.java} (61%) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java rename rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/{QLConfigurationTest.java => QLearningConfigurationTest.java} (52%) diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/split/RandomSplit.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/split/RandomSplit.java index 290e26873..fe4718f48 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/split/RandomSplit.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/split/RandomSplit.java @@ -16,6 +16,7 @@ package org.datavec.api.transform.split; + import lombok.AllArgsConstructor; import lombok.Data; diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index c08615250..a93ea6345 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -20,6 +20,18 @@ xmlns="http://maven.apache.org/POM/4.0.0" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 + + + + org.apache.maven.plugins + maven-compiler-plugin + + 8 + 8 + + + + rl4j diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java index 746a71396..533209ed7 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java @@ -1,3 +1,19 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning; public interface EpochStepCounter { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index d151f093b..43ed508b0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,10 +17,10 @@ package org.deeplearning4j.rl4j.learning; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16. @@ -34,21 +35,12 @@ public interface ILearning> { int getStepCounter(); - LConfiguration getConfiguration(); + ILearningConfiguration getConfiguration(); MDP getMdp(); IHistoryProcessor getHistoryProcessor(); - interface LConfiguration { - Integer getSeed(); - - int getMaxEpochStep(); - - int getMaxStep(); - - double getGamma(); - } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java index 5501a29e1..01c519b57 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,6 +20,8 @@ package org.deeplearning4j.rl4j.learning.async; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.network.NeuralNet; import org.nd4j.linalg.primitives.Pair; @@ -27,28 +30,26 @@ import java.util.concurrent.atomic.AtomicInteger; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. - * + *

* In the original paper, the authors uses Asynchronous * Gradient Descent: Hogwild! It is a way to apply gradients * and modify a model in a lock-free manner. - * + *

* As a way to implement this with dl4j, it is unfortunately * necessary at the time of writing to apply the gradient * (update the parameters) on a single separate global thread. - * + *

* This Central thread for Asynchronous Method of reinforcement learning * enqueue the gradients coming from the different threads and update its * model and target. Those neurals nets are then synced by the other threads. - * + *

* The benefits of this thread is that the updater is "shared" between all thread * we have a single updater which is the single updater of the model contained here - * + *

* This is similar to RMSProp with shared g and momentum - * + *

* When Hogwild! is implemented, this could be replaced by a simple data * structure - * - * */ @Slf4j public class AsyncGlobal extends Thread implements IAsyncGlobal { @@ -56,7 +57,7 @@ public class AsyncGlobal extends Thread implements IAsyncG @Getter final private NN current; final private ConcurrentLinkedQueue> queue; - final private AsyncConfiguration a3cc; + final private IAsyncLearningConfiguration configuration; private final IAsyncLearning learning; @Getter private AtomicInteger T = new AtomicInteger(0); @@ -65,20 +66,20 @@ public class AsyncGlobal extends Thread implements IAsyncG @Getter private boolean running = true; - public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) { + public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) { this.current = initial; target = (NN) initial.clone(); - this.a3cc = a3cc; + this.configuration = configuration; this.learning = learning; queue = new ConcurrentLinkedQueue<>(); } public boolean isTrainingComplete() { - return T.get() >= a3cc.getMaxStep(); + return T.get() >= configuration.getMaxStep(); } public void enqueue(Gradient[] gradient, Integer nstep) { - if(running && !isTrainingComplete()) { + if (running && !isTrainingComplete()) { queue.add(new Pair<>(gradient, nstep)); } } @@ -94,9 +95,8 @@ public class AsyncGlobal extends Thread implements IAsyncG synchronized (this) { current.applyGradient(gradient, pair.getSecond()); } - if (a3cc.getTargetDqnUpdateFreq() != -1 - && T.get() / a3cc.getTargetDqnUpdateFreq() > (T.get() - pair.getSecond()) - / a3cc.getTargetDqnUpdateFreq()) { + if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond()) + / configuration.getLearnerUpdateFrequency()) { log.info("TARGET UPDATE at T = " + T.get()); synchronized (this) { target.copy(current); @@ -111,7 +111,7 @@ public class AsyncGlobal extends Thread implements IAsyncG * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too. */ public void terminate() { - if(running) { + if (running) { running = false; queue.clear(); learning.terminate(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index 994ec9cb0..1c3c83972 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -21,14 +22,17 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.listener.*; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; +import org.deeplearning4j.rl4j.learning.listener.TrainingListener; +import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.factory.Nd4j; /** - * The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread() + * The entry point for async training. This class will start a number ({@link AsyncQLearningConfiguration#getNumThreads() * configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals * (see setProgressEventInterval(int)) * @@ -37,8 +41,8 @@ import org.nd4j.linalg.factory.Nd4j; */ @Slf4j public abstract class AsyncLearning, NN extends NeuralNet> - extends Learning - implements IAsyncLearning { + extends Learning + implements IAsyncLearning { private Thread monitorThread = null; @@ -56,9 +60,10 @@ public abstract class AsyncLearning, NN extends Ne } private void handleTraining(RunContext context) { - int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep); + int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep); SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); context.obs = subEpochReturn.getLastObs(); @@ -197,7 +199,7 @@ public abstract class AsyncThread, NN extends Ne protected abstract IAsyncGlobal getAsyncGlobal(); - protected abstract AsyncConfiguration getConf(); + protected abstract IAsyncLearningConfiguration getConf(); protected abstract IPolicy getPolicy(NN net); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index 27d49c366..a72abfa62 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -1,5 +1,7 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -112,7 +114,7 @@ public abstract class AsyncThreadDiscrete rewards.add(new MiniTrans(obs.getData(), null, null, 0)); else { INDArray[] output = null; - if (getConf().getTargetDqnUpdateFreq() == -1) + if (getConf().getLearnerUpdateFrequency() == -1) output = current.outputAll(obs.getData()); else synchronized (getAsyncGlobal()) { output = getAsyncGlobal().getTarget().outputAll(obs.getData()); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index 81308ba5a..0608ec5cc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,11 +17,15 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete; -import lombok.*; -import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncLearning; import org.deeplearning4j.rl4j.learning.async.AsyncThread; +import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.policy.ACPolicy; @@ -32,15 +37,14 @@ import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. * Training for A3C in the Discrete Domain - * + *

* All methods are fully implemented as described in the * https://arxiv.org/abs/1602.01783 paper. - * */ public abstract class A3CDiscrete extends AsyncLearning { @Getter - final public A3CConfiguration configuration; + final public A3CLearningConfiguration configuration; @Getter final protected MDP mdp; final private IActorCritic iActorCritic; @@ -49,15 +53,15 @@ public abstract class A3CDiscrete extends AsyncLearning policy; - public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CConfiguration conf) { + public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) { this.iActorCritic = iActorCritic; this.mdp = mdp; this.configuration = conf; asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this); - Integer seed = conf.getSeed(); + Long seed = conf.getSeed(); Random rnd = Nd4j.getRandom(); - if(seed != null) { + if (seed != null) { rnd.setSeed(seed); } @@ -65,7 +69,7 @@ public abstract class A3CDiscrete extends AsyncLearning extends AsyncLearning extends AsyncLearning * Training for A3C in the Discrete Domain - * + *

* Specialized constructors for the Conv (pixels input) case * Specialized conf + provide additional type safety - * + *

* It uses CompGraph because there is benefit to combine the * first layers since they're essentially doing the same dimension * reduction task - * **/ public class A3CDiscreteConv extends A3CDiscrete { @@ -46,12 +48,22 @@ public class A3CDiscreteConv extends A3CDiscrete { @Deprecated public A3CDiscreteConv(MDP mdp, IActorCritic actorCritic, - HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, actorCritic, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } + + @Deprecated public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { + + super(mdp, IActorCritic, conf.toLearningConfiguration()); + this.hpconf = hpconf; + setHistoryProcessor(hpconf); + } + + public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, + HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { super(mdp, IActorCritic, conf); this.hpconf = hpconf; setHistoryProcessor(hpconf); @@ -59,21 +71,35 @@ public class A3CDiscreteConv extends A3CDiscrete { @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, - HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } + + @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, + HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { + this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); + } + @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { - this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager); + HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { + this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); } + + @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { + this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf); + } + + public A3CDiscreteConv(MDP mdp, ActorCriticNetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java index 16b8151df..74332bf3a 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,8 +17,10 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete; +import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.*; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; @@ -25,67 +28,81 @@ import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16. - * + *

* Training for A3C in the Discrete Domain - * + *

* We use specifically the Separate version because * the model is too small to have enough benefit by sharing layers - * */ public class A3CDiscreteDense extends A3CDiscrete { @Deprecated public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf, - IDataManager dataManager) { + IDataManager dataManager) { this(mdp, IActorCritic, conf); addListener(new DataManagerTrainingListener(dataManager)); } + + @Deprecated public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CConfiguration conf) { + super(mdp, actorCritic, conf.toLearningConfiguration()); + } + + public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) { super(mdp, actorCritic, conf); } @Deprecated public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, - A3CConfiguration conf, IDataManager dataManager) { + A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, - dataManager); + dataManager); } + + @Deprecated public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - @Deprecated - public A3CDiscreteDense(MDP mdp, - ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, - IDataManager dataManager) { - this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager); - } - public A3CDiscreteDense(MDP mdp, - ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { - this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); - } - - @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, - A3CConfiguration conf, IDataManager dataManager) { - this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, - dataManager); - } - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, - A3CConfiguration conf) { + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, + A3CLearningConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } @Deprecated public A3CDiscreteDense(MDP mdp, - ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf, - IDataManager dataManager) { - this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager); - } - public A3CDiscreteDense(MDP mdp, - ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) { - this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf); + ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, + IDataManager dataManager) { + this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } + @Deprecated + public A3CDiscreteDense(MDP mdp, + ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { + this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf); + } + + public A3CDiscreteDense(MDP mdp, + ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) { + this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); + } + + @Deprecated + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + A3CConfiguration conf, IDataManager dataManager) { + this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, + dataManager); + } + + @Deprecated + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + A3CConfiguration conf) { + this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + A3CLearningConfiguration conf) { + this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index 22b3894b2..c2a16d6b4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,10 +20,10 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete; import lombok.Getter; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.IActorCritic; @@ -31,9 +32,9 @@ import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.api.rng.Random; import java.util.Stack; @@ -45,7 +46,7 @@ import java.util.Stack; public class A3CThreadDiscrete extends AsyncThreadDiscrete { @Getter - final protected A3CDiscrete.A3CConfiguration conf; + final protected A3CLearningConfiguration conf; @Getter final protected IAsyncGlobal asyncGlobal; @Getter @@ -54,14 +55,14 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< final private Random rnd; public A3CThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, - A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners, + A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners, int threadNumber) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); this.conf = a3cc; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; - Integer seed = conf.getSeed(); + Long seed = conf.getSeed(); rnd = Nd4j.getRandom(); if(seed != null) { rnd.setSeed(seed + threadNumber); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index c18de9e10..9a8049f6f 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -1,49 +1,53 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. * - * SPDX-License-Identifier: Apache-2.0 + * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; -import lombok.*; -import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncLearning; import org.deeplearning4j.rl4j.learning.async.AsyncThread; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ public abstract class AsyncNStepQLearningDiscrete - extends AsyncLearning { + extends AsyncLearning { @Getter - final public AsyncNStepQLConfiguration configuration; + final public AsyncQLearningConfiguration configuration; @Getter final private MDP mdp; @Getter final private AsyncGlobal asyncGlobal; - public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { + public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) { this.mdp = mdp; this.configuration = conf; this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this); @@ -62,12 +66,11 @@ public abstract class AsyncNStepQLearningDiscrete return new DQNPolicy(getNeuralNet()); } - @Data @AllArgsConstructor @Builder @EqualsAndHashCode(callSuper = false) - public static class AsyncNStepQLConfiguration implements AsyncConfiguration { + public static class AsyncNStepQLConfiguration { Integer seed; int maxEpochStep; @@ -82,5 +85,22 @@ public abstract class AsyncNStepQLearningDiscrete float minEpsilon; int epsilonNbStep; + public AsyncQLearningConfiguration toLearningConfiguration() { + return AsyncQLearningConfiguration.builder() + .seed(new Long(seed)) + .maxEpochStep(maxEpochStep) + .maxStep(maxStep) + .numThreads(numThread) + .nStep(nstep) + .targetDqnUpdateFreq(targetDqnUpdateFreq) + .updateStart(updateStart) + .rewardFactor(rewardFactor) + .gamma(gamma) + .errorClamp(errorClamp) + .minEpsilon(minEpsilon) + .build(); + } + } + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java index 83274b7f6..f92b704b6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java @@ -1,24 +1,27 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. * - * SPDX-License-Identifier: Apache-2.0 + * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.async.AsyncThread; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.IDQN; @@ -38,12 +41,12 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN @Deprecated public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, dqn, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { super(mdp, dqn, conf); this.hpconf = hpconf; setHistoryProcessor(hpconf); @@ -51,21 +54,21 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN @Deprecated public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } @Deprecated - public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { + public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); } - public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { + public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java index b58e15902..b6216e849 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java @@ -1,22 +1,26 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. * - * SPDX-License-Identifier: Apache-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; @@ -32,35 +36,56 @@ public class AsyncNStepQLearningDiscreteDense extends Async @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, - AsyncNStepQLConfiguration conf, IDataManager dataManager) { - super(mdp, dqn, conf); + AsyncNStepQLConfiguration conf, IDataManager dataManager) { + super(mdp, dqn, conf.toLearningConfiguration()); addListener(new DataManagerTrainingListener(dataManager)); } + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { + super(mdp, dqn, conf.toLearningConfiguration()); + } + + public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, + AsyncQLearningConfiguration conf) { super(mdp, dqn, conf); } @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, - AsyncNStepQLConfiguration conf, IDataManager dataManager) { + AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, - dataManager); + dataManager); } + + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, AsyncNStepQLConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, + AsyncQLearningConfiguration conf) { + this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, - DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { - this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); + DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { + this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } + + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) { + this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); + } + + public AsyncNStepQLearningDiscreteDense(MDP mdp, + DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf), conf); } + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index f8c470269..71199efaf 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -1,17 +1,18 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. * - * SPDX-License-Identifier: Apache-2.0 + * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; @@ -22,6 +23,7 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; @@ -42,7 +44,7 @@ import java.util.Stack; public class AsyncNStepQLearningThreadDiscrete extends AsyncThreadDiscrete { @Getter - final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf; + final protected AsyncQLearningConfiguration conf; @Getter final protected IAsyncGlobal asyncGlobal; @Getter @@ -51,7 +53,7 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn final private Random rnd; public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, - AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, + AsyncQLearningConfiguration conf, TrainingListenerList listeners, int threadNumber, int deviceNum) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); this.conf = conf; @@ -59,7 +61,7 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn this.threadNumber = threadNumber; rnd = Nd4j.getRandom(); - Integer seed = conf.getSeed(); + Long seed = conf.getSeed(); if(seed != null) { rnd.setSeed(seed + threadNumber); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java new file mode 100644 index 000000000..226fe4419 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@EqualsAndHashCode(callSuper = true) +public class A3CLearningConfiguration extends LearningConfiguration implements IAsyncLearningConfiguration { + + /** + * The number of asynchronous threads to use to generate gradients + */ + private final int numThreads; + + /** + * The number of steps to calculate gradients over + */ + private final int nStep; + + /** + * The frequency of async training iterations to update the target network. + * + * If this is set to -1 then the target network is updated after every training iteration + */ + @Builder.Default + private int learnerUpdateFrequency = -1; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java new file mode 100644 index 000000000..a60903e59 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.SuperBuilder; + + +@Data +@SuperBuilder +@EqualsAndHashCode(callSuper = true) +public class AsyncQLearningConfiguration extends QLearningConfiguration implements IAsyncLearningConfiguration { + + /** + * The number of asynchronous threads to use to generate experience data + */ + private final int numThreads; + + /** + * The number of steps in each training interations + */ + private final int nStep; + + public int getLearnerUpdateFrequency() { + return getTargetDqnUpdateFreq(); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java new file mode 100644 index 000000000..1e7cf3f2e --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +public interface IAsyncLearningConfiguration extends ILearningConfiguration { + + int getNumThreads(); + + int getNStep(); + + int getLearnerUpdateFrequency(); + + int getMaxStep(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java similarity index 61% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java index 0727db475..7ae215087 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,36 +14,16 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.rl4j.learning.async; +package org.deeplearning4j.rl4j.learning.configuration; -import org.deeplearning4j.rl4j.learning.ILearning; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/23/16. - * - * Interface configuration for all training method that inherit - * from AsyncLearning - */ -public interface AsyncConfiguration extends ILearning.LConfiguration { - - Integer getSeed(); +public interface ILearningConfiguration { + Long getSeed(); int getMaxEpochStep(); int getMaxStep(); - int getNumThread(); - - int getNstep(); - - int getTargetDqnUpdateFreq(); - - int getUpdateStart(); - - double getRewardFactor(); - double getGamma(); - double getErrorClamp(); - + double getRewardFactor(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java new file mode 100644 index 000000000..d1567e619 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@NoArgsConstructor +public class LearningConfiguration implements ILearningConfiguration { + + /** + * Seed value used for training + */ + @Builder.Default + private Long seed = System.currentTimeMillis(); + + /** + * The maximum number of steps in each episode + */ + @Builder.Default + private int maxEpochStep = 200; + + /** + * The maximum number of steps to train for + */ + @Builder.Default + private int maxStep = 150000; + + /** + * Gamma parameter used for discounted rewards + */ + @Builder.Default + private double gamma = 0.99; + + /** + * Scaling parameter for rewards + */ + @Builder.Default + private double rewardFactor = 1.0; + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java new file mode 100644 index 000000000..26ac57f0c --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java @@ -0,0 +1,79 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@NoArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class QLearningConfiguration extends LearningConfiguration { + + /** + * The maximum size of the experience replay buffer + */ + @Builder.Default + private int expRepMaxSize = 150000; + + /** + * The batch size of experience for each training iteration + */ + @Builder.Default + private int batchSize = 32; + + /** + * How many steps between target network updates + */ + @Builder.Default + private int targetDqnUpdateFreq = 100; + + /** + * The number of steps to initially wait for until samplling batches from experience replay buffer + */ + @Builder.Default + private int updateStart = 10; + + /** + * Prevent the new Q-Value from being farther than errorClamp away from the previous value. Double.NaN will result in no clamping + */ + @Builder.Default + private double errorClamp = 1.0; + + /** + * The minimum probability for random exploration action during episilon-greedy annealing + */ + @Builder.Default + private double minEpsilon = 0.1f; + + /** + * The number of steps to anneal epsilon to its minimum value. + */ + @Builder.Default + private int epsilonNbStep = 10000; + + /** + * Whether to use the double DQN algorithm + */ + @Builder.Default + private boolean doubleDQN = false; + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java index 22d936fcf..c42756145 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -63,7 +64,7 @@ public abstract class SyncLearning, NN extends N /** * This method will train the model

* The training stop when:
- * - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})
+ * - the number of steps reaches the maximum defined in the configuration (see {@link ILearningConfiguration#getMaxStep() LConfiguration.getMaxStep()})
* OR
* - a listener explicitly stops it
*

diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 0757043f0..40704d4e9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,10 +19,19 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; -import lombok.*; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import lombok.Value; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.EpochStepCounter; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.ExpReplay; import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.SyncLearning; @@ -59,15 +69,15 @@ public abstract class QLearning getLegacyMDPWrapper(); - public QLearning(QLConfiguration conf) { + public QLearning(QLearningConfiguration conf) { this(conf, getSeededRandom(conf.getSeed())); } - public QLearning(QLConfiguration conf, Random random) { + public QLearning(QLearningConfiguration conf, Random random) { expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random); } - private static Random getSeededRandom(Integer seed) { + private static Random getSeededRandom(Long seed) { Random rnd = Nd4j.getRandom(); if(seed != null) { rnd.setSeed(seed); @@ -95,7 +105,7 @@ public abstract class QLearning scores; - float epsilon; + double epsilon; double startQ; double meanQ; } @@ -213,12 +223,14 @@ public abstract class QLearning * DQN or Deep Q-Learning in the Discrete domain - * + *

* http://arxiv.org/abs/1312.5602 - * */ public abstract class QLearningDiscrete extends QLearning { @Getter - final private QLConfiguration configuration; + final private QLearningConfiguration configuration; private final LegacyMDPWrapper mdp; @Getter private DQNPolicy policy; @@ -78,16 +79,15 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLConfiguration conf, - int epsilonNbStep) { + public QLearningDiscrete(MDP mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep) { this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed())); } - public QLearningDiscrete(MDP mdp, IDQN dqn, QLConfiguration conf, + public QLearningDiscrete(MDP mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) { super(conf); this.configuration = conf; - this.mdp = new LegacyMDPWrapper(mdp, null, this); + this.mdp = new LegacyMDPWrapper<>(mdp, null, this); qNetwork = dqn; targetQNetwork = dqn.clone(); policy = new DQNPolicy(getQNetwork()); @@ -125,6 +125,7 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning extends QLearning extends QLearning extends QLearningDiscret @Deprecated public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, - QLConfiguration conf, IDataManager dataManager) { + QLConfiguration conf, IDataManager dataManager) { this(mdp, dqn, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } + + @Deprecated public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { + super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame()); + setHistoryProcessor(hpconf); + } + + public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, + QLearningConfiguration conf) { super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); setHistoryProcessor(hpconf); } @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } + + @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } + public QLearningDiscreteConv(MDP mdp, DQNFactory factory, + HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { + this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); + } + @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { - this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); + HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { + this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); } + + @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { + this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf); + } + + public QLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java index ef69ea6fb..5b95cc84e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,8 +17,10 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; @@ -38,7 +41,13 @@ public class QLearningDiscreteDense extends QLearningDiscre this(mdp, dqn, conf); addListener(new DataManagerTrainingListener(dataManager)); } + + @Deprecated public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf) { + super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep()); + } + + public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearningConfiguration conf) { super(mdp, dqn, conf, conf.getEpsilonNbStep()); } @@ -48,18 +57,33 @@ public class QLearningDiscreteDense extends QLearningDiscre this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } + + @Deprecated public QLearningDiscreteDense(MDP mdp, DQNFactory factory, QLearning.QLConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } + public QLearningDiscreteDense(MDP mdp, DQNFactory factory, + QLearningConfiguration conf) { + this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + @Deprecated public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, QLearning.QLConfiguration conf, IDataManager dataManager) { - this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); + + this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } + + @Deprecated public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, QLearning.QLConfiguration conf) { + this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); + } + + public QLearningDiscreteDense(MDP mdp, DQNDenseNetworkConfiguration netConf, + QLearningConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf), conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java index 274606ed9..63438bb74 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -36,7 +37,7 @@ import java.util.Collection; * * Standard implementation of ActorCriticCompGraph */ -public class ActorCriticCompGraph implements IActorCritic { +public class ActorCriticCompGraph implements IActorCritic { final protected ComputationGraph cg; @Getter @@ -73,13 +74,13 @@ public class ActorCriticCompGraph implements IA } } - public NN clone() { - NN nn = (NN)new ActorCriticCompGraph(cg.clone()); + public ActorCriticCompGraph clone() { + ActorCriticCompGraph nn = new ActorCriticCompGraph(cg.clone()); nn.cg.setListeners(cg.getListeners()); return nn; } - public void copy(NN from) { + public void copy(ActorCriticCompGraph from) { cg.setParams(from.cg.params()); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java index bdadd2969..eaccf2a10 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -31,12 +32,16 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration.ActorCriticNetworkConfigurationBuilder; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.Arrays; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. * @@ -45,8 +50,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; @Value public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCompGraph { - - Configuration conf; + ActorCriticNetworkConfiguration conf; public ActorCriticCompGraph buildActorCritic(int shapeInputs[], int numOutputs) { @@ -109,16 +113,33 @@ public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCom return new ActorCriticCompGraph(model); } - @AllArgsConstructor @Builder @Value + @Deprecated public static class Configuration { double l2; IUpdater updater; TrainingListener[] listeners; boolean useLSTM; + + /** + * Converts the deprecated Configuration to the new NetworkConfiguration format + */ + public ActorCriticNetworkConfiguration toNetworkConfiguration() { + ActorCriticNetworkConfigurationBuilder builder = ActorCriticNetworkConfiguration.builder() + .l2(l2) + .updater(updater) + .useLSTM(useLSTM); + + if (listeners != null) { + builder.listeners(Arrays.asList(listeners)); + } + + return builder.build(); + + } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java index 7c9e3e21b..0d9dae3c6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,8 +17,6 @@ package org.deeplearning4j.rl4j.network.ac; -import lombok.AllArgsConstructor; -import lombok.Builder; import lombok.Value; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; @@ -29,12 +28,11 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; /** @@ -45,7 +43,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; @Value public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCompGraph { - Configuration conf; + ActorCriticDenseNetworkConfiguration conf; public ActorCriticCompGraph buildActorCritic(int[] numInputs, int numOutputs) { int nIn = 1; @@ -65,27 +63,27 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo "input"); - for (int i = 1; i < conf.getNumLayer(); i++) { + for (int i = 1; i < conf.getNumLayers(); i++) { confB.addLayer(i + "", new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) .activation(Activation.RELU).build(), (i - 1) + ""); } if (conf.isUseLSTM()) { - confB.addLayer(getConf().getNumLayer() + "", new LSTM.Builder().activation(Activation.TANH) - .nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayer() - 1) + ""); + confB.addLayer(getConf().getNumLayers() + "", new LSTM.Builder().activation(Activation.TANH) + .nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayers() - 1) + ""); confB.addLayer("value", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nOut(1).build(), getConf().getNumLayer() + ""); + .nOut(1).build(), getConf().getNumLayers() + ""); confB.addLayer("softmax", new RnnOutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX) - .nOut(numOutputs).build(), getConf().getNumLayer() + ""); + .nOut(numOutputs).build(), getConf().getNumLayers() + ""); } else { confB.addLayer("value", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nOut(1).build(), (getConf().getNumLayer() - 1) + ""); + .nOut(1).build(), (getConf().getNumLayers() - 1) + ""); confB.addLayer("softmax", new OutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX) - .nOut(numOutputs).build(), (getConf().getNumLayer() - 1) + ""); + .nOut(numOutputs).build(), (getConf().getNumLayers() - 1) + ""); } confB.setOutputs("value", "softmax"); @@ -103,18 +101,4 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo return new ActorCriticCompGraph(model); } - @AllArgsConstructor - @Builder - @Value - public static class Configuration { - - int numLayer; - int numHiddenNodes; - double l2; - IUpdater updater; - TrainingListener[] listeners; - boolean useLSTM; - } - - } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java index a55e351c0..4ac557096 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -31,21 +32,24 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; + +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration.ActorCriticDenseNetworkConfigurationBuilder; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.Arrays; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. - * - * */ @Value public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySeparate { - Configuration conf; + ActorCriticDenseNetworkConfiguration conf; public ActorCriticSeparate buildActorCritic(int[] numInputs, int numOutputs) { int nIn = 1; @@ -53,27 +57,27 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep nIn *= i; } NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) - .weightInit(WeightInit.XAVIER) - .l2(conf.getL2()) - .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) + .weightInit(WeightInit.XAVIER) + .l2(conf.getL2()) + .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) + .activation(Activation.RELU).build()); - for (int i = 1; i < conf.getNumLayer(); i++) { + for (int i = 1; i < conf.getNumLayers(); i++) { confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .activation(Activation.RELU).build()); } if (conf.isUseLSTM()) { - confB.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); + confB.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); - confB.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nIn(conf.getNumHiddenNodes()).nOut(1).build()); + confB.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) + .nIn(conf.getNumHiddenNodes()).nOut(1).build()); } else { - confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nIn(conf.getNumHiddenNodes()).nOut(1).build()); + confB.layer(conf.getNumLayers(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) + .nIn(conf.getNumHiddenNodes()).nOut(1).build()); } confB.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn)); @@ -87,28 +91,28 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep } NeuralNetConfiguration.ListBuilder confB2 = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) - .weightInit(WeightInit.XAVIER) - //.regularization(true) - //.l2(conf.getL2()) - .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) + .weightInit(WeightInit.XAVIER) + //.regularization(true) + //.l2(conf.getL2()) + .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) + .activation(Activation.RELU).build()); - for (int i = 1; i < conf.getNumLayer(); i++) { + for (int i = 1; i < conf.getNumLayers(); i++) { confB2.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .activation(Activation.RELU).build()); } if (conf.isUseLSTM()) { - confB2.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); + confB2.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); - confB2.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss()) - .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); + confB2.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss()) + .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); } else { - confB2.layer(conf.getNumLayer(), new OutputLayer.Builder(new ActorCriticLoss()) - .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); + confB2.layer(conf.getNumLayers(), new OutputLayer.Builder(new ActorCriticLoss()) + .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); } confB2.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn)); @@ -128,6 +132,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep @AllArgsConstructor @Value @Builder + @Deprecated public static class Configuration { int numLayer; @@ -136,6 +141,22 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep IUpdater updater; TrainingListener[] listeners; boolean useLSTM; + + public ActorCriticDenseNetworkConfiguration toNetworkConfiguration() { + ActorCriticDenseNetworkConfigurationBuilder builder = ActorCriticDenseNetworkConfiguration.builder() + .numHiddenNodes(numHiddenNodes) + .numLayers(numLayer) + .l2(l2) + .updater(updater) + .useLSTM(useLSTM); + + if (listeners != null) { + builder.listeners(Arrays.asList(listeners)); + } + + return builder.build(); + + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java new file mode 100644 index 000000000..e85ec6356 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.network.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@EqualsAndHashCode(callSuper = true) +public class ActorCriticDenseNetworkConfiguration extends ActorCriticNetworkConfiguration { + + /** + * The number of layers in the dense network + */ + @Builder.Default + private int numLayers = 3; + + /** + * The number of hidden neurons in each layer + */ + @Builder.Default + private int numHiddenNodes = 100; + + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java new file mode 100644 index 000000000..c043f458e --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java @@ -0,0 +1,37 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.network.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class ActorCriticNetworkConfiguration extends NetworkConfiguration { + + /** + * Whether or not to add an LSTM layer to the network. + */ + @Builder.Default + private boolean useLSTM = false; + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java new file mode 100644 index 000000000..452cb83c2 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.network.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@EqualsAndHashCode(callSuper = true) +public class DQNDenseNetworkConfiguration extends NetworkConfiguration { + + /** + * The number of layers in the dense network + */ + @Builder.Default + private int numLayers = 3; + + /** + * The number of hidden neurons in each layer + */ + @Builder.Default + private int numHiddenNodes = 100; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java new file mode 100644 index 000000000..c77c379a2 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.network.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.Singular; +import lombok.experimental.SuperBuilder; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.learning.config.IUpdater; + +import java.util.List; + + +@Data +@SuperBuilder +@NoArgsConstructor +public class NetworkConfiguration { + + /** + * The learning rate of the network + */ + @Builder.Default + private double learningRate = 0.01; + + /** + * L2 regularization on the network + */ + @Builder.Default + private double l2 = 0.0; + + /** + * The network's gradient update algorithm + */ + private IUpdater updater; + + /** + * Training listeners attached to the network + */ + @Singular + private List listeners; + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java index ec09d1c1c..077bbf1ce 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -30,12 +31,15 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.Arrays; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16. */ @@ -43,7 +47,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; public class DQNFactoryStdConv implements DQNFactory { - Configuration conf; + NetworkConfiguration conf; public DQN buildDQN(int shapeInputs[], int numOutputs) { @@ -80,7 +84,6 @@ public class DQNFactoryStdConv implements DQNFactory { return new DQN(model); } - @AllArgsConstructor @Builder @Value @@ -90,6 +93,23 @@ public class DQNFactoryStdConv implements DQNFactory { double l2; IUpdater updater; TrainingListener[] listeners; + + /** + * Converts the deprecated Configuration to the new NetworkConfiguration format + */ + public NetworkConfiguration toNetworkConfiguration() { + NetworkConfiguration.NetworkConfigurationBuilder builder = NetworkConfiguration.builder() + .learningRate(learningRate) + .l2(l2) + .updater(updater); + + if (listeners != null) { + builder.listeners(Arrays.asList(listeners)); + } + + return builder.build(); + + } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java index 323ca7ecb..ebe730b4d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,12 +29,16 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration.DQNDenseNetworkConfigurationBuilder; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.Arrays; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16. */ @@ -41,32 +46,41 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; @Value public class DQNFactoryStdDense implements DQNFactory { - - Configuration conf; + DQNDenseNetworkConfiguration conf; public DQN buildDQN(int[] numInputs, int numOutputs) { int nIn = 1; + for (int i : numInputs) { nIn *= i; } + NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - //.updater(Updater.NESTEROVS).momentum(0.9) - //.updater(Updater.RMSPROP).rho(conf.getRmsDecay())//.rmsDecay(conf.getRmsDecay()) - .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) - .weightInit(WeightInit.XAVIER) - .l2(conf.getL2()) - .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) + .weightInit(WeightInit.XAVIER) + .l2(conf.getL2()) + .list() + .layer(0, + new DenseLayer.Builder() + .nIn(nIn) + .nOut(conf.getNumHiddenNodes()) + .activation(Activation.RELU).build() + ); - for (int i = 1; i < conf.getNumLayer(); i++) { + for (int i = 1; i < conf.getNumLayers(); i++) { confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .activation(Activation.RELU).build()); } - confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); + confB.layer(conf.getNumLayers(), + new OutputLayer.Builder(LossFunctions.LossFunction.MSE) + .activation(Activation.IDENTITY) + .nIn(conf.getNumHiddenNodes()) + .nOut(numOutputs) + .build() + ); MultiLayerConfiguration mlnconf = confB.build(); @@ -83,6 +97,7 @@ public class DQNFactoryStdDense implements DQNFactory { @AllArgsConstructor @Value @Builder + @Deprecated public static class Configuration { int numLayer; @@ -90,7 +105,23 @@ public class DQNFactoryStdDense implements DQNFactory { double l2; IUpdater updater; TrainingListener[] listeners; + + /** + * Converts the deprecated Configuration to the new NetworkConfiguration format + */ + public DQNDenseNetworkConfiguration toNetworkConfiguration() { + DQNDenseNetworkConfigurationBuilder builder = DQNDenseNetworkConfiguration.builder() + .numHiddenNodes(numHiddenNodes) + .numLayers(numLayer) + .l2(l2) + .updater(updater); + + if (listeners != null) { + builder.listeners(Arrays.asList(listeners)); + } + + return builder.build(); + } } - } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java index 3ed375084..3454a37e6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -46,7 +47,7 @@ public class EpsGreedy> extends Policy { final private int updateStart; final private int epsilonNbStep; final private Random rnd; - final private float minEpsilon; + final private double minEpsilon; final private IEpochTrainer learning; public NeuralNet getNeuralNet() { @@ -55,10 +56,10 @@ public class EpsGreedy> extends Policy { public A nextAction(INDArray input) { - float ep = getEpsilon(); + double ep = getEpsilon(); if (learning.getStepCounter() % 500 == 1) log.info("EP: " + ep + " " + learning.getStepCounter()); - if (rnd.nextFloat() > ep) + if (rnd.nextDouble() > ep) return policy.nextAction(input); else return mdp.getActionSpace().randomAction(); @@ -68,7 +69,7 @@ public class EpsGreedy> extends Policy { return this.nextAction(observation.getData()); } - public float getEpsilon() { - return Math.min(1f, Math.max(minEpsilon, 1f - (learning.getStepCounter() - updateStart) * 1f / epsilonNbStep)); + public double getEpsilon() { + return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep)); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java index b639efdaa..bffafdb76 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,17 +23,30 @@ import lombok.Builder; import lombok.Getter; import lombok.Value; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; -import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.rl4j.learning.ILearning; -import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.linalg.primitives.Pair; -import java.io.*; -import java.nio.file.*; +import java.io.BufferedOutputStream; +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.nio.file.StandardOpenOption; import java.util.zip.ZipEntry; import java.util.zip.ZipFile; import java.util.zip.ZipOutputStream; @@ -304,7 +318,7 @@ public class DataManager implements IDataManager { public static class Info { String trainingName; String mdpName; - ILearning.LConfiguration conf; + ILearningConfiguration conf; int stepCounter; long millisTime; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java index 26ec0708f..8718d252d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,14 +17,11 @@ package org.deeplearning4j.rl4j.learning; -import java.util.Arrays; import org.junit.Test; -import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; /** * @@ -32,7 +30,7 @@ import static org.junit.Assert.assertTrue; public class HistoryProcessorTest { @Test - public void testHistoryProcessor() throws Exception { + public void testHistoryProcessor() { HistoryProcessor.Configuration conf = HistoryProcessor.Configuration.builder() .croppingHeight(2).croppingWidth(2).rescaledHeight(2).rescaledWidth(2).build(); IHistoryProcessor hp = new HistoryProcessor(conf); @@ -43,8 +41,6 @@ public class HistoryProcessorTest { hp.add(a); INDArray[] h = hp.getHistory(); assertEquals(4, h.length); -// System.out.println(Arrays.toString(a.shape())); -// System.out.println(Arrays.toString(h[0].shape())); assertEquals( 1, h[0].shape()[0]); assertEquals(a.shape()[0], h[0].shape()[1]); assertEquals(a.shape()[1], h[0].shape()[2]); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java index 2302117d2..f2941feef 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java @@ -1,9 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.support.MockAsyncConfiguration; +import org.deeplearning4j.rl4j.support.MockAsyncGlobal; +import org.deeplearning4j.rl4j.support.MockEncodable; +import org.deeplearning4j.rl4j.support.MockNeuralNet; +import org.deeplearning4j.rl4j.support.MockPolicy; +import org.deeplearning4j.rl4j.support.MockTrainingListener; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -68,7 +91,7 @@ public class AsyncLearningTest { public static class TestContext { - MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0); + MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0); public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); public final MockPolicy policy = new MockPolicy(); public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy); @@ -82,11 +105,11 @@ public class AsyncLearningTest { } public static class TestAsyncLearning extends AsyncLearning { - private final AsyncConfiguration conf; + private final IAsyncLearningConfiguration conf; private final IAsyncGlobal asyncGlobal; private final IPolicy policy; - public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy policy) { + public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy policy) { this.conf = conf; this.asyncGlobal = asyncGlobal; this.policy = policy; @@ -98,7 +121,7 @@ public class AsyncLearningTest { } @Override - public AsyncConfiguration getConfiguration() { + public IAsyncLearningConfiguration getConfiguration() { return conf; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index bc396502f..72f374db5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -1,7 +1,25 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; @@ -32,7 +50,7 @@ public class AsyncThreadDiscreteTest { MockMDP mdpMock = new MockMDP(observationSpace); TrainingListenerList listeners = new TrainingListenerList(); MockPolicy policyMock = new MockPolicy(); - MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0); + MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5); TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); @@ -173,7 +191,7 @@ public class AsyncThreadDiscreteTest { } @Override - protected AsyncConfiguration getConf() { + protected IAsyncLearningConfiguration getConf() { return config; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 3dea25936..ff29960f1 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -3,12 +3,20 @@ package org.deeplearning4j.rl4j.learning.async; import lombok.AllArgsConstructor; import lombok.Getter; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.support.MockAsyncConfiguration; +import org.deeplearning4j.rl4j.support.MockAsyncGlobal; +import org.deeplearning4j.rl4j.support.MockEncodable; +import org.deeplearning4j.rl4j.support.MockHistoryProcessor; +import org.deeplearning4j.rl4j.support.MockMDP; +import org.deeplearning4j.rl4j.support.MockNeuralNet; +import org.deeplearning4j.rl4j.support.MockObservationSpace; +import org.deeplearning4j.rl4j.support.MockTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; @@ -16,7 +24,6 @@ import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; public class AsyncThreadTest { @@ -126,7 +133,7 @@ public class AsyncThreadTest { public final MockNeuralNet neuralNet = new MockNeuralNet(); public final MockObservationSpace observationSpace = new MockObservationSpace(); public final MockMDP mdp = new MockMDP(observationSpace); - public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0); + public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0); public final TrainingListenerList listeners = new TrainingListenerList(); public final MockTrainingListener listener = new MockTrainingListener(); public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); @@ -149,11 +156,11 @@ public class AsyncThreadTest { private final MockAsyncGlobal asyncGlobal; private final MockNeuralNet neuralNet; - private final AsyncConfiguration conf; + private final IAsyncLearningConfiguration conf; private final List trainSubEpochParams = new ArrayList(); - public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { + public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) { super(asyncGlobal, mdp, listeners, threadNumber, 0); this.asyncGlobal = asyncGlobal; @@ -184,7 +191,7 @@ public class AsyncThreadTest { } @Override - protected AsyncConfiguration getConf() { + protected IAsyncLearningConfiguration getConf() { return conf; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java index ef7fec7d0..b812a5582 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java @@ -1,11 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async.a3c.discrete; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.async.MiniTrans; -import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete; -import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete; +import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.support.*; @@ -31,7 +47,7 @@ public class A3CThreadDiscreteTest { double gamma = 0.9; MockObservationSpace observationSpace = new MockObservationSpace(); MockMDP mdpMock = new MockMDP(observationSpace); - A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0); + A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build(); MockActorCritic actorCriticMock = new MockActorCritic(); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock); @@ -54,9 +70,9 @@ public class A3CThreadDiscreteTest { Nd4j.zeros(5) }; output[0].putScalar(i, outputs[i]); - minitransList.push(new MiniTrans(obs, i, output, rewards[i])); + minitransList.push(new MiniTrans<>(obs, i, output, rewards[i])); } - minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans + minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans // Act sut.calcGradient(actorCriticMock, minitransList); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java index d105419df..2a8c5b832 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java @@ -1,7 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2015-2020 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.support.*; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; @@ -19,7 +36,7 @@ public class AsyncNStepQLearningThreadDiscreteTest { double gamma = 0.9; MockObservationSpace observationSpace = new MockObservationSpace(); MockMDP mdpMock = new MockMDP(observationSpace); - AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0); + AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build(); MockDQN dqnMock = new MockDQN(); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); @@ -42,9 +59,9 @@ public class AsyncNStepQLearningThreadDiscreteTest { Nd4j.zeros(5) }; output[0].putScalar(i, outputs[i]); - minitransList.push(new MiniTrans(obs, i, output, rewards[i])); + minitransList.push(new MiniTrans<>(obs, i, output, rewards[i])); } - minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans + minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans // Act sut.calcGradient(dqnMock, minitransList); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java index 79be025b5..22e4be3f6 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java @@ -1,6 +1,26 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.sync; import lombok.Getter; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; import org.deeplearning4j.rl4j.mdp.MDP; @@ -17,7 +37,7 @@ public class SyncLearningTest { @Test public void when_training_expect_listenersToBeCalled() { // Arrange - QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); @@ -34,7 +54,7 @@ public class SyncLearningTest { @Test public void when_trainingStartCanContinueFalse_expect_trainingStopped() { // Arrange - QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); @@ -52,7 +72,7 @@ public class SyncLearningTest { @Test public void when_newEpochCanContinueFalse_expect_trainingStopped() { // Arrange - QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); @@ -70,7 +90,7 @@ public class SyncLearningTest { @Test public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() { // Arrange - QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); @@ -87,12 +107,12 @@ public class SyncLearningTest { public static class MockSyncLearning extends SyncLearning { - private final LConfiguration conf; + private final ILearningConfiguration conf; @Getter private int currentEpochStep = 0; - public MockSyncLearning(LConfiguration conf) { + public MockSyncLearning(ILearningConfiguration conf) { this.conf = conf; } @@ -119,7 +139,7 @@ public class SyncLearningTest { } @Override - public LConfiguration getConfiguration() { + public ILearningConfiguration getConfiguration() { return conf; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLConfigurationTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java similarity index 52% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLConfigurationTest.java rename to rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java index b12866ed2..d7d9bf072 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLConfigurationTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -17,36 +18,24 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning; import com.fasterxml.jackson.databind.ObjectMapper; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -public class QLConfigurationTest { +public class QLearningConfigurationTest { @Rule public ExpectedException thrown = ExpectedException.none(); @Test public void serialize() throws Exception { ObjectMapper mapper = new ObjectMapper(); - QLearning.QLConfiguration qlConfiguration = - new QLearning.QLConfiguration( - 123, //Random seed - 200, //Max step By epoch - 8000, //Max step - 150000, //Max size of experience replay - 32, //size of batches - 500, //target update (hard) - 10, //num step noop warmup - 0.01, //reward scaling - 0.99, //gamma - 1.0, //td error clipping - 0.1f, //min epsilon - 10000, //num step for eps greedy anneal - true //double DQN - ); + + QLearningConfiguration qLearningConfiguration = QLearningConfiguration.builder() + .build(); // Should not throw.. - String json = mapper.writeValueAsString(qlConfiguration); - QLearning.QLConfiguration cnf = mapper.readValue(json, QLearning.QLConfiguration.class); + String json = mapper.writeValueAsString(qLearningConfiguration); + QLearningConfiguration cnf = mapper.readValue(json, QLearningConfiguration.class); } -} \ No newline at end of file +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 58aaab297..fe8dd6acc 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -1,6 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; @@ -27,7 +45,7 @@ public class QLearningDiscreteTest { // Arrange MockObservationSpace observationSpace = new MockObservationSpace(); MockDQN dqn = new MockDQN(); - MockRandom random = new MockRandom(new double[] { + MockRandom random = new MockRandom(new double[]{ 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, @@ -36,14 +54,26 @@ public class QLearningDiscreteTest { 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 - }, - new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); + }, + new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}); MockMDP mdp = new MockMDP(observationSpace, random); int initStepCount = 8; - QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000, - initStepCount, 1.0, 0, 0, 0, 0, true); + QLearningConfiguration conf = QLearningConfiguration.builder() + .seed(0L) + .maxEpochStep(24) + .maxStep(0) + .expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000) + .updateStart(initStepCount) + .rewardFactor(1.0) + .gamma(0) + .errorClamp(0) + .minEpsilon(0) + .epsilonNbStep(0) + .doubleDQN(true) + .build(); + MockDataManager dataManager = new MockDataManager(false); MockExpReplay expReplay = new MockExpReplay(); TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random); @@ -58,9 +88,9 @@ public class QLearningDiscreteTest { // Assert // HistoryProcessor calls - double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 }; + double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; assertEquals(expectedRecords.length, hp.recordCalls.size()); - for(int i = 0; i < expectedRecords.length; ++i) { + for (int i = 0; i < expectedRecords.length; ++i) { assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); } @@ -72,59 +102,59 @@ public class QLearningDiscreteTest { assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001); assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001); assertEquals(14, dqn.outputParams.size()); - double[][] expectedDQNOutput = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, - new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, - new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, - new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, - new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, - new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, - new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, + double[][] expectedDQNOutput = new double[][]{ + new double[]{0.0, 2.0, 4.0, 6.0, 8.0}, + new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, + new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, + new double[]{4.0, 6.0, 8.0, 10.0, 12.0}, + new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, + new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, + new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, + new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, + new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, + new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, + new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, + new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, + new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, + new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, }; - for(int i = 0; i < expectedDQNOutput.length; ++i) { + for (int i = 0; i < expectedDQNOutput.length; ++i) { INDArray outputParam = dqn.outputParams.get(i); assertEquals(5, outputParam.shape()[1]); assertEquals(1, outputParam.shape()[2]); double[] expectedRow = expectedDQNOutput[i]; - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001); + for (int j = 0; j < expectedRow.length; ++j) { + assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001); } } // MDP calls - assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray()); + assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray()); // ExpReplay calls - double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 }; - int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; - double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; - double[][] expectedTrObservations = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, - new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, - new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, - new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, + double[] expectedTrRewards = new double[]{9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0}; + int[] expectedTrActions = new int[]{1, 4, 2, 4, 4, 4, 4, 4}; + double[] expectedTrNextObservation = new double[]{2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0}; + double[][] expectedTrObservations = new double[][]{ + new double[]{0.0, 2.0, 4.0, 6.0, 8.0}, + new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, + new double[]{4.0, 6.0, 8.0, 10.0, 12.0}, + new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, + new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, + new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, + new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, + new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, }; assertEquals(expectedTrObservations.length, expReplay.transitions.size()); - for(int i = 0; i < expectedTrRewards.length; ++i) { + for (int i = 0; i < expectedTrRewards.length; ++i) { Transition tr = expReplay.transitions.get(i); assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001); assertEquals(expectedTrActions[i], tr.getAction()); assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001); - for(int j = 0; j < expectedTrObservations[i].length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001); + for (int j = 0; j < expectedTrObservations[i].length; ++j) { + assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001); } } @@ -132,12 +162,12 @@ public class QLearningDiscreteTest { assertEquals(initStepCount + 16, result.getStepCounter()); assertEquals(300.0, result.getReward(), 0.00001); assertTrue(dqn.hasBeenReset); - assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset); + assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset); } public static class TestQLearningDiscrete extends QLearningDiscrete { public TestQLearningDiscrete(MDP mdp, IDQN dqn, - QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, + QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, int epsilonNbStep, Random rnd) { super(mdp, dqn, conf, epsilonNbStep, rnd); addListener(new DataManagerTrainingListener(dataManager)); @@ -146,10 +176,10 @@ public class QLearningDiscreteTest { @Override protected DataSet setTarget(ArrayList> transitions) { - return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); + return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[]{123.0}), Nd4j.create(new double[]{234.0})); } - public void setExpReplay(IExpReplay exp){ + public void setExpReplay(IExpReplay exp) { this.expReplay = exp; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java index c43c26d50..821863054 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,7 @@ package org.deeplearning4j.rl4j.network.ac; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; import org.junit.Test; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.ndarray.INDArray; @@ -29,30 +31,31 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** - * * @author saudet */ public class ActorCriticTest { - public static ActorCriticFactorySeparateStdDense.Configuration NET_CONF = - new ActorCriticFactorySeparateStdDense.Configuration( - 4, //number of layers - 32, //number of hidden nodes - 0.001, //l2 regularization - new RmsProp(0.0005), null, false - ); + public static ActorCriticDenseNetworkConfiguration NET_CONF = + ActorCriticDenseNetworkConfiguration.builder() + .numLayers(4) + .numHiddenNodes(32) + .l2(0.001) + .updater(new RmsProp(0.0005)) + .useLSTM(false) + .build(); - public static ActorCriticFactoryCompGraphStdDense.Configuration NET_CONF_CG = - new ActorCriticFactoryCompGraphStdDense.Configuration( - 2, //number of layers - 128, //number of hidden nodes - 0.00001, //l2 regularization - new RmsProp(0.005), null, true - ); + public static ActorCriticDenseNetworkConfiguration NET_CONF_CG = + ActorCriticDenseNetworkConfiguration.builder() + .numLayers(2) + .numHiddenNodes(128) + .l2(0.00001) + .updater(new RmsProp(0.005)) + .useLSTM(true) + .build(); @Test public void testModelLoadSave() throws IOException { - ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[] {7}, 5); + ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[]{7}, 5); File fileValue = File.createTempFile("rl4j-value-", ".model"); File filePolicy = File.createTempFile("rl4j-policy-", ".model"); @@ -63,7 +66,7 @@ public class ActorCriticTest { assertEquals(acs.valueNet, acs2.valueNet); assertEquals(acs.policyNet, acs2.policyNet); - ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[] {37}, 43); + ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[]{37}, 43); File file = File.createTempFile("rl4j-cg-", ".model"); accg.save(file.getAbsolutePath()); @@ -83,15 +86,15 @@ public class ActorCriticTest { for (double i = eps; i < n; i++) { for (double j = eps; j < n; j++) { - INDArray labels = Nd4j.create(new double[] {i / n, 1 - i / n}, new long[]{1,2}); - INDArray output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2}); + INDArray labels = Nd4j.create(new double[]{i / n, 1 - i / n}, new long[]{1, 2}); + INDArray output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2}); INDArray gradient = loss.computeGradient(labels, output, activation, null); - output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2}); + output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2}); double score = loss.computeScore(labels, output, activation, null, false); - INDArray output1 = Nd4j.create(new double[] {j / n + eps, 1 - j / n}, new long[]{1,2}); + INDArray output1 = Nd4j.create(new double[]{j / n + eps, 1 - j / n}, new long[]{1, 2}); double score1 = loss.computeScore(labels, output1, activation, null, false); - INDArray output2 = Nd4j.create(new double[] {j / n, 1 - j / n + eps}, new long[]{1,2}); + INDArray output2 = Nd4j.create(new double[]{j / n, 1 - j / n + eps}, new long[]{1, 2}); double score2 = loss.computeScore(labels, output2, activation, null, false); double gradient1 = (score1 - score) / eps; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java index 3f68b8f3c..a9997ec0c 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,7 @@ package org.deeplearning4j.rl4j.network.dqn; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration; import org.junit.Test; import org.nd4j.linalg.learning.config.RmsProp; @@ -25,22 +27,20 @@ import java.io.IOException; import static org.junit.Assert.assertEquals; /** - * * @author saudet */ public class DQNTest { - public static DQNFactoryStdDense.Configuration NET_CONF = - new DQNFactoryStdDense.Configuration( - 3, //number of layers - 16, //number of hidden nodes - 0.001, //l2 regularization - new RmsProp(0.0005), null - ); + private static DQNDenseNetworkConfiguration NET_CONF = + DQNDenseNetworkConfiguration.builder().numLayers(3) + .numHiddenNodes(16) + .l2(0.001) + .updater(new RmsProp(0.0005)) + .build(); @Test public void testModelLoadSave() throws IOException { - DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[] {42}, 13); + DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[]{42}, 13); File file = File.createTempFile("rl4j-dqn-", ".model"); dqn.save(file.getAbsolutePath()); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java index fe79bdfc7..3f5e761a6 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java @@ -128,7 +128,7 @@ public class TransformProcessTest { // Assert assertFalse(result.isSkipped()); - assertEquals(1, result.getData().shape().length); + assertEquals(2, result.getData().shape().length); assertEquals(1, result.getData().shape()[0]); assertEquals(-10.0, result.getData().getDouble(0), 0.00001); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 0707e16ab..0dc16df09 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -24,16 +25,18 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; -import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest; -import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.support.MockDQN; +import org.deeplearning4j.rl4j.support.MockEncodable; +import org.deeplearning4j.rl4j.support.MockHistoryProcessor; +import org.deeplearning4j.rl4j.support.MockMDP; +import org.deeplearning4j.rl4j.support.MockNeuralNet; +import org.deeplearning4j.rl4j.support.MockObservationSpace; +import org.deeplearning4j.rl4j.support.MockRandom; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.junit.Test; import org.nd4j.linalg.activations.Activation; @@ -43,8 +46,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.IOException; import java.io.OutputStream; -import java.util.ArrayList; -import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -186,8 +187,22 @@ public class PolicyTest { new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); MockMDP mdp = new MockMDP(observationSpace, 30, random); - QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, - 0, 1.0, 0, 0, 0, 0, true); + QLearningConfiguration conf = QLearningConfiguration.builder() + .seed(0L) + .maxEpochStep(0) + .maxStep(0) + .expRepMaxSize(5) + .batchSize(1) + .targetDqnUpdateFreq(0) + .updateStart(0) + .rewardFactor(1.0) + .gamma(0) + .errorClamp(0) + .minEpsilon(0) + .epsilonNbStep(0) + .doubleDQN(true) + .build(); + MockNeuralNet nnMock = new MockNeuralNet(); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java index 56581cc0d..08689b032 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java @@ -1,22 +1,37 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.support; import lombok.AllArgsConstructor; -import lombok.Getter; import lombok.Value; -import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; -@AllArgsConstructor @Value -public class MockAsyncConfiguration implements AsyncConfiguration { +@AllArgsConstructor +public class MockAsyncConfiguration implements IAsyncLearningConfiguration { - private Integer seed; + private Long seed; private int maxEpochStep; private int maxStep; - private int numThread; - private int nstep; - private int targetDqnUpdateFreq; private int updateStart; private double rewardFactor; private double gamma; private double errorClamp; + private int numThreads; + private int nStep; + private int learnerUpdateFrequency; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java index a3a5598d4..3a2d5230a 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java @@ -1,3 +1,20 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.util; import lombok.Getter; @@ -5,6 +22,7 @@ import lombok.Setter; import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; import org.deeplearning4j.rl4j.mdp.MDP; @@ -162,7 +180,7 @@ public class DataManagerTrainingListenerTest { } @Override - public LConfiguration getConfiguration() { + public ILearningConfiguration getConfiguration() { return null; } diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java index 7400657ef..00b7c4f7a 100644 --- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java +++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,19 +17,18 @@ package org.deeplearning4j.malmo; -import java.util.HashMap; - +import com.microsoft.msr.malmo.TimestampedStringVector; +import com.microsoft.msr.malmo.WorldState; import org.json.JSONArray; import org.json.JSONObject; - import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import com.microsoft.msr.malmo.TimestampedStringVector; -import com.microsoft.msr.malmo.WorldState; +import java.util.HashMap; /** * Observation space that contains a grid of Minecraft blocks + * * @author howard-abrams (howard.abrams@ca.com) on 1/12/17. */ public class MalmoObservationSpaceGrid extends MalmoObservationSpace { @@ -43,11 +43,11 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace { /** * Construct observation space from a array of blocks policy should distinguish between. - * - * @param name Name given to Grid element in mission specification - * @param xSize total x size of grid - * @param ySize total y size of grid - * @param zSize total z size of grid + * + * @param name Name given to Grid element in mission specification + * @param xSize total x size of grid + * @param ySize total y size of grid + * @param zSize total z size of grid * @param blocks Array of block names to distinguish between. Supports combination of individual strings and/or arrays of strings to map multiple block types to a single observation value. If not specified, it will dynamically map block names to integers - however, because these will be mapped as they are seen, different missions may have different mappings! */ public MalmoObservationSpaceGrid(String name, int xSize, int ySize, int zSize, Object... blocks) { @@ -78,7 +78,7 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace { @Override public int[] getShape() { - return new int[] {totalSize}; + return new int[]{totalSize}; } @Override From 986ec4b51abf2f0b5da9f1cfbcc01c054155bdf9 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 6 Apr 2020 15:02:09 +1000 Subject: [PATCH 10/19] Add test from reported issue (confirmed fixed) (#359) Signed-off-by: Alex Black --- .../opvalidation/RandomOpValidation.java | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 053f3a70b..e6e8962ec 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -27,6 +27,7 @@ import org.nd4j.autodiff.validation.TestCase; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; @@ -410,4 +411,29 @@ public class RandomOpValidation extends BaseOpValidation { } } } + + @Test + public void testRandomExponential2(){ + Nd4j.getRandom().setSeed(12345); + DynamicCustomOp op = DynamicCustomOp.builder("random_exponential") + .addInputs(Nd4j.createFromArray(100)) + .addOutputs(Nd4j.create(DataType.FLOAT, 100)) + .addFloatingPointArguments(0.5) + .build(); + + Nd4j.exec(op); + + INDArray out = op.getOutputArgument(0); + int count0 = out.eq(0.0).castTo(DataType.INT32).sumNumber().intValue(); + int count1 = out.eq(1.0).castTo(DataType.INT32).sumNumber().intValue(); + + assertEquals(0, count0); + assertEquals(0, count1); + + double min = out.minNumber().doubleValue(); + double max = out.maxNumber().doubleValue(); + + assertTrue(String.valueOf(min), min > 0.0); + assertTrue(String.valueOf(max), max > 1.0); + } } From 04b2b4f9b6108c6eddbd417c172523778a694ab8 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 6 Apr 2020 21:01:59 +0300 Subject: [PATCH 11/19] Few fixes (#361) * INDArray.close() fix for CPU Signed-off-by: raver119 * - BroadcastableBoolOp introduced - ConfusionMatrix now supports explicit DataType argument Signed-off-by: raver119 * confusion_matrix: dtype is still optional Signed-off-by: raver119 * disable bert tests in debug builds Signed-off-by: raver119 * Affinity fix Signed-off-by: raver119 * minor workspace tweak to allow close() on scoped out borrowed workspace Signed-off-by: raver119 --- .../ops/declarable/BroadcastableBoolOp.h | 43 +++++ .../generic/broadcastable/equals.cpp | 2 +- .../generic/broadcastable/greater.cpp | 2 +- .../generic/broadcastable/greater_equal.cpp | 2 +- .../declarable/generic/broadcastable/less.cpp | 2 +- .../generic/broadcastable/less_equal.cpp | 2 +- .../generic/broadcastable/not_equals.cpp | 2 +- .../generic/parity_ops/confusion_matrix.cpp | 10 +- .../ops/declarable/headers/broadcastable.h | 13 +- .../declarable/impl/BroadcastableBoolOp.cpp | 72 ++++++++ libnd4j/include/system/op_boilerplate.h | 14 ++ .../layers_tests/DeclarableOpsTests12.cpp | 2 +- .../layers_tests/DeclarableOpsTests5.cpp | 2 +- .../layers_tests/PlaygroundTests.cpp | 7 + .../linalg/api/buffer/BaseDataBuffer.java | 1 - .../api/memory/abstracts/Nd4jWorkspace.java | 7 + .../api/ops/impl/shape/ConfusionMatrix.java | 3 + .../org/nd4j/nativeblas/OpaqueDataBuffer.java | 7 + .../jita/concurrency/CudaAffinityManager.java | 13 +- .../jcublas/buffer/BaseCudaDataBuffer.java | 6 +- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 44 +++++ .../org/nd4j/nativeblas/Nd4jCudaPresets.java | 3 +- .../nativecpu/buffer/BaseCpuDataBuffer.java | 6 + .../java/org/nd4j/nativeblas/Nd4jCpu.java | 156 ++++++++++++++++-- .../org/nd4j/nativeblas/Nd4jCpuPresets.java | 3 +- .../linalg/broadcast/BasicBroadcastTests.java | 13 ++ .../workspace/SpecialWorkspaceTests.java | 37 +++++ .../workspace/WorkspaceProviderTests.java | 8 +- 28 files changed, 430 insertions(+), 52 deletions(-) create mode 100644 libnd4j/include/ops/declarable/BroadcastableBoolOp.h create mode 100644 libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp diff --git a/libnd4j/include/ops/declarable/BroadcastableBoolOp.h b/libnd4j/include/ops/declarable/BroadcastableBoolOp.h new file mode 100644 index 000000000..c48650294 --- /dev/null +++ b/libnd4j/include/ops/declarable/BroadcastableBoolOp.h @@ -0,0 +1,43 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver on 6/6/2018. +// + +#ifndef SD_BROADCASTABLEBOOLOP_H +#define SD_BROADCASTABLEBOOLOP_H + +#include +#include "OpDescriptor.h" +#include "DeclarableOp.h" +#include "DeclarableCustomOp.h" + +namespace sd { + namespace ops { + class ND4J_EXPORT BroadcastableBoolOp : public DeclarableCustomOp{ + protected: + Nd4jStatus validateAndExecute(Context& block) override = 0; + public: + BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs); + + ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override; + }; + } +} + + +#endif //SD_BROADCASTABLEBOOLOP_H diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp index c82fe6748..5d4aaef5e 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp @@ -23,7 +23,7 @@ namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(equals, 0, 0) { + BROADCASTABLE_BOOL_OP_IMPL(equals, 0, 0) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp index 961259946..084453dc8 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp @@ -23,7 +23,7 @@ namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(greater, 0, 0) { + BROADCASTABLE_BOOL_OP_IMPL(greater, 0, 0) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp index 1adbad420..5f448585e 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp @@ -22,7 +22,7 @@ namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(greater_equal, 0, 0) { + BROADCASTABLE_BOOL_OP_IMPL(greater_equal, 0, 0) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp index ba5c72fa4..5d9c73f1b 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp @@ -22,7 +22,7 @@ namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(less, 0, 0) { + BROADCASTABLE_BOOL_OP_IMPL(less, 0, 0) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp index b602f1374..a0f0a0366 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp @@ -22,7 +22,7 @@ namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(less_equal, 0, 0) { + BROADCASTABLE_BOOL_OP_IMPL(less_equal, 0, 0) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp index fddd653e9..9e2609f9d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp @@ -22,7 +22,7 @@ namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(not_equals, 0, 0) { + BROADCASTABLE_BOOL_OP_IMPL(not_equals, 0, 0) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp index cc8a64fa6..f90513ca3 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp @@ -45,8 +45,8 @@ namespace sd { weights = INPUT_VARIABLE(2); REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape"); } - auto output = OUTPUT_VARIABLE(0); - output->assign(0.); + auto output = OUTPUT_NULLIFIED(0); + int minPrediction = predictions->reduceNumber(reduce::Min).e(0); int minLabel = labels->reduceNumber(reduce::Min).e(0); @@ -64,11 +64,7 @@ namespace sd { DECLARE_SHAPE_FN(confusion_matrix) { auto labels = INPUT_VARIABLE(0); auto predictions = INPUT_VARIABLE(1); - auto dtype = block.dataType(); - dtype = sd::DataType::INT64; // dtype - should be a param with int argument - if (block.numI() > 1) - dtype = (sd::DataType)INT_ARG(1); - + auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; int numClasses = 0; if (block.getIArguments()->size() > 0) { diff --git a/libnd4j/include/ops/declarable/headers/broadcastable.h b/libnd4j/include/ops/declarable/headers/broadcastable.h index 691a1b7b2..7380412a4 100644 --- a/libnd4j/include/ops/declarable/headers/broadcastable.h +++ b/libnd4j/include/ops/declarable/headers/broadcastable.h @@ -22,6 +22,7 @@ #define LIBND4J_HEADERS_BROADCASTABLE_H #include +#include #include #include @@ -261,7 +262,7 @@ namespace sd { * */ #if NOT_EXCLUDED(OP_equals) - DECLARE_BROADCASTABLE_OP(equals, 0, 0); + DECLARE_BROADCASTABLE_BOOL_OP(equals, 0, 0); #endif /** @@ -269,7 +270,7 @@ namespace sd { * Math is: _x != _y ? (T) 1.0f : (T) 0.0f; */ #if NOT_EXCLUDED(OP_not_equals) - DECLARE_BROADCASTABLE_OP(not_equals, 0, 0); + DECLARE_BROADCASTABLE_BOOL_OP(not_equals, 0, 0); #endif /** @@ -277,7 +278,7 @@ namespace sd { * Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; */ #if NOT_EXCLUDED(OP_less_equal) - DECLARE_BROADCASTABLE_OP(less_equal, 0, 0); + DECLARE_BROADCASTABLE_BOOL_OP(less_equal, 0, 0); #endif /** @@ -285,7 +286,7 @@ namespace sd { * Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; */ #if NOT_EXCLUDED(OP_greater_equal) - DECLARE_BROADCASTABLE_OP(greater_equal, 0, 0); + DECLARE_BROADCASTABLE_BOOL_OP(greater_equal, 0, 0); #endif /** @@ -293,7 +294,7 @@ namespace sd { * Math is: _x < _y ? (T) 1.0f : (T) 0.0f; */ #if NOT_EXCLUDED(OP_less) - DECLARE_BROADCASTABLE_OP(less, 0, 0); + DECLARE_BROADCASTABLE_BOOL_OP(less, 0, 0); #endif /** @@ -301,7 +302,7 @@ namespace sd { * Math is: _x > _y ? (T) 1.0f : (T) 0.0f; */ #if NOT_EXCLUDED(OP_greater) - DECLARE_BROADCASTABLE_OP(greater, 0, 0); + DECLARE_BROADCASTABLE_BOOL_OP(greater, 0, 0); #endif /** diff --git a/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp b/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp new file mode 100644 index 000000000..66eade39f --- /dev/null +++ b/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp @@ -0,0 +1,72 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver on 6/6/2018. +// + +#include +#include +#include +#include + +namespace sd { + namespace ops { + BroadcastableBoolOp::BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs) : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, numIArgs) { + // + } + + ShapeList *BroadcastableBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { + auto shapeList = SHAPELIST(); + auto x = inputShape->at(0); + auto y = inputShape->at(1); + sd::DataType dtype = sd::DataType::BOOL; + + if(shape::isEmpty(x) || shape::isEmpty(y)) { + // this is edge case, [3, 4] + [] = [] + if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype))); + return shapeList; + } + + Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype))); + } else if (shape::isScalar(x) && shape::isScalar(y)) { + if (shape::rank(x) >= shape::rank(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); + } else { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype))); + } + } else if (shape::equalsSoft(x, y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); + } else if (shape::isScalar(x) && !shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype))); + } else if (!shape::isScalar(x) && shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); + } else if (ShapeUtils::areShapesBroadcastable(x, y)) { + Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype))); + } else { + // in this case we'll throw exception later + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); + } + + return shapeList; + } + } +} \ No newline at end of file diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index 4e7a288f0..1df4f0047 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1446,10 +1446,24 @@ };\ REGISTER_H(NAME) +#define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp { \ + protected: \ + void registerTypes(); \ + Nd4jStatus validateAndExecute(Context& block); \ + public:\ + NAME(); \ + };\ + REGISTER_H(NAME) + + #define BROADCASTABLE_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableOp(#NAME, TARGS, IARGS) { }; \ REGISTER_C(NAME) \ Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) +#define BROADCASTABLE_BOOL_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableBoolOp(#NAME, TARGS, IARGS) { }; \ + REGISTER_C(NAME) \ + Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) + #define DECLARE_DEVICE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 615f95bbd..0684f7887 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -515,7 +515,7 @@ TEST_F(DeclarableOpsTests12, TestConfusionZero_1) { //exp1.assign(1.); //exp2.assign(-2.); sd::ops::confusion_matrix op; - Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}); + Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output.equalsTo(exp)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 6ac9d34cd..26c6b5d53 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -2374,7 +2374,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test4) { auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); sd::ops::confusion_matrix op; - auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3, sd::DataType::DOUBLE}); + auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}, {}, {sd::DataType::DOUBLE}); auto output = results.at(0); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 94156d4bc..5636b2e29 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -91,6 +91,8 @@ TEST_F(PlaygroundTests, test_biasAdd_1) { TEST_F(PlaygroundTests, test_bert_full_1) { +#ifdef _RELEASE + // this test will run ONLY if this model exists if (sd::graph::getFileSize("/home/raver119/Downloads/BertFull/model.fb") < 0) return; @@ -147,10 +149,12 @@ TEST_F(PlaygroundTests, test_bert_full_1) { nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); */ delete graph; +#endif } TEST_F(PlaygroundTests, test_bert_1) { +#ifdef _RELEASE // this test will run ONLY if this model exists if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0) return; @@ -206,9 +210,11 @@ TEST_F(PlaygroundTests, test_bert_1) { nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); */ delete graph; +#endif } TEST_F(PlaygroundTests, test_bert_2) { +#ifdef _RELEASE // this test will run ONLY if this model exists if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb") < 0) return; @@ -256,6 +262,7 @@ TEST_F(PlaygroundTests, test_bert_2) { nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); */ delete graph; +#endif } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 6b226ce20..b99d9f105 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -1930,7 +1930,6 @@ public abstract class BaseDataBuffer implements DataBuffer { protected void release() { this.released = true; - this.pointer.deallocate(); this.indexer = null; this.pointer = null; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java index 088ec0056..823b67bbe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java @@ -580,6 +580,13 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace { public void close() { // first we check if this workspace was borrowed. if yes - just close without reset. if (isBorrowed.get()) { + if (tagScope.get() > 0) { + if (tagScope.decrementAndGet() == 0) { + Nd4j.getMemoryManager().setCurrentWorkspace(this); + } + return; + } + isBorrowed.set(false); Nd4j.getMemoryManager().setCurrentWorkspace(borrowingWorkspace); return; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java index 2bf94021a..dab70c801 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java @@ -42,6 +42,7 @@ public class ConfusionMatrix extends DynamicCustomOp { public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, @NonNull DataType dataType){ super(new INDArray[]{labels, predicted}, null); this.outputType = dataType; + addDArgument(dataType); } public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, int numClasses){ @@ -66,6 +67,7 @@ public class ConfusionMatrix extends DynamicCustomOp { if(numClasses != null) { addIArgument(numClasses); } + addDArgument(dataType); } @@ -77,6 +79,7 @@ public class ConfusionMatrix extends DynamicCustomOp { public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){ super(null, sameDiff, new SDVariable[]{labels, pred}); this.outputType = dataType; + addDArgument(dataType); } public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java index d7c2e0ac0..49c4e8be3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java @@ -209,4 +209,11 @@ public class OpaqueDataBuffer extends Pointer { public void syncToPrimary() { NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(this); } + + /** + * This method releases underlying buffer + */ + public void closeBuffer() { + NativeOpsHolder.getInstance().getDeviceNativeOps().dbClose(this); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index bd5c7a9a0..356df88e0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -72,9 +72,16 @@ public class CudaAffinityManager extends BasicAffinityManager { */ @Override public Integer getDeviceForThread(long threadId) { - val id = affinityMap.get(threadId); - if (id == null) - throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet"); + Integer id = affinityMap.get(threadId); + if (id == null) { + // if this is current thread - we're still able to fetch id from native side, and update map + if (threadId == Thread.currentThread().getId()) { + id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice(); + affinityMap.put(Long.valueOf(threadId), id); + } else + // TODO: we should get rid of this method, and forbid such kind of queries + throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet"); + } return id; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index f84f96384..dc20d9a5a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -1792,11 +1792,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override protected void release() { if (!released) { - //AtomicAllocator.getInstance().freeMemory(allocationPoint);n - NativeOpsHolder.getInstance().getDeviceNativeOps().dbClose(allocationPoint.getPtrDataBuffer()); + ptrDataBuffer.closeBuffer(); allocationPoint.setReleased(true); } - released = true; + + super.release(); } /* diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 33260da70..8f30cdd82 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -9553,6 +9553,50 @@ public static final int PREALLOC_SIZE = 33554432; // #endif //LIBND4J_BROADCASTABLEOP_H +// Parsed from ops/declarable/BroadcastableBoolOp.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver on 6/6/2018. +// + +// #ifndef SD_BROADCASTABLEBOOLOP_H +// #define SD_BROADCASTABLEBOOLOP_H + +// #include +// #include "OpDescriptor.h" +// #include "DeclarableOp.h" +// #include "DeclarableCustomOp.h" + @Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BroadcastableBoolOp(Pointer p) { super(p); } + + + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } + + + + +// #endif //SD_BROADCASTABLEBOOLOP_H + + // Parsed from helpers/OpArgsHolder.h /******************************************************************************* diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index 1f3f7bde4..05b335c87 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -76,7 +76,8 @@ import org.bytedeco.javacpp.tools.InfoMapper; "ops/InputType.h", "ops/declarable/OpDescriptor.h", "ops/declarable/PlatformHelper.h", - "ops/declarable/BroadcastableOp.h", + "ops/declarable/BroadcastableOp.h", + "ops/declarable/BroadcastableBoolOp.h", "helpers/OpArgsHolder.h", "ops/declarable/DeclarableOp.h", "ops/declarable/DeclarableListOp.h", diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java index b3def0f71..7abc8e7be 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java @@ -837,6 +837,12 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo this(data, true, workspace); } + @Override + protected void release() { + ptrDataBuffer.closeBuffer(); + super.release(); + } + /** * Reallocate the native memory of the buffer * @param length the new length of the buffer diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 47791f865..e96325460 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -6,6 +6,7 @@ import java.nio.*; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.annotation.*; +import static org.bytedeco.javacpp.presets.javacpp.*; import static org.bytedeco.openblas.global.openblas_nolapack.*; import static org.bytedeco.openblas.global.openblas.*; @@ -11406,10 +11407,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // }; // REGISTER_H(NAME) +// #define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp { +// protected: +// void registerTypes(); +// Nd4jStatus validateAndExecute(Context& block); +// public: +// NAME(); +// }; +// REGISTER_H(NAME) + + // #define BROADCASTABLE_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableOp(#NAME, TARGS, IARGS) { }; // REGISTER_C(NAME) // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) +// #define BROADCASTABLE_BOOL_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableBoolOp(#NAME, TARGS, IARGS) { }; +// REGISTER_C(NAME) +// Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) + // #define DECLARE_DEVICE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) @@ -11871,6 +11886,50 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif //LIBND4J_BROADCASTABLEOP_H +// Parsed from ops/declarable/BroadcastableBoolOp.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver on 6/6/2018. +// + +// #ifndef SD_BROADCASTABLEBOOLOP_H +// #define SD_BROADCASTABLEBOOLOP_H + +// #include +// #include "OpDescriptor.h" +// #include "DeclarableOp.h" +// #include "DeclarableCustomOp.h" + @Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BroadcastableBoolOp(Pointer p) { super(p); } + + + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } + + + + +// #endif //SD_BROADCASTABLEBOOLOP_H + + // Parsed from ops/declarable/DeclarableOp.h /******************************************************************************* @@ -13636,6 +13695,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #define LIBND4J_HEADERS_BROADCASTABLE_H // #include +// #include // #include // #include // TODO: make broadcastables separate class @@ -14317,7 +14377,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * */ // #if NOT_EXCLUDED(OP_equals) - @Namespace("sd::ops") public static class equals extends BroadcastableOp { + @Namespace("sd::ops") public static class equals extends BroadcastableBoolOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public equals(Pointer p) { super(p); } @@ -14338,7 +14398,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * Math is: _x != _y ? (T) 1.0f : (T) 0.0f; */ // #if NOT_EXCLUDED(OP_not_equals) - @Namespace("sd::ops") public static class not_equals extends BroadcastableOp { + @Namespace("sd::ops") public static class not_equals extends BroadcastableBoolOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public not_equals(Pointer p) { super(p); } @@ -14359,7 +14419,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; */ // #if NOT_EXCLUDED(OP_less_equal) - @Namespace("sd::ops") public static class less_equal extends BroadcastableOp { + @Namespace("sd::ops") public static class less_equal extends BroadcastableBoolOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public less_equal(Pointer p) { super(p); } @@ -14380,7 +14440,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; */ // #if NOT_EXCLUDED(OP_greater_equal) - @Namespace("sd::ops") public static class greater_equal extends BroadcastableOp { + @Namespace("sd::ops") public static class greater_equal extends BroadcastableBoolOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public greater_equal(Pointer p) { super(p); } @@ -14401,7 +14461,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * Math is: _x < _y ? (T) 1.0f : (T) 0.0f; */ // #if NOT_EXCLUDED(OP_less) - @Namespace("sd::ops") public static class less extends BroadcastableOp { + @Namespace("sd::ops") public static class less extends BroadcastableBoolOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public less(Pointer p) { super(p); } @@ -14422,7 +14482,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * Math is: _x > _y ? (T) 1.0f : (T) 0.0f; */ // #if NOT_EXCLUDED(OP_greater) - @Namespace("sd::ops") public static class greater extends BroadcastableOp { + @Namespace("sd::ops") public static class greater extends BroadcastableBoolOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public greater(Pointer p) { super(p); } @@ -16672,6 +16732,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } + @Namespace("sd::ops") public static class mergemax_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergemax_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergemax_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergemax_bp position(long position) { + return (mergemax_bp)super.position(position); + } + + public mergemax_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /* * Complete tensor with max indices merged from all input tensors list @@ -16714,6 +16789,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } + @Namespace("sd::ops") public static class mergeadd_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeadd_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeadd_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeadd_bp position(long position) { + return (mergeadd_bp)super.position(position); + } + + public mergeadd_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif // #if NOT_EXCLUDED(OP_mergeavg) @@ -16732,6 +16822,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } + @Namespace("sd::ops") public static class mergeavg_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeavg_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeavg_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeavg_bp position(long position) { + return (mergeavg_bp)super.position(position); + } + + public mergeavg_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif // #if NOT_EXCLUDED(OP_scatter_update) @@ -19074,23 +19179,40 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * - 2D matrix MxN * - 1D vector with N elements * output value - 2D matrix NxN as multiply of matrixes and add vector + * Int args: + * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul */ // #if NOT_EXCLUDED(OP_xw_plus_b) - @Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public xw_plus_b(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public xw_plus_b position(long position) { - return (xw_plus_b)super.position(position); - } - + @Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public xw_plus_b(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public xw_plus_b position(long position) { + return (xw_plus_b)super.position(position); + } + public xw_plus_b() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } + @Namespace("sd::ops") public static class xw_plus_b_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public xw_plus_b_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public xw_plus_b_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public xw_plus_b_bp position(long position) { + return (xw_plus_b_bp)super.position(position); + } + + public xw_plus_b_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index 057dd5b95..c6e57e876 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -81,7 +81,8 @@ import java.util.Scanner; "ops/InputType.h", "ops/declarable/OpDescriptor.h", "ops/declarable/PlatformHelper.h", - "ops/declarable/BroadcastableOp.h", + "ops/declarable/BroadcastableOp.h", + "ops/declarable/BroadcastableBoolOp.h", "ops/declarable/DeclarableOp.h", "ops/declarable/DeclarableListOp.h", "ops/declarable/DeclarableReductionOp.h", diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java index 00e294530..991bcf369 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java @@ -24,9 +24,11 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -316,6 +318,17 @@ public class BasicBroadcastTests extends BaseNd4jTest { assertEquals(exp, sum); } + @Test + public void testBroadcatableBool_1() { + val op = DynamicCustomOp.builder("greater_equal") + .addInputs(Nd4j.create(DataType.FLOAT, 3), Nd4j.create(DataType.FLOAT, 3)) + .build(); + + val l = op.calculateOutputShape(); + assertEquals(1, l.size()); + assertEquals(DataType.BOOL, l.get(0).dataType()); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java index 8558a788a..8d7ebb040 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java @@ -36,6 +36,9 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; +import java.util.ArrayList; +import java.util.Arrays; + import static org.junit.Assert.*; /** @@ -298,6 +301,40 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { log.info("{} ns", ((timeEnd - timeStart) / (double) iterations)); } + @Test + public void testWorkspaceOrder_1(){ + WorkspaceConfiguration conf = WorkspaceConfiguration.builder() + .initialSize(1_000_000) + .overallocationLimit(0.05) + .policyLearning(LearningPolicy.NONE) + .build(); + + val exp = Arrays.asList("outer", null, "outer", "inner", "outer", null); + val res = new ArrayList(); + + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(conf, "outer")){ + try(MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(conf, "inner")){ + try(MemoryWorkspace ws3 = ws.notifyScopeBorrowed()){ + System.out.println("X: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //outer + res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId()); + try(MemoryWorkspace ws4 = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ + System.out.println("A: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //None (null) + res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId()); + } + System.out.println("B: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //outer + res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId()); + } + System.out.println("C: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //inner + res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId()); + } + System.out.println("D: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //outer + res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId()); + } + System.out.println("E: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //None (null) + res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId()); + + assertEquals(exp, res); + } @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java index 1615abbc3..20e7f367f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java @@ -616,19 +616,17 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test + @Ignore("raver119: This test doesn't make any sense to me these days. We're borrowing from the same workspace. Why?") public void testNestedWorkspaces11() { for (int x = 1; x < 10; x++) { try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { INDArray array1 = Nd4j.create(100 * x); for (int i = 1; i < 10; i++) { - try (MemoryWorkspace ws2 = - Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { + try (MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { INDArray array2 = Nd4j.create(100 * x); for (int e = 1; e < 10; e++) { - try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(basicConfiguration, "WS_1") - .notifyScopeBorrowed()) { + try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(basicConfiguration, "WS_1").notifyScopeBorrowed()) { INDArray array3 = Nd4j.create(100 * x); } } From e57f35c2e49233003b8ce77e2d3df2473f75c267 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 7 Apr 2020 13:14:43 +0300 Subject: [PATCH 12/19] mkldnn version bump Signed-off-by: raver119 --- libnd4j/CMakeLists.txt.mkldnn.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/CMakeLists.txt.mkldnn.in b/libnd4j/CMakeLists.txt.mkldnn.in index 36c426053..224f5d50d 100644 --- a/libnd4j/CMakeLists.txt.mkldnn.in +++ b/libnd4j/CMakeLists.txt.mkldnn.in @@ -5,7 +5,7 @@ project(mkldnn-download NONE) include(ExternalProject) ExternalProject_Add(mkldnn GIT_REPOSITORY https://github.com/intel/mkl-dnn.git - GIT_TAG v1.2.2 + GIT_TAG v1.3 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" CONFIGURE_COMMAND "" From ab083b916715d1d8b13e87c63186dabe6beeeaee Mon Sep 17 00:00:00 2001 From: Samuel Audet Date: Wed, 8 Apr 2020 21:09:45 +0900 Subject: [PATCH 13/19] Update versions of JavaCPP Presets for OpenCV and MKL (#363) Signed-off-by: Samuel Audet --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 5a8d49d88..17708b222 100644 --- a/pom.xml +++ b/pom.xml @@ -298,8 +298,8 @@ ${numpy.version}-${javacpp-presets.version} 0.3.9 - 2020.0 - 4.2.0 + 2020.1 + 4.3.0 4.2.2 1.79.0 1.12.0 From d86dd5b131404282d14cb089e0134299b0be2179 Mon Sep 17 00:00:00 2001 From: Andrii T <39699084+atuzhykov@users.noreply.github.com> Date: Wed, 8 Apr 2020 17:20:48 +0300 Subject: [PATCH 14/19] DL4J and SameDiff integration tests + LSTMLayer java op class (#353) * init in this branch Signed-off-by: Andrii Tuzhykov * Lenetet Mnist workflow Signed-off-by: Andrii Tuzhykov * small fix for calculations Signed-off-by: Andrii Tuzhykov * for Alex to check placeholder null pointer issue Signed-off-by: Andrii Tuzhykov * CNN3D workflow Signed-off-by: Andrii Tuzhykov * state for launching on dxg to regenterate dl4j examples Signed-off-by: Andrii Tuzhykov * SD RNN test case workflow Signed-off-by: Andrii Tuzhykov * small fixes Signed-off-by: Andrii Tuzhykov * checkpoint at lstmBlock: Input array 1 (x) rank must be got input with rank 2 issue Signed-off-by: Andrii Tuzhykov * Fix LSTMLayer inputs order Signed-off-by: Andrii Tuzhykov * lstm mismatch with c++ op issue Signed-off-by: Andrii Tuzhykov * LSTMLayer config draft Signed-off-by: Andrii Tuzhykov * LSTMLayer config draft v2 Signed-off-by: Andrii Tuzhykov * have doubt I had to do this Signed-off-by: Andrii Tuzhykov * NDRNN generated by codegen Signed-off-by: Andrii Tuzhykov * LSTMLayerTestCases draft Signed-off-by: Andrii Tuzhykov * minor fixes again * added LSTMLayer testcases to nd4j-tests + setted Preconditions in LSTMLayer constructors Signed-off-by: Andrii Tuzhykov * added lost SDCNNtestcases Signed-off-by: Andrii Tuzhykov * overrided getNumOutputs from DynamicCustomOp in LSTMLayer and reorganized LSTMLayerOutputs according to cpp op Signed-off-by: Andrii Tuzhykov * finished with LSTMLayerOutputs Signed-off-by: Andrii Tuzhykov * Fix MKLDNN platform checks (i.e., when MKLDNN can be used vs. not) Signed-off-by: Alex Black * Fix LSTMLayerWeights input order Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black * minor fixes Signed-off-by: Andrii Tuzhykov * fixed LSTMLayer testcases Signed-off-by: Andrii Tuzhykov * finished SameDiffRNNTestCase Signed-off-by: Andrii Tuzhykov * finished all testcases + minor fixes Signed-off-by: Andrii Tuzhykov * Multiple generation-related fixes Signed-off-by: Alex Black * Fix multiple issues Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black * LSTM fixes Signed-off-by: Alex Black * Regenerate ND4J namespaces and fix multiple issues Signed-off-by: Alex Black * changed SameDiffRNNTestCase Signed-off-by: Andrii Tuzhykov * Small fix Signed-off-by: Alex Black * added Nd4j.getRandom().setSeed(12345) where needed Signed-off-by: Andrii Tuzhykov * #8828 Fix ND4J profiler NaN/Inf checks when using OpContext Signed-off-by: Alex Black * #8828 Fix ND4J profiler NaN/Inf checks when using OpContext Signed-off-by: Alex Black * Tweak to weight init for SameDiff CNN test case Signed-off-by: Alex Black * Tweaks for test cases Signed-off-by: Alex Black * Ignore failing tests until fixed Signed-off-by: Alex Black * Fix Signed-off-by: Alex Black Co-authored-by: Alex Black --- .../eval/EvaluationCalibration.java | 2 +- .../conf/layers/RecurrentAttentionLayer.java | 4 +- .../IntegrationTestBaselineGenerator.java | 86 +- .../integration/IntegrationTestsSameDiff.java | 17 + .../testcases/dl4j/CNN2DTestCases.java | 5 +- .../testcases/dl4j/MLPTestCases.java | 7 +- .../testcases/dl4j/RNNTestCases.java | 8 +- .../testcases/dl4j/UnsupervisedTestCases.java | 4 +- .../testcases/samediff/SameDiffCNNCases.java | 398 + .../samediff/SameDiffMLPTestCases.java | 188 +- .../samediff/SameDiffRNNTestCases.java | 289 + .../declarable/platform/mkldnn/lstmLayer.cpp | 20 +- .../DifferentialFunctionFactory.java | 2 +- .../nd4j/autodiff/samediff/SDVariable.java | 7 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 271 +- .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 8103 ++++++++++------- .../org/nd4j/autodiff/samediff/ops/SDCNN.java | 29 +- .../nd4j/autodiff/samediff/ops/SDMath.java | 12 +- .../org/nd4j/autodiff/samediff/ops/SDRNN.java | 233 +- .../src/main/java/org/nd4j/enums/CellAct.java | 45 + .../src/main/java/org/nd4j/enums/GateAct.java | 45 + .../java/org/nd4j/enums/LSTMDataFormat.java | 36 + .../org/nd4j/enums/LSTMDirectionMode.java | 38 + .../src/main/java/org/nd4j/enums/OutAct.java | 45 + .../java/org/nd4j/enums/RnnDataFormat.java | 32 + .../converters/ImportClassMapping.java | 1 + .../ops/executioner/DefaultOpExecutioner.java | 68 +- .../ops/executioner/OpExecutionerUtil.java | 37 +- .../layers/convolution/MaxPoolWithArgmax.java | 8 +- .../ops/impl/layers/convolution/SConv2D.java | 2 +- .../ops/impl/layers/recurrent/LSTMBlock.java | 144 + .../ops/impl/layers/recurrent/LSTMLayer.java | 173 +- .../recurrent/config/LSTMActivations.java | 48 + .../recurrent/config/LSTMDataFormat.java | 41 + .../recurrent/config/LSTMDirectionMode.java | 38 + .../recurrent/config/LSTMLayerConfig.java | 119 + .../recurrent/outputs/LSTMLayerOutputs.java | 190 +- .../recurrent/weights/LSTMLayerWeights.java | 99 + .../nd4j/linalg/api/ops/impl/reduce/Mmul.java | 2 + .../api/ops/impl/reduce/custom/BatchMmul.java | 20 + .../linalg/api/ops/impl/shape/GatherNd.java | 13 +- .../linalg/api/ops/impl/shape/Linspace.java | 21 +- .../linalg/api/ops/impl/shape/MeshGrid.java | 7 + .../linalg/api/ops/impl/shape/Reshape.java | 7 +- .../api/ops/impl/shape/SequenceMask.java | 11 +- .../nd4j/linalg/api/ops/impl/shape/Slice.java | 4 + .../nd4j/linalg/api/ops/impl/shape/Stack.java | 2 +- .../api/ops/impl/shape/StridedSlice.java | 12 +- .../linalg/api/ops/impl/shape/Unstack.java | 11 +- .../linalg/api/ops/impl/transforms/Pad.java | 4 + .../transforms/custom/DynamicPartition.java | 5 +- .../ops/impl/transforms/custom/ListDiff.java | 8 +- .../ops/impl/transforms/custom/XwPlusB.java | 8 +- .../api/ops/impl/transforms/dtype/Cast.java | 22 +- .../linalg/api/ops/random/impl/Range.java | 7 + .../org/nd4j/linalg/factory/ops/NDBase.java | 214 +- .../org/nd4j/linalg/factory/ops/NDCNN.java | 14 +- .../org/nd4j/linalg/factory/ops/NDLoss.java | 7 +- .../org/nd4j/linalg/factory/ops/NDMath.java | 106 +- .../org/nd4j/linalg/factory/ops/NDNN.java | 46 +- .../org/nd4j/linalg/factory/ops/NDRNN.java | 106 +- .../ops/executioner/CudaExecutioner.java | 26 +- .../nativecpu/ops/NativeOpExecutioner.java | 13 +- .../opvalidation/LayerOpValidation.java | 251 +- .../opvalidation/MiscOpValidation.java | 10 +- .../opvalidation/RnnOpValidation.java | 6 +- .../opvalidation/ShapeOpValidation.java | 44 +- .../opvalidation/TransformOpValidation.java | 2 +- .../nd4j/autodiff/samediff/SameDiffTests.java | 85 +- .../nd4j/linalg/nativ/OpsMappingTests.java | 2 +- .../profiling/OperationProfilerTests.java | 55 + .../java/org/nd4j/linalg/util/ArrayUtil.java | 15 + 72 files changed, 8063 insertions(+), 3997 deletions(-) create mode 100644 deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java create mode 100644 deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/CellAct.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/GateAct.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDataFormat.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDirectionMode.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/OutAct.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/RnnDataFormat.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlock.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMActivations.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDataFormat.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDirectionMode.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/EvaluationCalibration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/EvaluationCalibration.java index 4a4299042..bda8b21b2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/EvaluationCalibration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/EvaluationCalibration.java @@ -25,7 +25,7 @@ import org.nd4j.shade.jackson.annotation.JsonProperty; */ @Deprecated @Getter -@EqualsAndHashCode +@EqualsAndHashCode(callSuper = true) public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation { /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java index e4e5b7d21..d12e0ec74 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java @@ -185,7 +185,9 @@ public class RecurrentAttentionLayer extends SameDiffLayer { final val R = paramTable.get(RECURRENT_WEIGHT_KEY); final val b = paramTable.get(BIAS_KEY); - SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2); + long[] shape = layerInput.getShape(); + Preconditions.checkState(shape != null, "Null shape for input placeholder"); + SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2, (int)shape[2]); this.timeSteps = inputSlices.length; SDVariable[] outputSlices = new SDVariable[timeSteps]; SDVariable prev = null; diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java index a493337c8..01b3b2e53 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java @@ -20,7 +20,10 @@ package org.deeplearning4j.integration; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; +import org.deeplearning4j.integration.testcases.dl4j.*; +import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; +import org.deeplearning4j.integration.testcases.samediff.SameDiffRNNTestCases; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; @@ -66,14 +69,36 @@ public class IntegrationTestBaselineGenerator { } runGeneration( - SameDiffMLPTestCases.getMLPMnist() + + // DL4J integration test cases. + +// CNN1DTestCases.getCnn1dTestCaseCharRNN(), +// CNN2DTestCases.testLenetTransferDropoutRepeatability(), +//// CNN2DTestCases.getCnn2DSynthetic(), +// CNN2DTestCases.getLenetMnist(), +// CNN2DTestCases.getVGG16TransferTinyImagenet(), +// CNN2DTestCases.getYoloHouseNumbers(), +// CNN3DTestCases.getCnn3dTestCaseSynthetic(), +// MLPTestCases.getMLPMnist(), +// MLPTestCases.getMLPMoon(), +// RNNTestCases.getRnnCharacterTestCase(), +// RNNTestCases.getRnnCsvSequenceClassificationTestCase1(), +// RNNTestCases.getRnnCsvSequenceClassificationTestCase2(), +// UnsupervisedTestCases.getVAEMnistAnomaly(), + + // Samediff test cases done + SameDiffMLPTestCases.getMLPMnist(), + SameDiffMLPTestCases.getMLPMoon(), + SameDiffCNNCases.getLenetMnist(), + SameDiffCNNCases.getCnn3dSynthetic(), + SameDiffRNNTestCases.getRnnCsvSequenceClassificationTestCase1() ); } private static void runGeneration(TestCase... testCases) throws Exception { - for( TestCase tc : testCases ) { + for (TestCase tc : testCases) { final ModelType modelType = tc.modelType(); //Basic validation: @@ -122,18 +147,18 @@ public class IntegrationTestBaselineGenerator { mln = new MultiLayerNetwork(mlc); mln.init(); m = mln; - } else if (config instanceof ComputationGraphConfiguration){ + } else if (config instanceof ComputationGraphConfiguration) { ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config; json = cgc.toJson(); cg = new ComputationGraph(cgc); cg.init(); m = cg; } else { - sd = (SameDiff)config; + sd = (SameDiff) config; } File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME); - if(modelType != ModelType.SAMEDIFF) { + if (modelType != ModelType.SAMEDIFF) { File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json")); FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8); log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath()); @@ -147,10 +172,10 @@ public class IntegrationTestBaselineGenerator { m = tc.getPretrainedModel(); if (m instanceof MultiLayerNetwork) { mln = (MultiLayerNetwork) m; - } else if(m instanceof ComputationGraph){ + } else if (m instanceof ComputationGraph) { cg = (ComputationGraph) m; } else { - sd = (SameDiff)m; + sd = (SameDiff) m; } } @@ -158,7 +183,7 @@ public class IntegrationTestBaselineGenerator { //Generate predictions to compare against if (tc.isTestPredictions()) { List> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null; - List> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null; + List> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null; // Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName()); @@ -178,7 +203,7 @@ public class IntegrationTestBaselineGenerator { Nd4j.write(out, dos); } } - } else if(modelType == ModelType.CG) { + } else if (modelType == ModelType.CG) { for (Pair p : inputs) { INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null); @@ -192,11 +217,11 @@ public class IntegrationTestBaselineGenerator { } } else { List outNames = tc.getPredictionsNamesSameDiff(); - for( Map ph : inputsSd ){ - Map out = sd.output(ph, outNames); + for (Map ph : inputsSd) { + Map out = sd.output(ph, outNames); //Save the output... - for(String s : outNames){ + for (String s : outNames) { File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin"); try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) { Nd4j.write(out.get(s), dos); @@ -211,7 +236,7 @@ public class IntegrationTestBaselineGenerator { //Compute and save gradients: if (tc.isTestGradients()) { INDArray gradientFlat = null; - Map grad; + Map grad; if (modelType == ModelType.MLN) { MultiDataSet data = tc.getGradientsTestData(); mln.setInput(data.getFeatures(0)); @@ -220,7 +245,7 @@ public class IntegrationTestBaselineGenerator { mln.computeGradientAndScore(); gradientFlat = mln.getFlattenedGradients(); grad = m.gradient().gradientForVariable(); - } else if(modelType == ModelType.CG) { + } else if (modelType == ModelType.CG) { MultiDataSet data = tc.getGradientsTestData(); cg.setInputs(data.getFeatures()); cg.setLabels(data.getLabels()); @@ -229,17 +254,17 @@ public class IntegrationTestBaselineGenerator { gradientFlat = cg.getFlattenedGradients(); grad = m.gradient().gradientForVariable(); } else { - Map ph = tc.getGradientsTestDataSameDiff(); + Map ph = tc.getGradientsTestDataSameDiff(); List allVars = new ArrayList<>(); - for(SDVariable v : sd.variables()){ - if(v.getVariableType() == VariableType.VARIABLE){ + for (SDVariable v : sd.variables()) { + if (v.getVariableType() == VariableType.VARIABLE) { allVars.add(v.name()); } } grad = sd.calculateGradients(ph, allVars); } - if(modelType != ModelType.SAMEDIFF) { + if (modelType != ModelType.SAMEDIFF) { File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME); IntegrationTestRunner.write(gradientFlat, gFlatFile); } @@ -254,25 +279,25 @@ public class IntegrationTestBaselineGenerator { } //Test pretraining - if(tc.isTestUnsupervisedTraining()){ + if (tc.isTestUnsupervisedTraining()) { log.info("Performing layerwise pretraining"); MultiDataSetIterator iter = tc.getUnsupervisedTrainData(); INDArray paramsPostTraining; - if(modelType == ModelType.MLN){ + if (modelType == ModelType.MLN) { int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN(); Preconditions.checkState(layersToTrain != null, "Layer indices must not be null"); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); - for( int i : layersToTrain){ + for (int i : layersToTrain) { mln.pretrainLayer(i, dsi); } paramsPostTraining = mln.params(); - } else if(modelType == ModelType.CG) { + } else if (modelType == ModelType.CG) { String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); Preconditions.checkState(layersToTrain != null, "Layer names must not be null"); - for( String i : layersToTrain){ + for (String i : layersToTrain) { cg.pretrainLayer(i, iter); } paramsPostTraining = cg.params(); @@ -290,20 +315,20 @@ public class IntegrationTestBaselineGenerator { MultiDataSetIterator trainData = tc.getTrainingData(); CollectScoresListener l = new CollectScoresListener(1); - if(modelType != ModelType.SAMEDIFF) + if (modelType != ModelType.SAMEDIFF) m.setListeners(l); History h = null; if (modelType == ModelType.MLN) { mln.fit(trainData); - } else if(modelType == ModelType.CG) { + } else if (modelType == ModelType.CG) { cg.fit(trainData); } else { h = sd.fit(trainData, 1); } double[] scores; - if(modelType != ModelType.SAMEDIFF){ + if (modelType != ModelType.SAMEDIFF) { scores = l.getListScore().toDoubleArray(); } else { scores = h.lossCurve().getLossValues().toDoubleVector(); @@ -314,11 +339,11 @@ public class IntegrationTestBaselineGenerator { FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8); if (tc.isTestParamsPostTraining()) { - if(modelType == ModelType.SAMEDIFF){ + if (modelType == ModelType.SAMEDIFF) { File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR); p.mkdirs(); - for(SDVariable v : sd.variables()){ - if(v.getVariableType() == VariableType.VARIABLE){ + for (SDVariable v : sd.variables()) { + if (v.getVariableType() == VariableType.VARIABLE) { INDArray arr = v.getArr(); File p2 = new File(p, v.name() + ".bin"); IntegrationTestRunner.write(arr, p2); @@ -331,7 +356,6 @@ public class IntegrationTestBaselineGenerator { } } - if (tc.isTestEvaluation()) { IEvaluation[] evals = tc.getNewEvaluations(); MultiDataSetIterator iter = tc.getEvaluationTestData(); @@ -339,7 +363,7 @@ public class IntegrationTestBaselineGenerator { if (modelType == ModelType.MLN) { DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); mln.doEvaluation(dsi, evals); - } else if(modelType == ModelType.CG){ + } else if (modelType == ModelType.CG) { cg.doEvaluation(iter, evals); } else { evals = tc.doEvaluationSameDiff(sd, iter, evals); diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java index f16a5e187..de5bc0ea1 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java @@ -16,6 +16,7 @@ package org.deeplearning4j.integration; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; import org.junit.Rule; import org.junit.Test; @@ -37,4 +38,20 @@ public class IntegrationTestsSameDiff extends BaseDL4JTest { IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir); } + @Test + public void testMLPMoon() throws Exception { + IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMoon(), testDir); + } + + @Test + public void testLenetMnist() throws Exception { + IntegrationTestRunner.runTest(SameDiffCNNCases.getLenetMnist(), testDir); + } + + @Test + public void testCnn3dSynthetic() throws Exception { + IntegrationTestRunner.runTest(SameDiffCNNCases.getCnn3dSynthetic(), testDir); + } + + } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java index 71336c0a6..b857b2fb2 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java @@ -194,6 +194,8 @@ public class CNN2DTestCases { testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb) testEvaluation = false; testOverfitting = false; + maxRelativeErrorOutput = 0.2; + minAbsErrorOutput = 0.05; //Max value is around 0.22 } @Override @@ -314,6 +316,7 @@ public class CNN2DTestCases { ComputationGraph model = new TransferLearning.GraphBuilder(pretrained) .fineTuneConfiguration(fineTuneConf) .removeVertexKeepConnections("conv2d_9") + .removeVertexAndConnections("outputs") .addLayer("convolution2d_9", new ConvolutionLayer.Builder(1,1) .nIn(1024) @@ -393,7 +396,7 @@ public class CNN2DTestCases { @Override public ModelType modelType() { - return ModelType.CG; + return ModelType.MLN; } @Override diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java index 232219f04..4264531aa 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java @@ -77,6 +77,10 @@ public class MLPTestCases { testOverfitting = true; maxRelativeErrorOverfit = 2e-2; minAbsErrorOverfit = 1e-2; + maxRelativeErrorGradients = 0.01; + minAbsErrorGradients = 0.05; + maxRelativeErrorParamsPostTraining = 0.01; + minAbsErrorParamsPostTraining = 0.05; } @Override @@ -135,8 +139,7 @@ public class MLPTestCases { public IEvaluation[] getNewEvaluations(){ return new IEvaluation[]{ new Evaluation(), - new ROCMultiClass(), - new EvaluationCalibration() + new ROCMultiClass() }; } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java index f89643380..29f382735 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java @@ -24,6 +24,7 @@ import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.shade.guava.io.Files; import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator; @@ -91,7 +92,7 @@ public class RNNTestCases { } private int miniBatchSize = 32; - private int exampleLength = 1000; + private int exampleLength = 200; @Override @@ -101,6 +102,7 @@ public class RNNTestCases { @Override public Object getConfiguration() throws Exception { + Nd4j.getRandom().setSeed(12345); CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); int nOut = iter.totalOutcomes(); @@ -113,7 +115,7 @@ public class RNNTestCases { .seed(12345) .l2(0.001) .weightInit(WeightInit.XAVIER) - .updater(new RmsProp(0.1)) + .updater(new Adam(1e-3)) .list() .layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize) .activation(Activation.TANH).build()) @@ -140,7 +142,7 @@ public class RNNTestCases { @Override public MultiDataSetIterator getTrainingData() throws Exception { DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); - iter = new EarlyTerminationDataSetIterator(iter, 2); //3 minibatches, 1000/200 = 5 updates per minibatch + iter = new EarlyTerminationDataSetIterator(iter, 2); //2 minibatches, 200/50 = 4 updates per minibatch return new MultiDataSetIteratorAdapter(iter); } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java index 622a6e9cf..b627f06dc 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java @@ -72,12 +72,12 @@ public class UnsupervisedTestCases { return new NeuralNetConfiguration.Builder() .dataType(DataType.FLOAT) .seed(12345) - .updater(new Adam(0.05)) + .updater(new Adam(1e-3)) .weightInit(WeightInit.XAVIER) .l2(1e-4) .list() .layer(0, new VariationalAutoencoder.Builder() - .activation(Activation.LEAKYRELU) + .activation(Activation.TANH) .encoderLayerSizes(256, 256) //2 encoder layers, each of size 256 .decoderLayerSizes(256, 256) //2 decoder layers, each of size 256 .pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java new file mode 100644 index 000000000..74c4f3bfb --- /dev/null +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java @@ -0,0 +1,398 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.integration.testcases.samediff; + +import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; +import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; +import org.deeplearning4j.integration.ModelType; +import org.deeplearning4j.integration.TestCase; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.EvaluationCalibration; +import org.nd4j.evaluation.classification.ROCMultiClass; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.Nesterovs; + +import java.util.*; + +public class SameDiffCNNCases { + + + public static TestCase getLenetMnist() { + return new TestCase() { + { + testName = "LenetMnistSD"; + testType = TestType.RANDOM_INIT; + testPredictions = true; + testTrainingCurves = true; + testGradients = true; + testParamsPostTraining = true; + testEvaluation = true; + testOverfitting = false; + } + + @Override + public ModelType modelType() { + return ModelType.SAMEDIFF; + } + + public Object getConfiguration() throws Exception { + Nd4j.getRandom().setSeed(12345); + + int nChannels = 1; // Number of input channels + int outputNum = 10; // The number of possible outcomes + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, outputNum); + + //input [minibatch, channels=1, Height = 28, Width = 28] + SDVariable in4d = in.reshape(-1, nChannels, 28, 28); + + int kernelHeight = 5; + int kernelWidth = 5; + + + // w0 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 1, outputChannels = 20] + // b0 [20] + SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, nChannels, 20).muli(0.01)); + SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 20).muli(0.01)); + + + SDVariable layer0 = sd.nn.relu(sd.cnn.conv2d("layer0", in4d, w0, b0, Conv2DConfig.builder() + .kH(kernelHeight) + .kW(kernelWidth) + .sH(1) + .sW(1) + .dataFormat("NCHW") + .build()), 0); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W) = ( 28 - 5 + 2*0 ) / 1 + 1 = 24 + // [minibatch,20,24,24] + + + SDVariable layer1 = sd.cnn.maxPooling2d("layer1", layer0, Pooling2DConfig.builder() + .kH(2).kW(2) + .sH(2).sW(2) + .isNHWC(false) + .build()); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W) = ( 24 - 2 + 2*0 ) / 2 + 1 = 12 + // [minibatch,12,12,20] + + + // w2 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 20, outputChannels = 50] + // b0 [50] + SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, 20, 50).muli(0.01)); + SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 50).muli(0.01)); + + + SDVariable layer2 = sd.nn.relu(sd.cnn.conv2d("layer2", layer1, w2, b2, Conv2DConfig.builder() + .kH(kernelHeight) + .kW(kernelWidth) + .sH(1) + .sW(1) + .dataFormat("NCHW") + .build()), 0); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W) = ( 12 - 5 + 2*0 ) / 1 + 1 = 8 + // [minibatch,8,8,50] + + + SDVariable layer3 = sd.cnn.maxPooling2d("layer3", layer2, Pooling2DConfig.builder() + .kH(2).kW(2) + .sH(2).sW(2) + .isNHWC(false) + .build()); + + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W) = ( 8 - 2 + 2*0 ) / 2 + 1 = 4 + // [minibatch,4,4,50] + + int channels_height_width = 4 * 4 * 50; + SDVariable layer3_reshaped = layer3.reshape(-1, channels_height_width); + + SDVariable w4 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width, 500).muli(0.01)); + SDVariable b4 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 500).muli(0.01)); + + + SDVariable layer4 = sd.nn.relu("layer4", layer3_reshaped.mmul(w4).add(b4), 0); + + SDVariable w5 = sd.var("w5", Nd4j.rand(DataType.FLOAT, 500, outputNum)); + SDVariable b5 = sd.var("b5", Nd4j.rand(DataType.FLOAT, outputNum)); + + SDVariable out = sd.nn.softmax("out", layer4.mmul(w5).add(b5)); + SDVariable loss = sd.loss.logLoss("loss", label, out); + + //Also set the training configuration: + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Adam(1e-3)) + .l2(1e-3) + .dataSetFeatureMapping("in") //features[0] -> "in" placeholder + .dataSetLabelMapping("label") //labels[0] -> "label" placeholder + .build()); + + + return sd; + + + } + + @Override + public Map getGradientsTestDataSameDiff() throws Exception { + DataSet ds = new MnistDataSetIterator(8, true, 12345).next(); + Map map = new HashMap<>(); + map.put("in", ds.getFeatures()); + map.put("label", ds.getLabels()); + return map; + } + + @Override + public MultiDataSetIterator getTrainingData() throws Exception { + DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); + + iter = new EarlyTerminationDataSetIterator(iter, 60); + return new MultiDataSetIteratorAdapter(iter); + } + + @Override + public MultiDataSetIterator getEvaluationTestData() throws Exception { + return new MultiDataSetIteratorAdapter(new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10)); + } + + @Override + public List> getPredictionsTestDataSameDiff() throws Exception { + DataSetIterator iter = new MnistDataSetIterator(8, true, 12345); + + List> list = new ArrayList<>(); + + org.nd4j.linalg.dataset.DataSet ds = iter.next(); + ds = ds.asList().get(0); + + list.add(Collections.singletonMap("in", ds.getFeatures())); + ds = iter.next(); + list.add(Collections.singletonMap("in", ds.getFeatures())); + return list; + } + + @Override + public List getPredictionsNamesSameDiff() { + return Collections.singletonList("out"); + + } + + @Override + public IEvaluation[] getNewEvaluations() { + return new IEvaluation[]{ + new Evaluation(), + new ROCMultiClass(), + new EvaluationCalibration()}; + } + + + + @Override + public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) { + sd.evaluate(iter, "out", 0, evaluations); + return evaluations; + } + + }; + } + + + public static TestCase getCnn3dSynthetic() { + return new TestCase() { + { + testName = "Cnn3dSynthetic"; + testType = TestType.RANDOM_INIT; + testPredictions = true; + testTrainingCurves = true; + testGradients = true; + testParamsPostTraining = true; + testEvaluation = true; + testOverfitting = false; + } + + @Override + public ModelType modelType() { + return ModelType.SAMEDIFF; + } + + public Object getConfiguration() throws Exception { + Nd4j.getRandom().setSeed(12345); + + int nChannels = 3; // Number of input channels + int outputNum = 10; // The number of possible outcomes + + SameDiff sd = SameDiff.create(); + + + //input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8] + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, nChannels, 8, 8, 8); + + SDVariable label = sd.placeHolder("label", DataType.FLOAT, nChannels, outputNum); + + //input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8] + + // Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels] + // [kernelDepth = 3, kernelHeight = 3, kernelWidth = 3, inputChannels = 3, outputChannels = 8] + SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 3, 3, 3, nChannels, 8)); + // Optional 1D bias array with shape [outputChannels]. May be null. + SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 8)); + + + SDVariable layer0 = sd.nn.relu(sd.cnn.conv3d("layer0", in, w0, b0, Conv3DConfig.builder() + .kH(3) + .kW(3) + .kD(3) + .sH(2) + .sW(2) + .sD(2) + .dataFormat("NCDHW") + .build()), 0); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W)(D) = (8 - 3 + 2*0 ) / 2 + 1 = 3 + // [minibatch,8,3,3,3] + + + SDVariable layer1 = sd.cnn.maxPooling3d("layer1", layer0, Pooling3DConfig.builder() + .kH(2).kW(2).kD(2) + .sH(2).sW(2).sD(2) + .isNCDHW(true) + .build()); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W)(D) = ( 3 - 2 + 2*0 ) / 2 + 1 = 1 + // [minibatch,8,1,1,1] + + + int channels_height_width_depth = 8 * 1 * 1 * 1; + + SDVariable layer1_reshaped = layer1.reshape(-1, channels_height_width_depth); + + SDVariable w1 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width_depth, 10)); + SDVariable b1 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 10)); + + + SDVariable out = sd.nn.softmax("out", layer1_reshaped.mmul(w1).add(b1)); + SDVariable loss = sd.loss.logLoss("loss", label, out); + + //Also set the training configuration: + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Nesterovs(0.01, 0.9)) + .dataSetFeatureMapping("in") //features[0] -> "in" placeholder + .dataSetLabelMapping("label") //labels[0] -> "label" placeholder + .build()); + + return sd; + + } + + @Override + public Map getGradientsTestDataSameDiff() throws Exception { + Nd4j.getRandom().setSeed(12345); + //NCDHW format + INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10); + + Map map = new HashMap<>(); + map.put("in", arr); + map.put("label", labels); + return map; + + } + + + + @Override + public List getPredictionsNamesSameDiff() { + + return Collections.singletonList("out"); + + } + + + + @Override + public List> getPredictionsTestDataSameDiff() throws Exception { + Nd4j.getRandom().setSeed(12345); + + List> list = new ArrayList<>(); + INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + + list.add(Collections.singletonMap("in", arr)); + + return list; + } + + @Override + public MultiDataSet getGradientsTestData() throws Exception { + Nd4j.getRandom().setSeed(12345); + //NCDHW format + INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10); + return new org.nd4j.linalg.dataset.MultiDataSet(arr, labels); + } + + @Override + public MultiDataSetIterator getTrainingData() throws Exception { + return new SingletonMultiDataSetIterator(getGradientsTestData()); + } + + + @Override + public MultiDataSetIterator getEvaluationTestData() throws Exception { + return getTrainingData(); + } + + @Override + public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){ + sd.evaluate(iter, "out", 0, evaluations); + return evaluations; + } + + @Override + public IEvaluation[] getNewEvaluations(){ + return new IEvaluation[]{new Evaluation()}; + } + + + }; + + } +} \ No newline at end of file diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java index ced461089..9761c87b0 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java @@ -15,9 +15,14 @@ ******************************************************************************/ package org.deeplearning4j.integration.testcases.samediff; +import org.datavec.api.records.reader.RecordReader; +import org.datavec.api.records.reader.impl.csv.CSVRecordReader; +import org.datavec.api.split.FileSplit; +import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; +import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; import org.deeplearning4j.integration.ModelType; import org.deeplearning4j.integration.TestCase; import org.nd4j.autodiff.loss.LossReduce; @@ -26,21 +31,34 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.EvaluationCalibration; +import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.Nesterovs; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.resources.Resources; +import java.io.File; import java.util.*; +import static org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig.*; + public class SameDiffMLPTestCases { - public static TestCase getMLPMnist(){ + public static TestCase getMLPMnist() { return new TestCase() { { testName = "MLPMnistSD"; @@ -69,10 +87,10 @@ public class SameDiffMLPTestCases { SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784); SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10); - SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256)); - SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256)); - SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10)); - SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10)); + SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256).muli(0.1)); + SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256).muli(0.1)); + SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10).muli(0.1)); + SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10).muli(0.1)); SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0)); SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1)); @@ -91,7 +109,7 @@ public class SameDiffMLPTestCases { @Override public List> getPredictionsTestDataSameDiff() throws Exception { - List> out = new ArrayList<>(); + List> out = new ArrayList<>(); DataSetIterator iter = new MnistDataSetIterator(1, true, 12345); out.add(Collections.singletonMap("in", iter.next().getFeatures())); @@ -110,7 +128,7 @@ public class SameDiffMLPTestCases { @Override public Map getGradientsTestDataSameDiff() throws Exception { DataSet ds = new MnistDataSetIterator(8, true, 12345).next(); - Map map = new HashMap<>(); + Map map = new HashMap<>(); map.put("in", ds.getFeatures()); map.put("label", ds.getLabels()); return map; @@ -153,4 +171,160 @@ public class SameDiffMLPTestCases { }; } + + public static TestCase getMLPMoon() { + return new TestCase() { + { + testName = "MLPMoonSD"; + testType = TestType.RANDOM_INIT; + testPredictions = true; + testTrainingCurves = true; + testGradients = true; + testParamsPostTraining = true; + testEvaluation = true; + testOverfitting = true; + maxRelativeErrorOverfit = 2e-2; + minAbsErrorOverfit = 1e-2; + } + + @Override + public ModelType modelType() { + return ModelType.SAMEDIFF; + } + + @Override + public Object getConfiguration() throws Exception { + + int numInputs = 2; + int numOutputs = 2; + int numHiddenNodes = 20; + double learningRate = 0.005; + + + Nd4j.getRandom().setSeed(12345); + + //Define the network structure: + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, numInputs); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, numOutputs); + + SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, numInputs, numHiddenNodes)); + SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, numHiddenNodes)); + SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numHiddenNodes, numOutputs)); + SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numOutputs)); + + SDVariable a0 = sd.nn.relu(in.mmul(w0).add(b0), 0); + SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1)); + SDVariable loss = sd.loss.logLoss("loss", label, out); + + //Also set the training configuration: + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Nesterovs(learningRate, 0.9)) + .weightDecay(1e-3, true) + .dataSetFeatureMapping("in") //features[0] -> "in" placeholder + .dataSetLabelMapping("label") //labels[0] -> "label" placeholder + .build()); + + return sd; + } + + @Override + public List> getPredictionsTestDataSameDiff() throws Exception { + List> out = new ArrayList<>(); + + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv"); + + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 0, 2); + + out.add(Collections.singletonMap("in", iter.next().getFeatures())); + + + return out; + } + + + @Override + public List getPredictionsNamesSameDiff() throws Exception { + return Collections.singletonList("out"); + } + + @Override + public Map getGradientsTestDataSameDiff() throws Exception { + + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv"); + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + org.nd4j.linalg.dataset.DataSet ds = new RecordReaderDataSetIterator(rr, 5, 0, 2).next(); + + Map map = new HashMap<>(); + map.put("in", ds.getFeatures()); + map.put("label", ds.getLabels()); + return map; + } + + @Override + public MultiDataSetIterator getTrainingData() throws Exception { + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_train.csv"); + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2); + + iter = new EarlyTerminationDataSetIterator(iter, 32); + return new MultiDataSetIteratorAdapter(iter); + } + + @Override + public IEvaluation[] getNewEvaluations() { + return new IEvaluation[]{ + new Evaluation(), + new ROCMultiClass(), + new EvaluationCalibration()}; + } + + @Override + public MultiDataSetIterator getEvaluationTestData() throws Exception { + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv"); + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2); + return new MultiDataSetIteratorAdapter(iter); + } + + + @Override + public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) { + sd.evaluate(iter, "out", 0, evaluations); + return evaluations; + } + + @Override + public MultiDataSet getOverfittingData() throws Exception { + + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv"); + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + return new RecordReaderDataSetIterator(rr, 1, 0, 2).next().toMultiDataSet(); + } + + @Override + public int getOverfitNumIterations() { + return 200; + } + }; + + } } + + + + + + + + + + + + diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java new file mode 100644 index 000000000..6bc6254c9 --- /dev/null +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java @@ -0,0 +1,289 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.integration.testcases.samediff; + +import org.datavec.api.records.reader.SequenceRecordReader; +import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; +import org.datavec.api.split.NumberedFileInputSplit; +import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; +import org.deeplearning4j.integration.ModelType; +import org.deeplearning4j.integration.TestCase; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.EvaluationCalibration; +import org.nd4j.evaluation.classification.ROCMultiClass; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor; +import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization; +import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.resources.Resources; +import org.nd4j.shade.guava.io.Files; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class SameDiffRNNTestCases { + + public static TestCase getRnnCsvSequenceClassificationTestCase1() { + return new SameDiffRNNTestCases.RnnCsvSequenceClassificationTestCase1(); + } + + protected static class RnnCsvSequenceClassificationTestCase1 extends TestCase { + protected RnnCsvSequenceClassificationTestCase1() { + testName = "RnnCsvSequenceClassification1"; + testType = TestType.RANDOM_INIT; + testPredictions = true; + testTrainingCurves = false; + testGradients = false; + testParamsPostTraining = false; + testEvaluation = true; + testOverfitting = false; //Not much point on this one - it already fits very well... + } + + + protected MultiDataNormalization normalizer; + + protected MultiDataNormalization getNormalizer() throws Exception { + if (normalizer != null) { + return normalizer; + } + + normalizer = new MultiNormalizerStandardize(); + normalizer.fit(getTrainingDataUnnormalized()); + + return normalizer; + } + + + @Override + public ModelType modelType() { + return ModelType.SAMEDIFF; + } + + + @Override + public Object getConfiguration() throws Exception { + Nd4j.getRandom().setSeed(12345); + + + int miniBatchSize = 10; + int numLabelClasses = 6; + int nIn = 60; + int numUnits = 7; + int timeSteps = 3; + + + SameDiff sd = SameDiff.create(); + + SDVariable in = sd.placeHolder("in", DataType.FLOAT, miniBatchSize, timeSteps, nIn); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, miniBatchSize, numLabelClasses); + + + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits)); + + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.NTS) + .directionMode(LSTMDirectionMode.FWD) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.TANH) + .outAct(LSTMActivations.TANH) + .retFullSequence(true) + .retLastC(true) + .retLastH(true) + .build(); + + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) + .peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits))) + .bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))) + .build(), + c), c); + + +// Behaviour with default settings: 3d (time series) input with shape +// [miniBatchSize, vectorSize, timeSeriesLength] -> 2d output [miniBatchSize, vectorSize] + SDVariable layer0 = outputs.getOutput(); + + SDVariable layer1 = layer0.mean(1); + + SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numUnits, numLabelClasses)); + SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numLabelClasses)); + + + SDVariable out = sd.nn.softmax("out", layer1.mmul(w1).add(b1)); + SDVariable loss = sd.loss.logLoss("loss", label, out); + + //Also set the training configuration: + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Adam(5e-2)) + .l1(1e-3).l2(1e-3) + .dataSetFeatureMapping("in") //features[0] -> "in" placeholder + .dataSetLabelMapping("label") //labels[0] -> "label" placeholder + .build()); + + return sd; + + } + + + @Override + public List> getPredictionsTestDataSameDiff() throws Exception { + + MultiDataSet mds = getTrainingData().next(); + + List> list = new ArrayList<>(); + + list.add(Collections.singletonMap("in", mds.getFeatures()[0].reshape(10, 1, 60))); + //[batchsize, insize] + + return list; + } + + @Override + public List getPredictionsNamesSameDiff() throws Exception { + return Collections.singletonList("out"); + } + + + @Override + public MultiDataSetIterator getTrainingData() throws Exception { + MultiDataSetIterator iter = getTrainingDataUnnormalized(); + MultiDataSetPreProcessor pp = multiDataSet -> { + INDArray l = multiDataSet.getLabels(0); + l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1)); + multiDataSet.setLabels(0, l); + multiDataSet.setLabelsMaskArray(0, null); + }; + + + iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp)); + + return iter; + } + + protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception { + int miniBatchSize = 10; + int numLabelClasses = 6; + + File featuresDirTrain = Files.createTempDir(); + File labelsDirTrain = Files.createTempDir(); + Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain); + Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain); + + SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); + trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); + SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); + trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); + + DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, + false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); + + MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData); + + return iter; + } + + @Override + public IEvaluation[] getNewEvaluations() { + return new IEvaluation[]{ + new Evaluation(), + new ROCMultiClass(), + new EvaluationCalibration() + }; + } + + @Override + public MultiDataSetIterator getEvaluationTestData() throws Exception { + int miniBatchSize = 10; + int numLabelClasses = 6; + +// File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile(); +// File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile(); + File featuresDirTest = Files.createTempDir(); + File labelsDirTest = Files.createTempDir(); + Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest); + Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest); + + SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); + trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); + SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); + trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); + + DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, + false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); + + MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData); + + MultiDataSetPreProcessor pp = multiDataSet -> { + INDArray l = multiDataSet.getLabels(0); + l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1)); + multiDataSet.setLabels(0, l); + multiDataSet.setLabelsMaskArray(0, null); + }; + + + iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp)); + + return iter; + } + + @Override + public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) { + sd.evaluate(iter, "out", 0, evaluations); + return evaluations; + } + } + + +} + + + + + + + + + + + + + diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index 94c795401..d09a40120 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -368,7 +368,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); - REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); + REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); count = 0; auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output @@ -464,13 +464,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { } PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided const auto hasInitH = B_ARG(2); // indicates whether initial output is provided const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto x = INPUT_VARIABLE(0); // input const auto Wx = INPUT_VARIABLE(1); // input weights const auto Wr = INPUT_VARIABLE(2); // recurrent weights @@ -495,7 +503,15 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { DataType hLType = hL != nullptr ? hL->dataType() : xType; DataType cLType = cL != nullptr ? cL->dataType() : xType; - return block.isUseMKLDNN() && ( + auto featuresSupported = (cellClip == 0) //Cell clipping not supported + && retFullSeq //Always return full sequence in case of MKL DNN + && !hasPH //Peephole connections not supported in MKL DNN + && !hasSeqLen //Sequence length array not supported in MKL DNN + && dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn] + && directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat + && retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other) + + return block.isUseMKLDNN() && featuresSupported && ( (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || (xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) || (xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8)) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index fcb63ea0a..093e3099b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -2148,7 +2148,7 @@ public class DifferentialFunctionFactory { public SDVariable gatherNd(SDVariable df, SDVariable indices) { validateDifferentialFunctionsameDiff(df); - return new GatherNd(sameDiff(), df, indices, false).outputVariable(); + return new GatherNd(sameDiff(), df, indices).outputVariable(); } public SDVariable trace(SDVariable in){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 65416a659..3b29e6ccb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.weightinit.WeightInitScheme; import java.io.Serializable; @@ -244,7 +245,7 @@ public class SDVariable implements Serializable { * @return new variable */ public SDVariable assign(Number value){ - return sameDiff.scalarSet(this, value); + return sameDiff.scalarSet(this, value.doubleValue()); } /** @@ -538,7 +539,7 @@ public class SDVariable implements Serializable { * @return Output variable (result of mmul) */ public SDVariable mmul(String name, SDVariable other, @NonNull MMulTranspose mMulTranspose) { - return sameDiff.mmul(name, this, other, mMulTranspose); + return sameDiff.mmul(name, this, other, mMulTranspose.isTransposeA(), mMulTranspose.isTransposeB(), mMulTranspose.isTransposeResult()); } @@ -1403,7 +1404,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable reshape(int... newShape){ - return sameDiff.reshape(this, newShape); + return sameDiff.reshape(this, ArrayUtil.toLongArray(newShape)); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index ab3279fd0..c51ac28a1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -53,6 +53,7 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; @@ -78,6 +79,7 @@ import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ND4JFileUtils; import org.nd4j.shade.guava.collect.HashBasedTable; +import org.nd4j.shade.guava.collect.Sets; import org.nd4j.shade.guava.collect.Table; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.weightinit.WeightInitScheme; @@ -104,7 +106,6 @@ import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; *

* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)} */ -@AllArgsConstructor @Slf4j public class SameDiff extends SDBaseOps { protected static final String GRAD_FN_KEY = "grad"; @@ -914,6 +915,8 @@ public class SameDiff extends SDBaseOps { } private SameDiff() { + super(null); + super.sd = this; functionFactory = new DifferentialFunctionFactory(this); sameDiffFunctionInstances = new LinkedHashMap<>(); fieldVariableResolutionMapping = HashBasedTable.create(); @@ -4544,7 +4547,7 @@ public class SameDiff extends SDBaseOps { } //Also exclude assert etc ops - doesn't make sense to return these "outputs" to user - if (v.getOutputOfOp() != null) { + if (v.getOutputOfOp() != null && v.getVariable().dataType().isFPType()) { String opName = v.getOutputOfOp(); SameDiffOp o = ops.get(opName); if (o.getOp() instanceof Assert) { @@ -4621,12 +4624,6 @@ public class SameDiff extends SDBaseOps { return varToUpdate; } - @Override - protected SameDiff sd() { - //Helper method for SDBaseOps etc - return this; - } - /** * Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable. @@ -5840,7 +5837,6 @@ public class SameDiff extends SDBaseOps { * See {@link #generateNewVarName(String, int, boolean)} * existingOp is true. */ - @Override public String generateNewVarName(String base, int argIndex) { return generateNewVarName(base, argIndex, true); } @@ -5868,4 +5864,261 @@ public class SameDiff extends SDBaseOps { public String toString(){ return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")"; } + + + + /** + * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} + */ + public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond, + @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ + return ifCond(null, null, cond, trueBody, falseBody); + } + + + /** + * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} + */ + public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond, + @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ + return ifCond(null, ifName, cond, trueBody, falseBody); + } + + /** + * Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) + * + * If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody + * + * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate. + * + * See Tensorflow Control Flow Implementation + * + * @param outputName Name to give the output variable. If null, doesn't rename + * @param ifName The name of the if block. If null, uses "if" + * @param cond A lambda evaluating to the if condition + * @param trueBody A lambda to be executed if cond is true (the if block) + * @param falseBody A lambda to be executed if cond is false (the else block) + * @return The value of trueBody if cond is true, or falseBody if it isn't + */ + public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, + @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ + + ifName = newBlockName(ifName == null ? "if" : ifName); + + NameScope ifScope = sd.withNameScope(ifName); + + NameScope condScope = withNameScope("cond"); + final SDVariable pred = cond.define(this); + condScope.close(); + + if (pred.dataType() != DataType.BOOL) { + //cleanup partially added block + + for(SDVariable v : getVariablesInScope(ifScope)) + this.getVariables().remove(v.name()); + + for(SameDiffOp op : this.getOpsInScope(ifScope)) { + for(String in : op.getInputsToOp()){ + this.removeArgFromOp(in, op.getOp()); + } + this.getOps().remove(op.getName()); + } + + + throw new IllegalStateException("Can not use " + pred.name() + + " as the condition of an If statement, the condition must be a boolean."); + } + + final Map switches = new HashMap<>(); + + final Set declared = Sets.newHashSet(this.variableMap().keySet()); + + this.addArgumentInterceptor(new ArgumentInterceptor() { + @Override + public SDVariable intercept(SDVariable argument) { + + // if its declared in the if, we don't care acout it + if(!declared.contains(argument.name())) + return argument; + + // if we've already added a switch, move on + if(switches.containsKey(argument.name())) + return switches.get(argument.name())[1]; + + SDVariable[] s = f().switchOp(argument, pred); + switches.put(argument.name(), s); + return s[1]; + } + }); + NameScope trueScope = this.withNameScope("trueBody"); + SDVariable trueOut = trueBody.define(this); + this.removeArgumentInterceptor(); + + if(declared.contains(trueOut.name())) { + SDVariable[] s = f().switchOp(trueOut, pred); + switches.put(trueOut.name(), s); + trueOut = s[1]; + } + + trueScope.close(); + + final Set declared2 = Sets.newHashSet(variableMap().keySet()); + sd.addArgumentInterceptor(new ArgumentInterceptor() { + @Override + public SDVariable intercept(SDVariable argument) { + + // if its declared in the if, we don't care acout it + if(!declared2.contains(argument.name())) + return argument; + + // if we've already added a switch, move on + if(switches.containsKey(argument.name())) + return switches.get(argument.name())[0]; + + SDVariable[] s = f().switchOp(argument, pred); + switches.put(argument.name(), s); + return s[0]; + } + }); + NameScope falseScope = this.withNameScope("falseBody"); + SDVariable falseOut = falseBody.define(this); + this.removeArgumentInterceptor(); + + if(declared2.contains(falseOut.name())) { + SDVariable[] s = f().switchOp(falseOut, pred); + switches.put(falseOut.name(), s); + falseOut = s[0]; + } + falseScope.close(); + + SDVariable output = f().merge(trueOut, falseOut); + + ifScope.close(); + + return updateVariableNameAndReference(output, outputName); + } + + /** + * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} + */ + public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars, + @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ + return whileLoop(null, null, loopVars, cond, body); + } + + /** + * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} + */ + public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars, + @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ + return whileLoop(null, loopName, loopVars, cond, body); + } + + + /** + * Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration) + * + * Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false + * + * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations. + * + * See Tensorflow Control Flow Implementation + * + * @param outputNames Names to give the output variables. If null, doesn't rename + * @param loopName The name of the loop block and frame (must be unique). If null, uses "if" + * @param loopVars Loop variables' inputs + * @param cond A lambda evaluating to the loop condition + * @param body A lambda doing the loop operation and returning the new loop variable values + * @return The values of the loop variables once condition is false + */ + public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars, + @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ + + final String frameName = this.newBlockName(loopName == null ? "while" : loopName); + + NameScope loopScope = this.withNameScope(frameName); + + //SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0); + + SDVariable[] entered = new SDVariable[loopVars.length]; + for(int i = 0 ; i < loopVars.length ; i++){ + entered[i] = f().enter(loopVars[i], frameName); + } + + //counter = SD.f().enter(counter, frameName); + + SDVariable[] merged = new SDVariable[loopVars.length]; + Merge[] mergeOps = new Merge[loopVars.length]; + for(int i = 0 ; i < loopVars.length ; i++){ + // the second arg will later be replaced with the output of NextIteration + // but that isn't available yet (and can't be, as it depends on this) + mergeOps[i] = new Merge(this, entered[i], entered[i]); + merged[i] = mergeOps[i].outputVariable(); + } + + //Merge counterMerge = new Merge(SD, counter, counter); + //counter = counterMerge.outputVariable(); + + NameScope condScope = this.withNameScope("cond"); + SDVariable cond_result = cond.define(this, merged); + condScope.close(); + + + if (cond_result.dataType() != DataType.BOOL) + throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean."); + + + final Set alreadyEntered = Sets.newHashSet(); + SDVariable[] trueSwitches = new SDVariable[loopVars.length]; + SDVariable[] exits = new SDVariable[loopVars.length]; + for(int i = 0 ; i < loopVars.length ; i++){ + SDVariable[] s = f().switchOp(merged[i], cond_result); + trueSwitches[i] = s[1]; + alreadyEntered.add(s[1].name()); + exits[i] = f().exit(s[0]); + } + + //SDVariable[] cs = SD.f().switchOp(counter, cond_result); + //SDVariable counterExit = SD.f().exit(cs[0]); + //counter = cs[1]; + + final Set declared = Sets.newHashSet(this.variableMap().keySet()); + final Map done = new HashMap<>(); + + this.addArgumentInterceptor(new ArgumentInterceptor() { + @Override + public SDVariable intercept(SDVariable argument) { + + if(!declared.contains(argument.name())) + return argument; + + if(alreadyEntered.contains(argument.name())) + return argument; + + if(done.containsKey(argument.name())) + return done.get(argument.name()); + + SDVariable e = f().enter(argument, frameName, true); + done.put(argument.name(), e); + return e; + } + }); + + NameScope bodyScope = this.withNameScope("body"); + SDVariable[] outs = body.define(this, trueSwitches); + bodyScope.close(); + this.removeArgumentInterceptor(); + + //counter.add(1); + + for(int i = 0 ; i < loopVars.length ; i++){ + SDVariable n = f().nextIteration(outs[i]); + mergeOps[i].replaceArg(1,n); + } + + //counterMerge.replaceArg(1, counter); + + loopScope.close(); + return updateVariableNamesAndReferences(exits, outputNames); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 8e5d1ca36..3b53e5b65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,3403 +14,4720 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import org.nd4j.shade.guava.collect.Sets; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; -import lombok.NonNull; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; -import org.nd4j.autodiff.samediff.ArgumentInterceptor; -import org.nd4j.autodiff.samediff.NameScope; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffLambda; -import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda; -import org.nd4j.autodiff.samediff.SameDiffSingleLambda; -import org.nd4j.autodiff.samediff.internal.SameDiffOp; -import org.nd4j.linalg.api.blas.params.MMulTranspose; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.shape.OneHot; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.indexing.conditions.Condition; -import static org.nd4j.autodiff.samediff.ops.SDValidation.*; - -/** - * Core op creator methods available via SameDiff class directly - * - * @author Alex Black - * @see SDMath SDMath for Math operations - * @see SDRandom SDRandom for random number generator operations - * @see SDNN SDNN for general neural network operations - * @see SDCNN SDCNN for Convolutional Neural Network operations - * @see SDRNN SDRNN for Recurrent Neural Network operations - * @see SDLoss SDLoss for loss function operations - */ -public abstract class SDBaseOps { - - /** - * Intended for internal/developer use - */ - protected SDVariable gradientBackwardsMarker(SDVariable x) { - return gradientBackwardsMarker(generateNewVarName(new GradientBackwardsMarker().opName(), 0), x); - } - - /** - * Intended for internal/developer use - */ - protected SDVariable gradientBackwardsMarker(String name, SDVariable x) { - SDVariable result = f().gradientBackwardsMarker(x); - return updateVariableNameAndReference(result, name); - } - - protected abstract String generateNewVarName(String baseName, int argIndex); - - protected abstract DifferentialFunctionFactory f(); - - protected abstract SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName); - - protected abstract SameDiff sd(); - - /** - * Argmax array reduction operation, optionally along specified dimensions.
- * Output values are the index of the maximum value of each slice along the specified dimension - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable argmax(SDVariable in, int... dimensions) { - return argmax(null, in, false, dimensions); - } - - /** - * Argmax array reduction operation, optionally along specified dimensions.
- * Output values are the index of the maximum value of each slice along the specified dimension.
- *
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true - */ - public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) { - validateNumerical("argmax", in); - SDVariable ret = f().argmax(in, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #argmax(String, SDVariable, boolean, int...) - */ - public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { - return argmax(null, in, keepDims, dimensions); - } - - /** - * Argmax array reduction operation, optionally along specified dimensions.
- * Output values are the index of the maximum value of each slice along the specified dimension - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable argmax(String name, SDVariable in, int... dimensions) { - return argmax(name, in, false, dimensions); - } - - /** - * Argmin array reduction operation, optionally along specified dimensions.
- * Output values are the index of the minimum value of each slice along the specified dimension - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable argmin(SDVariable in, int... dimensions) { - return argmin(null, in, dimensions); - } - - /** - * Argmin array reduction operation, optionally along specified dimensions.
- * Output values are the index of the minimum value of each slice along the specified dimension - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable argmin(String name, SDVariable in, int... dimensions) { - return argmin(name, in, false, dimensions); - } - - /** - * Argmin array reduction operation, optionally along specified dimensions.
- * Output values are the index of the minimum value of each slice along the specified dimension.
- *
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true - */ - public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) { - validateNumerical("argmin", in); - SDVariable ret = f().argmin(in, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #argmin(String, SDVariable, boolean, int...) - */ - public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { - return argmin(null, in, keepDims, dimensions); - } - - /** - * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same - * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), - * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead. - * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N). - *

- *

- * The result of this operation will be a batch of multiplied matrices. The - * result has the same length as both input batches and each output matrix is of shape (M, K). - * - * @param matricesA First array of input matrices, all of shape (M, N) or (N, M) - * @param matricesB Second array of input matrices, all of shape (N, K) or (K, N) - * @param transposeA whether first batch of matrices is transposed. - * @param transposeB whether second batch of matrices is transposed. - * @return Array of multiplied SDVariables of shape (M, K) - */ - public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB, - boolean transposeA, boolean transposeB) { - return batchMmul(null, matricesA, matricesB, transposeA, transposeB); - } - - /** - * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same - * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), - * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead. - * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N). - *

- *

- * The result of this operation will be a batch of multiplied matrices. The - * result has the same length as both input batches and each output matrix is of shape (M, K). - * - * @param matricesA First array of input matrices, all of shape (M, N) or (N, M) - * @param matricesB Second array of input matrices, all of shape (N, K) or (K, N) - * @param transposeA whether first batch of matrices is transposed. - * @param transposeB whether second batch of matrices is transposed. - * @param names names for all provided SDVariables - * @return Array of multiplied SDVariables of shape (M, K) - */ - public SDVariable[] batchMmul(String[] names, SDVariable[] matricesA, SDVariable[] matricesB, - boolean transposeA, boolean transposeB) { - validateSameType("batchMmul", true, matricesA); - validateSameType("batchMmul", true, matricesB); - SDVariable[] result = f().batchMmul(matricesA, matricesB, transposeA, transposeB); - return updateVariableNamesAndReferences(result, names); - } - - protected abstract SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames); - - /** - * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same - * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), - * respectively. The result of this operation will be a batch of multiplied matrices. The - * result has the same length as both input batches and each output matrix is of shape (M, K). - * - * @param matricesA First array of input matrices, all of shape (M, N) - * @param matricesB Second array of input matrices, all of shape (N, K) - * @return Array of multiplied SDVariables of shape (M, K) - */ - public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB) { - return batchMmul(null, matricesA, matricesB, false, false); - } - - public SDVariable castTo(SDVariable toCast, org.nd4j.linalg.api.buffer.DataType toType) { - return castTo(null, toCast, toType); - } - - public SDVariable castTo(String name, SDVariable toCast, org.nd4j.linalg.api.buffer.DataType toType) { - SDVariable ret = f().cast(toCast, toType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #concat(String, int, SDVariable...) - */ - public SDVariable concat(int dimension, SDVariable... inputs) { - return concat(null, dimension, inputs); - } - - /** - * Concatenate a set of inputs along the specified dimension.
- * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
- * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c] - * - * @param name Name of the output variable - * @param dimension Dimension to concatenate on - * @param inputs Input variables - * @return Output variable - * @see #stack(String, int, SDVariable...) - */ - public SDVariable concat(String name, int dimension, SDVariable... inputs) { - validateSameType("concat", false, inputs); - SDVariable result = f().concat(dimension, inputs); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #cumprod(String, SDVariable, boolean, boolean, int...) - */ - public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return cumprod(null, in, exclusive, reverse, axis); - } - - /** - * Cumulative product operation.
- * For input: [ a, b, c], output is:
- * exclusize=false, reverse=false: [a, a*b, a*b*c]
- * exclusive=true, reverse=false, [0, a, a*b]
- * exclusive=false, reverse=true: [a*b*c, b*c, c]
- * exclusive=true, reverse=true: [b*c, c, 0]

- * - * @param name Name of the output variable - * @param in Input variable - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along - * @param exclusive If true: exclude the first value - * @param reverse If true: reverse the direction of the accumulation - * @return Output variable - */ - public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) { - validateNumerical("cumprod", in); - SDVariable ret = f().cumprod(in, exclusive, reverse, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #cumsum(String, SDVariable, boolean, boolean, int...) - */ - public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return cumsum(null, in, exclusive, reverse, axis); - } - - /** - * Cumulative sum operation.
- * For input: [ a, b, c], output is:
- * exclusize=false, reverse=false: [a, a+b, a+b+c]
- * exclusive=true, reverse=false, [0, a, a+b]
- * exclusive=false, reverse=true: [a+b+c, b+c, c]
- * exclusive=true, reverse=true: [b+c, c, 0]

- * - * @param name Name of the output variable - * @param in Input variable - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along - * @param exclusive If true: exclude the first value - * @param reverse If true: reverse the direction of the accumulation - * @return Output variable - */ - public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) { - validateNumerical("cumsum", in); - SDVariable ret = f().cumsum(in, exclusive, reverse, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * TODO doc string - * - * @param x - * @param y - * @param dimensions - * @return - */ - public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) { - return dot(null, x, y, dimensions); - } - - /** - * TODO doc string - * - * @param name - * @param x - * @param y - * @param dimensions - * @return - */ - public SDVariable dot(String name, SDVariable x, SDVariable y, int... dimensions) { - SDValidation.validateNumerical("dot", x, y); - SDVariable ret = f().dot(x, y, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #dynamicPartition(String[], SDVariable, SDVariable, int) - */ - public SDVariable[] dynamicPartition(SDVariable x, SDVariable partitions, int numPartitions) { - return dynamicPartition(null, x, partitions, numPartitions); - } - - /** - * Dynamically partition the input variable values into the specified number of paritions, using the indices.
- * Example:
- *

-     * {@code input = [1,2,3,4,5]
-     * numPartitions = 2
-     * partitions = [1,0,0,1,0]
-     * out[0] = [2,3,5]
-     * out[1] = [1,4] }
-     * 
- * - * @param name Names for the output variables. Length must be equal to numPartitions - * @param x Input variable - * @param partitions 1D input with values 0 to numPartitions-1 - * @param numPartitions Number of partitions, >= 1 - * @return Output variables (equal in number to numPartitions) - */ - public SDVariable[] dynamicPartition(String[] name, SDVariable x, SDVariable partitions, int numPartitions) { - SDVariable[] ret = f().dynamicPartition(x, partitions, numPartitions); - return updateVariableNamesAndReferences(ret, name); - } - - /** - * @see #dynamicStitch(String, SDVariable[], SDVariable[]) - */ - public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] x) { - return dynamicStitch(null, indices, x); - } - - /** - * Dynamically merge the specified input arrays into a single array, using the specified indices - * - * @param name Name of the output variable - * @param indices Indices to use when merging. Must be >= 1, same length as input variables - * @param x Input variables. - * @return Merged output variable - */ - public SDVariable dynamicStitch(String name, SDVariable[] indices, SDVariable[] x) { - SDVariable ret = f().dynamicStitch(indices, x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Equals operation: elementwise x == y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable eq(SDVariable x, double y) { - return eq(null, x, y); - } - - /** - * Equals operation: elementwise x == y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable eq(String name, SDVariable x, double y) { - SDVariable result = f().eq(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Equal to operation: elementwise x == y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable eq(SDVariable x, SDVariable y) { - return eq(null, x, y); - } - - /** - * Equal to operation: elementwise x == y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable eq(String name, SDVariable x, SDVariable y) { - SDVariable result = f().eq(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #expandDims(String, SDVariable, int) - */ - public SDVariable expandDims(SDVariable x, int axis) { - return expandDims(null, x, axis); - } - - /** - * Reshape the input by adding a 1 at the specified location.
- * For example, if input has shape [a, b], then output shape is:
- * axis = 0: [1, a, b]
- * axis = 1: [a, 1, b]
- * axis = 2: [a, b, 1]
- * - * @param name Name of the output variable - * @param x Input variable - * @param axis Axis to expand - * @return Output variable - * @see #squeeze(String, SDVariable, int) - */ - public SDVariable expandDims(String name, SDVariable x, int axis) { - SDVariable result = f().expandDims(x, axis); - return updateVariableNameAndReference(result, name); - } - - /** - * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value - * - * @param shape Shape: must be a 1D array/variable - * @param value Value to set all elements to - * @return Output variable - */ - public SDVariable fill(SDVariable shape, org.nd4j.linalg.api.buffer.DataType dataType, double value) { - return fill(null, shape, dataType, value); - } - - /** - * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value - * - * @param name Name of the output variable - * @param shape Shape: must be a 1D array/variable - * @param value Value to set all elements to - * @return Output variable - */ - public SDVariable fill(String name, SDVariable shape, org.nd4j.linalg.api.buffer.DataType dataType, double value) { - SDVariable result = f().fill(shape, dataType, value); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #gather(String, SDVariable, int[], int) - */ - public SDVariable gather(SDVariable df, int[] indices, int axis) { - return gather(null, df, indices, axis); - } - - /** - * Gather slices from the input variable where the indices are specified as fixed int[] values.
- * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length. - * - * @param name name of the output variable - * @param df Input variable - * @param indices Indices to get - * @param axis Axis that the indices refer to - * @return Output variable with slices pulled from the specified axis - */ - public SDVariable gather(String name, SDVariable df, int[] indices, int axis) { - SDVariable ret = f().gather(df, indices, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #gather(String, SDVariable, SDVariable, int) - */ - public SDVariable gather(SDVariable df, SDVariable indices, int axis) { - return gather(null, df, indices, axis); - } - - /** - * Gather slices from the input variable where the indices are specified as dynamic SDVariable values.
- * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length. - * - * @param name name of the output variable - * @param df Input variable - * @param indices Indices to get slices for. Rank 0 or 1 input - * @param axis Axis that the indices refer to - * @return Output variable with slices pulled from the specified axis - */ - public SDVariable gather(String name, SDVariable df, SDVariable indices, int axis) { - SDVariable ret = f().gather(df, indices, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * TODO doc string - * - * @param df - * @param indices - * @return - */ - public SDVariable gatherNd(SDVariable df, SDVariable indices) { - return gatherNd(null, df, indices); - } - - /** - * TODO doc string - * - * @param name - * @param df - * @param indices - * @return - */ - public SDVariable gatherNd(String name, SDVariable df, SDVariable indices) { - SDVariable ret = f().gatherNd(df, indices); - return updateVariableNameAndReference(ret, name); - } - - /** - * Greater than operation: elementwise x > y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gt(SDVariable x, double y) { - return gt(null, x, y); - } - - /** - * Greater than operation: elementwise x > y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gt(String name, SDVariable x, double y) { - validateNumerical("greater than (gt)", x); - SDVariable result = f().gt(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Greater than operation: elementwise x > y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gt(SDVariable x, SDVariable y) { - return gt(null, x, y); - } - - /** - * Greater than operation: elementwise x > y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gt(String name, SDVariable x, SDVariable y) { - SDValidation.validateNumerical("greater than (gt)", x, y); - SDVariable result = f().gt(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Greater than or equals operation: elementwise x >= y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gte(SDVariable x, double y) { - return gte(null, x, y); - } - - /** - * Greater than or equals operation: elementwise x >= y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gte(String name, SDVariable x, double y) { - validateNumerical("greater than or equal (gte)", x); - SDVariable result = f().gte(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Greater than or equal to operation: elementwise x >= y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gte(SDVariable x, SDVariable y) { - return gte(null, x, y); - } - - /** - * Greater than or equal to operation: elementwise x >= y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gte(String name, SDVariable x, SDVariable y) { - validateNumerical("greater than or equal (gte)", x, y); - SDVariable result = f().gte(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise identity operation: out = x - * - * @param input Input variable - * @return Output variable - */ - public SDVariable identity(SDVariable input) { - return identity(null, input); - } - - /** - * Elementwise identity operation: out = x - * - * @param name name of the output variable - * @param input Input variable - * @return Output variable - */ - public SDVariable identity(String name, SDVariable input) { - SDVariable s = f().identity(input); - return updateVariableNameAndReference(s, name); - } - - /** - * Compute the inverse permutation indices for a permutation operation
- * Example: if input is [2, 0, 1] then output is [1, 2, 0]
- * The idea is that x.permute(input).permute(invertPermutation(input)) == x - * - * @param input 1D indices for permutation - * @return 1D inverted permutation - */ - public SDVariable invertPermutation(SDVariable input) { - return invertPermutation(null, input); - } - - /** - * Compute the inverse permutation indices for a permutation operation
- * Example: if input is [2, 0, 1] then output is [1, 2, 0]
- * The idea is that x.permute(input).permute(invertPermutation(input)) == x - * - * @param name name of the output variable - * @param input 1D indices for permutation - * @return 1D inverted permutation - */ - public SDVariable invertPermutation(String name, SDVariable input) { - validateInteger("invert permutation", input); - SDVariable ret = f().invertPermutation(input, false); - return updateVariableNameAndReference(ret, name); - } - - /** - * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1 - * - * @param x Input variable - * @return Scalar variable with value 1 - */ - public SDVariable isNumericTensor(SDVariable x) { - return isNumericTensor(null, x); - } - - /** - * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1 - * - * @param name Output variable name - * @param x Input variable - * @return Scalar variable with value 1 - */ - public SDVariable isNumericTensor(String name, SDVariable x) { - SDVariable result = f().isNumericTensor(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop' - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0] - * - * @param start Start value - * @param stop Stop value - * @param number Number of values to generate - * @return SDVariable with linearly spaced elements - */ - // TODO: fix or remove, currently it is internal recursion - /*public SDVariable linspace(DataType dataType, double start, double stop, long number) { - return linspace(dataType, start, stop, number); - }*/ - - /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop' - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0] - * - * @param name Name of the new variable - * @param dataType Data type of the output array - * @param start Start value - * @param stop Stop value - * @param number Number of values to generate - * @return SDVariable with linearly spaced elements - */ - public SDVariable linspace(String name, DataType dataType, double start, double stop, long number) { - SDVariable ret = f().linspace(sd().constant(start), sd().constant(stop), sd().constant(number), dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop' - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0] - * - * @param name Name of the new variable - * @param from Start value - * @param to Stop value - * @param length Number of values to generate - * @param dt Data type of the output array - * @return SDVariable with linearly spaced elements - */ - public SDVariable linspace(String name, SDVariable from, SDVariable to, SDVariable length, DataType dt) { - SDVariable ret = f().linspace(from, to, length, dt); - return updateVariableNameAndReference(ret, name); - } - - /** - * Less than operation: elementwise x < y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lt(SDVariable x, double y) { - return lt(null, x, y); - } - - /** - * Less than operation: elementwise x < y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lt(String name, SDVariable x, double y) { - validateNumerical("less than (lt)", x); - SDVariable result = f().lt(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Less than operation: elementwise x < y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lt(SDVariable x, SDVariable y) { - return lt(null, x, y); - } - - /** - * Less than operation: elementwise x < y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lt(String name, SDVariable x, SDVariable y) { - validateNumerical("less than (lt)", x, y); - SDVariable result = f().lt(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Less than or equals operation: elementwise x <= y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lte(SDVariable x, double y) { - return lte(null, x, y); - } - - /** - * Less than or equals operation: elementwise x <= y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lte(String name, SDVariable x, double y) { - validateNumerical("less than or equal (lte)", x); - SDVariable result = f().lte(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Less than or equal to operation: elementwise x <= y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lte(SDVariable x, SDVariable y) { - return lte(null, x, y); - } - - /** - * Less than or equal to operation: elementwise x <= y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lte(String name, SDVariable x, SDVariable y) { - validateNumerical("less than or equal (lte)", x, y); - SDVariable result = f().lte(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise - * - * @param in Input variable - * @param condition Condition - * @return Boolean mask mariable - */ - public SDVariable matchCondition(SDVariable in, Condition condition) { - return matchCondition(null, in, condition); - } - - /** - * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise - * - * @param in Input - * @param condition Condition - * @return Boolean mask - */ - public SDVariable matchCondition(String name, SDVariable in, Condition condition) { - SDVariable ret = f().matchCondition(in, condition); - return updateVariableNameAndReference(ret, name); - } - - /** - * Returns a count of the number of elements that satisfy the condition - * - * @param in Input - * @param condition Condition - * @return Number of elements that the condition is satisfied for - */ - public SDVariable matchConditionCount(SDVariable in, Condition condition) { - return matchConditionCount(null, in, condition); - } - - /** - * Returns a count of the number of elements that satisfy the condition - * - * @param name Name of the output variable - * @param in Input - * @param condition Condition - * @return Number of elements that the condition is satisfied for - */ - public SDVariable matchConditionCount(String name, SDVariable in, Condition condition) { - return matchConditionCount(name, in, condition, false); - } - - /** - * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition - * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Number of elements that the condition is satisfied for - */ - public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, boolean keepDim, int... dimensions) { - SDVariable ret = f().matchConditionCount(in, condition, keepDim, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Max array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable max(SDVariable x, int... dimensions) { - return max(null, x, dimensions); - } - - /** - * Max array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable max(String name, SDVariable x, int... dimensions) { - return max(name, x, false, dimensions); - } - - /** - * Max array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable max(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("max reduction", x); - SDVariable result = f().max(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise maximum operation: out[i] = max(first[i], second[i])
- * Supports broadcasting - * - * @param first First input array - * @param second Second input array - * @return Output variable - */ - public SDVariable max(SDVariable first, SDVariable second) { - return max(null, first, second); - } - - /** - * Element-wise maximum operation: out[i] = max(first[i], second[i])
- * Supports broadcasting - * - * @param name Name of the output variable - * @param first First input array - * @param second Second input array - * @return Output variable - */ - public SDVariable max(String name, SDVariable first, SDVariable second) { - validateNumerical("pairwise maxiumum (max)", first, second); - SDVariable result = f().max(first, second); - return updateVariableNameAndReference(result, name); - } - - /** - * Full array mean reduction operation - * - * @param x Input variable - * @return Output variable - scalar - */ - public SDVariable mean(SDVariable x) { - return mean(null, x); - } - - /** - * Mean (average) array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable mean(String name, SDVariable x, int... dimension) { - return mean(name, x, false, dimension); - } - - /** - * Mean (average) array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable mean(String name, SDVariable x, boolean keepDims, int... dimension) { - validateNumerical("mean reduction", x); - SDVariable result = f().mean(x, keepDims, dimension); - return updateVariableNameAndReference(result, name); - } - - /** - * Mean (average) array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable mean(SDVariable x, int... dimension) { - return mean(null, x, dimension); - } - - /** - * Minimum array reduction operation, optionally along specified dimensions. out = min(in) - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable min(SDVariable x, int... dimensions) { - return min(null, x, dimensions); - } - - /** - * Minimum array reduction operation, optionally along specified dimensions. out = min(in) - * - * @param name Output variable name - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable min(String name, SDVariable x, int... dimensions) { - return min(name, x, false, dimensions); - } - - /** - * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable min(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("min reduction", x); - SDVariable result = f().min(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - - } - - /** - * Element-wise minimum operation: out[i] = min(first[i], second[i])
- * Supports broadcasting - * - * @param first First input array - * @param second Second input array - * @return Output variable - */ - public SDVariable min(SDVariable first, SDVariable second) { - return min(null, first, second); - } - - /** - * Element-wise minimum operation: out[i] = min(first[i], second[i])
- * Supports broadcasting - * - * @param name Name of the output variable - * @param first First input array - * @param second Second input array - * @return Output variable - */ - public SDVariable min(String name, SDVariable first, SDVariable second) { - validateNumerical("mean (pairwise)", first, second); - SDVariable result = f().min(first, second); - return updateVariableNameAndReference(result, name); - } - - /** - * Matrix multiplication: out = mmul(x,y)
- * Supports specifying a {@link MMulTranspose} argument to perform operation such as mmul(a^T, b), etc. - * - * @param x First input variable - * @param y Second input variable - * @param transpose Transpose arguments - * @return Output variable - */ - public SDVariable mmul(SDVariable x, SDVariable y, MMulTranspose transpose) { - return mmul(null, x, y, transpose); - - } - - /** - * Matrix multiplication: out = mmul(x,y)
- * Supports specifying a {@link MMulTranspose} argument to perform operation such as mmul(a^T, b), etc. - * - * @param name Output variable name - * @param x First input variable - * @param y Second input variable - * @param transpose Transpose arguments - * @return Output variable - */ - public SDVariable mmul(String name, SDVariable x, SDVariable y, MMulTranspose transpose) { - validateNumerical("matrix multiplication (mmul)", x, y); - SDVariable result = f().mmul(x, y, transpose); - return updateVariableNameAndReference(result, name); - } - - /** - * Matrix multiplication: out = mmul(x,y) - * - * @param x First input variable - * @param y Second input variable - * @return Output variable - */ - public SDVariable mmul(SDVariable x, SDVariable y) { - return mmul(null, x, y); - } - - /** - * Matrix multiplication: out = mmul(x,y) - * - * @param name Output variable name - * @param x First input variable - * @param y Second input variable - * @return Output variable - */ - public SDVariable mmul(String name, SDVariable x, SDVariable y) { - return mmul(name, x, y, MMulTranspose.allFalse()); - } - - /** - * Not equals operation: elementwise x != y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable neq(SDVariable x, double y) { - return neq(null, x, y); - } - - /** - * Not equals operation: elementwise x != y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable neq(String name, SDVariable x, double y) { - validateNumerical("not equals (neq)", x); - SDVariable result = f().neq(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Not equal to operation: elementwise x != y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable neq(SDVariable x, SDVariable y) { - return neq(null, x, y); - } - - /** - * Not equal to operation: elementwise x != y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable neq(String name, SDVariable x, SDVariable y) { - validateNumerical("not equals (neq)", x, y); - SDVariable result = f().neq(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i]) - * - * @param name Output variable name - * @param x Input variable - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable norm1(String name, SDVariable x, int... dimensions) { - return norm1(name, x, false, dimensions); - } - - /** - * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i])
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable norm1(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("norm1 reduction", x); - SDVariable result = f().norm1(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
- * out = sqrt(sum_i x[i]^2) - * - * @param name Output variable name - * @param x Input variable - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable norm2(String name, SDVariable x, int... dimensions) { - return norm2(name, x, false, dimensions); - } - - /** - * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
- * out = sqrt(sum_i x[i]^2)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable norm2(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("norm2 reduction", x); - SDVariable result = f().norm2(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the - * specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable normmax(String name, SDVariable x, int... dimensions) { - return normmax(name, x, false, dimensions); - } - - /** - * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the - * specified dimensions:
- * out = max(abs(x[i]))
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable normmax(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("norm max reduction", x); - SDVariable result = f().normmax(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #oneHot(String, SDVariable, int) - */ - public SDVariable oneHot(SDVariable indices, int depth) { - return oneHot(null, indices, depth, -1, 1.00, 0.00); - } - - /** - * Convert the array to a one-hot array with walues {@code on} and {@code off} for each entry
- * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth], - * with {@code out[i, ..., j, in[i,...,j]] = on} with other values being set to {@code off} - * - * @param name Output variable name - * @param indices Indices - value 0 to depth-1 - * @param depth Number of classes - * @return Output variable - */ - public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, double off) { - return oneHot(name, indices, depth, axis, on, off, OneHot.DEFAULT_DTYPE); - } - - /** - * As per {@link #oneHot(String, SDVariable, int, int, double, double)} but allows configuring the output datatype - */ - public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { - validateInteger("oneHot", "indices", indices); - SDVariable ret = f().onehot(indices, depth, axis, on, off, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #oneHot(String, SDVariable, int, int, double, double) - */ - public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) { - return oneHot(null, indices, depth, axis, on, off, OneHot.DEFAULT_DTYPE); - } - - /** - * @see #oneHot(String, SDVariable, int, int, double, double, DataType) - */ - public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { - return oneHot(null, indices, depth, axis, on, off, dataType); - } - - /** - * Convert the array to a one-hot array with walues 0 and 1 for each entry
- * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth], - * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0 - * - * @param name Output variable name - * @param indices Indices - value 0 to depth-1 - * @param depth Number of classes - * @return Output variable - * @see #oneHot(SDVariable, int, int, double, double) - */ - public SDVariable oneHot(String name, SDVariable indices, int depth) { - return oneHot(name, indices, depth, -1, 1.00, 0.00); - } - - /** - * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic: - * if the input shape changes in later execution, the returned variable's shape will also be updated - * - * @param input Input SDVariable - * @return A new SDVariable with the same (dynamic) shape as the input - */ - public SDVariable onesLike(SDVariable input) { - return onesLike(null, input); - } - - /** - * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic: - * if the input shape changes in later execution, the returned variable's shape will also be updated - * - * @param name Name of the new SDVariable - * @param input Input SDVariable - * @return A new SDVariable with the same (dynamic) shape as the input - */ - public SDVariable onesLike(String name, SDVariable input) { - return onesLike(name, input, input.dataType()); - } - - /** - * As per {@link #onesLike(String, SDVariable)} but the output datatype may be specified - */ - public SDVariable onesLike(String name, @NonNull SDVariable input, @NonNull DataType dataType) { - SDVariable ret = f().onesLike(name, input, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #stack(String, int, SDVariable...) - */ - public SDVariable parallel_stack(SDVariable[] values) { - return parallel_stack(null, values); - } - - /** - * @see #stack(String, int, SDVariable...) - */ - public SDVariable parallel_stack(String name, SDVariable[] values) { - validateSameType("parallel_stack", false, values); - SDVariable ret = f().parallel_stack(values); - return updateVariableNameAndReference(ret, name); - } - - /** - * Array permutation operation: permute the dimensions according to the specified permutation indices.
- * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b] - * - * @param x Input variable - * @return Output variable (permuted input) - */ - public SDVariable permute(SDVariable x, int... dimensions) { - return permute(null, x, dimensions); - } - - /** - * Array permutation operation: permute the dimensions according to the specified permutation indices.
- * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b] - * - * @param name Output variable name - * @param x Input variable - * @return Output variable (permuted input) - */ - public SDVariable permute(String name, SDVariable x, int... dimensions) { - SDVariable result = f().permute(x, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * As per {@link #permute(String, SDVariable, int...)} but with SDVariable permute dimension - */ - public SDVariable permute(String name, SDVariable x, SDVariable dimensions){ - SDVariable result = f().permute(x, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Product array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable prod(SDVariable x, int... dimensions) { - return prod(null, x, dimensions); - } - - /** - * Product array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable prod(String name, SDVariable x, int... dimensions) { - return prod(name, x, false, dimensions); - } - - /** - * Product array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable prod(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("product reduction (prod)", x); - SDVariable result = f().prod(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Create a new variable with a 1d array, where the values start at {@code from} and increment by {@code step} - * up to (but not including) limit.
- * For example, {@code range(1.0, 3.0, 0.5)} will return {@code [1.0, 1.5, 2.0, 2.5]} - * - * @param from Initial/smallest value - * @param to Largest value (exclusive) - * @param step Step size - * @param dataType The output variable datatype - * @return 1D SDVariable with the specified values - */ - public SDVariable range(double from, double to, double step, DataType dataType) { - return range(null, from, to, step, dataType); - } - - /** - * Create a new variable with a 1d array, where the values start at {@code from} and increment by {@code step} - * up to (but not including) limit.
- * For example, {@code range(1.0, 3.0, 0.5)} will return {@code [1.0, 1.5, 2.0, 2.5]} - * - * @param name Name of the new variable - * @param from Initial/smallest value - * @param to Largest value (exclusive) - * @param step Step size - * @return 1D SDVariable with the specified values - */ - public SDVariable range(String name, double from, double to, double step, DataType dataType) { - SDVariable ret = f().range(from, to, step, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * As per {@link #range(String, double, double, double, DataType)} but with SDVariable arguments - */ - public SDVariable range(String name, SDVariable from, SDVariable to, SDVariable step, DataType dataType) { - SDVariable ret = f().range(from, to, step, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * Returns the rank (number of dimensions, i.e., length(shape)) of the specified SDVariable as a 0D scalar variable - * - * @param in Input variable - * @return 0D (scalar) output variable with value equal to the rank of the input variable - */ - public SDVariable rank(SDVariable in) { - return rank(null, in); - } - - /** - * Returns the rank (number of dimensions, i.e., length(shape)) of the specified SDVariable as a 0D scalar variable - * - * @param name Name of the output variable - * @param in Input variable - * @return 0D (scalar) output variable with value equal to the rank of the input variable - */ - public SDVariable rank(String name, SDVariable in) { - SDVariable ret = f().rank(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #repeat(String, SDVariable, int) - */ - public SDVariable repeat(SDVariable df, int axis) { - return repeat(null, df, axis); - } - - /** - * @see #repeat(String, SDVariable, int) - */ - public SDVariable repeat(String name, SDVariable df, int axis) { - SDVariable ret = f().repeat(df, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise replace where condition:
- * out[i] = from[i] if condition(update[i]) is satisfied, or
- * out[i] = update[i] if condition(update[i]) is NOT satisfied - * - * @param update Source array - * @param from Replacement values array (used conditionally). Must be same shape as 'update' array - * @param condition Condition to check on update array elements - * @return New array with values replaced where condition is satisfied - */ - public SDVariable replaceWhere(SDVariable update, SDVariable from, Condition condition) { - return replaceWhere(null, update, from, condition); - } - - /** - * Element-wise replace where condition:
- * out[i] = from[i] if condition(update[i]) is satisfied, or
- * out[i] = update[i] if condition(update[i]) is NOT satisfied - * - * @param name Name of the output variable - * @param update Source array - * @param from Replacement values array (used conditionally). Must be same shape as 'update' array - * @param condition Condition to check on update array elements - * @return New array with values replaced where condition is satisfied - */ - public SDVariable replaceWhere(String name, SDVariable update, SDVariable from, Condition condition) { - SDVariable ret = f().replaceWhere(update, from, condition); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise replace where condition:
- * out[i] = value if condition(update[i]) is satisfied, or
- * out[i] = update[i] if condition(update[i]) is NOT satisfied - * - * @param update Source array - * @param value Value to set at the output, if the condition is satisfied - * @param condition Condition to check on update array elements - * @return New array with values replaced where condition is satisfied - */ - public SDVariable replaceWhere(SDVariable update, Number value, Condition condition) { - return replaceWhere(null, update, value, condition); - } - - /** - * Element-wise replace where condition:
- * out[i] = value if condition(update[i]) is satisfied, or
- * out[i] = update[i] if condition(update[i]) is NOT satisfied - * - * @param name Name of the output variable - * @param update Source array - * @param value Value to set at the output, if the condition is satisfied - * @param condition Condition to check on update array elements - * @return New array with values replaced where condition is satisfied - */ - public SDVariable replaceWhere(String name, SDVariable update, Number value, Condition condition) { - SDVariable ret = f().replaceWhere(update, value, condition); - return updateVariableNameAndReference(ret, name); - } - - /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, SDVariable) - */ - public SDVariable reshape(SDVariable x, long... shape) { - return reshape(null, x, shape); - } - - /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param name Output variable name - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, SDVariable) - */ - public SDVariable reshape(String name, SDVariable x, long... shape) { - SDVariable result = f().reshape(x, shape); - return updateVariableNameAndReference(result, name); - } - - /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, SDVariable) - */ - public SDVariable reshape(SDVariable x, int... shape) { - return reshape(null, x, shape); - } - - /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param name Output variable name - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, SDVariable) - */ - public SDVariable reshape(String name, SDVariable x, int... shape) { - SDVariable result = f().reshape(x, shape); - return updateVariableNameAndReference(result, name); - } - - /** - * Reshape the input variable to the specified (dynamic) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, int[]) - */ - public SDVariable reshape(SDVariable x, SDVariable shape) { - return reshape(null, x, shape); - } - - /** - * Reshape the input variable to the specified (dynamic) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param name Output variable name - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, int[]) - */ - public SDVariable reshape(String name, SDVariable x, SDVariable shape) { - validateInteger("reshape", "shape", shape); - SDVariable result = f().reshape(x, shape); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #reverse(String, SDVariable, int...) - */ - public SDVariable reverse(SDVariable x, int... dimensions) { - return reverse(null, x, dimensions); - } - - /** - * Reverse the values of an array for the specified dimensions
- * If input is:
- * [ 1, 2, 3]
- * [ 4, 5, 6]
- * then
- * reverse(in, 0):
- * [3, 2, 1]
- * [6, 5, 4]
- * reverse(in, 0):
- * [4, 5, 6]
- * [1, 2 3]
- * - * @param x Input variable - * @param dimensions Dimensions - * @return Output variable - */ - public SDVariable reverse(String name, SDVariable x, int... dimensions) { - SDVariable ret = f().reverse(x, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #reverseSequence(String, SDVariable, SDVariable, int, int) - */ - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) { - return reverseSequence(null, x, seq_lengths, seqDim, batchDim); - } - - /** - * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed - * - * @param name Name of the output variable - * @param x Input variable - * @param seq_lengths Length of the sequences - * @param seqDim Sequence dimension - * @param batchDim Batch dimension - * @return Reversed sequences - */ - public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) { - SDVariable ret = f().reverseSequence(x, seq_lengths, seqDim, batchDim); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #reverseSequence(String, SDVariable, SDVariable, int, int) - */ - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) { - return reverseSequence(null, x, seq_lengths); - } - - /** - * @see #reverseSequence(String, SDVariable, SDVariable, int, int) - */ - public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths) { - SDVariable ret = f().reverseSequence(x, seq_lengths); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise scalar floor modulus operation: out = floorMod(in, value). - * i.e., returns the remainder after division by 'value' - * - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarFloorMod(SDVariable in, Number value) { - return scalarFloorMod(null, in, value); - } - - /** - * Element-wise scalar floor modulus operation: out = floorMod(in, value). - * i.e., returns the remainder after division by 'value' - * - * @param name Name of the output variable - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarFloorMod(String name, SDVariable in, Number value) { - validateNumerical("floorMod", in); - SDVariable ret = f().scalarFloorMod(in, value); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise scalar maximum operation: out = max(in, value) - * - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarMax(SDVariable in, Number value) { - return scalarMax(null, in, value); - } - - /** - * Element-wise scalar maximum operation: out = max(in, value) - * - * @param name Name of the output variable - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarMax(String name, SDVariable in, Number value) { - validateNumerical("max", in); - SDVariable ret = f().scalarMax(in, value); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise scalar minimum operation: out = min(in, value) - * - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarMin(SDVariable in, Number value) { - return scalarMin(null, in, value); - } - - /** - * Element-wise scalar minimum operation: out = min(in, value) - * - * @param name Name of the output variable - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarMin(String name, SDVariable in, Number value) { - validateNumerical("min", in); - SDVariable ret = f().scalarMin(in, value); - return updateVariableNameAndReference(ret, name); - } - - /** - * Return an array with equal shape to the input, but all elements set to value 'set' - * - * @param in Input variable - * @param set Value to set - * @return Output variable - */ - public SDVariable scalarSet(SDVariable in, Number set) { - return scalarSet(null, in, set); - } - - /** - * Return a variable with equal shape to the input, but all elements set to value 'set' - * - * @param name Name of the output variable - * @param in Input variable - * @param set Value to set - * @return Output variable - */ - public SDVariable scalarSet(String name, SDVariable in, Number set) { - SDVariable ret = f().scalarSet(in, set); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterAdd(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterAdd(null, ref, indices, updates); - } - - /** - * Scatter addition operation.
- * If indices is rank 0 (a scalar), then out[index, ...] += updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] += updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] += updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterAdd", "indices", indices); - SDVariable ret = f().scatterAdd(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterDiv(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterDiv(null, ref, indices, updates); - } - - /** - * Scatter division operation.
- * If indices is rank 0 (a scalar), then out[index, ...] /= updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] /= updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] /= updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterDiv", "indices", indices); - SDVariable ret = f().scatterDiv(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterMax(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterMax(null, ref, indices, updates); - } - - /** - * Scatter max operation.
- * If indices is rank 0 (a scalar), then out[index, ...] = max(updates[...], in[index,...])
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = max(updates[i,...], in[indices[i],...])
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = max(updates[i, ..., k, ...], in[indices[i], ..., indices[k], ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterMax", "indices", indices); - SDVariable ret = f().scatterMax(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterMin(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterMin(null, ref, indices, updates); - } - - /** - * Scatter min operation.
- * If indices is rank 0 (a scalar), then out[index, ...] = min(updates[...], in[index,...])
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = min(updates[i,...], in[indices[i],...])
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = min(updates[i, ..., k, ...], in[indices[i], ..., indices[k], ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterMin", "indices", indices); - SDVariable ret = f().scatterMin(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterMul(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterMul(null, ref, indices, updates); - } - - /** - * Scatter multiplication operation.
- * If indices is rank 0 (a scalar), then out[index, ...] *= updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] *= updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] *= updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterMul", "indices", indices); - SDVariable ret = f().scatterMul(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterSub(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterSub(null, ref, indices, updates); - } - - /** - * Scatter subtraction operation.
- * If indices is rank 0 (a scalar), then out[index, ...] -= updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] -= updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] -= updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterSub", "indices", indices); - SDVariable ret = f().scatterSub(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterUpdate(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterUpdate(null, ref, indices, updates); - } - - /** - * Scatter update operation.
- * If indices is rank 0 (a scalar), then out[index, ...] = updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the output at those locations is undefined - different - * updates may occur in different orders - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterUpdate", "indices", indices); - SDVariable ret = f().scatterUpdate(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentMax(String, SDVariable, SDVariable) - */ - public SDVariable segmentMax(SDVariable data, SDVariable segmentIds) { - return segmentMax(null, data, segmentIds); - } - - /** - * Segment max operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentMax(String, SDVariable, SDVariable, int)} - * for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment max output - */ - public SDVariable segmentMax(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentMax", "data", data); - validateInteger("segmentMax", "segmentIds", segmentIds); - SDVariable ret = f().segmentMax(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentMean(String, SDVariable, SDVariable) - */ - public SDVariable segmentMean(SDVariable data, SDVariable segmentIds) { - return segmentMean(null, data, segmentIds); - } - - /** - * Segment mean operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentMean(String, SDVariable, SDVariable, int)} for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment mean output - */ - public SDVariable segmentMean(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentMean", "data", data); - validateInteger("segmentMean", "segmentIds", segmentIds); - SDVariable ret = f().segmentMean(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentMin(String, SDVariable, SDVariable) - */ - public SDVariable segmentMin(SDVariable data, SDVariable segmentIds) { - return segmentMin(null, data, segmentIds); - } - - /** - * Segment min operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentMin(String, SDVariable, SDVariable, int)} for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment min output - */ - public SDVariable segmentMin(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentMin", "data", data); - validateInteger("segmentMin", "segmentIds", segmentIds); - SDVariable ret = f().segmentMin(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentProd(String, SDVariable, SDVariable) - */ - public SDVariable segmentProd(SDVariable data, SDVariable segmentIds) { - return segmentProd(null, data, segmentIds); - } - - /** - * Segment product operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [18, 36, 16] = [prod(3,6), prod(1,4,9), prod(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentProd(String, SDVariable, SDVariable, int)} for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment product output - */ - public SDVariable segmentProd(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentProd", "data", data); - validateInteger("segmentProd", "segmentIds", segmentIds); - SDVariable ret = f().segmentProd(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentSum(String, SDVariable, SDVariable) - */ - public SDVariable segmentSum(SDVariable data, SDVariable segmentIds) { - return segmentSum(null, data, segmentIds); - } - - /** - * Segment sum operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentSum(String, SDVariable, SDVariable, int)} for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment sum output - */ - public SDVariable segmentSum(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentSum", "data", data); - validateInteger("segmentSum", "segmentIds", segmentIds); - SDVariable ret = f().segmentSum(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) { - return sequenceMask(null, lengths, maxLen, dataType); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen, DataType dataType) { - validateInteger("sequenceMask", "lengths", lengths); - SDVariable ret = f().sequenceMask(lengths, maxLen, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(String name, SDVariable lengths, DataType dataType) { - SDVariable ret = f().sequenceMask(lengths, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(SDVariable lengths, DataType dataType) { - return sequenceMask(lengths, null, dataType); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) { - return sequenceMask(null, lengths, maxLen, dataType); - } - - /** - * Generate a sequence mask (with values 0 or 1) based on the specified lengths
- * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0) - * - * @param name Name of the output variable - * @param lengths Lengths of the sequences - * @param maxLen Maximum sequence length - * @return Output variable - */ - public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen, DataType dataType) { - validateInteger("sequenceMask", "lengths", lengths); - SDVariable ret = f().sequenceMask(lengths, maxLen, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * Returns the shape of the specified SDVariable as a 1D SDVariable - * - * @param input Input variable - * @return 1D output variable with contents equal to the shape of the input - */ - public SDVariable shape(SDVariable input) { - return shape(null, input); - } - - /** - * Returns the shape of the specified SDVariable as a 1D SDVariable - * - * @param name Name of the output variable - * @param input Input variable - * @return 1D output variable with contents equal to the shape of the input - */ - public SDVariable shape(String name, SDVariable input) { - SDVariable ret = f().shape(input); - return updateVariableNameAndReference(ret, name); - } - - /** - * Returns the size (number of elements, i.e., prod(shape)) of the specified SDVariable as a 0D scalar variable - * - * @param in Input variable - * @return 0D (scalar) output variable with value equal to the number of elements in the specified array - */ - public SDVariable size(SDVariable in) { - return size(null, in); - } - - /** - * Returns the size (number of elements, i.e., prod(shape)) of the specified SDVariable as a 0D scalar variable - * - * @param name Name of the output variable - * @param in Input variable - * @return 0D (scalar) output variable with value equal to the number of elements in the specified array - */ - public SDVariable size(String name, SDVariable in) { - SDVariable ret = f().size(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #sizeAt(String, SDVariable, int) - */ - public SDVariable sizeAt(SDVariable in, int dimension) { - return sizeAt(null, in, dimension); - } - - /** - * Returns a rank 0 (scalar) variable for the size of the specified dimension. - * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30 - * - * @param name Name of the output variable - * @param in Input variable - * @param dimension Dimension to get size of - * @return Scalar SDVariable for size at specified variable - */ - public SDVariable sizeAt(String name, SDVariable in, int dimension) { - SDVariable ret = f().sizeAt(in, dimension); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #slice(String, SDVariable, int[], int[]) - */ - public SDVariable slice(SDVariable input, int[] begin, int[] size) { - return slice(null, input, begin, size); - } - - public SDVariable slice(SDVariable input, SDVariable begin, SDVariable size) { - return slice(null, input, begin, size); - } - - /** - * Get a subset of the specified input, by specifying the first element and the size of the array.
- * For example, if input is:
- * [a, b, c]
- * [d, e, f]
- * then slice(input, begin=[0,1], size=[2,1] will return:
- * [b]
- * [e]
- *
- * Note that for each dimension i, begin[i] + size[i] <= input.size(i) - * - * @param name Output variable name - * @param input Variable to get subset of - * @param begin Beginning index. Must be same length as rank of input array - * @param size Size of the output array. Must be same length as rank of input array - * @return Subset of the input - */ - public SDVariable slice(String name, SDVariable input, int[] begin, int[] size) { - SDVariable ret = f().slice(input, begin, size); - return updateVariableNameAndReference(ret, name); - } - - public SDVariable slice(String name, SDVariable input, @NonNull SDVariable begin, @NonNull SDVariable size) { - SDVariable ret = f().slice(input, begin, size); - return updateVariableNameAndReference(ret, name); - } - - - - /** - * Squared L2 norm: see {@link #norm2(String, SDVariable, int...)} - */ - public SDVariable squaredNorm(SDVariable x, int... dimensions) { - return squaredNorm(null, x, false, dimensions); - } - - /** - * Squared L2 norm: see {@link #norm2(String, SDVariable, boolean, int...)} - */ - public SDVariable squaredNorm(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("squaredNorm", x); - SDVariable result = f().squaredNorm(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Squared L2 norm: see {@link #norm2(String, SDVariable, int...)} - */ - public SDVariable squaredNorm(String name, SDVariable x, int... dimensions) { - return squaredNorm(name, x, false, dimensions); - } - - /** - * Squared L2 norm: see {@link #norm2(String, SDVariable, boolean, int...)} - */ - public SDVariable squaredNorm(SDVariable x, boolean keepDims, int... dimensions) { - return squaredNorm(null, x, keepDims, dimensions); - } - - /** - * @see #squeeze(String, SDVariable, int) - */ - public SDVariable squeeze(SDVariable x, int axis) { - return squeeze(null, x, axis); - } - - /** - * Remove a single dimension of size 1. - * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c] - * - * @param name Name of the output variable - * @param x Input variable - * @param axis Size 1 dimension to remove - * @return Output variable - */ - public SDVariable squeeze(String name, SDVariable x, int axis) { - SDVariable result = f().squeeze(x, axis); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #stack(String, int, SDVariable...) - */ - public SDVariable stack(int axis, SDVariable... values) { - return stack(null, axis, values); - } - - /** - * Stack a set of N SDVariables of rank X into one rank X+1 variable. - * If inputs have shape [a,b,c] then output has shape:
- * axis = 0: [N,a,b,c]
- * axis = 1: [a,N,b,c]
- * axis = 2: [a,b,N,c]
- * axis = 3: [a,b,c,N]
- * - * @param name Name of the output variable - * @param axis Axis to stack on - * @param values Input variables to stack. Must have the same shape for all inputs - * @return Output variable - * @see #unstack(String[], SDVariable, int, int) - */ - public SDVariable stack(String name, int axis, SDVariable... values) { - validateSameType("stack", false, values); - SDVariable ret = f().stack(values, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #standardDeviation(String, SDVariable, boolean, int...) - */ - public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, int... dimensions) { - return standardDeviation(null, x, biasCorrected, dimensions); - } - - /** - * Stardard deviation array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, int... dimensions) { - return standardDeviation(name, x, biasCorrected, false, dimensions); - } - - /** - * Stardard deviation array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param x Input variable - * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) { - validateNumerical("standard deviation", x); - SDVariable result = f().std(x, biasCorrected, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[]) - */ - public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides) { - return stridedSlice(null, input, begin, end, strides); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[]) - */ - public SDVariable stridedSlice(String name, SDVariable input, int[] begin, int[] end, int[] strides) { - return stridedSlice(name, input, begin, end, strides, 0, 0, 0, 0, 0); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int) - */ - public SDVariable stridedSlice(String name, SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - SDVariable ret = f().stridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[]) - */ - public SDVariable stridedSlice(SDVariable input, long[] begin, long[] end, long[] strides) { - return stridedSlice(null, input, begin, end, strides); - } - - /** - * Get a subset of the specified input, by specifying the first element, last element, and the strides.
- * For example, if input is:
- * [a, b, c]
- * [d, e, f]
- * [g, h, i]
- * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1]) will return:
- * [b, c]
- * [h, i]
- *
- * - * @param name Output variable name - * @param input Variable to get subset of - * @param begin Beginning index. Must be same length as rank of input array - * @param end End index. Must be same length as the rank of the array - * @param strides Stride ("step size") for each dimension. Must be same length as the rank of the array. For example, - * stride of 2 means take every second element. - * @return Subset of the input - */ - public SDVariable stridedSlice(String name, SDVariable input, long[] begin, long[] end, long[] strides) { - return stridedSlice(name, input, begin, end, strides, 0, 0, 0, 0, 0); - } - - /** - * Get a subset of the specified input, by specifying the first element, last element, and the strides.
- * Operates as described in {@link #stridedSlice(SDVariable, long[], long[], long[])} with some extra mask arrays - * as described below. - * - * @param name Output variable name - * @param in Variable to get subset of - * @param begin Beginning index - * @param end End index - * @param strides Stride ("step size") for each dimension. For example, - * stride of 2 means take every second element. - * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, - * and a value of 0 is used instead for the beginning index for that dimension - * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, - * and a value of size(i)-1 is used instead for the end index for that dimension - * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other - * dimensions are inserted as required at the specified position - * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and - * a size 1 dimension is inserted at this point - * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and - * a size 1 dimension is removed at this point. Note that begin/end/stride values must - * result in a size 1 output for these dimensions - * @return A subset of the input array - */ - public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - SDVariable ret = f().stridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int) - */ - public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return stridedSlice(null, in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int) - */ - public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return stridedSlice(null, in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - } - - /** - * Sum array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable sum(SDVariable x, int... dimensions) { - return sum(null, x, dimensions); - } - - /** - * Sum array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true - */ - public SDVariable sum(String name, SDVariable x, int... dimensions) { - return sum(name, x, false, dimensions); - } - - /** - * Sum array reduction operation, optionally along specified dimensions.
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true - */ - public SDVariable sum(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("sum reduction", x); - SDVariable result = f().sum(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #sum(String, SDVariable, boolean, int...) - */ - public SDVariable sum(SDVariable x, boolean keepDims, int... dimensions) { - return sum(null, x, keepDims, dimensions); - } - - /** - * @param x - * @param y - * @param dimensions - * @return - */ - public SDVariable tensorMmul(SDVariable x, - SDVariable y, - int[][] dimensions) { - return tensorMmul(null, x, y, dimensions); - } - - /** - * @param x Input variable x - * @param y Input variable y - * @param dimensions dimensions - * @return Output variable - */ - public SDVariable tensorMmul(String name, - SDVariable x, - SDVariable y, - int[][] dimensions) { - validateNumerical("tensorMmul", x, y); - SDVariable result = f().tensorMmul(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #tile(String, SDVariable, int...) - */ - public SDVariable tile(SDVariable x, int... repeat) { - return tile(null, x, repeat); - } - - /** - * Repeat (tile) the input tensor the specified number of times.
- * For example, if input is
- * [1, 2]
- * [3, 4]
- * and repeat is [2, 3]
- * then output is
- * [1, 2, 1, 2, 1, 2]
- * [3, 4, 3, 4, 3, 4]
- * [1, 2, 1, 2, 1, 2]
- * [3, 4, 3, 4, 3, 4]
- *
- * - * @param name Output variable name - * @param x Input variable - * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array - * @return Output variable - */ - public SDVariable tile(String name, SDVariable x, int... repeat) { - SDVariable result = f().tile(x, repeat); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #tile(String, SDVariable, int...) - */ - public SDVariable tile(SDVariable x, SDVariable repeat) { - return tile(null, x, repeat); - } - - /** - * @see #tile(String, SDVariable, int...) - */ - public SDVariable tile(String name, SDVariable x, SDVariable repeat) { - SDVariable result = f().tile(x, repeat); - return updateVariableNameAndReference(result, name); - } - /** - * Matrix transpose operation: If input has shape [a,b] output has shape [b,a] - * - * @param x Input variable - * @return Output variable (transposed input) - */ - public SDVariable transpose(SDVariable x) { - return transpose(null, x); - } - - /** - * Matrix transpose operation: If input has shape [a,b] output has shape [b,a] - * - * @param name Output variable name - * @param x Input variable - * @return Output variable (transposed input) - */ - public SDVariable transpose(String name, SDVariable x) { - SDVariable result = f().transpose(x); - return updateVariableNameAndReference(result, name); - } - - /** - * See {@link #unsortedSegmentMax(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentMax(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentMax(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment max operation. As per {@link #segmentMax(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment max on - * @param segmentIds Variable for the segment IDs - * @param numSegments Number of segments - * @return Unsorted segment max output - */ - public SDVariable unsortedSegmentMax(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentMax", "data", data); - validateInteger("unsortedSegmentMax", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentMax(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentMean(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentMean(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentMean(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment mean operation. As per {@link #segmentMean(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment mean on - * @param segmentIds Variable for the segment IDs - * @param numSegments Number of segments - * @return Unsorted segment mean output - */ - public SDVariable unsortedSegmentMean(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentMean", "data", data); - validateInteger("unsortedSegmentMean", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentMean(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentMin(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentMin(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentMin(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment min operation. As per {@link #segmentMin(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment min on - * @param segmentIds Variable for the segment IDs - * @param numSegments Number of segments - * @return Unsorted segment min output - */ - public SDVariable unsortedSegmentMin(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentMin", "data", data); - validateInteger("unsortedSegmentMin", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentMin(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentProd(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentProd(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentProd(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment product operation. As per {@link #segmentProd(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment product on - * @param segmentIds Variable for the segment IDs - * @return Unsorted segment product output - */ - public SDVariable unsortedSegmentProd(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentProd", "data", data); - validateInteger("unsortedSegmentProd", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentProd(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentSqrtN(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentSqrtN(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentSqrtN(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment sqrtN on - * @param segmentIds Variable for the segment IDs - * @return Unsorted segment sqrtN output - */ - public SDVariable unsortedSegmentSqrtN(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - SDVariable ret = f().unsortedSegmentSqrtN(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentSum(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentSum(@NonNull SDVariable data, @NonNull SDVariable segmentIds, int numSegments) { - return unsortedSegmentSum(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment sum operation. As per {@link #segmentSum(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment sum on - * @param segmentIds Variable for the segment IDs - * @param numSegments Number of segments - * @return Unsorted segment sum output - */ - public SDVariable unsortedSegmentSum(String name, @NonNull SDVariable data, @NonNull SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentSum", "data", data); - validateInteger("unsortedSegmentSum", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentSum(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #unstack(String[], SDVariable, int, int) - */ - public SDVariable[] unstack(SDVariable value, int axis) { - return unstack(null, value, axis); - } - - /** - * @see #unstack(String[], SDVariable, int, int) - */ - public SDVariable[] unstack(String[] names, @NonNull SDVariable value, int axis) { - SDVariable[] ret = f().unstack(value, axis); - return updateVariableNamesAndReferences(ret, names); - } - - /** - * @see #unstack(String[], SDVariable, int, int) - */ - public SDVariable[] unstack(@NonNull SDVariable value, int axis, int num) { - return unstack(null, value, axis, num); - } - - /** - * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis. - * If input has shape [a,b,c] then output has shape: - * axis = 0: [b,c]
- * axis = 1: [a,c]
- * axis = 2: [a,b]
- * - * @param names Output variable names. May be null - * @param value Input variable to unstack - * @param axis Axis to unstack on - * @param num Number of output variables - * @return Output variables - * @see #stack(String, int, SDVariable...) - */ - public SDVariable[] unstack(String[] names, @NonNull SDVariable value, int axis, int num) { - SDVariable[] ret = f().unstack(value, axis, num); - return updateVariableNamesAndReferences(ret, names); - } - - /** - * @see #variance(String, SDVariable, boolean, int...) - */ - public SDVariable variance(@NonNull SDVariable x, boolean biasCorrected, int... dimensions) { - return variance(null, x, biasCorrected, dimensions); - } - - /** - * Variance array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable variance(String name, @NonNull SDVariable x, boolean biasCorrected, int... dimensions) { - return variance(name, x, biasCorrected, false, dimensions); - } - - /** - * Variance array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable variance(String name, @NonNull SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) { - validateNumerical("variance", x); - SDVariable result = f().variance(x, biasCorrected, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic: - * if the input shape changes in later execution, the returned variable's shape will also be updated - * - * @param input Input SDVariable - * @return A new SDVariable with the same (dynamic) shape as the input - */ - public SDVariable zerosLike(@NonNull SDVariable input) { - return zerosLike(null, input); - } - - public SDVariable zerosLike(@NonNull SDVariable input, @NonNull DataType dataType) { - return zerosLike(null, input, dataType); - } - /** - * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic: - * if the input shape changes in later execution, the returned variable's shape will also be updated - * - * @param name Name of the new SDVariable - * @param input Input SDVariable - * @return A new SDVariable with the same (dynamic) shape as the input - */ - public SDVariable zerosLike(String name, @NonNull SDVariable input) { - SDVariable ret = f().zerosLike(name, input); - return updateVariableNameAndReference(ret, name); - } - - public SDVariable zerosLike(String name, @NonNull SDVariable input, @NonNull DataType dataType) { - SDVariable ret = f().zerosLike(name, input, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #any(String, SDVariable, int...)} - */ - public SDVariable any(SDVariable x, int... dimensions){ - return any(null, x, dimensions); - } - //TODO check any w/ no dimensions - - /** - * Boolean or array reduction operation, optionally along specified dimensions - * - * @param name Name of the output variable - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable any(String name, SDVariable x, int... dimensions){ - validateBool("any", x); - SDVariable ret = f().any(x, dimensions); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #all(String, SDVariable, int...)} - */ - public SDVariable all(SDVariable x, int... dimensions){ - return all(null, x, dimensions); - } - - - /** - * Boolean and array reduction operation, optionally along specified dimensions - * - * @param name Name of the output variable - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable all(String name, SDVariable x, int... dimensions){ - validateBool("all", x); - SDVariable ret = f().all(x, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} - */ - public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars, - @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ - return whileLoop(null, null, loopVars, cond, body); - } - - /** - * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} - */ - public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars, - @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ - return whileLoop(null, loopName, loopVars, cond, body); - } - - /** - * Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration) - * - * Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false - * - * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations. - * - * See Tensorflow Control Flow Implementation - * - * @param outputNames Names to give the output variables. If null, doesn't rename - * @param loopName The name of the loop block and frame (must be unique). If null, uses "if" - * @param loopVars Loop variables' inputs - * @param cond A lambda evaluating to the loop condition - * @param body A lambda doing the loop operation and returning the new loop variable values - * @return The values of the loop variables once condition is false - */ - public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars, - @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ - - final String frameName = sd().newBlockName(loopName == null ? "while" : loopName); - - NameScope loopScope = sd().withNameScope(frameName); - - //SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0); - - SDVariable[] entered = new SDVariable[loopVars.length]; - for(int i = 0 ; i < loopVars.length ; i++){ - entered[i] = f().enter(loopVars[i], frameName); - } - - //counter = SD.f().enter(counter, frameName); - - SDVariable[] merged = new SDVariable[loopVars.length]; - Merge[] mergeOps = new Merge[loopVars.length]; - for(int i = 0 ; i < loopVars.length ; i++){ - // the second arg will later be replaced with the output of NextIteration - // but that isn't available yet (and can't be, as it depends on this) - mergeOps[i] = new Merge(sd(), entered[i], entered[i]); - merged[i] = mergeOps[i].outputVariable(); - } - - //Merge counterMerge = new Merge(SD, counter, counter); - //counter = counterMerge.outputVariable(); - - NameScope condScope = sd().withNameScope("cond"); - SDVariable cond_result = cond.define(sd(), merged); - condScope.close(); - - - if (cond_result.dataType() != DataType.BOOL) - throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean."); - - - final Set alreadyEntered = Sets.newHashSet(); - SDVariable[] trueSwitches = new SDVariable[loopVars.length]; - SDVariable[] exits = new SDVariable[loopVars.length]; - for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable[] s = f().switchOp(merged[i], cond_result); - trueSwitches[i] = s[1]; - alreadyEntered.add(s[1].name()); - exits[i] = f().exit(s[0]); - } - - //SDVariable[] cs = SD.f().switchOp(counter, cond_result); - //SDVariable counterExit = SD.f().exit(cs[0]); - //counter = cs[1]; - - final Set declared = Sets.newHashSet(sd().variableMap().keySet()); - final Map done = new HashMap<>(); - - sd().addArgumentInterceptor(new ArgumentInterceptor() { - @Override - public SDVariable intercept(SDVariable argument) { - - if(!declared.contains(argument.name())) - return argument; - - if(alreadyEntered.contains(argument.name())) - return argument; - - if(done.containsKey(argument.name())) - return done.get(argument.name()); - - SDVariable e = f().enter(argument, frameName, true); - done.put(argument.name(), e); - return e; - } - }); - - NameScope bodyScope = sd().withNameScope("body"); - SDVariable[] outs = body.define(sd(), trueSwitches); - bodyScope.close(); - sd().removeArgumentInterceptor(); - - //counter.add(1); - - for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable n = f().nextIteration(outs[i]); - mergeOps[i].replaceArg(1,n); - } - - //counterMerge.replaceArg(1, counter); - - loopScope.close(); - return updateVariableNamesAndReferences(exits, outputNames); - } - - /** - * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} - */ - public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond, - @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ - return ifCond(null, null, cond, trueBody, falseBody); - } - - - /** - * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} - */ - public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond, - @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ - return ifCond(null, ifName, cond, trueBody, falseBody); - } - - /** - * Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) - * - * If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody - * - * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate. - * - * See Tensorflow Control Flow Implementation - * - * @param outputName Name to give the output variable. If null, doesn't rename - * @param ifName The name of the if block. If null, uses "if" - * @param cond A lambda evaluating to the if condition - * @param trueBody A lambda to be executed if cond is true (the if block) - * @param falseBody A lambda to be executed if cond is false (the else block) - * @return The value of trueBody if cond is true, or falseBody if it isn't - */ - public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, - @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ - - ifName = sd().newBlockName(ifName == null ? "if" : ifName); - - NameScope ifScope = sd().withNameScope(ifName); - - NameScope condScope = sd().withNameScope("cond"); - final SDVariable pred = cond.define(sd()); - condScope.close(); - - if (pred.dataType() != DataType.BOOL) { - //cleanup partially added block - - for(SDVariable v : sd().getVariablesInScope(ifScope)) - sd().getVariables().remove(v.name()); - - for(SameDiffOp op : sd().getOpsInScope(ifScope)) { - for(String in : op.getInputsToOp()){ - sd().removeArgFromOp(in, op.getOp()); - } - sd().getOps().remove(op.getName()); - } - - - throw new IllegalStateException("Can not use " + pred.name() - + " as the condition of an If statement, the condition must be a boolean."); - } - - final Map switches = new HashMap<>(); - - final Set declared = Sets.newHashSet(sd().variableMap().keySet()); - - sd().addArgumentInterceptor(new ArgumentInterceptor() { - @Override - public SDVariable intercept(SDVariable argument) { - - // if its declared in the if, we don't care acout it - if(!declared.contains(argument.name())) - return argument; - - // if we've already added a switch, move on - if(switches.containsKey(argument.name())) - return switches.get(argument.name())[1]; - - SDVariable[] s = f().switchOp(argument, pred); - switches.put(argument.name(), s); - return s[1]; - } - }); - NameScope trueScope = sd().withNameScope("trueBody"); - SDVariable trueOut = trueBody.define(sd()); - sd().removeArgumentInterceptor(); - - if(declared.contains(trueOut.name())) { - SDVariable[] s = f().switchOp(trueOut, pred); - switches.put(trueOut.name(), s); - trueOut = s[1]; - } - - trueScope.close(); - - final Set declared2 = Sets.newHashSet(sd().variableMap().keySet()); - sd().addArgumentInterceptor(new ArgumentInterceptor() { - @Override - public SDVariable intercept(SDVariable argument) { - - // if its declared in the if, we don't care acout it - if(!declared2.contains(argument.name())) - return argument; - - // if we've already added a switch, move on - if(switches.containsKey(argument.name())) - return switches.get(argument.name())[0]; - - SDVariable[] s = f().switchOp(argument, pred); - switches.put(argument.name(), s); - return s[0]; - } - }); - NameScope falseScope = sd().withNameScope("falseBody"); - SDVariable falseOut = falseBody.define(sd()); - sd().removeArgumentInterceptor(); - - if(declared2.contains(falseOut.name())) { - SDVariable[] s = f().switchOp(falseOut, pred); - switches.put(falseOut.name(), s); - falseOut = s[0]; - } - falseScope.close(); - - SDVariable output = f().merge(trueOut, falseOut); - - ifScope.close(); - - return updateVariableNameAndReference(output, outputName); - } +public class SDBaseOps { + protected SameDiff sd; + + public SDBaseOps(SameDiff sameDiff) { + this.sd = sameDiff; + } + + /** + * Boolean and array reduction operation, optionally along specified dimensions
+ * + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public SDVariable all(SDVariable x, int... dimensions) { + SDValidation.validateBool("all", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); + } + + /** + * Boolean and array reduction operation, optionally along specified dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public SDVariable all(String name, SDVariable x, int... dimensions) { + SDValidation.validateBool("all", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Boolean or array reduction operation, optionally along specified dimensions
+ * + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public SDVariable any(SDVariable x, int... dimensions) { + SDValidation.validateBool("any", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); + } + + /** + * Boolean or array reduction operation, optionally along specified dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public SDVariable any(String name, SDVariable x, int... dimensions) { + SDValidation.validateBool("any", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable(); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmax(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable(); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmax(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable(); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmin(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable(); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmin(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param transposeA Whether to transpose A arrays or not + * @param transposeB Whether to transpose B arrays or not + */ + public SDVariable[] batchMmul(SDVariable[] inputsA, SDVariable[] inputsB, boolean transposeA, + boolean transposeB) { + SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(sd,inputsA, inputsB, transposeA, transposeB).outputVariables(); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param transposeA Whether to transpose A arrays or not + * @param transposeB Whether to transpose B arrays or not + */ + public SDVariable[] batchMmul(String[] names, SDVariable[] inputsA, SDVariable[] inputsB, + boolean transposeA, boolean transposeB) { + SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(sd,inputsA, inputsB, transposeA, transposeB).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + */ + public SDVariable[] batchMmul(SDVariable[] inputsA, SDVariable... inputsB) { + SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(sd,inputsA, inputsB, false, false).outputVariables(); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + */ + public SDVariable[] batchMmul(String[] names, SDVariable[] inputsA, SDVariable... inputsB) { + SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(sd,inputsA, inputsB, false, false).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Cast the array to a new datatype - for example, Integer -> Float
+ * + * @param arg Input variable to cast (NDARRAY type) + * @param datatype Datatype to cast to + * @return output Output array (after casting) (NDARRAY type) + */ + public SDVariable castTo(SDVariable arg, DataType datatype) { + return new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(sd,arg, datatype).outputVariable(); + } + + /** + * Cast the array to a new datatype - for example, Integer -> Float
+ * + * @param name name May be null. Name for the output variable + * @param arg Input variable to cast (NDARRAY type) + * @param datatype Datatype to cast to + * @return output Output array (after casting) (NDARRAY type) + */ + public SDVariable castTo(String name, SDVariable arg, DataType datatype) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(sd,arg, datatype).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Concatenate a set of inputs along the specified dimension.
+ * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
+ * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c]
+ * + * Inputs must satisfy the following constraints:
+ * Input arrays must all be the same datatype: isSameType(inputs)
+ * + * @param inputs Input variables (NUMERIC type) + * @param dimension Dimension to concatenate on + * @return output (NUMERIC type) + */ + public SDVariable concat(int dimension, SDVariable... inputs) { + SDValidation.validateNumerical("concat", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); + return new org.nd4j.linalg.api.ops.impl.shape.Concat(sd,inputs, dimension).outputVariable(); + } + + /** + * Concatenate a set of inputs along the specified dimension.
+ * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
+ * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c]
+ * + * Inputs must satisfy the following constraints:
+ * Input arrays must all be the same datatype: isSameType(inputs)
+ * + * @param name name May be null. Name for the output variable + * @param dimension Dimension to concatenate on + * @param inputs Input variables (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable concat(String name, int dimension, SDVariable... inputs) { + SDValidation.validateNumerical("concat", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Concat(sd,inputs, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... axis) { + SDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, exclusive, reverse, axis).outputVariable(); + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, + int... axis) { + SDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, exclusive, reverse, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cumprod(SDVariable in, int... axis) { + SDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, false, false, axis).outputVariable(); + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cumprod(String name, SDVariable in, int... axis) { + SDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, false, false, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... axis) { + SDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, exclusive, reverse, axis).outputVariable(); + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, + int... axis) { + SDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, exclusive, reverse, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public SDVariable cumsum(SDVariable in, int... axis) { + SDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, false, false, axis).outputVariable(); + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public SDVariable cumsum(String name, SDVariable in, int... axis) { + SDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, false, false, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise dot product reduction along dimension
+ * output = sum(i=0 ... size(dim)-1) x[i] * y[i]
+ * + * @param x first input (NUMERIC type) + * @param y second input (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output output variable (NUMERIC type) + */ + public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("dot", "x", x); + SDValidation.validateNumerical("dot", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.Dot(sd,x, y, dimensions).outputVariable(); + } + + /** + * Pairwise dot product reduction along dimension
+ * output = sum(i=0 ... size(dim)-1) x[i] * y[i]
+ * + * @param name name May be null. Name for the output variable + * @param x first input (NUMERIC type) + * @param y second input (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output output variable (NUMERIC type) + */ + public SDVariable dot(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("dot", "x", x); + SDValidation.validateNumerical("dot", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.Dot(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Dynamically partition the input variable values into the specified number of paritions, using the indices.
+ * Example:
+ *

+ * input = [1,2,3,4,5]
+ * numPartitions = 2
+ * partitions = [1,0,0,1,0]
+ * out[0] = [2,3,5]
+ * out[1] = [1,4] }
+ *

+ * + * @param x Input variable (NUMERIC type) + * @param partitions 1D input with values 0 to numPartitions-1 (INT type) + * @param numPartitions Number of partitions, >= 1 + */ + public SDVariable[] dynamicPartition(SDVariable x, SDVariable partitions, int numPartitions) { + SDValidation.validateNumerical("dynamicPartition", "x", x); + SDValidation.validateInteger("dynamicPartition", "partitions", partitions); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(sd,x, partitions, numPartitions).outputVariables(); + } + + /** + * Dynamically partition the input variable values into the specified number of paritions, using the indices.
+ * Example:
+ *

+ * input = [1,2,3,4,5]
+ * numPartitions = 2
+ * partitions = [1,0,0,1,0]
+ * out[0] = [2,3,5]
+ * out[1] = [1,4] }
+ *

+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input variable (NUMERIC type) + * @param partitions 1D input with values 0 to numPartitions-1 (INT type) + * @param numPartitions Number of partitions, >= 1 + */ + public SDVariable[] dynamicPartition(String[] names, SDVariable x, SDVariable partitions, + int numPartitions) { + SDValidation.validateNumerical("dynamicPartition", "x", x); + SDValidation.validateInteger("dynamicPartition", "partitions", partitions); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(sd,x, partitions, numPartitions).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Dynamically merge the specified input arrays into a single array, using the specified indices
+ * + * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @param x Input variables. (NUMERIC type) + * @return output Merged output variable (NUMERIC type) + */ + public SDVariable dynamicStitch(SDVariable[] indices, SDVariable... x) { + SDValidation.validateInteger("dynamicStitch", "indices", indices); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + SDValidation.validateNumerical("dynamicStitch", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(sd,indices, x).outputVariable(); + } + + /** + * Dynamically merge the specified input arrays into a single array, using the specified indices
+ * + * @param name name May be null. Name for the output variable + * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @param x Input variables. (NUMERIC type) + * @return output Merged output variable (NUMERIC type) + */ + public SDVariable dynamicStitch(String name, SDVariable[] indices, SDVariable... x) { + SDValidation.validateInteger("dynamicStitch", "indices", indices); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + SDValidation.validateNumerical("dynamicStitch", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(sd,indices, x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Equals operation: elementwise x == y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable eq(SDVariable x, double y) { + SDValidation.validateNumerical("eq", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals(sd,x, y).outputVariable(); + } + + /** + * Equals operation: elementwise x == y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable eq(String name, SDVariable x, double y) { + SDValidation.validateNumerical("eq", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Equal to operation: elementwise x == y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable eq(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("eq", "x", x); + SDValidation.validateNumerical("eq", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(sd,x, y).outputVariable(); + } + + /** + * Equal to operation: elementwise x == y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable eq(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("eq", "x", x); + SDValidation.validateNumerical("eq", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reshape the input by adding a 1 at the specified location.
+ * For example, if input has shape [a, b], then output shape is:
+ * axis = 0: [1, a, b]
+ * axis = 1: [a, 1, b]
+ * axis = 2: [a, b, 1]
+ * + * @param x Input variable (NDARRAY type) + * @param axis Axis to expand + * @return output Output variable (NUMERIC type) + */ + public SDVariable expandDims(SDVariable x, int axis) { + return new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(sd,x, axis).outputVariable(); + } + + /** + * Reshape the input by adding a 1 at the specified location.
+ * For example, if input has shape [a, b], then output shape is:
+ * axis = 0: [1, a, b]
+ * axis = 1: [a, 1, b]
+ * axis = 2: [a, b, 1]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NDARRAY type) + * @param axis Axis to expand + * @return output Output variable (NUMERIC type) + */ + public SDVariable expandDims(String name, SDVariable x, int axis) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(sd,x, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
+ * + * @param shape Shape: must be a 1D array/variable (INT type) + * @param dataType Datatype of the output array + * @param value Value to set all elements to + * @return output Output variable (NUMERIC type) + */ + public SDVariable fill(SDVariable shape, DataType dataType, double value) { + SDValidation.validateInteger("fill", "shape", shape); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(sd,shape, dataType, value).outputVariable(); + } + + /** + * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
+ * + * @param name name May be null. Name for the output variable + * @param shape Shape: must be a 1D array/variable (INT type) + * @param dataType Datatype of the output array + * @param value Value to set all elements to + * @return output Output variable (NUMERIC type) + */ + public SDVariable fill(String name, SDVariable shape, DataType dataType, double value) { + SDValidation.validateInteger("fill", "shape", shape); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(sd,shape, dataType, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Gather slices from the input variable where the indices are specified as fixed int[] values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param df Input variable (NUMERIC type) + * @param indices Indices to get (Size: AtLeast(min=1)) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public SDVariable gather(SDVariable df, int[] indices, int axis) { + SDValidation.validateNumerical("gather", "df", df); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + return new org.nd4j.linalg.api.ops.impl.shape.Gather(sd,df, indices, axis).outputVariable(); + } + + /** + * Gather slices from the input variable where the indices are specified as fixed int[] values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param name name May be null. Name for the output variable + * @param df Input variable (NUMERIC type) + * @param indices Indices to get (Size: AtLeast(min=1)) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public SDVariable gather(String name, SDVariable df, int[] indices, int axis) { + SDValidation.validateNumerical("gather", "df", df); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Gather(sd,df, indices, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Gather slices from the input variable where the indices are specified as dynamic array values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param df Input variable (NUMERIC type) + * @param indices Indices to get slices for. Rank 0 or 1 input (INT type) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public SDVariable gather(SDVariable df, SDVariable indices, int axis) { + SDValidation.validateNumerical("gather", "df", df); + SDValidation.validateInteger("gather", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.Gather(sd,df, indices, axis).outputVariable(); + } + + /** + * Gather slices from the input variable where the indices are specified as dynamic array values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param name name May be null. Name for the output variable + * @param df Input variable (NUMERIC type) + * @param indices Indices to get slices for. Rank 0 or 1 input (INT type) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public SDVariable gather(String name, SDVariable df, SDVariable indices, int axis) { + SDValidation.validateNumerical("gather", "df", df); + SDValidation.validateInteger("gather", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Gather(sd,df, indices, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Gather slices from df with shape specified by indices.
+ * + * @param df (NUMERIC type) + * @param indices (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable gatherNd(SDVariable df, SDVariable indices) { + SDValidation.validateNumerical("gatherNd", "df", df); + SDValidation.validateNumerical("gatherNd", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.GatherNd(sd,df, indices).outputVariable(); + } + + /** + * Gather slices from df with shape specified by indices.
+ * + * @param name name May be null. Name for the output variable + * @param df (NUMERIC type) + * @param indices (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable gatherNd(String name, SDVariable df, SDVariable indices) { + SDValidation.validateNumerical("gatherNd", "df", df); + SDValidation.validateNumerical("gatherNd", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.GatherNd(sd,df, indices).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greater than operation: elementwise x > y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gt(SDVariable x, double y) { + SDValidation.validateNumerical("gt", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan(sd,x, y).outputVariable(); + } + + /** + * Greater than operation: elementwise x > y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gt(String name, SDVariable x, double y) { + SDValidation.validateNumerical("gt", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greater than operation: elementwise x > y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gt(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("gt", "x", x); + SDValidation.validateNumerical("gt", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(sd,x, y).outputVariable(); + } + + /** + * Greater than operation: elementwise x > y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gt(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("gt", "x", x); + SDValidation.validateNumerical("gt", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greater than or equals operation: elementwise x >= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gte(SDVariable x, double y) { + SDValidation.validateNumerical("gte", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(sd,x, y).outputVariable(); + } + + /** + * Greater than or equals operation: elementwise x >= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gte(String name, SDVariable x, double y) { + SDValidation.validateNumerical("gte", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greater than or equal to operation: elementwise x >= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable gte(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("gte", "x", x); + SDValidation.validateNumerical("gte", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(sd,x, y).outputVariable(); + } + + /** + * Greater than or equal to operation: elementwise x >= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable gte(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("gte", "x", x); + SDValidation.validateNumerical("gte", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise identity operation: out = x
+ * + * @param input Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable identity(SDVariable input) { + SDValidation.validateNumerical("identity", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(sd,input).outputVariable(); + } + + /** + * Elementwise identity operation: out = x
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable identity(String name, SDVariable input) { + SDValidation.validateNumerical("identity", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the inverse permutation indices for a permutation operation
+ * Example: if input is [2, 0, 1] then output is [1, 2, 0]
+ * The idea is that x.permute(input).permute(invertPermutation(input)) == x
+ * + * @param input 1D indices for permutation (INT type) + * @return output 1D inverted permutation (INT type) + */ + public SDVariable invertPermutation(SDVariable input) { + SDValidation.validateInteger("invertPermutation", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(sd,input).outputVariable(); + } + + /** + * Compute the inverse permutation indices for a permutation operation
+ * Example: if input is [2, 0, 1] then output is [1, 2, 0]
+ * The idea is that x.permute(input).permute(invertPermutation(input)) == x
+ * + * @param name name May be null. Name for the output variable + * @param input 1D indices for permutation (INT type) + * @return output 1D inverted permutation (INT type) + */ + public SDVariable invertPermutation(String name, SDVariable input) { + SDValidation.validateInteger("invertPermutation", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
+ * + * @param x Input variable (NUMERIC type) + * @return output scalar boolean with value true or false (NDARRAY type) + */ + public SDVariable isNumericTensor(SDVariable x) { + SDValidation.validateNumerical("isNumericTensor", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(sd,x).outputVariable(); + } + + /** + * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output scalar boolean with value true or false (NDARRAY type) + */ + public SDVariable isNumericTensor(String name, SDVariable x) { + SDValidation.validateNumerical("isNumericTensor", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param dataType Data type of the output array + * @param start Start value + * @param stop Stop value + * @param number Number of values to generate + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public SDVariable linspace(DataType dataType, double start, double stop, long number) { + return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,dataType, start, stop, number).outputVariable(); + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param name name May be null. Name for the output variable + * @param dataType Data type of the output array + * @param start Start value + * @param stop Stop value + * @param number Number of values to generate + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public SDVariable linspace(String name, DataType dataType, double start, double stop, + long number) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,dataType, start, stop, number).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param start Start value (NUMERIC type) + * @param stop Stop value (NUMERIC type) + * @param number Number of values to generate (LONG type) + * @param dataType Data type of the output array + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public SDVariable linspace(SDVariable start, SDVariable stop, SDVariable number, + DataType dataType) { + SDValidation.validateNumerical("linspace", "start", start); + SDValidation.validateNumerical("linspace", "stop", stop); + SDValidation.validateInteger("linspace", "number", number); + return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,start, stop, number, dataType).outputVariable(); + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param name name May be null. Name for the output variable + * @param start Start value (NUMERIC type) + * @param stop Stop value (NUMERIC type) + * @param number Number of values to generate (LONG type) + * @param dataType Data type of the output array + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public SDVariable linspace(String name, SDVariable start, SDVariable stop, SDVariable number, + DataType dataType) { + SDValidation.validateNumerical("linspace", "start", start); + SDValidation.validateNumerical("linspace", "stop", stop); + SDValidation.validateInteger("linspace", "number", number); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,start, stop, number, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Less than operation: elementwise x < y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lt(SDVariable x, double y) { + SDValidation.validateNumerical("lt", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan(sd,x, y).outputVariable(); + } + + /** + * Less than operation: elementwise x < y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lt(String name, SDVariable x, double y) { + SDValidation.validateNumerical("lt", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Less than operation: elementwise x < y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lt(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("lt", "x", x); + SDValidation.validateNumerical("lt", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(sd,x, y).outputVariable(); + } + + /** + * Less than operation: elementwise x < y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lt(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("lt", "x", x); + SDValidation.validateNumerical("lt", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Less than or equals operation: elementwise x <= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lte(SDVariable x, double y) { + SDValidation.validateNumerical("lte", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(sd,x, y).outputVariable(); + } + + /** + * Less than or equals operation: elementwise x <= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lte(String name, SDVariable x, double y) { + SDValidation.validateNumerical("lte", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Less than or equal to operation: elementwise x <= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lte(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("lte", "x", x); + SDValidation.validateNumerical("lte", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(sd,x, y).outputVariable(); + } + + /** + * Less than or equal to operation: elementwise x <= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lte(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("lte", "x", x); + SDValidation.validateNumerical("lte", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
+ * + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Boolean mask (NUMERIC type) + */ + public SDVariable matchCondition(SDVariable in, Condition condition) { + SDValidation.validateNumerical("matchCondition", "in", in); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(sd,in, condition).outputVariable(); + } + + /** + * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Boolean mask (NUMERIC type) + */ + public SDVariable matchCondition(String name, SDVariable in, Condition condition) { + SDValidation.validateNumerical("matchCondition", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(sd,in, condition).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a count of the number of elements that satisfy the condition
+ * + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(SDVariable in, Condition condition) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition).outputVariable(); + } + + /** + * Returns a count of the number of elements that satisfy the condition
+ * + * @param name name May be null. Name for the output variable + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(String name, SDVariable in, Condition condition) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(SDVariable in, Condition condition, boolean keepDim, + int... dimensions) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable(); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, + boolean keepDim, int... dimensions) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, false, dimensions).outputVariable(); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, + int... dimensions) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable max(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Max(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable max(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Max(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable max(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Max(sd,x, false, dimensions).outputVariable(); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable max(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Max(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise maximum operation: out[i] = max(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable max(SDVariable first, SDVariable second) { + SDValidation.validateNumerical("max", "first", first); + SDValidation.validateNumerical("max", "second", second); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,first, second).outputVariable(); + } + + /** + * Element-wise maximum operation: out[i] = max(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param name name May be null. Name for the output variable + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable max(String name, SDVariable first, SDVariable second) { + SDValidation.validateNumerical("max", "first", first); + SDValidation.validateNumerical("max", "second", second); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,first, second).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable mean(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable mean(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable mean(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(sd,x, false, dimensions).outputVariable(); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable mean(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable min(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Min(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable min(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Min(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable min(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Min(sd,x, false, dimensions).outputVariable(); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable min(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Min(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise minimum operation: out[i] = min(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Second input array (NUMERIC type) + */ + public SDVariable min(SDVariable first, SDVariable second) { + SDValidation.validateNumerical("min", "first", first); + SDValidation.validateNumerical("min", "second", second); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,first, second).outputVariable(); + } + + /** + * Element-wise minimum operation: out[i] = min(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param name name May be null. Name for the output variable + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Second input array (NUMERIC type) + */ + public SDVariable min(String name, SDVariable first, SDVariable second) { + SDValidation.validateNumerical("min", "first", first); + SDValidation.validateNumerical("min", "second", second); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,first, second).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public SDVariable mmul(SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public SDVariable mmul(String name, SDVariable x, SDVariable y, boolean transposeX, + boolean transposeY, boolean transposeZ) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable mmul(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable(); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable mmul(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Not equals operation: elementwise x != y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable neq(SDVariable x, double y) { + SDValidation.validateNumerical("neq", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals(sd,x, y).outputVariable(); + } + + /** + * Not equals operation: elementwise x != y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable neq(String name, SDVariable x, double y) { + SDValidation.validateNumerical("neq", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Not equal to operation: elementwise x != y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable neq(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("neq", "x", x); + SDValidation.validateNumerical("neq", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(sd,x, y).outputVariable(); + } + + /** + * Not equal to operation: elementwise x != y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable neq(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("neq", "x", x); + SDValidation.validateNumerical("neq", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm1(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm1(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm1(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(sd,x, false, dimensions).outputVariable(); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm1(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm2(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm2(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm2(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(sd,x, false, dimensions).outputVariable(); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm2(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable normmax(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable normmax(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable normmax(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(sd,x, false, dimensions).outputVariable(); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable normmax(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @param dataType Output data type + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off, + DataType dataType) { + SDValidation.validateNumerical("oneHot", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable(); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param name name May be null. Name for the output variable + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @param dataType Output data type + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, + double off, DataType dataType) { + SDValidation.validateNumerical("oneHot", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) { + SDValidation.validateNumerical("oneHot", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, DataType.FLOAT).outputVariable(); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param name name May be null. Name for the output variable + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, + double off) { + SDValidation.validateNumerical("oneHot", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, DataType.FLOAT).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convert the array to a one-hot array with walues 0 and 1 for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0
+ * see oneHot(SDVariable, int, int, double, double)
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(SDVariable indices, int depth) { + SDValidation.validateNumerical("oneHot", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth).outputVariable(); + } + + /** + * Convert the array to a one-hot array with walues 0 and 1 for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0
+ * see oneHot(SDVariable, int, int, double, double)
+ * + * @param name name May be null. Name for the output variable + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(String name, SDVariable indices, int depth) { + SDValidation.validateNumerical("oneHot", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param input Input INDArray (NUMERIC type) + * @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type) + */ + public SDVariable onesLike(SDVariable input) { + SDValidation.validateNumerical("onesLike", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.OnesLike(sd,input).outputVariable(); + } + + /** + * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param name name May be null. Name for the output variable + * @param input Input INDArray (NUMERIC type) + * @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type) + */ + public SDVariable onesLike(String name, SDVariable input) { + SDValidation.validateNumerical("onesLike", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OnesLike(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per onesLike(String, SDVariable) but the output datatype may be specified
+ * + * @param input (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public SDVariable onesLike(SDVariable input, DataType dataType) { + SDValidation.validateNumerical("onesLike", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.OnesLike(sd,input, dataType).outputVariable(); + } + + /** + * As per onesLike(String, SDVariable) but the output datatype may be specified
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public SDVariable onesLike(String name, SDVariable input, DataType dataType) { + SDValidation.validateNumerical("onesLike", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OnesLike(sd,input, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Permute dimensions (INT type) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public SDVariable permute(SDVariable x, SDVariable dimensions) { + SDValidation.validateNumerical("permute", "x", x); + SDValidation.validateInteger("permute", "dimensions", dimensions); + return new org.nd4j.linalg.api.ops.impl.shape.Permute(sd,x, dimensions).outputVariable(); + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Permute dimensions (INT type) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public SDVariable permute(String name, SDVariable x, SDVariable dimensions) { + SDValidation.validateNumerical("permute", "x", x); + SDValidation.validateInteger("permute", "dimensions", dimensions); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Permute(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public SDVariable permute(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("permute", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.shape.Permute(sd,x, dimensions).outputVariable(); + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public SDVariable permute(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("permute", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Permute(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable prod(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable prod(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable prod(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(sd,x, false, dimensions).outputVariable(); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable prod(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param from Initial/smallest value + * @param to Largest value (exclusive) + * @param step Step size + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public SDVariable range(double from, double to, double step, DataType dataType) { + return new org.nd4j.linalg.api.ops.random.impl.Range(sd,from, to, step, dataType).outputVariable(); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param name name May be null. Name for the output variable + * @param from Initial/smallest value + * @param to Largest value (exclusive) + * @param step Step size + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public SDVariable range(String name, double from, double to, double step, DataType dataType) { + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.Range(sd,from, to, step, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param from Initial/smallest value (NUMERIC type) + * @param to Largest value (exclusive) (NUMERIC type) + * @param step Step size (NUMERIC type) + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public SDVariable range(SDVariable from, SDVariable to, SDVariable step, DataType dataType) { + SDValidation.validateNumerical("range", "from", from); + SDValidation.validateNumerical("range", "to", to); + SDValidation.validateNumerical("range", "step", step); + return new org.nd4j.linalg.api.ops.random.impl.Range(sd,from, to, step, dataType).outputVariable(); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param name name May be null. Name for the output variable + * @param from Initial/smallest value (NUMERIC type) + * @param to Largest value (exclusive) (NUMERIC type) + * @param step Step size (NUMERIC type) + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public SDVariable range(String name, SDVariable from, SDVariable to, SDVariable step, + DataType dataType) { + SDValidation.validateNumerical("range", "from", from); + SDValidation.validateNumerical("range", "to", to); + SDValidation.validateNumerical("range", "step", step); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.Range(sd,from, to, step, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param in Input variable (NUMERIC type) + * @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type) + */ + public SDVariable rank(SDVariable in) { + SDValidation.validateNumerical("rank", "in", in); + return new org.nd4j.linalg.api.ops.impl.shape.Rank(sd,in).outputVariable(); + } + + /** + * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type) + */ + public SDVariable rank(String name, SDVariable in) { + SDValidation.validateNumerical("rank", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Rank(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise replace where condition:
+ * out[i] = from[i] if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param update Source array (NUMERIC type) + * @param from Replacement values array (used conditionally). Must be same shape as 'update' array (NUMERIC type) + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public SDVariable replaceWhere(SDVariable update, SDVariable from, Condition condition) { + SDValidation.validateNumerical("replaceWhere", "update", update); + SDValidation.validateNumerical("replaceWhere", "from", from); + return new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(sd,update, from, condition).outputVariable(); + } + + /** + * Element-wise replace where condition:
+ * out[i] = from[i] if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param name name May be null. Name for the output variable + * @param update Source array (NUMERIC type) + * @param from Replacement values array (used conditionally). Must be same shape as 'update' array (NUMERIC type) + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public SDVariable replaceWhere(String name, SDVariable update, SDVariable from, + Condition condition) { + SDValidation.validateNumerical("replaceWhere", "update", update); + SDValidation.validateNumerical("replaceWhere", "from", from); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(sd,update, from, condition).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise replace where condition:
+ * out[i] = value if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param update Source array (NUMERIC type) + * @param value Value to set at the output, if the condition is satisfied + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public SDVariable replaceWhere(SDVariable update, double value, Condition condition) { + SDValidation.validateNumerical("replaceWhere", "update", update); + return new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(sd,update, value, condition).outputVariable(); + } + + /** + * Element-wise replace where condition:
+ * out[i] = value if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param name name May be null. Name for the output variable + * @param update Source array (NUMERIC type) + * @param value Value to set at the output, if the condition is satisfied + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public SDVariable replaceWhere(String name, SDVariable update, double value, + Condition condition) { + SDValidation.validateNumerical("replaceWhere", "update", update); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(sd,update, value, condition).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reshape(SDVariable x, SDVariable shape) { + SDValidation.validateNumerical("reshape", "x", x); + SDValidation.validateNumerical("reshape", "shape", shape); + return new org.nd4j.linalg.api.ops.impl.shape.Reshape(sd,x, shape).outputVariable(); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reshape(String name, SDVariable x, SDVariable shape) { + SDValidation.validateNumerical("reshape", "x", x); + SDValidation.validateNumerical("reshape", "shape", shape); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Reshape(sd,x, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reshape(SDVariable x, long... shape) { + SDValidation.validateNumerical("reshape", "x", x); + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.impl.shape.Reshape(sd,x, shape).outputVariable(); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reshape(String name, SDVariable x, long... shape) { + SDValidation.validateNumerical("reshape", "x", x); + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Reshape(sd,x, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reverse the values of an array for the specified dimensions
+ * If input is:
+ * [ 1, 2, 3]
+ * [ 4, 5, 6]
+ * then
+ * reverse(in, 0):
+ * [3, 2, 1]
+ * [6, 5, 4]
+ * reverse(in, 1):
+ * [4, 5, 6]
+ * [1, 2 3]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Input variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reverse(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("reverse", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(sd,x, dimensions).outputVariable(); + } + + /** + * Reverse the values of an array for the specified dimensions
+ * If input is:
+ * [ 1, 2, 3]
+ * [ 4, 5, 6]
+ * then
+ * reverse(in, 0):
+ * [3, 2, 1]
+ * [6, 5, 4]
+ * reverse(in, 1):
+ * [4, 5, 6]
+ * [1, 2 3]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Input variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reverse(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("reverse", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @param seqDim Sequence dimension + * @param batchDim Batch dimension + * @return output Reversed sequences (NUMERIC type) + */ + public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim, + int batchDim) { + SDValidation.validateNumerical("reverseSequence", "x", x); + SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable(); + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @param seqDim Sequence dimension + * @param batchDim Batch dimension + * @return output Reversed sequences (NUMERIC type) + */ + public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim, + int batchDim) { + SDValidation.validateNumerical("reverseSequence", "x", x); + SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @return output Reversed sequences (NUMERIC type) + */ + public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) { + SDValidation.validateNumerical("reverseSequence", "x", x); + SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, -1, 0).outputVariable(); + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @return output Reversed sequences (NUMERIC type) + */ + public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths) { + SDValidation.validateNumerical("reverseSequence", "x", x); + SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, -1, 0).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise scalar floor modulus operation: out = floorMod(in, value).
+ * i.e., returns the remainder after division by 'value'
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarFloorMod(SDVariable in, double value) { + SDValidation.validateNumerical("scalarFloorMod", "in", in); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,in, value).outputVariable(); + } + + /** + * Element-wise scalar floor modulus operation: out = floorMod(in, value).
+ * i.e., returns the remainder after division by 'value'
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarFloorMod(String name, SDVariable in, double value) { + SDValidation.validateNumerical("scalarFloorMod", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,in, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise scalar maximum operation: out = max(in, value)
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Scalar value to compare (NUMERIC type) + */ + public SDVariable scalarMax(SDVariable in, double value) { + SDValidation.validateNumerical("scalarMax", "in", in); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMax(sd,in, value).outputVariable(); + } + + /** + * Element-wise scalar maximum operation: out = max(in, value)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Scalar value to compare (NUMERIC type) + */ + public SDVariable scalarMax(String name, SDVariable in, double value) { + SDValidation.validateNumerical("scalarMax", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMax(sd,in, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise scalar minimum operation: out = min(in, value)
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarMin(SDVariable in, double value) { + SDValidation.validateNumerical("scalarMin", "in", in); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMin(sd,in, value).outputVariable(); + } + + /** + * Element-wise scalar minimum operation: out = min(in, value)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarMin(String name, SDVariable in, double value) { + SDValidation.validateNumerical("scalarMin", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMin(sd,in, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Return a variable with equal shape to the input, but all elements set to value 'set'
+ * + * @param in Input variable (NUMERIC type) + * @param set Value to set + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarSet(SDVariable in, double set) { + SDValidation.validateNumerical("scalarSet", "in", in); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarSet(sd,in, set).outputVariable(); + } + + /** + * Return a variable with equal shape to the input, but all elements set to value 'set'
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param set Value to set + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarSet(String name, SDVariable in, double set) { + SDValidation.validateNumerical("scalarSet", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarSet(sd,in, set).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter addition operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterAdd", "ref", ref); + SDValidation.validateNumerical("scatterAdd", "indices", indices); + SDValidation.validateNumerical("scatterAdd", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter addition operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterAdd", "ref", ref); + SDValidation.validateNumerical("scatterAdd", "indices", indices); + SDValidation.validateNumerical("scatterAdd", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter division operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterDiv", "ref", ref); + SDValidation.validateNumerical("scatterDiv", "indices", indices); + SDValidation.validateNumerical("scatterDiv", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter division operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterDiv", "ref", ref); + SDValidation.validateNumerical("scatterDiv", "indices", indices); + SDValidation.validateNumerical("scatterDiv", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter max operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterMax", "ref", ref); + SDValidation.validateNumerical("scatterMax", "indices", indices); + SDValidation.validateNumerical("scatterMax", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter max operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterMax", "ref", ref); + SDValidation.validateNumerical("scatterMax", "indices", indices); + SDValidation.validateNumerical("scatterMax", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter min operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterMin", "ref", ref); + SDValidation.validateNumerical("scatterMin", "indices", indices); + SDValidation.validateNumerical("scatterMin", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter min operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterMin", "ref", ref); + SDValidation.validateNumerical("scatterMin", "indices", indices); + SDValidation.validateNumerical("scatterMin", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter multiplication operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterMul", "ref", ref); + SDValidation.validateNumerical("scatterMul", "indices", indices); + SDValidation.validateNumerical("scatterMul", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter multiplication operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterMul", "ref", ref); + SDValidation.validateNumerical("scatterMul", "indices", indices); + SDValidation.validateNumerical("scatterMul", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter subtraction operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterSub", "ref", ref); + SDValidation.validateNumerical("scatterSub", "indices", indices); + SDValidation.validateNumerical("scatterSub", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter subtraction operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterSub", "ref", ref); + SDValidation.validateNumerical("scatterSub", "indices", indices); + SDValidation.validateNumerical("scatterSub", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter update operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterUpdate", "ref", ref); + SDValidation.validateNumerical("scatterUpdate", "indices", indices); + SDValidation.validateNumerical("scatterUpdate", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter update operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterUpdate", "ref", ref); + SDValidation.validateNumerical("scatterUpdate", "indices", indices); + SDValidation.validateNumerical("scatterUpdate", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment max operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMax(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment max operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMax(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment mean operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMean(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment mean operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMean(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment min operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMin(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment min operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMin(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment product operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentProd(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment product operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentProd(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment sum operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentSum(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment sum operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentSum(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + return new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable(); + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param name name May be null. Name for the output variable + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length (INT type) + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + SDValidation.validateInteger("sequenceMask", "maxLen", maxLen); + return new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable(); + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param name name May be null. Name for the output variable + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length (INT type) + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen, + DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + SDValidation.validateInteger("sequenceMask", "maxLen", maxLen); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * see sequenceMask(String, SDVariable, SDVariable, DataType)
+ * + * @param lengths (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public SDVariable sequenceMask(SDVariable lengths, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + return new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, dataType).outputVariable(); + } + + /** + * see sequenceMask(String, SDVariable, SDVariable, DataType)
+ * + * @param name name May be null. Name for the output variable + * @param lengths (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public SDVariable sequenceMask(String name, SDVariable lengths, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns the shape of the specified INDArray as a 1D INDArray
+ * + * @param input Input variable (NUMERIC type) + * @return output 1D output variable with contents equal to the shape of the input (NUMERIC type) + */ + public SDVariable shape(SDVariable input) { + SDValidation.validateNumerical("shape", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.Shape(sd,input).outputVariable(); + } + + /** + * Returns the shape of the specified INDArray as a 1D INDArray
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @return output 1D output variable with contents equal to the shape of the input (NUMERIC type) + */ + public SDVariable shape(String name, SDVariable input) { + SDValidation.validateNumerical("shape", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Shape(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param in Input variable (NUMERIC type) + * @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type) + */ + public SDVariable size(SDVariable in) { + SDValidation.validateNumerical("size", "in", in); + return new org.nd4j.linalg.api.ops.impl.shape.Size(sd,in).outputVariable(); + } + + /** + * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type) + */ + public SDVariable size(String name, SDVariable in) { + SDValidation.validateNumerical("size", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Size(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a rank 0 (scalar) variable for the size of the specified dimension.
+ * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
+ * + * @param in Input variable (NUMERIC type) + * @param dimension Dimension to get size of + * @return output Scalar INDArray for size at specified variable (NUMERIC type) + */ + public SDVariable sizeAt(SDVariable in, int dimension) { + SDValidation.validateNumerical("sizeAt", "in", in); + return new org.nd4j.linalg.api.ops.impl.shape.SizeAt(sd,in, dimension).outputVariable(); + } + + /** + * Returns a rank 0 (scalar) variable for the size of the specified dimension.
+ * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimension Dimension to get size of + * @return output Scalar INDArray for size at specified variable (NUMERIC type) + */ + public SDVariable sizeAt(String name, SDVariable in, int dimension) { + SDValidation.validateNumerical("sizeAt", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SizeAt(sd,in, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @param size Size of the output array. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @return output Subset of the input (NUMERIC type) + */ + public SDVariable slice(SDVariable input, int[] begin, int... size) { + SDValidation.validateNumerical("slice", "input", input); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); + return new org.nd4j.linalg.api.ops.impl.shape.Slice(sd,input, begin, size).outputVariable(); + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param name name May be null. Name for the output variable + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @param size Size of the output array. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @return output Subset of the input (NUMERIC type) + */ + public SDVariable slice(String name, SDVariable input, int[] begin, int... size) { + SDValidation.validateNumerical("slice", "input", input); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Slice(sd,input, begin, size).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (INT type) + * @param size Size of the output array. Must be same length as rank of input array (INT type) + * @return output Subset of the input (NUMERIC type) + */ + public SDVariable slice(SDVariable input, SDVariable begin, SDVariable size) { + SDValidation.validateNumerical("slice", "input", input); + SDValidation.validateInteger("slice", "begin", begin); + SDValidation.validateInteger("slice", "size", size); + return new org.nd4j.linalg.api.ops.impl.shape.Slice(sd,input, begin, size).outputVariable(); + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param name name May be null. Name for the output variable + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (INT type) + * @param size Size of the output array. Must be same length as rank of input array (INT type) + * @return output Subset of the input (NUMERIC type) + */ + public SDVariable slice(String name, SDVariable input, SDVariable begin, SDVariable size) { + SDValidation.validateNumerical("slice", "input", input); + SDValidation.validateInteger("slice", "begin", begin); + SDValidation.validateInteger("slice", "size", size); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Slice(sd,input, begin, size).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x (NUMERIC type) + * @param keepDims + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable squaredNorm(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x (NUMERIC type) + * @param keepDims + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable squaredNorm(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable squaredNorm(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(sd,x, false, dimensions).outputVariable(); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable squaredNorm(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Remove a single dimension of size 1.
+ * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c]
+ * + * @param x Input variable (NUMERIC type) + * @param axis Size 1 dimension to remove + * @return output Output variable (NUMERIC type) + */ + public SDVariable squeeze(SDVariable x, int axis) { + SDValidation.validateNumerical("squeeze", "x", x); + return new org.nd4j.linalg.api.ops.impl.shape.Squeeze(sd,x, axis).outputVariable(); + } + + /** + * Remove a single dimension of size 1.
+ * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param axis Size 1 dimension to remove + * @return output Output variable (NUMERIC type) + */ + public SDVariable squeeze(String name, SDVariable x, int axis) { + SDValidation.validateNumerical("squeeze", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Squeeze(sd,x, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Stack a set of N INDArray of rank X into one rank X+1 variable.
+ * If inputs have shape [a,b,c] then output has shape:
+ * axis = 0: [N,a,b,c]
+ * axis = 1: [a,N,b,c]
+ * axis = 2: [a,b,N,c]
+ * axis = 3: [a,b,c,N]
+ * see unstack(String[], SDVariable, int, int)
+ * + * @param values Input variables to stack. Must have the same shape for all inputs (NDARRAY type) + * @param axis Axis to stack on + * @return output Output variable (NDARRAY type) + */ + public SDVariable stack(int axis, SDVariable... values) { + Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); + return new org.nd4j.linalg.api.ops.impl.shape.Stack(sd,values, axis).outputVariable(); + } + + /** + * Stack a set of N INDArray of rank X into one rank X+1 variable.
+ * If inputs have shape [a,b,c] then output has shape:
+ * axis = 0: [N,a,b,c]
+ * axis = 1: [a,N,b,c]
+ * axis = 2: [a,b,N,c]
+ * axis = 3: [a,b,c,N]
+ * see unstack(String[], SDVariable, int, int)
+ * + * @param name name May be null. Name for the output variable + * @param axis Axis to stack on + * @param values Input variables to stack. Must have the same shape for all inputs (NDARRAY type) + * @return output Output variable (NDARRAY type) + */ + public SDVariable stack(String name, int axis, SDVariable... values) { + Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Stack(sd,values, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, + boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, int... dimensions) { + SDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, false, dimensions).outputVariable(); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, + int... dimensions) { + SDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, and a value of 0 is used instead for the beginning index for that dimension + * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, and a value of size(i)-1 is used instead for the end index for that dimension + * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other dimensions are inserted as required at the specified position + * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is inserted at this point + * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions + * @return output A subset of the input array (NUMERIC type) + */ + public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, + int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { + SDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(sd,in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask).outputVariable(); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param name name May be null. Name for the output variable + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, and a value of 0 is used instead for the beginning index for that dimension + * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, and a value of size(i)-1 is used instead for the end index for that dimension + * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other dimensions are inserted as required at the specified position + * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is inserted at this point + * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions + * @return output A subset of the input array (NUMERIC type) + */ + public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, + long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, + int shrinkAxisMask) { + SDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(sd,in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @return output A subset of the input array (NUMERIC type) + */ + public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long... strides) { + SDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(sd,in, begin, end, strides, 0, 0, 0, 0, 0).outputVariable(); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param name name May be null. Name for the output variable + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @return output A subset of the input array (NUMERIC type) + */ + public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, + long... strides) { + SDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(sd,in, begin, end, strides, 0, 0, 0, 0, 0).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable sum(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable sum(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable sum(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(sd,x, false, dimensions).outputVariable(); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable sum(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output Output variable (NUMERIC type) + */ + public SDVariable tensorMmul(SDVariable x, SDVariable y, int[] dimensionsX, int[] dimensionsY, + boolean transposeX, boolean transposeY, boolean transposeZ) { + SDValidation.validateNumerical("tensorMmul", "x", x); + SDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + return new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(sd,x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ).outputVariable(); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output Output variable (NUMERIC type) + */ + public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX, + int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) { + SDValidation.validateNumerical("tensorMmul", "x", x); + SDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(sd,x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tensorMmul(SDVariable x, SDVariable y, int[] dimensionsX, int... dimensionsY) { + SDValidation.validateNumerical("tensorMmul", "x", x); + SDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + return new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(sd,x, y, dimensionsX, dimensionsY, false, false, false).outputVariable(); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX, + int... dimensionsY) { + SDValidation.validateNumerical("tensorMmul", "x", x); + SDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(sd,x, y, dimensionsX, dimensionsY, false, false, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Repeat (tile) the input tensor the specified number of times.
+ * For example, if input is
+ * [1, 2]
+ * [3, 4]
+ * and repeat is [2, 3]
+ * then output is
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * + * @param x Input variable (NDARRAY type) + * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array (INT type) + * @return output Output variable (NDARRAY type) + */ + public SDVariable tile(SDVariable x, SDVariable repeat) { + SDValidation.validateInteger("tile", "repeat", repeat); + return new org.nd4j.linalg.api.ops.impl.shape.Tile(sd,x, repeat).outputVariable(); + } + + /** + * Repeat (tile) the input tensor the specified number of times.
+ * For example, if input is
+ * [1, 2]
+ * [3, 4]
+ * and repeat is [2, 3]
+ * then output is
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NDARRAY type) + * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array (INT type) + * @return output Output variable (NDARRAY type) + */ + public SDVariable tile(String name, SDVariable x, SDVariable repeat) { + SDValidation.validateInteger("tile", "repeat", repeat); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Tile(sd,x, repeat).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * see tile(String, SDVariable, int...)
+ * + * @param x (NDARRAY type) + * @param repeat (Size: AtLeast(min=1)) + * @return output (NDARRAY type) + */ + public SDVariable tile(SDVariable x, int... repeat) { + Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); + return new org.nd4j.linalg.api.ops.impl.shape.Tile(sd,x, repeat).outputVariable(); + } + + /** + * see tile(String, SDVariable, int...)
+ * + * @param name name May be null. Name for the output variable + * @param x (NDARRAY type) + * @param repeat (Size: AtLeast(min=1)) + * @return output (NDARRAY type) + */ + public SDVariable tile(String name, SDVariable x, int... repeat) { + Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Tile(sd,x, repeat).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix transpose operation: If input has shape [a,b] output has shape [b,a]
+ * + * @param x Input variable (NDARRAY type) + * @return output transposed input (NDARRAY type) + */ + public SDVariable transpose(SDVariable x) { + return new org.nd4j.linalg.api.ops.impl.shape.Transpose(sd,x).outputVariable(); + } + + /** + * Matrix transpose operation: If input has shape [a,b] output has shape [b,a]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NDARRAY type) + * @return output transposed input (NDARRAY type) + */ + public SDVariable transpose(String name, SDVariable x) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Transpose(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMax(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMax", "data", data); + SDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMax(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMax", "data", data); + SDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMean(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMean", "data", data); + SDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMean(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMean", "data", data); + SDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMin(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMin", "data", data); + SDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMin(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMin", "data", data); + SDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentProd(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentProd", "data", data); + SDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentProd(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentProd", "data", data); + SDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentSqrtN(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); + SDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentSqrtN(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); + SDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentSum(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentSum", "data", data); + SDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentSum(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentSum", "data", data); + SDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.
+ * If input has shape [a,b,c] then output has shape:
+ * axis = 0: [b,c]
+ * axis = 1: [a,c]
+ * axis = 2: [a,b]
+ * + * @param value Input variable to unstack (NDARRAY type) + * @param axis Axis to unstack on + * @param num Number of output variables + */ + public SDVariable[] unstack(SDVariable value, int axis, int num) { + return new org.nd4j.linalg.api.ops.impl.shape.Unstack(sd,value, axis, num).outputVariables(); + } + + /** + * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.
+ * If input has shape [a,b,c] then output has shape:
+ * axis = 0: [b,c]
+ * axis = 1: [a,c]
+ * axis = 2: [a,b]
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param value Input variable to unstack (NDARRAY type) + * @param axis Axis to unstack on + * @param num Number of output variables + */ + public SDVariable[] unstack(String[] names, SDVariable value, int axis, int num) { + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.Unstack(sd,value, axis, num).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable variance(SDVariable x, boolean biasCorrected, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable variance(String name, SDVariable x, boolean biasCorrected, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable variance(SDVariable x, boolean biasCorrected, int... dimensions) { + SDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, false, dimensions).outputVariable(); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable variance(String name, SDVariable x, boolean biasCorrected, int... dimensions) { + SDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param input Input (NUMERIC type) + * @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type) + */ + public SDVariable zerosLike(SDVariable input) { + SDValidation.validateNumerical("zerosLike", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(sd,input).outputVariable(); + } + + /** + * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param name name May be null. Name for the output variable + * @param input Input (NUMERIC type) + * @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type) + */ + public SDVariable zerosLike(String name, SDVariable input) { + SDValidation.validateNumerical("zerosLike", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java index d367e3d4a..bb9f027c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java @@ -23,8 +23,8 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.enums.DataFormat; import org.nd4j.base.Preconditions; +import org.nd4j.enums.DataFormat; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; @@ -753,6 +753,33 @@ public class SDCNN extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices
+ * + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + */ + public SDVariable[] maxPoolWithArgmax(SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("maxPoolWithArgmax", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(sd,input, Pooling2DConfig).outputVariables(); + } + + /** + * 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + */ + public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable input, + Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("maxPoolWithArgmax", "input", input); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(sd,input, Pooling2DConfig).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + /** * 2D Convolution layer operation - max pooling 2d
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index f4a490813..ead137a57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -2205,7 +2205,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeAdd(SDVariable[] inputs) { + public SDVariable mergeAdd(SDVariable... inputs) { SDValidation.validateNumerical("mergeAdd", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); @@ -2219,7 +2219,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeAdd(String name, SDVariable[] inputs) { + public SDVariable mergeAdd(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeAdd", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); @@ -2233,7 +2233,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeAvg(SDVariable[] inputs) { + public SDVariable mergeAvg(SDVariable... inputs) { SDValidation.validateNumerical("mergeAvg", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); @@ -2247,7 +2247,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeAvg(String name, SDVariable[] inputs) { + public SDVariable mergeAvg(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeAvg", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); @@ -2261,7 +2261,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeMax(SDVariable[] inputs) { + public SDVariable mergeMax(SDVariable... inputs) { SDValidation.validateNumerical("mergeMax", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); @@ -2275,7 +2275,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeMax(String name, SDVariable[] inputs) { + public SDVariable mergeMax(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeMax", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java index 6b1831de7..de8148c02 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java @@ -18,17 +18,15 @@ package org.nd4j.autodiff.samediff.ops; -import java.lang.String; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; -import lombok.NonNull; +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; @@ -43,28 +41,26 @@ public class SDRNN extends SDOps { * @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object - * @return output The cell's outputs. (NUMERIC type) */ - public SDVariable gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + public SDVariable[] gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { SDValidation.validateNumerical("gru", "x", x); SDValidation.validateNumerical("gru", "hLast", hLast); - return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables(); } /** * The GRU cell. Does a single time step operation
* - * @param name name May be null. Name for the output variable + * @param names names May be null. Arrays of names for the output variables. * @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object - * @return output The cell's outputs. (NUMERIC type) */ - public GRUCellOutputs gru(String name, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + public SDVariable[] gru(String[] names, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { SDValidation.validateNumerical("gru", "x", x); SDValidation.validateNumerical("gru", "hLast", hLast); - GRUCell c = new GRUCell(sd,x, hLast, GRUWeights); - return new GRUCellOutputs(c.outputVariables(name)); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); } /** @@ -75,39 +71,172 @@ public class SDRNN extends SDOps { * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) * @param LSTMWeights Configuration Object * @param LSTMConfiguration Configuration Object - * @return output The cell's outputs (NUMERIC type) */ - public LSTMCellOutputs lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast, + public SDVariable[] lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { SDValidation.validateNumerical("lstmCell", "x", x); SDValidation.validateNumerical("lstmCell", "cLast", cLast); SDValidation.validateNumerical("lstmCell", "yLast", yLast); - LSTMBlockCell c = new LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); - return new LSTMCellOutputs(c.outputVariables()); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariables(); } /** * The LSTM cell. Does a single time step operation.
* - * @param name name May be null. Name for the output variable + * @param names names May be null. Arrays of names for the output variables. * @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type) * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) * @param LSTMWeights Configuration Object * @param LSTMConfiguration Configuration Object - * @return output The cell's outputs (NUMERIC type) */ - public LSTMCellOutputs lstmCell(String name, SDVariable x, SDVariable cLast, SDVariable yLast, + public SDVariable[] lstmCell(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { SDValidation.validateNumerical("lstmCell", "x", x); SDValidation.validateNumerical("lstmCell", "cLast", cLast); SDValidation.validateNumerical("lstmCell", "yLast", yLast); - LSTMBlockCell c = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); - return new LSTMCellOutputs(c.outputVariables(name)); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); } /** - * The LSTM layer. Does multiple time steps.
+ * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public SDVariable[] lstmLayer(SDVariable x, SDVariable cLast, SDVariable yLast, + SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) { + SDValidation.validateNumerical("lstmLayer", "x", x); + SDValidation.validateNumerical("lstmLayer", "cLast", cLast); + SDValidation.validateNumerical("lstmLayer", "yLast", yLast); + SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig).outputVariables(); + } + + /** + * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public SDVariable[] lstmLayer(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast, + SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) { + SDValidation.validateNumerical("lstmLayer", "x", x); + SDValidation.validateNumerical("lstmLayer", "cLast", cLast); + SDValidation.validateNumerical("lstmLayer", "yLast", yLast); + SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public SDVariable[] lstmLayer(SDVariable x, LSTMLayerWeights LSTMLayerWeights, + LSTMLayerConfig LSTMLayerConfig) { + SDValidation.validateNumerical("lstmLayer", "x", x); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, null, null, null, LSTMLayerWeights, LSTMLayerConfig).outputVariables(); + } + + /** + * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public SDVariable[] lstmLayer(String[] names, SDVariable x, LSTMLayerWeights LSTMLayerWeights, + LSTMLayerConfig LSTMLayerConfig) { + SDValidation.validateNumerical("lstmLayer", "x", x); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, null, null, null, LSTMLayerWeights, LSTMLayerConfig).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * The LSTM block
* * @param maxTSLength (NUMERIC type) * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) @@ -117,17 +246,17 @@ public class SDRNN extends SDOps { * @param LSTMConfiguration Configuration Object * @return output The layer's outputs. (NUMERIC type) */ - public SDVariable lstmLayer(SDVariable maxTSLength, SDVariable x, SDVariable cLast, + public SDVariable lstmblock(SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { - SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); - SDValidation.validateNumerical("lstmLayer", "x", x); - SDValidation.validateNumerical("lstmLayer", "cLast", cLast); - SDValidation.validateNumerical("lstmLayer", "yLast", yLast); - return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength); + SDValidation.validateNumerical("lstmblock", "x", x); + SDValidation.validateNumerical("lstmblock", "cLast", cLast); + SDValidation.validateNumerical("lstmblock", "yLast", yLast); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); } /** - * The LSTM layer. Does multiple time steps.
+ * The LSTM block
* * @param name name May be null. Name for the output variable * @param maxTSLength (NUMERIC type) @@ -138,13 +267,43 @@ public class SDRNN extends SDOps { * @param LSTMConfiguration Configuration Object * @return output The layer's outputs. (NUMERIC type) */ - public SDVariable lstmLayer(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast, + public SDVariable lstmblock(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { - SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); - SDValidation.validateNumerical("lstmLayer", "x", x); - SDValidation.validateNumerical("lstmLayer", "cLast", cLast); - SDValidation.validateNumerical("lstmLayer", "yLast", yLast); - SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength); + SDValidation.validateNumerical("lstmblock", "x", x); + SDValidation.validateNumerical("lstmblock", "cLast", cLast); + SDValidation.validateNumerical("lstmblock", "yLast", yLast); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * The LSTM block
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public SDVariable lstmblock(SDVariable x, LSTMWeights LSTMWeights, + LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmblock", "x", x); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,null, x, null, null, LSTMWeights, LSTMConfiguration).outputVariable(); + } + + /** + * The LSTM block
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public SDVariable lstmblock(String name, SDVariable x, LSTMWeights LSTMWeights, + LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmblock", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,null, x, null, null, LSTMWeights, LSTMConfiguration).outputVariable(); return sd.updateVariableNameAndReference(out, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/CellAct.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/CellAct.java new file mode 100644 index 000000000..f7f458ffd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/CellAct.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Activations */ +public enum CellAct { + TANH, + + RELU, + + SIGMOID, + + AFFINE, + + LEAKY_RELU, + + THRESHHOLD_RELU, + + SCALED_TAHN, + + HARD_SIGMOID, + + ELU, + + SOFTSIGN, + + SOFTPLUS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/GateAct.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/GateAct.java new file mode 100644 index 000000000..498f825fd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/GateAct.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Activations */ +public enum GateAct { + TANH, + + RELU, + + SIGMOID, + + AFFINE, + + LEAKY_RELU, + + THRESHHOLD_RELU, + + SCALED_TAHN, + + HARD_SIGMOID, + + ELU, + + SOFTSIGN, + + SOFTPLUS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDataFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDataFormat.java new file mode 100644 index 000000000..cd8855b05 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDataFormat.java @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * for unidirectional: + * TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ * NST: shape [numExamples, inOutSize, timeLength]
+ * NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout
+ * for bidirectional: + * T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */ +public enum LSTMDataFormat { + TNS, + + NST, + + NTS, + + T2NS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDirectionMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDirectionMode.java new file mode 100644 index 000000000..4732fc611 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDirectionMode.java @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * direction
+ * FWD: 0 = fwd + * BWD: 1 = bwd + * BIDIR_SUM: 2 = bidirectional sum + * BIDIR_CONCAT: 3 = bidirectional concat + * BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */ +public enum LSTMDirectionMode { + FWD, + + BWD, + + BIDIR_SUM, + + BIDIR_CONCAT, + + BIDIR_EXTRA_DIM +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/OutAct.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/OutAct.java new file mode 100644 index 000000000..df034a294 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/OutAct.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Activations */ +public enum OutAct { + TANH, + + RELU, + + SIGMOID, + + AFFINE, + + LEAKY_RELU, + + THRESHHOLD_RELU, + + SCALED_TAHN, + + HARD_SIGMOID, + + ELU, + + SOFTSIGN, + + SOFTPLUS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/RnnDataFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/RnnDataFormat.java new file mode 100644 index 000000000..8b6e2fbd6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/RnnDataFormat.java @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * The data format of the input. Input shape depends on data format (in config):
+ * TNS -> [timeSteps, batchSize, inSize]
+ * NST -> [batchSize, inSize, timeSteps]
+ * NTS -> [batchSize, timeSteps, inSize]
*/ +public enum RnnDataFormat { + TNS, + + NST, + + NTS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 62edb778f..043a16e87 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -146,6 +146,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class, + org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class, org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java index c60b11d23..ff139c236 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java @@ -301,24 +301,27 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } } - protected void checkForWorkspaces(CustomOp op) { - for (val input: op.inputArguments()) + protected void checkForWorkspaces(CustomOp op, OpContext oc) { + List inArgs = oc != null ? oc.getInputArrays() : op.inputArguments(); + List outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments(); + + for (val input: inArgs) checkWorkspace(op.opName(), input); - for (val output: op.outputArguments()) + for (val output: outArgs) checkWorkspace(op.opName(), output); } - protected void checkForWorkspaces(Op op) { - val x = op.x(); + protected void checkForWorkspaces(Op op, OpContext oc) { + val x = oc != null ? oc.getInputArray(0) : op.x(); if (x != null) checkWorkspace(op.opName(), x); - val y = op.y(); + val y = oc != null && oc.getInputArrays().size() > 1 ? oc.getInputArray(1) : op.y(); if (y != null) checkWorkspace(op.opName(), y); - val z = op.z(); + val z = oc != null ? oc.getOutputArray(0) : op.z(); if (z != null) checkWorkspace(op.opName(), z); } @@ -346,7 +349,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().processOpCall(op, tadBuffers); break; case SCOPE_PANIC: - checkForWorkspaces(op); + checkForWorkspaces(op, null); return 0L; case DISABLED: default: @@ -357,7 +360,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } @Deprecated - public long profilingHookIn(CustomOp op) { + public long profilingHookIn(CustomOp op, OpContext oc) { switch (profilingMode) { case ALL: OpProfiler.getInstance().processOpCall(op); @@ -368,7 +371,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().processOpCall(op); break; case SCOPE_PANIC: - checkForWorkspaces(op); + checkForWorkspaces(op, oc); return 0L; case DISABLED: default: @@ -379,7 +382,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } @Deprecated - public void profilingHookOut(Op op, long timeStart) { + public void profilingHookOut(Op op, OpContext oc, long timeStart) { switch (profilingMode) { case ALL: OpProfiler.getInstance().processStackCall(op, timeStart); @@ -392,14 +395,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().timeOpCall(op, timeStart); break; case NAN_PANIC: - OpExecutionerUtil.checkForNaN(op); + OpExecutionerUtil.checkForNaN(op, oc); break; case INF_PANIC: - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForInf(op, oc); break; case ANY_PANIC: - OpExecutionerUtil.checkForNaN(op); - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForNaN(op, oc); + OpExecutionerUtil.checkForInf(op, oc); break; case DISABLED: default: @@ -413,7 +416,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } @Deprecated - public void profilingHookOut(CustomOp op, long timeStart) { + public void profilingHookOut(CustomOp op, OpContext oc, long timeStart) { switch (profilingMode) { case ALL: OpProfiler.getInstance().processStackCall(op, timeStart); @@ -426,14 +429,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().timeOpCall(op, timeStart); break; case NAN_PANIC: - OpExecutionerUtil.checkForNaN(op); + OpExecutionerUtil.checkForNaN(op, oc); break; case INF_PANIC: - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForInf(op, oc); break; case ANY_PANIC: - OpExecutionerUtil.checkForNaN(op); - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForNaN(op, oc); + OpExecutionerUtil.checkForInf(op, oc); break; case DISABLED: default: @@ -442,12 +445,15 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } - public long profilingConfigurableHookIn(CustomOp op) { - for (val arr: op.inputArguments()) + public long profilingConfigurableHookIn(CustomOp op, OpContext oc) { + List inArgs = oc != null ? oc.getInputArrays() : op.inputArguments(); + List outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments(); + + for (val arr: inArgs) if (arr.wasClosed()) throw new IllegalStateException("One of Input arguments was closed before call"); - for (val arr: op.outputArguments()) + for (val arr: outArgs) if (arr.wasClosed()) throw new IllegalStateException("One of Output arguments was closed before call"); @@ -460,7 +466,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) { - checkForWorkspaces(op); + checkForWorkspaces(op, oc); } return System.nanoTime(); @@ -491,14 +497,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().processOpCall(op, tadBuffers); } if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) { - checkForWorkspaces(op); + checkForWorkspaces(op, null); } return System.nanoTime(); } - public void profilingConfigurableHookOut(Op op, long timeStart) { + public void profilingConfigurableHookOut(Op op, OpContext oc, long timeStart) { if (OpProfiler.getInstance().getConfig() == null) return; @@ -509,10 +515,10 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().timeOpCall(op, timeStart); } if (OpProfiler.getInstance().getConfig().isCheckForNAN()) { - OpExecutionerUtil.checkForNaN(op); + OpExecutionerUtil.checkForNaN(op, oc); } if (OpProfiler.getInstance().getConfig().isCheckForINF()) { - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForInf(op, oc); } if (OpProfiler.getInstance().getConfig().isNativeStatistics()) { if (op.z() != null) { @@ -531,7 +537,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } } - public void profilingConfigurableHookOut(CustomOp op, long timeStart) { + public void profilingConfigurableHookOut(CustomOp op, OpContext oc, long timeStart) { if (OpProfiler.getInstance().getConfig() == null) return; @@ -542,10 +548,10 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().timeOpCall(op, timeStart); } if (OpProfiler.getInstance().getConfig().isCheckForNAN()) { - OpExecutionerUtil.checkForNaN(op); + OpExecutionerUtil.checkForNaN(op, oc); } if (OpProfiler.getInstance().getConfig().isCheckForINF()) { - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForInf(op, oc); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java index 080825433..83421e247 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java @@ -22,12 +22,15 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.profiler.OpProfiler; +import java.util.List; + /**Utility functions for the DefaultOpExecutioner * @author Alex Black */ @@ -58,7 +61,7 @@ public class OpExecutionerUtil { } if (match > 0) - throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): "); + throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s)"); } public static void checkForAny(INDArray z) { @@ -92,44 +95,52 @@ public class OpExecutionerUtil { } - public static void checkForNaN(Op op) { + public static void checkForNaN(Op op, OpContext oc) { if (!OpProfiler.getInstance().getConfig().isCheckForNAN()) return; - if (op.z() != null && !(op instanceof MatchCondition)) { - checkForNaN(op.z()); + INDArray z = oc != null ? oc.getOutputArray(0) : op.z(); + if (z != null && !(op instanceof MatchCondition)) { + checkForNaN(z); } } - public static void checkForInf(Op op) { + public static void checkForInf(Op op, OpContext oc) { if (!OpProfiler.getInstance().getConfig().isCheckForINF()) return; - if (op.z() != null && !(op instanceof MatchCondition)) { - checkForInf(op.z()); + INDArray z = oc != null ? oc.getOutputArray(0) : op.z(); + if (z != null && !(op instanceof MatchCondition)) { + checkForInf(z); } } - public static void checkForInf(CustomOp op) { + public static void checkForInf(CustomOp op, OpContext oc) { if (!OpProfiler.getInstance().getConfig().isCheckForINF()) return; - for (val input: op.inputArguments()) + List inArgs = oc != null ? oc.getInputArrays() : op.inputArguments(); + List outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments(); + + for (val input: inArgs) checkForInf(input); - for (val output: op.outputArguments()) + for (val output: outArgs) checkForInf(output); } - public static void checkForNaN(CustomOp op) { + public static void checkForNaN(CustomOp op, OpContext oc) { if (!OpProfiler.getInstance().getConfig().isCheckForNAN()) return; - for (val input: op.inputArguments()) + List inArgs = oc != null ? oc.getInputArrays() : op.inputArguments(); + List outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments(); + + for (val input: inArgs) checkForNaN(input); - for (val output: op.outputArguments()) + for (val output: outArgs) checkForNaN(output); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java index 0a7338814..5f5f6747a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java @@ -57,8 +57,12 @@ public class MaxPoolWithArgmax extends DynamicCustomOp { addArgs(); } - public MaxPoolWithArgmax(INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){ - super(null, new INDArray[]{input}, new INDArray[]{output, outArgMax}); + public MaxPoolWithArgmax(@NonNull INDArray input, @NonNull Pooling2DConfig config){ + this(input, null, null, config); + } + + public MaxPoolWithArgmax(@NonNull INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){ + super(null, new INDArray[]{input}, wrapFilterNull(output, outArgMax)); config.setType(Pooling2D.Pooling2DType.MAX); this.config = config; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java index cf4e87814..fdc1b40fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java @@ -45,7 +45,7 @@ public class SConv2D extends Conv2D { } public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, - @NonNull SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { + SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlock.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlock.java new file mode 100644 index 000000000..20d84a2d6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlock.java @@ -0,0 +1,144 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.layers.recurrent; + +import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + * LSTM layer implemented as a single operation. + * Implementation of operation for LSTM layer with optional peep hole connections.
+ * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and https://research.google.com/pubs/archive/43905.pdf
+ * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.
+ * See also: https://arxiv.org/pdf/1503.04069.pdf
+ *

+ * See also {@link LSTMBlockCell} - lstmBlockCell op is used internally at C++ level for computation.
+ *
+ * Input arrays:
+ * 0: max sequence length; long/int64 scalar
+ * 1: input [seqLength, bS, inSize] at time t
+ * 2: previous/initial cell state [bS, numUnits]
+ * 3: previous/initial output [bS, numUnits]
+ * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]
+ * 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits]
+ * 6: weights - cell peephole (t-1) connections to forget gate, [numUnits]
+ * 7: weights - cell peephole (t) connections to output gate, [numUnits]
+ * 8: biases, shape [4*numUnits]
+ *
+ * Input integer arguments: set via {@link LSTMConfiguration}
+ * 0: if not zero, provide peephole connections
+ * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]
+ *
+ * Input float arguments: set via {@link LSTMConfiguration}
+ * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training
+ * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
+ *

+ * Output arrays:
+ * 0: i - Input modulation gate activations, rank 3, shape as per dataFormat
+ * 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat
+ * 2: f - Output - forget gate activations, rank 3, shape as per dataFormat
+ * 3: o - Output - output gate activations, rank 3, shape as per dataFormat
+ * 4: z (ci) - Output - block input, rank 3, shape as per dataFormat
+ * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat
+ * 6: y (h) - Current cell output, rank 3, shape as per dataFormat
+ * + * @author Alex Black + */ +public class LSTMBlock extends DynamicCustomOp { + + private LSTMConfiguration configuration; + + @Getter + private LSTMWeights weights; + + public LSTMBlock() { + } + + public LSTMBlock(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) { + super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast)); + this.configuration = configuration; + this.weights = weights; + addIArgument(configuration.iArgs(true)); + addTArgument(configuration.tArgs()); + } + + public LSTMBlock(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) { + super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast)); + this.configuration = lstmConfiguration; + this.weights = lstmWeights; + addIArgument(configuration.iArgs(true)); + addTArgument(configuration.tArgs()); + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMBlock, got %s", inputDataTypes); + //7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input + DataType dt = inputDataTypes.get(1); + Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); + return Arrays.asList(dt, dt, dt, dt, dt, dt, dt); + } + + @Override + public List doDiff(List grads) { + throw new UnsupportedOperationException("Not yet implemented"); + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + configuration = LSTMConfiguration.builder() + .forgetBias(attributesForNode.get("forget_bias").getF()) + .clippingCellValue(attributesForNode.get("cell_clip").getF()) + .peepHole(attributesForNode.get("use_peephole").getB()) + .dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM + .build(); + addIArgument(configuration.iArgs(true)); + addTArgument(configuration.tArgs()); + } + + @Override + public String opName() { + return "lstmBlock"; + } + + @Override + public Map propertiesForFunction() { + return configuration.toProperties(true); + } + + @Override + public String tensorflowName() { + return "BlockLSTM"; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java index 59b85f500..a433b23d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java @@ -1,5 +1,5 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -13,7 +13,6 @@ * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - package org.nd4j.linalg.api.ops.impl.layers.recurrent; import lombok.Getter; @@ -24,89 +23,103 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; +import org.nd4j.shade.guava.primitives.Booleans; +import javax.xml.crypto.Data; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; + /** * LSTM layer implemented as a single operation. * Implementation of operation for LSTM layer with optional peep hole connections.
* S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and https://research.google.com/pubs/archive/43905.pdf
* Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.
* See also: https://arxiv.org/pdf/1503.04069.pdf
- *

- * See also {@link LSTMBlockCell} - lstmBlockCell op is used internally at C++ level for computation.
- *
* Input arrays:
- * 0: max sequence length; long/int64 scalar
- * 1: input [seqLength, bS, inSize] at time t
- * 2: previous/initial cell state [bS, numUnits]
- * 3: previous/initial output [bS, numUnits]
- * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]
- * 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits]
- * 6: weights - cell peephole (t-1) connections to forget gate, [numUnits]
- * 7: weights - cell peephole (t) connections to output gate, [numUnits]
- * 8: biases, shape [4*numUnits]
- *
- * Input integer arguments: set via {@link LSTMConfiguration}
- * 0: if not zero, provide peephole connections
- * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]
- *
- * Input float arguments: set via {@link LSTMConfiguration}
- * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training
- * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
+ * 0: input
+ * [sL, bS, nIn] when dataFormat - TNS
+ * [bS, sL, nIn] when dataFormat - NST
+ * [bS, nIn, sL] when dataFormat - NST
+ * 1: previous/initial cell state
+ * shapes [nIn, 4*nOut] for FWD, BWD Direction Mode
+ * shapes [2, nIn, 4*nOut] BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM Direction Mode
+ * 2: previous/initial output [bS, numUnits]
+ * * shapes [nIn, 4*nOut] for FWD, BWD Direction Mode
+ * * shapes [2, nIn, 4*nOut] BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM Direction Mode
+ * 3 max sequence length [bS]
+ * 4: LSTMLayerWeights - {@link LSTMLayerWeights}
+ * 5: LSTMLayerConfig - {@link LSTMLayerConfig}
*

* Output arrays:
- * 0: i - Input modulation gate activations, rank 3, shape as per dataFormat
- * 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat
- * 2: f - Output - forget gate activations, rank 3, shape as per dataFormat
- * 3: o - Output - output gate activations, rank 3, shape as per dataFormat
- * 4: z (ci) - Output - block input, rank 3, shape as per dataFormat
- * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat
- * 6: y (h) - Current cell output, rank 3, shape as per dataFormat
- * - * @author Alex Black + * 0: output h - rank 3 or 4, depends on DirectionMode and dataFormat
+ * 1: output at last step hL - rank 3 or 4, depends on DirectionMode and dataFormat<
+ * 2: cell state at last step cL - same shape as in hL
*/ public class LSTMLayer extends DynamicCustomOp { - private LSTMConfiguration configuration; + @Getter + private LSTMLayerConfig configuration; @Getter - private LSTMWeights weights; + private LSTMLayerWeights weights; + public LSTMLayer() { } - public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) { - super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast)); + public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) { + super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast)); this.configuration = configuration; this.weights = weights; - addIArgument(configuration.iArgs(true)); - addTArgument(configuration.tArgs()); + addIArgument(iArgs()); + addTArgument(tArgs()); + addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); + + Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), + "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them"); + + } - public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) { + public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMLayerWeights lstmWeights, LSTMLayerConfig LSTMLayerConfig) { super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast)); - this.configuration = lstmConfiguration; + this.configuration = LSTMLayerConfig; this.weights = lstmWeights; - addIArgument(configuration.iArgs(true)); - addTArgument(configuration.tArgs()); + addIArgument(iArgs()); + addTArgument(tArgs()); + addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); + + Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), + "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them"); } @Override public List calculateOutputDataTypes(List inputDataTypes) { - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMLayer, got %s", inputDataTypes); + Preconditions.checkState(inputDataTypes != null && 3 <= inputDataTypes.size() && inputDataTypes.size() <= 8, "Expected amount of inputs to LSTMLayer between 3 inputs minimum (input, Wx, Wr only) or 8 maximum, got %s", inputDataTypes); //7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input DataType dt = inputDataTypes.get(1); + ArrayList list = new ArrayList<>(); + if (configuration.isRetFullSequence()) { + + list.add(dt); + } + + if (configuration.isRetLastC()) { + + list.add(dt); + } + if (configuration.isRetLastH()){ + + list.add(dt); + } + Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); - return Arrays.asList(dt, dt, dt, dt, dt, dt, dt); + return list; } @Override @@ -114,31 +127,61 @@ public class LSTMLayer extends DynamicCustomOp { throw new UnsupportedOperationException("Not yet implemented"); } - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - configuration = LSTMConfiguration.builder() - .forgetBias(attributesForNode.get("forget_bias").getF()) - .clippingCellValue(attributesForNode.get("cell_clip").getF()) - .peepHole(attributesForNode.get("use_peephole").getB()) - .dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM - .build(); - addIArgument(configuration.iArgs(true)); - addTArgument(configuration.tArgs()); - } @Override public String opName() { - return "lstmBlock"; + return "lstmLayer"; } @Override public Map propertiesForFunction() { - return configuration.toProperties(true); + return configuration.toProperties(true, true); + } + + + public long[] iArgs() { + return new long[]{ + configuration.getLstmdataformat().ordinal(),// INT_ARG(0) + configuration.getDirectionMode().ordinal(), // INT_ARG(1) + configuration.getGateAct().ordinal(), // INT_ARG(2) + configuration.getOutAct().ordinal(), // INT_ARG(3) + configuration.getCellAct().ordinal() // INT_ARG(4) + + }; + } + + public double[] tArgs() { + return new double[]{this.configuration.getCellClip()}; // T_ARG(0) + } + + + public boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { + return new boolean[]{ + weights.hasBias(), // hasBiases: B_ARG(0) + maxTSLength != null, // hasSeqLen: B_ARG(1) + yLast != null, // hasInitH: B_ARG(2) + cLast != null, // hasInitC: B_ARG(3) + weights.hasPH(), // hasPH: B_ARG(4) + configuration.isRetFullSequence(), //retFullSequence: B_ARG(5) + configuration.isRetLastH(), // retLastH: B_ARG(6) + configuration.isRetLastC() // retLastC: B_ARG(7) + }; + } @Override - public String tensorflowName() { - return "BlockLSTM"; + public int getNumOutputs(){ + + return Booleans.countTrue( + configuration.isRetFullSequence(), //retFullSequence: B_ARG(5) + configuration.isRetLastH(), // retLastH: B_ARG(6) + configuration.isRetLastC() // retLastC: B_ARG(7) + ); } + + + } + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMActivations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMActivations.java new file mode 100644 index 000000000..27ebbc82f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMActivations.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; + + /** + * integer numbers corresponding to activations: + * 0=tanh, + * 1=relu, + * 2=sigmoid, + * 3=affine, + * 4=leaky relu, + * 5= thresholded relu, + * 6=scaled tanh, + * 7=hard sigmoid, + * 8=ELU, + * 9=softsign, + * 10=softplus + */ + public enum LSTMActivations { + //Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end + + TANH, + RELU, + SIGMOID, + AFFINE, + LEAKY_RELU, + THRESHHOLD_RELU, + SCALED_TAHN, + HARD_SIGMOID, + ELU, + SOFTSIGN, + SOFTPLUS + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDataFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDataFormat.java new file mode 100644 index 000000000..788e87d59 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDataFormat.java @@ -0,0 +1,41 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; + + /** + * notations
+ * for unidirectional: + * TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ * NST: shape [numExamples, inOutSize, timeLength]
+ * NTS: shape [numExamples, timeLength, inOutSize]
+ * for bidirectional: + * T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + */ + + public enum LSTMDataFormat { + //Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end + + + TNS, + NTS, + NST, + T2NS + + } + + + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDirectionMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDirectionMode.java new file mode 100644 index 000000000..c93bc05f9 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDirectionMode.java @@ -0,0 +1,38 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; + +/** + * direction
+ * FWD: 0 = fwd + * BWD: 1 = bwd + * BIDIR_SUM: 2 = bidirectional sum + * BIDIR_CONCAT: 3 = bidirectional concat + * BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */ + +// const auto directionMode = INT_ARG(1); // direction: + +public enum LSTMDirectionMode { + //Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end + + + FWD, + BWD, + BIDIR_SUM, + BIDIR_CONCAT, + BIDIR_EXTRA_DIM + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java new file mode 100644 index 000000000..9901213da --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java @@ -0,0 +1,119 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; + +import lombok.Builder; +import lombok.Data; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; + +import java.util.LinkedHashMap; +import java.util.Map; + + +@Builder +@Data +public class LSTMLayerConfig { + + + /** + * notations
+ * for unidirectional: + * TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ * NST: shape [numExamples, inOutSize, timeLength]
+ * NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout
+ * for bidirectional: + * T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + */ + @Builder.Default + private LSTMDataFormat lstmdataformat = LSTMDataFormat.TNS; //INT_ARG(0) + + + /** + * direction
+ * FWD: 0 = fwd + * BWD: 1 = bwd + * BS: 2 = bidirectional sum + * BC: 3 = bidirectional concat + * BE: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + */ + @Builder.Default + private LSTMDirectionMode directionMode = LSTMDirectionMode.FWD; //INT_ARG(1) + + /** + * Activation for input (i), forget (f) and output (o) gates + */ + @Builder.Default + private LSTMActivations gateAct = LSTMActivations.SIGMOID; // INT_ARG(2) + + @Builder.Default + private LSTMActivations cellAct = LSTMActivations.TANH; // INT_ARG(3) + + @Builder.Default + private LSTMActivations outAct = LSTMActivations.TANH; // INT_ARG(4) + + + + + /** + * indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + */ + @Builder.Default + private boolean retFullSequence = true; //B_ARG(5) + + /** + * indicates whether to return output at last time step only, + * in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + */ + private boolean retLastH; //B_ARG(6) + + /** + * indicates whether to return cells state at last time step only, + * in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + */ + private boolean retLastC; // B_ARG(7) + + /** + * Cell clipping value, if it = 0 then do not apply clipping + */ + @Builder.Default + private double cellClip; //T_ARG(0) + + + public Map toProperties(boolean includeLSTMDataFormat, boolean includeLSTMDirectionMode) { + Map ret = new LinkedHashMap<>(); + ret.put("gateAct", gateAct.ordinal()); + ret.put("outAct", outAct.ordinal()); + ret.put("cellAct", cellAct.ordinal()); + ret.put("retFullSequence", retFullSequence); + ret.put("retLastH", retLastH); + ret.put("retLastC", retLastC); + ret.put("cellClip", cellClip); + + if (includeLSTMDataFormat) + ret.put("LSTMDataFormat", lstmdataformat.ordinal()); + if (includeLSTMDirectionMode) + ret.put("LSTMDirectionMode", directionMode.ordinal()); + return ret; + } + +} + + + + + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java index a01be219f..d8a2e6e9a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java @@ -2,13 +2,18 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; import java.util.Arrays; import java.util.List; + import lombok.AccessLevel; import lombok.Getter; import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; +import org.nd4j.shade.guava.primitives.Booleans; /** * The outputs of a LSTM layer ({@link LSTMLayer}. @@ -16,165 +21,78 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; @Getter public class LSTMLayerOutputs { - private RnnDataFormat dataFormat; + /** + * The LSTM layer data format ({@link LSTMDataFormat}. + */ + private LSTMDataFormat dataFormat; + /** - * Output - input modulation gate activations. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
+ * output h: + * [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + * [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + * [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + * [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + * [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + * [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + * [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + * numbers mean index in corresponding enums {@link LSTMDataFormat} and {@link LSTMDirectionMode} */ - private SDVariable i; + private SDVariable timeSeriesOutput; /** - * Activations, cell state (pre tanh). - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
+ * cell state at last step cL: + * [bS, nOut] when directionMode FWD or BWD + * 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM */ - private SDVariable c; + private SDVariable lastCellStateOutput; /** - * Output - forget gate activations. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
+ * output at last step hL: + * [bS, nOut] when directionMode FWD or BWD + * 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM */ - private SDVariable f; + private SDVariable lastTimeStepOutput; - /** - * Output - output gate activations. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - private SDVariable o; - /** - * Output - input gate activations. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - private SDVariable z; + public LSTMLayerOutputs(SDVariable[] outputs, LSTMLayerConfig lstmLayerConfig) { + Preconditions.checkArgument(outputs.length > 0 && outputs.length <= 3, + "Must have from 1 to 3 LSTM layer outputs, got %s", outputs.length); - /** - * Cell state, post tanh. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - private SDVariable h; + int i = 0; + timeSeriesOutput = lstmLayerConfig.isRetFullSequence() ? outputs[i++] : null; + lastTimeStepOutput = lstmLayerConfig.isRetLastH() ? outputs[i++] : null; + lastCellStateOutput = lstmLayerConfig.isRetLastC() ? outputs[i++] : null; - /** - * Current cell output. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - private SDVariable y; - public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){ - Preconditions.checkArgument(outputs.length == 7, - "Must have 7 LSTM layer outputs, got %s", outputs.length); - - i = outputs[0]; - c = outputs[1]; - f = outputs[2]; - o = outputs[3]; - z = outputs[4]; - h = outputs[5]; - y = outputs[6]; - this.dataFormat = dataFormat; + this.dataFormat = lstmLayerConfig.getLstmdataformat(); } - /** - * Get all outputs returned by the cell. - */ - public List getAllOutputs(){ - return Arrays.asList(i, c, f, o, z, h, y); - } /** - * Get y, the output of the cell for all time steps. - * - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
+ * Get h, the output of the cell for all time steps. + *

+ * Shape depends on data format defined in {@link LSTMLayerConfig }:
+ * for unidirectional: + * TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ * NST: shape [numExamples, inOutSize, timeLength]
+ * NTS: shape [numExamples, timeLength, inOutSize]
+ * for bidirectional: + * T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */ - public SDVariable getOutput(){ - return y; + public SDVariable getOutput() { + Preconditions.checkArgument(timeSeriesOutput != null, "retFullSequence was setted as false in LSTMLayerConfig"); + return timeSeriesOutput; } - /** - * Get c, the cell's state for all time steps. - * - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - public SDVariable getState(){ - return c; + public SDVariable getLastState() { + Preconditions.checkArgument(lastCellStateOutput != null, "retLastC was setted as false in LSTMLayerConfig"); + return lastCellStateOutput; } - private SDVariable lastOutput = null; - - /** - * Get y, the output of the cell, for the last time step. - * - * Has shape [batchSize, numUnits]. - */ - public SDVariable getLastOutput(){ - if(lastOutput != null) - return lastOutput; - - switch (dataFormat){ - case TNS: - lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all()); - break; - case NST: - lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); - break; - case NTS: - lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all()); - break; - } - return lastOutput; + public SDVariable getLastOutput() { + Preconditions.checkArgument(lastTimeStepOutput != null, "retLastH was setted as false in LSTMLayerConfig"); + return lastTimeStepOutput; } - private SDVariable lastState = null; - - /** - * Get c, the state of the cell, for the last time step. - * - * Has shape [batchSize, numUnits]. - */ - public SDVariable getLastState(){ - if(lastState != null) - return lastState; - - switch (dataFormat){ - case TNS: - lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all()); - break; - case NST: - lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); - break; - case NTS: - lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all()); - break; - } - return lastState; - } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java new file mode 100644 index 000000000..98985df57 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java @@ -0,0 +1,99 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; +import org.nd4j.linalg.util.ArrayUtil; + +/** + * The weight configuration of a LSTMLayer. For {@link LSTMLayer} + * @author Alex Black + */ +@EqualsAndHashCode(callSuper = true) +@Data +@Builder +public class LSTMLayerWeights extends RNNWeights { + + /** + * Input to hidden weights with a shape of [inSize, 4*numUnits]. + * + * Input to hidden and hidden to hidden are concatenated in dimension 0, + * so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :]. + */ + private SDVariable weights; + private INDArray iWeights; + + /** + * hidden to hidden weights (aka "recurrent weights", with a shape of [numUnits, 4*numUnits]. + * + */ + private SDVariable rWeights; + private INDArray irWeights; + + /** + * Peephole weights, with a shape of [3*numUnits]. + */ + private SDVariable peepholeWeights; + private INDArray iPeepholeWeights; + + /** + * Input to hidden and hidden to hidden biases, with shape [4*numUnits]. + */ + private SDVariable bias; + private INDArray iBias; + + @Override + public SDVariable[] args() { + return filterNonNull(weights, rWeights, peepholeWeights, bias); + } + + @Override + public INDArray[] arrayArgs() { + return filterNonNull(iWeights, irWeights, iPeepholeWeights, iBias); + } + + @Override + public SDVariable[] argsWithInputs(SDVariable... inputs){ + Preconditions.checkArgument(inputs.length == 4, "Expected 4 inputs, got %s", inputs.length); //Order: x, seqLen, yLast, cLast + //lstmLayer c++ op expects: x, Wx, Wr, Wp, b, seqLen, yLast, cLast + return ArrayUtil.filterNull(inputs[0], weights, rWeights, bias, inputs[1], inputs[2], inputs[3], peepholeWeights); + } + + @Override + public INDArray[] argsWithInputs(INDArray... inputs) { + Preconditions.checkArgument(inputs.length == 4, "Expected 4 inputs, got %s", inputs.length); //Order: x, seqLen, yLast, cLast + //lstmLayer c++ op expects: x, Wx, Wr, Wp, b, seqLen, yLast, cLast + return ArrayUtil.filterNull(inputs[0], iWeights, irWeights, iBias, inputs[1], inputs[2], inputs[3], iPeepholeWeights); + } + + + public boolean hasBias() { + return (bias!=null||iBias!=null); + } + + public boolean hasPH() { + return (peepholeWeights!=null||iPeepholeWeights!=null); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index d22478e71..30ca8ebc5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -98,6 +98,7 @@ public class Mmul extends DynamicCustomOp { addIArgument(ArrayUtil.fromBoolean(transposeX), ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeZ)); + mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); } public Mmul(INDArray x, INDArray y) { @@ -110,6 +111,7 @@ public class Mmul extends DynamicCustomOp { addIArgument(ArrayUtil.fromBoolean(transposeX), ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeZ)); + mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); } public Mmul() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java index 7bccc8035..97fd5e538 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -49,6 +50,9 @@ public class BatchMmul extends DynamicCustomOp { protected int N; protected int K; + public BatchMmul(SameDiff sameDiff, SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) { + this(sameDiff, ArrayUtils.addAll(matricesA, matricesB), transposeA, transposeB); + } public BatchMmul(SameDiff sameDiff, SDVariable[] matrices, @@ -85,6 +89,22 @@ public class BatchMmul extends DynamicCustomOp { addArgs(); } + public BatchMmul(INDArray[] matricesA, INDArray[] matricesB, boolean transposeA, boolean transposeB){ + super(ArrayUtils.addAll(matricesA, matricesB), null); + this.batchSize = matricesA.length; + + this.transposeA = transposeA ? 1 : 0; + this.transposeB = transposeB ? 1 : 0; + + long[] firstShape = matricesA[0].shape(); + long[] lastShape = matricesB[0].shape(); + + this.M = transposeA ? (int) firstShape[1]: (int) firstShape[0]; + this.N = transposeA ? (int) firstShape[0]: (int) firstShape[1]; + this.K = transposeB ? (int) lastShape[0]: (int) lastShape[1]; + addArgs(); + } + @Override public int getNumOutputs(){ return batchSize; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index a239bd9ec..593531098 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -34,17 +34,12 @@ import java.util.List; @NoArgsConstructor public class GatherNd extends DynamicCustomOp { - public GatherNd(SameDiff sameDiff, SDVariable[] inputs, SDVariable[] indices) { - super(null, sameDiff, ArrayUtils.addAll(inputs, indices), false); + public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices) { + super(null, sameDiff, new SDVariable[] {input, indices}); } - public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices, boolean inPlace) { - super(null, sameDiff, new SDVariable[] {input, indices}, inPlace); - } - - public GatherNd(INDArray[] df, INDArray[] indices) { - addInputArgument(df); - addInputArgument(indices); + public GatherNd(INDArray df, INDArray indices) { + super(new INDArray[]{df, indices}, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index 6fca99eae..4bc3b3f63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -16,13 +16,16 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import org.apache.commons.lang3.NotImplementedException; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -41,21 +44,27 @@ public class Linspace extends DynamicCustomOp { private DataType dataType; public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) { - super(sameDiff, new SDVariable[0]); - addTArgument(start,stop); - addIArgument(number); - addDArgument(dataType); + this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType); } public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){ super(sameDiff, new SDVariable[]{from, to, length}); this.dataType = dataType; + addDArgument(dataType); } public Linspace(DataType dataType, double start, double stop, long number) { + this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number)); + } + + public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) { + this(start, stop, number, dataType); + } + + public Linspace(@NonNull INDArray start, @NonNull INDArray stop, @NonNull INDArray number, @NonNull DataType dataType) { + super(new INDArray[]{start, stop, number}, null); + this.dataType = dataType; addDArgument(dataType); - addTArgument(start, stop); - addIArgument(number); } public Linspace(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java index f2c11f1ef..ce83a808c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java @@ -16,9 +16,11 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.ArrayList; @@ -41,6 +43,11 @@ public class MeshGrid extends DynamicCustomOp { this(sd, cartesian, inputs); } + public MeshGrid(@NonNull INDArray[] inputs, boolean cartesian){ + super(inputs, null); + addIArgument(cartesian ? 1 : 0); + } + public MeshGrid(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 6feace53f..2126dfe27 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -44,7 +44,6 @@ import java.util.Map; public class Reshape extends DynamicCustomOp { private long[] shape; - private String arrName; public Reshape(SameDiff sameDiff, SDVariable i_v, long[] shape) { super(null, sameDiff, new SDVariable[]{i_v}); @@ -56,6 +55,12 @@ public class Reshape extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{i_v, shape}); } + public Reshape(INDArray in, long... shape){ + super(new INDArray[]{in}, null); + this.shape = shape; + addIArgument(shape); + } + public Reshape(INDArray in, INDArray shape){ this(in, shape, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index a2ca91c65..3c3baf1f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -64,15 +65,19 @@ public class SequenceMask extends DynamicCustomOp { addDArgument(dataType); } - public SequenceMask(INDArray input, int maxLen, DataType dataType) { + public SequenceMask(@NonNull INDArray input, int maxLen, DataType dataType) { addInputArgument(input); addIArgument(maxLen); this.dataType = dataType; addDArgument(dataType); } - public SequenceMask(INDArray input, DataType dataType) { - addInputArgument(input); + public SequenceMask(@NonNull INDArray input, @NonNull DataType dataType) { + this(input, null, dataType); + } + + public SequenceMask(@NonNull INDArray input, INDArray maxLength, @NonNull DataType dataType) { + super(wrapFilterNull(input, maxLength), null); this.dataType = dataType; addDArgument(dataType); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java index c5f7cdd70..46b8f6286 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java @@ -59,6 +59,10 @@ public class Slice extends DynamicCustomOp { addIArgument(size); } + public Slice(@NonNull INDArray input, @NonNull INDArray begin, @NonNull INDArray end){ + super(new INDArray[]{input, begin, end}, null); + } + @Override public String opName() { return "slice"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index 89c459be3..17a8beb3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -50,7 +50,7 @@ public class Stack extends DynamicCustomOp { addArgs(); } - public Stack(INDArray input, int axis) { + public Stack(INDArray[] input, int axis) { addInputArgument(input); this.jaxis = axis; addArgs(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java index a053403af..456edfe1c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java @@ -98,10 +98,16 @@ public class StridedSlice extends DynamicCustomOp { public StridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { + this(in, ArrayUtil.toLongArray(begin), ArrayUtil.toLongArray(end), ArrayUtil.toLongArray(strides), + beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); + } + + public StridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask, + int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { addInputArgument(in); - this.begin = ArrayUtil.toLongArray(begin); - this.end = ArrayUtil.toLongArray(end); - this.strides = ArrayUtil.toLongArray(strides); + this.begin = begin; + this.end = end; + this.strides = strides; this.beginMask = beginMask; this.endMask = endMask; this.ellipsisMask = ellipsisMask; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java index 6f9e94de0..b8d952b37 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; @@ -67,6 +68,13 @@ public class Unstack extends DynamicCustomOp { addArgs(); } + public Unstack(@NonNull INDArray value, int axis, int num){ + super(new INDArray[]{value}, null); + this.jaxis = axis; + this.num = num; + addArgs(); + } + public Unstack(INDArray in, INDArray[] out, int axis){ super(null, new INDArray[]{in}, out, null, (int[])null); this.jaxis = axis; @@ -136,7 +144,8 @@ public class Unstack extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(sameDiff.stack(jaxis, f1.toArray(new SDVariable[f1.size()]))); + return Collections.singletonList(sameDiff.stack(jaxis, f1.toArray(new SDVariable[0]))); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index 8d0a9d0d6..b7bd0e0f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -58,6 +58,10 @@ public class Pad extends DynamicCustomOp { this(sd, in, padding, Mode.CONSTANT, padValue); } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, double padValue){ + this(in, padding, null, Mode.CONSTANT, padValue); + } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){ super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out}); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index b64581b49..3efc13af0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -66,11 +66,8 @@ public class DynamicPartition extends DynamicCustomOp { addArgs(); } - public DynamicPartition(INDArray input, INDArray[] partitions, int numPartitions) { + public DynamicPartition(INDArray input, INDArray partitions, int numPartitions) { addInputArgument(input); - for (INDArray part : partitions) - addInputArgument(part); - addIArgument(numPartitions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java index d1b2fdfdb..880717fb5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java @@ -16,9 +16,11 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -30,10 +32,14 @@ public class ListDiff extends DynamicCustomOp { // } - public ListDiff(SameDiff sd, SDVariable x, SDVariable y){ + public ListDiff(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){ super(sd, new SDVariable[]{x, y}); } + public ListDiff(@NonNull INDArray x, @NonNull INDArray y){ + super(new INDArray[]{x, y}, null); + } + @Override public String tensorflowName() { return "ListDiff"; //Note: Seems to be renamed to tf.setdiff1d in public API? diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java index 1ae979ec7..7b8dbf209 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java @@ -73,12 +73,8 @@ public class XwPlusB extends DynamicCustomOp { SDVariable dLdOut = gradient.get(0); SDVariable dLdb = dLdOut.sum(0); - SDVariable dLdIn = sameDiff.mmul(dLdOut, w, MMulTranspose.builder() - .transposeB(true) - .build()); - SDVariable dLdW = sameDiff.mmul(in, dLdOut, MMulTranspose.builder() - .transposeA(true) - .build()); + SDVariable dLdIn = sameDiff.mmul(dLdOut, w, false, true, false); + SDVariable dLdW = sameDiff.mmul(in, dLdOut, true, false, false); return Arrays.asList(dLdIn, dLdW, dLdb); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java index a7dd7c2b9..d588ef4a8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java @@ -28,6 +28,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -55,24 +56,11 @@ public class Cast extends BaseDynamicTransformOp { addArgs(); } -/* - @Override - public void setValueFor(Field target, Object value) { - if(value == null) { - throw new ND4JIllegalStateException("Unable to set field " + target + " using null value!"); - } - - // FIXME! - if (!(value instanceof DataType)) - return; - - try { - target.set(this, (DataType) value); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } + public Cast(@NonNull INDArray arg, @NonNull DataType dataType){ + super(new INDArray[]{arg}, null); + this.typeDst = dataType; + addArgs(); } - */ @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java index ade01281c..b024659c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.tensorflow.framework.AttrValue; @@ -73,6 +74,12 @@ public class Range extends DynamicCustomOp { addDArgument(dataType); } + public Range(INDArray from, INDArray to, INDArray step, DataType dataType){ + super(new INDArray[]{from, to, step}, null); + this.dataType = dataType; + addDArgument(dataType); + } + @Override public int opNum() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java index cfaf00d18..64e6a96b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -149,6 +149,60 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions)); } + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param transposeA Whether to transpose A arrays or not + * @param transposeB Whether to transpose B arrays or not + */ + public INDArray[] batchMmul(INDArray[] inputsA, INDArray[] inputsB, boolean transposeA, + boolean transposeB) { + NDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + NDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, transposeA, transposeB)); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + */ + public INDArray[] batchMmul(INDArray[] inputsA, INDArray... inputsB) { + NDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + NDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, false, false)); + } + + /** + * Cast the array to a new datatype - for example, Integer -> Float
+ * + * @param arg Input variable to cast (NDARRAY type) + * @param datatype Datatype to cast to + * @return output Output array (after casting) (NDARRAY type) + */ + public INDArray castTo(INDArray arg, DataType datatype) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(arg, datatype))[0]; + } + /** * Concatenate a set of inputs along the specified dimension.
* Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
@@ -161,7 +215,7 @@ public class NDBase { * @param dimension Dimension to concatenate on * @return output (NUMERIC type) */ - public INDArray concat(INDArray[] inputs, int dimension) { + public INDArray concat(int dimension, INDArray... inputs) { NDValidation.validateNumerical("concat", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); @@ -274,28 +328,26 @@ public class NDBase { * @param x Input variable (NUMERIC type) * @param partitions 1D input with values 0 to numPartitions-1 (INT type) * @param numPartitions Number of partitions, >= 1 - * @return output Output variables (equal in number to numPartitions) (NUMERIC type) */ - public INDArray dynamicPartition(INDArray x, INDArray[] partitions, int numPartitions) { + public INDArray[] dynamicPartition(INDArray x, INDArray partitions, int numPartitions) { NDValidation.validateNumerical("dynamicPartition", "x", x); NDValidation.validateInteger("dynamicPartition", "partitions", partitions); - Preconditions.checkArgument(partitions.length >= 1, "partitions has incorrect size/length. Expected: partitions.length >= 1, got %s", partitions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions)); } /** * Dynamically merge the specified input arrays into a single array, using the specified indices
* - * @param x Input variables. (NUMERIC type) * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @param x Input variables. (NUMERIC type) * @return output Merged output variable (NUMERIC type) */ - public INDArray dynamicStitch(INDArray[] x, INDArray[] indices) { - NDValidation.validateNumerical("dynamicStitch", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + public INDArray dynamicStitch(INDArray[] indices, INDArray... x) { NDValidation.validateInteger("dynamicStitch", "indices", indices); Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(x, indices))[0]; + NDValidation.validateNumerical("dynamicStitch", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0]; } /** @@ -395,11 +447,9 @@ public class NDBase { * @param indices (NUMERIC type) * @return output (NUMERIC type) */ - public INDArray gatherNd(INDArray[] df, INDArray[] indices) { + public INDArray gatherNd(INDArray df, INDArray indices) { NDValidation.validateNumerical("gatherNd", "df", df); - Preconditions.checkArgument(df.length >= 1, "df has incorrect size/length. Expected: df.length >= 1, got %s", df.length); NDValidation.validateNumerical("gatherNd", "indices", indices); - Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0]; } @@ -516,6 +566,23 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; } + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param start Start value (NUMERIC type) + * @param stop Stop value (NUMERIC type) + * @param number Number of values to generate (LONG type) + * @param dataType Data type of the output array + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public INDArray linspace(INDArray start, INDArray stop, INDArray number, DataType dataType) { + NDValidation.validateNumerical("linspace", "start", start); + NDValidation.validateNumerical("linspace", "stop", stop); + NDValidation.validateInteger("linspace", "number", number); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0]; + } + /** * Less than operation: elementwise x < y
* @@ -1071,6 +1138,20 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0]; } + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Permute dimensions (INT type) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public INDArray permute(INDArray x, INDArray dimensions) { + NDValidation.validateNumerical("permute", "x", x); + NDValidation.validateInteger("permute", "dimensions", dimensions); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0]; + } + /** * Array permutation operation: permute the dimensions according to the specified permutation indices.
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
@@ -1141,6 +1222,24 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0]; } + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param from Initial/smallest value (NUMERIC type) + * @param to Largest value (exclusive) (NUMERIC type) + * @param step Step size (NUMERIC type) + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataType) { + NDValidation.validateNumerical("range", "from", from); + NDValidation.validateNumerical("range", "to", to); + NDValidation.validateNumerical("range", "step", step); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0]; + } + /** * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
* @@ -1168,6 +1267,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition)); } + /** + * Element-wise replace where condition:
+ * out[i] = value if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param update Source array (NUMERIC type) + * @param value Value to set at the output, if the condition is satisfied + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public INDArray replaceWhere(INDArray update, double value, Condition condition) { + NDValidation.validateNumerical("replaceWhere", "update", update); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(update, value, condition)); + } + /** * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
* input, but with the specified shape.
@@ -1183,6 +1297,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; } + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray reshape(INDArray x, long... shape) { + NDValidation.validateNumerical("reshape", "x", x); + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; + } + /** * Reverse the values of an array for the specified dimensions
* If input is:
@@ -1532,6 +1661,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; } + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length (INT type) + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataType) { + NDValidation.validateNumerical("sequenceMask", "lengths", lengths); + NDValidation.validateInteger("sequenceMask", "maxLen", maxLen); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + } + /** * see sequenceMask(String, SDVariable, SDVariable, DataType)
* @@ -1601,6 +1745,28 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; } + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (INT type) + * @param size Size of the output array. Must be same length as rank of input array (INT type) + * @return output Subset of the input (NUMERIC type) + */ + public INDArray slice(INDArray input, INDArray begin, INDArray size) { + NDValidation.validateNumerical("slice", "input", input); + NDValidation.validateInteger("slice", "begin", begin); + NDValidation.validateInteger("slice", "size", size); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; + } + /** * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
* @@ -1668,7 +1834,8 @@ public class NDBase { * @param axis Axis to stack on * @return output Output variable (NDARRAY type) */ - public INDArray stack(INDArray values, int axis) { + public INDArray stack(int axis, INDArray... values) { + Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0]; } @@ -1737,7 +1904,7 @@ public class NDBase { * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions * @return output A subset of the input array (NUMERIC type) */ - public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, + public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { NDValidation.validateNumerical("stridedSlice", "in", in); Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); @@ -1762,7 +1929,7 @@ public class NDBase { * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) * @return output A subset of the input array (NUMERIC type) */ - public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int... strides) { + public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long... strides) { NDValidation.validateNumerical("stridedSlice", "in", in); Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); @@ -1999,6 +2166,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0]; } + /** + * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.
+ * If input has shape [a,b,c] then output has shape:
+ * axis = 0: [b,c]
+ * axis = 1: [a,c]
+ * axis = 2: [a,b]
+ * + * @param value Input variable to unstack (NDARRAY type) + * @param axis Axis to unstack on + * @param num Number of output variables + */ + public INDArray[] unstack(INDArray value, int axis, int num) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Unstack(value, axis, num)); + } + /** * Variance array reduction operation, optionally along specified dimensions
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java index cb00a28c2..1e3c89111 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.DataFormat; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; @@ -32,7 +33,6 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.enums.DataFormat; public class NDCNN { public NDCNN() { @@ -370,6 +370,18 @@ public class NDCNN { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(input, LocalResponseNormalizationConfig))[0]; } + /** + * 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices
+ * + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + */ + public INDArray[] maxPoolWithArgmax(INDArray input, Pooling2DConfig Pooling2DConfig) { + NDValidation.validateNumerical("maxPoolWithArgmax", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(input, Pooling2DConfig)); + } + /** * 2D Convolution layer operation - max pooling 2d
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java index cdee59ea1..184f3edea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java @@ -222,15 +222,12 @@ public class NDLoss { * * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param epsilon epsilon * @return output Log loss (NUMERIC type) */ - public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, double epsilon) { + public INDArray logLoss(INDArray label, INDArray predictions) { NDValidation.validateNumerical("logLoss", "label", label); NDValidation.validateNumerical("logLoss", "predictions", predictions); - NDValidation.validateNumerical("logLoss", "weights", weights); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, epsilon))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index eddbe3db7..bee0da889 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -190,6 +190,58 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(x)); } + /** + * Bit shift operation
+ * + * @param x input (NUMERIC type) + * @param shift shift value (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public INDArray bitShift(INDArray x, INDArray shift) { + NDValidation.validateNumerical("bitShift", "x", x); + NDValidation.validateNumerical("bitShift", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, shift))[0]; + } + + /** + * Right bit shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift shift argument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public INDArray bitShiftRight(INDArray x, INDArray shift) { + NDValidation.validateNumerical("bitShiftRight", "x", x); + NDValidation.validateNumerical("bitShiftRight", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, shift))[0]; + } + + /** + * Cyclic bit shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift shift argy=ument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public INDArray bitShiftRotl(INDArray x, INDArray shift) { + NDValidation.validateNumerical("bitShiftRotl", "x", x); + NDValidation.validateNumerical("bitShiftRotl", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, shift))[0]; + } + + /** + * Cyclic right shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift Shift argument (NUMERIC type) + * @return output Shifted output (NUMERIC type) + */ + public INDArray bitShiftRotr(INDArray x, INDArray shift) { + NDValidation.validateNumerical("bitShiftRotr", "x", x); + NDValidation.validateNumerical("bitShiftRotr", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, shift))[0]; + } + /** * Element-wise ceiling function: out = ceil(x).
* Rounds each value up to the nearest integer value (if not already an integer)
@@ -346,13 +398,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray cosineDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("cosineDistance", "x", x); NDValidation.validateNumerical("cosineDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(x, y, dimensions)); } @@ -363,13 +415,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray cosineSimilarity(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("cosineSimilarity", "x", x); NDValidation.validateNumerical("cosineSimilarity", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(x, y, dimensions)); } @@ -501,13 +553,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray euclideanDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("euclideanDistance", "x", x); NDValidation.validateNumerical("euclideanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(x, y, dimensions)); } @@ -665,13 +717,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray hammingDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("hammingDistance", "x", x); NDValidation.validateNumerical("hammingDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(x, y, dimensions)); } @@ -817,13 +869,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray jaccardDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("jaccardDistance", "x", x); NDValidation.validateNumerical("jaccardDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(x, y, dimensions)); } @@ -872,6 +924,18 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, keepDims, condition, dimensions)); } + /** + * Calculates difference between inputs X and Y.
+ * + * @param x Input variable X (NUMERIC type) + * @param y Input variable Y (NUMERIC type) + */ + public INDArray[] listDiff(INDArray x, INDArray y) { + NDValidation.validateNumerical("listDiff", "x", x); + NDValidation.validateNumerical("listDiff", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(x, y)); + } + /** * Element-wise logarithm function (base e - natural logarithm): out = log(x)
* @@ -940,13 +1004,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray manhattanDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("manhattanDistance", "x", x); NDValidation.validateNumerical("manhattanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(x, y, dimensions)); } @@ -983,7 +1047,7 @@ public class NDMath { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public INDArray mergeAdd(INDArray[] inputs) { + public INDArray mergeAdd(INDArray... inputs) { NDValidation.validateNumerical("mergeAdd", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(inputs))[0]; @@ -996,7 +1060,7 @@ public class NDMath { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public INDArray mergeAvg(INDArray[] inputs) { + public INDArray mergeAvg(INDArray... inputs) { NDValidation.validateNumerical("mergeAvg", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(inputs))[0]; @@ -1009,12 +1073,24 @@ public class NDMath { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public INDArray mergeMax(INDArray[] inputs) { + public INDArray mergeMax(INDArray... inputs) { NDValidation.validateNumerical("mergeMax", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMax(inputs))[0]; } + /** + * Broadcasts parameters for evaluation on an N-D grid.
+ * + * @param inputs (NUMERIC type) + * @param cartesian + */ + public INDArray[] meshgrid(INDArray[] inputs, boolean cartesian) { + NDValidation.validateNumerical("meshgrid", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(inputs, cartesian)); + } + /** * Calculate the mean and (population) variance for the input variable, for the specified axis
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 04a713ecf..3f9e1431a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -237,12 +237,11 @@ public class NDNN { * Alpha value is most commonly set to 0.01
* * @param x Input variable (NUMERIC type) - * @param alpha Cutoff - commonly 0.01 (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ - public INDArray leakyRelu(INDArray x, INDArray alpha) { + public INDArray leakyRelu(INDArray x, double alpha) { NDValidation.validateNumerical("leakyRelu", "x", x); - NDValidation.validateNumerical("leakyRelu", "alpha", alpha); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha)); } @@ -250,12 +249,11 @@ public class NDNN { * Leaky ReLU derivative: dOut/dIn given input.
* * @param x Input variable (NUMERIC type) - * @param alpha Cutoff - commonly 0.01 (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ - public INDArray leakyReluDerivative(INDArray x, INDArray alpha) { + public INDArray leakyReluDerivative(INDArray x, double alpha) { NDValidation.validateNumerical("leakyReluDerivative", "x", x); - NDValidation.validateNumerical("leakyReluDerivative", "alpha", alpha); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha)); } @@ -346,6 +344,20 @@ public class NDNN { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0]; } + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public INDArray pad(INDArray input, INDArray padding, double constant) { + NDValidation.validateNumerical("pad", "input", input); + NDValidation.validateNumerical("pad", "padding", padding); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, constant))[0]; + } + /** * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
* out[i] = in[i] if in[i] >= 0
@@ -461,6 +473,17 @@ public class NDNN { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0]; } + /** + * Softmax activation, along the specified dimension
+ * + * @param x Input (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray softmax(INDArray x) { + NDValidation.validateNumerical("softmax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, -1))[0]; + } + /** * Softmax derivative function
* @@ -519,4 +542,15 @@ public class NDNN { NDValidation.validateNumerical("swish", "x", x); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x)); } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray tanh(INDArray x) { + NDValidation.validateNumerical("tanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(x)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java index 0587aeda5..9bb7d9640 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java @@ -22,7 +22,9 @@ import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.nd4j.linalg.factory.NDValidation; @@ -38,12 +40,11 @@ public class NDRNN { * @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object - * @return output The cell's outputs. (NUMERIC type) */ - public INDArray gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) { + public INDArray[] gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) { NDValidation.validateNumerical("gru", "x", x); NDValidation.validateNumerical("gru", "hLast", hLast); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights)); } /** @@ -54,18 +55,83 @@ public class NDRNN { * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) * @param LSTMWeights Configuration Object * @param LSTMConfiguration Configuration Object - * @return output The cell's outputs (NUMERIC type) */ - public INDArray lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights, + public INDArray[] lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { NDValidation.validateNumerical("lstmCell", "x", x); NDValidation.validateNumerical("lstmCell", "cLast", cLast); NDValidation.validateNumerical("lstmCell", "yLast", yLast); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration)); } /** - * The LSTM layer. Does multiple time steps.
+ * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public INDArray[] lstmLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, + LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) { + NDValidation.validateNumerical("lstmLayer", "x", x); + NDValidation.validateNumerical("lstmLayer", "cLast", cLast); + NDValidation.validateNumerical("lstmLayer", "yLast", yLast); + NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig)); + } + + /** + * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public INDArray[] lstmLayer(INDArray x, LSTMLayerWeights LSTMLayerWeights, + LSTMLayerConfig LSTMLayerConfig) { + NDValidation.validateNumerical("lstmLayer", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(x, null, null, null, LSTMLayerWeights, LSTMLayerConfig)); + } + + /** + * The LSTM block
* * @param maxTSLength (NUMERIC type) * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) @@ -75,13 +141,27 @@ public class NDRNN { * @param LSTMConfiguration Configuration Object * @return output The layer's outputs. (NUMERIC type) */ - public INDArray lstmLayer(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast, + public INDArray lstmblock(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { - NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); - NDValidation.validateNumerical("lstmLayer", "x", x); - NDValidation.validateNumerical("lstmLayer", "cLast", cLast); - NDValidation.validateNumerical("lstmLayer", "yLast", yLast); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; + NDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength); + NDValidation.validateNumerical("lstmblock", "x", x); + NDValidation.validateNumerical("lstmblock", "cLast", cLast); + NDValidation.validateNumerical("lstmblock", "yLast", yLast); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; + } + + /** + * The LSTM block
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public INDArray lstmblock(INDArray x, LSTMWeights LSTMWeights, + LSTMConfiguration LSTMConfiguration) { + NDValidation.validateNumerical("lstmblock", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(null, x, null, null, LSTMWeights, LSTMConfiguration))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index a6ccd25ed..bda208ce7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -199,7 +199,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return op.z(); } @@ -436,7 +436,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return op.z(); } @@ -524,7 +524,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { long st = profilingConfigurableHookIn(op); naiveExec(op, dimension); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return op.z(); } @@ -607,7 +607,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return op.z(); } @@ -772,7 +772,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return null; } @@ -863,7 +863,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return null; @@ -1113,7 +1113,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); Nd4j.getExecutioner().commit(); @@ -1200,7 +1200,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return null; } @@ -1296,7 +1296,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return null; } @@ -1460,7 +1460,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (ret != null) ret.elementWiseStride(); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return null; } @@ -1579,7 +1579,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return z; } @@ -2292,7 +2292,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray[] exec(CustomOp op, OpContext context) { - long st = profilingConfigurableHookIn(op); + long st = profilingConfigurableHookIn(op, context); val ctx = AtomicAllocator.getInstance().getDeviceContext(); ((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation()); @@ -2304,7 +2304,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (status != 0) throw new RuntimeException("Op [" + op.opName() + "] execution failed"); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, context, st); if (context.getOutputArrays().isEmpty()) return new INDArray[0]; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 7a29f71d7..f0488636f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -236,7 +236,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return getZ(op, oc); } @@ -690,7 +690,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return getZ(op, oc); } @@ -774,7 +774,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (z == null) setZ(Nd4j.create(op.resultType(), x.shape()), op, oc); -// op.setZ(Nd4j.create(op.resultType(), op.x().shape())); op.validateDataTypes(oc, experimentalMode.get()); @@ -884,7 +883,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); } public INDArray exec(BroadcastOp op) { @@ -1306,7 +1305,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return z; } @@ -2040,7 +2039,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray[] exec(CustomOp op, @NonNull OpContext context) { - long st = profilingConfigurableHookIn(op); + long st = profilingConfigurableHookIn(op, context); boolean mklOverride = false; try { if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) { @@ -2125,7 +2124,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } finally { if (mklOverride) Nd4jCpu.Environment.getInstance().setUseMKLDNN(true); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, context, st); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index eab974821..794348369 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -20,8 +20,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; + import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; @@ -36,6 +38,12 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; import org.nd4j.linalg.factory.Nd4j; @@ -257,7 +265,7 @@ public class LayerOpValidation extends BaseOpValidation { msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW); inSize = inSizeNCHW; in = sd.var("in", inSize); - out = sd.cnn().upsampling2d(in, 2, 2, true); + out = sd.cnn().upsampling2d(in, 2, 2, true); break; default: throw new RuntimeException(); @@ -578,8 +586,6 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable dW = sd.var("dW", depthWeightArr); SDVariable b = sd.var("b", bArr); - SDVariable[] vars = new SDVariable[]{in, dW, b}; - Conv2DConfig c = Conv2DConfig.builder() .kH(kH).kW(kW) .pH(0).pW(0) @@ -588,8 +594,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().separableConv2d(in, dW, b, c); - out = sd.f().tanh(out); + SDVariable out = sd.cnn().separableConv2d(in, dW, null, b, c); + out = sd.nn().tanh("out", out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 @@ -623,8 +629,6 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable pW = sd.var("pW", pointWeightArr); SDVariable b = sd.var("b", bArr); - //SDVariable[] vars = new SDVariable[]{in, dW, pW, b}; - Conv2DConfig c = Conv2DConfig.builder() .kH(kH).kW(kW) .pH(0).pW(0) @@ -635,7 +639,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().separableConv2d(in, dW, pW, b, c); - out = sd.nn().tanh(out); + out = sd.nn().tanh("out", out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7 @@ -675,8 +679,6 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable w = sd.var("W", wArr); SDVariable b = sd.var("b", bArr); - SDVariable[] vars = new SDVariable[]{in, w, b}; - DeConv2DConfig deconv = DeConv2DConfig.builder() .kH(kH).kW(kW) .pH(0).pW(0) @@ -685,8 +687,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.f().deconv2d(vars, deconv); - out = sd.f().tanh(out); + SDVariable out = sd.cnn().deconv2d(in, w, b, deconv); + out = sd.nn().tanh("out", out); INDArray outArr = out.eval(); //Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9 @@ -723,7 +725,6 @@ public class LayerOpValidation extends BaseOpValidation { //Order: https://github.com/deeplearning4j/libnd4j/blob/6c41ea5528bb1f454e92a9da971de87b93ff521f/include/ops/declarable/generic/convo/conv2d.cpp#L20-L22 //in, w, b - bias is optional - SDVariable[] vars = new SDVariable[]{in, w, b}; Conv2DConfig c = Conv2DConfig.builder() .kH(kH).kW(kW) @@ -733,8 +734,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.f().conv2d(vars, c); - out = sd.f().tanh(out); + SDVariable out = sd.cnn().conv2d("conv", in, w, b, c); + out = sd.nn().tanh("out", out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 @@ -767,7 +768,7 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(true) .build(); - SDVariable[] results = sd.f().maxPoolWithArgmax(/*new String[]{"out","idx"},*/ in, pooling2DConfig); + SDVariable[] results = sd.cnn().maxPoolWithArgmax(new String[]{"out", "idx"}, in, pooling2DConfig); assertArrayEquals(inArr.shape(), results[0].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape()); } @@ -797,7 +798,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig); - SDVariable out = sd.f().tanh(/*"out",*/ outPool); + SDVariable out = sd.nn().tanh("out", outPool); INDArray outArr = out.eval(); val outShape = outArr.shape(); @@ -855,7 +856,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig); - SDVariable out = sd.f().tanh(/*"out",*/ outPool); + SDVariable out = sd.nn().tanh("out", outPool); INDArray outArr = out.eval(); val outShape = outArr.shape(); @@ -906,7 +907,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig); - out = sd.f().tanh(/*"loss", */out).shape().rename("out"); + out = sd.nn().tanh("loss", out).shape().rename("out"); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L); @@ -942,7 +943,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig); - out = sd.math().tanh("loss", out).shape().rename("out"); + out = sd.nn().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -976,8 +977,8 @@ public class LayerOpValidation extends BaseOpValidation { .paddingMode(PaddingMode.VALID) .build(); - SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); - out = sd.math().tanh("loss", out).shape().rename("out"); + SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); + out = sd.nn().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -996,7 +997,7 @@ public class LayerOpValidation extends BaseOpValidation { int nOut = 4; int mb = 2; - for( int k : new int[]{2, 3}) { + for (int k : new int[]{2, 3}) { for (int sz : new int[]{3, 4, 5}) { for (int s : new int[]{1, 2}) { for (int d : new int[]{1, 2}) { @@ -1018,7 +1019,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig); - SDVariable loss = sd.f().tanh(out).std(true).rename("loss"); + SDVariable loss = sd.nn().tanh(out).std(true).rename("loss"); sd.setLossVariables("loss"); @@ -1039,7 +1040,7 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void testConv1dForward(){ + public void testConv1dForward() { int nIn = 2; int nOut = 1; int kernel = 3; @@ -1057,7 +1058,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); SDVariable w = sd.var("w", wArr); - SDVariable res = sd.cnn.conv1d(in, w, null, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); + SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); INDArray expected = Nd4j.createFromArray( new double[][][]{ @@ -1113,7 +1114,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig); - out = sd.math().tanh("loss", out).shape().rename("out"); + out = sd.nn().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1156,7 +1157,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig); - out = sd.math().tanh("loss", out).shape().rename("out"); + out = sd.nn().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1201,13 +1202,13 @@ public class LayerOpValidation extends BaseOpValidation { public void testLayerNorm4d() { int mb = 3; int ch = 4; - for(boolean nchw : new boolean[]{true, false}) { + for (boolean nchw : new boolean[]{true, false}) { double eps = 0.0; INDArray x = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch}); INDArray gain4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); INDArray bias4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); INDArray mean = x.mean(true, 1, 2, 3); - INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1); + INDArray std = Transforms.sqrt(x.var(false, 1, 2, 3).addi(eps)).reshape(mb, 1, 1, 1); INDArray standardized = x.sub(mean).div(std); INDArray exp = standardized.mul(gain4d).add(bias4d); @@ -1274,7 +1275,7 @@ public class LayerOpValidation extends BaseOpValidation { final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(DataType.DOUBLE,4); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain); final INDArray output = Nd4j.zerosLike(res); @@ -1287,7 +1288,7 @@ public class LayerOpValidation extends BaseOpValidation { public void testLayerNormNoDeviation() { final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); for (int i = 0; i < 4; i++) { - random.putScalar(1,i, 7); + random.putScalar(1, i, 7); } final INDArray standardized = random.ulike(); @@ -1335,7 +1336,7 @@ public class LayerOpValidation extends BaseOpValidation { .paddingMode(PaddingMode.VALID) .build(); - SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); + SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); } @@ -1391,16 +1392,16 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testLayerNormMixedOrders(){ + public void testLayerNormMixedOrders() { Nd4j.getRandom().setSeed(12345); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); INDArray bias = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); - INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); - INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); - INDArray outFC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); - INDArray outCF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); + INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'f'); + INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'c'); + INDArray outFC = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'c'); + INDArray outCF = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'f'); //F in, F out case Nd4j.exec(DynamicCustomOp.builder("layer_norm") @@ -1441,11 +1442,11 @@ public class LayerOpValidation extends BaseOpValidation { public void testBiasAdd_nchw_nhwc() { Nd4j.getRandom().setSeed(12345); - for(boolean nchw : new boolean[]{true, false}) { + for (boolean nchw : new boolean[]{true, false}) { log.info("Starting test: {}", nchw ? "nchw" : "nhwc"); SameDiff sameDiff = SameDiff.create(); - SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4})); + SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2, 4, 3, 3} : new long[]{2, 3, 3, 4})); SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4})); SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw); @@ -1453,10 +1454,10 @@ public class LayerOpValidation extends BaseOpValidation { INDArray exp = in.getArr().dup(); - if(nchw){ - exp.addi(b.getArr().reshape(1,4,1,1)); + if (nchw) { + exp.addi(b.getArr().reshape(1, 4, 1, 1)); } else { - exp.addi(b.getArr().reshape(1,1,1,4)); + exp.addi(b.getArr().reshape(1, 1, 1, 4)); } TestCase tc = new TestCase(sameDiff) @@ -1467,4 +1468,168 @@ public class LayerOpValidation extends BaseOpValidation { assertNull(err); } } + + + @Test + public void LSTMLayerTestCase1() { + + int bS = 5; + int nIn = 3; + int numUnits = 7; + int sL = 10; //small just for test + + SameDiff sd = SameDiff.create(); + + // notations: + // bS - batch size, numExamples + // sL - sequence length, number of time steps, timeLength + // nIn - input size, inOutSize + + // TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ // NST: shape [numExamples, inOutSize, timeLength]
+ // NTS: shape [numExamples, timeLength, inOutSize]
+ // for bidirectional: + // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + + + SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, nIn, sL)); + + + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.NST) + .directionMode(LSTMDirectionMode.FWD) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.TANH) + .outAct(LSTMActivations.TANH) + .retFullSequence(true) + .retLastC(true) + .retLastH(true) + .build(); + + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) + .peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits))) + .bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))).build(), + c), c); + + long[] out = new long[]{bS, numUnits, sL}; + long[] hL = new long[]{bS, numUnits}; + long[] cL = new long[]{bS, numUnits}; + + assertArrayEquals(out, outputs.getOutput().eval().shape()); + assertArrayEquals(hL, outputs.getLastTimeStepOutput().eval().shape()); + assertArrayEquals(cL, outputs.getLastCellStateOutput().eval().shape()); + + + } + + + @Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824 + public void LSTMLayerTestCase2() { + int bS = 5; + int nIn = 3; + int numUnits = 7; + int sL = 10; //small just for test + + SameDiff sd = SameDiff.create(); + + // notations: + // bS - batch size, numExamples + // sL - sequence length, number of time steps, timeLength + // nIn - input size, inOutSize + + // TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ // NST: shape [numExamples, inOutSize, timeLength]
+ // NTS: shape [numExamples, timeLength, inOutSize]
+ // for bidirectional: + // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, sL, bS, nIn)); + + + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.TNS) + .directionMode(LSTMDirectionMode.FWD) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.TANH) + .outAct(LSTMActivations.TANH) + .retFullSequence(true) + .retLastC(false) + .retLastH(false) + .build(); + + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) + .build(), + c), c); + + + long[] out = new long[]{sL, bS, numUnits}; + assertArrayEquals(out, outputs.getOutput().eval().shape()); + + } + + @Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824 + public void LSTMLayerTestCase3() { + int bS = 5; + int nIn = 3; + int numUnits = 7; + int sL = 10; //small just for test + + SameDiff sd = SameDiff.create(); + + // notations: + // bS - batch size, numExamples + // sL - sequence length, number of time steps, timeLength + // nIn - input size, inOutSize + + // TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ // NST: shape [numExamples, inOutSize, timeLength]
+ // NTS: shape [numExamples, timeLength, inOutSize]
+ // for bidirectional: + // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, sL, nIn)); + + + // when directionMode >= 2 (BIDIR_CONCAT=3) + // Wx, Wr [2, nIn, 4*nOut] + // hI, cI [2, bS, nOut] + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits)); + + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.NTS) + .directionMode(LSTMDirectionMode.BIDIR_CONCAT) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.SOFTPLUS) + .outAct(LSTMActivations.SOFTPLUS) + .retFullSequence(true) + .retLastC(false) + .retLastH(false) + .build(); + + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(new String[]{"out"}, + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, 2, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, 2, numUnits, 4 * numUnits))) + .build(), + c), c); + + + long[] out = new long[]{bS, sL, 2 * numUnits}; + + assertArrayEquals(out, outputs.getOutput().eval().shape()); + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 3998bc184..47f383f52 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -548,7 +548,7 @@ public class MiscOpValidation extends BaseOpValidation { INDArray arr2 = Nd4j.rand(new long[]{2, 2, 2}); SDVariable x = sameDiff.var("x", arr); SDVariable y = sameDiff.var("y", arr2); - SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}}); + SDVariable result = sameDiff.tensorMmul(x, y, new int[]{0}, new int[]{1}); assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}), result.eval().shape()); assertEquals(16, sameDiff.numElements()); @@ -689,13 +689,7 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable a = sd.var("a", aArr); SDVariable b = sd.var("b", bArr); - MMulTranspose mt = MMulTranspose.builder() - .transposeA(transposeA) - .transposeB(transposeB) - .transposeResult(transposeResult) - .build(); - - SDVariable mmul = sd.mmul(a, b, mt); + SDVariable mmul = sd.mmul(a, b, transposeA, transposeB, transposeResult); INDArray exp = (transposeA ? aArr.transpose() : aArr); exp = exp.mmul(transposeB ? bArr.transpose() : bArr); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index 69385f814..c47d02b04 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -70,7 +70,7 @@ public class RnnOpValidation extends BaseOpValidation { LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); - LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y + LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf)); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); for(SDVariable sdv : v.getAllOutputs()){ toExec.add(sdv.name()); @@ -173,7 +173,7 @@ public class RnnOpValidation extends BaseOpValidation { LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); - LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y + LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf)); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); for(SDVariable sdv : v.getAllOutputs()){ toExec.add(sdv.name()); @@ -227,7 +227,7 @@ public class RnnOpValidation extends BaseOpValidation { .cBias(bc) .build(); - List v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs(); + SDVariable[] v = sd.rnn().gru(x, hLast, weights); List toExec = new ArrayList<>(); for(SDVariable sdv : v){ toExec.add(sdv.name()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 0cbe52479..47394de1e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -119,7 +119,7 @@ public class ShapeOpValidation extends BaseOpValidation { List failed = new ArrayList<>(); - for (int[] toShape : new int[][]{{3, 4 * 5}, {3 * 4, 5}, {1, 3 * 4 * 5}, {3 * 4 * 5, 1}}) { + for (long[] toShape : new long[][]{{3, 4 * 5}, {3 * 4, 5}, {1, 3 * 4 * 5}, {3 * 4 * 5, 1}}) { for(char order : new char[]{'c','f'}){ INDArray inArr = Nd4j.rand(DataType.DOUBLE, origShape, order).muli(100); @@ -388,10 +388,10 @@ public class ShapeOpValidation extends BaseOpValidation { @Builder(builderClassName = "Builder") @Data private static class SSCase { - private int[] shape; - private int[] begin; - private int[] end; - private int[] strides; + private long[] shape; + private long[] begin; + private long[] end; + private long[] strides; private int beginMask; private int endMask; private int ellipsisMask; @@ -400,22 +400,22 @@ public class ShapeOpValidation extends BaseOpValidation { public static class Builder { - public Builder shape(int... shape) { + public Builder shape(long... shape) { this.shape = shape; return this; } - public Builder begin(int... begin) { + public Builder begin(long... begin) { this.begin = begin; return this; } - public Builder end(int... end) { + public Builder end(long... end) { this.end = end; return this; } - public Builder strides(int... strides) { + public Builder strides(long... strides) { this.strides = strides; return this; } @@ -1571,7 +1571,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2); SDVariable x1 = sameDiff.var("x1", arr1); SDVariable x2 = sameDiff.var("x2", arr2); - SDVariable result = sameDiff.parallel_stack(new SDVariable[]{x1, x2}); + SDVariable result = sameDiff.stack(0, new SDVariable[]{x1, x2}); assertArrayEquals(new long[]{2, 3, 2}, result.eval().shape()); assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval()); } @@ -1661,9 +1661,9 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice_full = sd.stridedSlice(in, new int[]{0, 0}, new int[]{3, 4}, new int[]{1, 1}); - SDVariable subPart = sd.stridedSlice(in, new int[]{1, 2}, new int[]{3, 4}, new int[]{1, 1}); - // SDVariable subPart2 = sd.stridedSlice(in, new int[]{0, 0}, new int[]{4, 5}, new int[]{2, 2}); + SDVariable slice_full = sd.stridedSlice(in,new long[]{0, 0},new long[]{3, 4},new long[]{1, 1}); + SDVariable subPart = sd.stridedSlice(in,new long[]{1, 2},new long[]{3, 4},new long[]{1, 1}); + // SDVariable subPart2 = sd.stridedSlice(in,new long[]{0, 0},new long[]{4, 5},new long[]{2, 2}); sd.outputAll(null); @@ -1679,8 +1679,8 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice1 = sd.stridedSlice(in, new int[]{-999, 0}, new int[]{2, 4}, new int[]{1, 1}, 1 << 1, 0, 0, 0, 0); - SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 0}, new int[]{-999, 4}, new int[]{1, 1}, 0, 1, 0, 0, 0); + SDVariable slice1 = sd.stridedSlice(in,new long[]{-999, 0},new long[]{2, 4},new long[]{1, 1}, 1 << 1, 0, 0, 0, 0); + SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 0},new long[]{-999, 4},new long[]{1, 1}, 0, 1, 0, 0, 0); sd.outputAll(null); @@ -1695,9 +1695,9 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); //[1:3,...] -> [1:3,:,:] - SDVariable slice = sd.stridedSlice(in, new int[]{1}, new int[]{3}, new int[]{1}, 0, 0, 1 << 1, 0, 0); + SDVariable slice = sd.stridedSlice(in,new long[]{1},new long[]{3},new long[]{1}, 0, 0, 1 << 1, 0, 0); //[1:3,...,1:4] -> [1:3,:,1:4] - SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 1}, new int[]{3, 4}, new int[]{1, 1}, 0, 0, 1 << 1, 0, 0); + SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 1},new long[]{3, 4},new long[]{1, 1}, 0, 0, 1 << 1, 0, 0); sd.outputAll(Collections.emptyMap()); @@ -1710,7 +1710,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice = sd.stridedSlice(in, new int[]{-999, 0, 0, 0}, new int[]{-999, 3, 4, 5}, new int[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0); + SDVariable slice = sd.stridedSlice(in,new long[]{-999, 0, 0, 0},new long[]{-999, 3, 4, 5},new long[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0); INDArray out = slice.eval(); @@ -1723,7 +1723,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice = sd.stridedSlice(in, new int[]{1, 1, -999, 1}, new int[]{3, 3, -999, 4}, new int[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0); + SDVariable slice = sd.stridedSlice(in,new long[]{1, 1, -999, 1},new long[]{3, 3, -999, 4},new long[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0); INDArray out = slice.eval(); assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape()); @@ -1735,9 +1735,9 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice = sd.stridedSlice(in, new int[]{0, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1); - SDVariable slice2 = sd.stridedSlice(in, new int[]{2, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1); - SDVariable slice3 = sd.stridedSlice(in, new int[]{1, 2, 1}, new int[]{-999, -999, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1); + SDVariable slice = sd.stridedSlice(in,new long[]{0, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1); + SDVariable slice2 = sd.stridedSlice(in,new long[]{2, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1); + SDVariable slice3 = sd.stridedSlice(in,new long[]{1, 2, 1},new long[]{-999, -999, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1); sd.outputAll(null); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 9be66f484..27a15b517 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -1920,7 +1920,7 @@ public class TransformOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable sdA = sd.var("a", a); SDVariable sdB = sd.var("b", b); - SDVariable t = sd.mmul(sdA, sdB, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).transposeResult(transposeResult).build()); + SDVariable t = sd.mmul(sdA, sdB, transposeA, transposeB, transposeResult); t.norm1("out"); String err = OpValidation.validate(new TestCase(sd) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 88915e35b..3e33534e1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -759,8 +759,7 @@ public class SameDiffTests extends BaseNd4jTest { val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1); val input1 = sd.var("input", matrix); val input2 = sd.var("input2", vector); - val output = sd - .mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build()); + val output = sd.mmul("output", input1, input2, true, false, false); INDArray out = output.eval(); assertArrayEquals(new long[]{3, 1}, out.shape()); } @@ -2675,7 +2674,7 @@ public class SameDiffTests extends BaseNd4jTest { final long timeSteps = sdInput.getShape()[2]; SDVariable[] outputSlices = new SDVariable[(int) timeSteps]; - final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2); + final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2, 2); final val x_0 = inputSlices[0]; outputSlices[0] = x_0; @@ -2702,7 +2701,7 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sd = SameDiff.create(); final SDVariable sdInput = sd.var("input", input); - final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2); + final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2, 2); final val temp = inputSlices[0].add(inputSlices[1]).div(inputSlices[1]).mul(inputSlices[0]); final val out = temp.add(temp).add(inputSlices[1]); out.norm2("out"); @@ -3242,61 +3241,61 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testNestedIf() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable a = SD.var("a", Nd4j.createFromArray(2.0)); - SDVariable b = SD.var("b", Nd4j.createFromArray(5.0)); - SDVariable c = SD.var("c", Nd4j.createFromArray(9.0)); - SDVariable d = SD.var("d", Nd4j.createFromArray(-7.0)); + SameDiff sd = SameDiff.create(); + SDVariable a = sd.var("a", Nd4j.createFromArray(2.0)); + SDVariable b = sd.var("b", Nd4j.createFromArray(5.0)); + SDVariable c = sd.var("c", Nd4j.createFromArray(9.0)); + SDVariable d = sd.var("d", Nd4j.createFromArray(-7.0)); - SDVariable output = SD.ifCond("out", null, - (sd) -> a.lt(b), - (sd) -> sd.ifCond( + SDVariable output = sd.ifCond("out", null, + (s) -> a.lt(b), + (s) -> s.ifCond( (sd2) -> d.lte(0), (sd2) -> c.add(1), (sd2) -> d), - (sd) -> c.add(5)); + (s) -> c.add(5)); INDArray out = output.eval(); assertEquals(Nd4j.createFromArray(10.0), out); - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false)); - assertEquals(Nd4j.createFromArray(10.0), SD.output(Collections.emptyMap(), "out").get("out")); + assertEquals(Nd4j.createFromArray(10.0), sd.output(Collections.emptyMap(), "out").get("out")); } @Test public void testWhile() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable countIn = SD.constant(5); - SDVariable sumIn = SD.constant(0); + SameDiff sd = SameDiff.create(); + SDVariable countIn = sd.constant(5); + SDVariable sumIn = sd.constant(0); - SDVariable[] sum = SD.whileLoop("while_1", new SDVariable[]{countIn, sumIn}, - (sd, vars) -> vars[0].gt(0), - (sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])}); + SDVariable[] sum = sd.whileLoop("while_1", new SDVariable[]{countIn, sumIn}, + (s, vars) -> vars[0].gt(0), + (s, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])}); INDArray out = sum[1].eval(); assertEquals(15, out.getInt(0)); String outName = sum[1].name(); - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false)); - assertEquals(15, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); + assertEquals(15, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test @Ignore public void testNestedWhile() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable countIn = SD.constant(5); - SDVariable sumIn = SD.constant(0); - SDVariable sum2 = SD.constant(0); + SameDiff sd = SameDiff.create(); + SDVariable countIn = sd.constant(5); + SDVariable sumIn = sd.constant(0); + SDVariable sum2 = sd.constant(0); //TODO creating constant instead of using sum2 causes errors - SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn}, - (sd, vars) -> vars[0].gt(0), - (sd, vars) -> new SDVariable[]{vars[0].sub(1), - vars[1].add(sd.whileLoop(new SDVariable[]{vars[0], sum2}, + SDVariable[] sum = sd.whileLoop(new SDVariable[]{countIn, sumIn}, + (s, vars) -> vars[0].gt(0), + (s, vars) -> new SDVariable[]{vars[0].sub(1), + vars[1].add(s.whileLoop(new SDVariable[]{vars[0], sum2}, (sd2, vars2) -> vars2[0].gt(0), (sd2, vars2) -> new SDVariable[]{vars2[0].sub(1), vars2[1].add(vars2[0])})[1])}); @@ -3305,23 +3304,23 @@ public class SameDiffTests extends BaseNd4jTest { String outName = sum[1].name(); - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false)); - assertEquals(35, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); + assertEquals(35, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test public void testNestedWhileIf() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable countIn = SD.constant(5); - SDVariable sumIn = SD.constant(0); - SDVariable hundred = SD.constant(100); + SameDiff sd = SameDiff.create(); + SDVariable countIn = sd.constant(5); + SDVariable sumIn = sd.constant(0); + SDVariable hundred = sd.constant(100); - SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn}, - (sd, vars) -> vars[0].gte(0), - (sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add( - sd.ifCond((sd2) -> vars[0].eq(0), + SDVariable[] sum = sd.whileLoop(new SDVariable[]{countIn, sumIn}, + (s, vars) -> vars[0].gte(0), + (s, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add( + s.ifCond((sd2) -> vars[0].eq(0), (sd2) -> vars[0].add(100), //TODO replace with hundred and things break (sd2) -> vars[0]) )}); @@ -3331,9 +3330,9 @@ public class SameDiffTests extends BaseNd4jTest { String outName = sum[1].name(); - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false)); - assertEquals(115, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); + assertEquals(115, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java index 03b469e70..2c1c284bc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java @@ -61,7 +61,7 @@ public class OpsMappingTests extends BaseNd4jTest { @Override public long getTimeoutMilliseconds() { - return 180000L; //Can be slow on some CI machines such as PPC + return 360000L; //Can be very slow on some CI machines (PPC) } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index 867fe1611..22c6e3a52 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -29,7 +29,10 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; +import org.nd4j.linalg.api.ops.impl.shape.Concat; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -473,6 +476,7 @@ public class OperationProfilerTests extends BaseNd4jTest { Nd4j.exec(op); //Should trigger NaN panic fail(); } catch (Exception e){ + e.printStackTrace(); assertTrue(e.getMessage(), e.getMessage().contains("Inf")); } @@ -488,4 +492,55 @@ public class OperationProfilerTests extends BaseNd4jTest { Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build()); } } + + + @Test + public void testOpProfilerOpContextLegacy(){ + + for(boolean nan : new boolean[]{true, false}) { + + INDArray in = Nd4j.valueArrayOf(10, nan ? -1 : 0).castTo(DataType.FLOAT); + + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(nan).checkForINF(!nan).build()); + + OpContext oc = Nd4j.getExecutioner().buildContext(); + oc.setInputArray(0, in); + oc.setOutputArray(0, in.ulike()); + try { + Nd4j.exec(new Log(), oc); + System.out.println(oc.getOutputArray(0)); + fail("Expected op profiler exception"); + } catch (Throwable t) { + //OK + assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf")); + } + } + } + + @Test + public void testOpProfilerOpContextCustomOp(){ + + for(boolean nan : new boolean[]{true, false}) { + + INDArray in = Nd4j.create(DataType.DOUBLE, 10).assign(nan ? Double.NaN : Double.POSITIVE_INFINITY); + INDArray in2 = in.dup(); + + + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(nan).checkForINF(!nan).build()); + + OpContext oc = Nd4j.getExecutioner().buildContext(); + oc.setIArguments(0); + oc.setInputArray(0, in); + oc.setInputArray(1, in2); + oc.setOutputArray(0, Nd4j.create(DataType.DOUBLE, 20)); + try { + Nd4j.exec(new Concat(), oc); + System.out.println(oc.getOutputArray(0)); + fail("Expected op profiler exception"); + } catch (Throwable t) { + //OK + assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf")); + } + } + } } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java index 35e5607a2..39c09e627 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java @@ -3579,4 +3579,19 @@ public class ArrayUtil { } return false; } + + public static T[] filterNull(T... in){ + int count = 0; + for( int i=0; i Date: Thu, 9 Apr 2020 10:00:27 +1000 Subject: [PATCH 15/19] Add javacpp classifier dependency to nd4j-native and nd4j-cuda (#366) Signed-off-by: Alex Black --- nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml | 6 ++++++ nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index 14cb6af6e..b450e58b6 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -242,6 +242,12 @@ javacpp ${javacpp.version} + + org.bytedeco + javacpp + ${javacpp.version} + ${dependency.platform} + org.bytedeco cuda diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml index 0aab8c241..91888f400 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml @@ -39,6 +39,12 @@ javacpp ${javacpp.version} + + org.bytedeco + javacpp + ${javacpp.version} + ${dependency.platform} + org.bytedeco openblas From 3e2dbc65ddeb316c77f53b8c011d30a35324bbf6 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 10 Apr 2020 17:57:02 +0300 Subject: [PATCH 16/19] MatMul for gemm/gemv calls (#365) * libnd4j added optional alpha and beta support to matmul Signed-off-by: Oleg * libnd4j typos fixes Signed-off-by: Oleg * libnd4j add optional alpha and beta to matmul_bp Signed-off-by: Oleg * libnd4j one more typo fix Signed-off-by: Oleg * libnd4j added optional alpha and beta to mkl implementation Signed-off-by: Oleg * MatMul alpha/beta on java side Signed-off-by: raver119 * alpha/beta fix in libnd4j Signed-off-by: raver119 * alpha/beta fix in matmul_bp Signed-off-by: raver119 * restored view validation Signed-off-by: raver119 * gemv/gemm now use MatMul op Signed-off-by: raver119 * few tests fixed Signed-off-by: raver119 * additional INDArray.mmul signature Signed-off-by: raver119 * make C order default for INDArray.mmul, unless both A/B have F order Signed-off-by: raver119 * Nd4j.gemm validation fix Signed-off-by: raver119 * disable mkldnn matmul for xxf with beta != 0 case Signed-off-by: raver119 * SimpleRnn workspace fix + timeouts Signed-off-by: Alex Black * two more tests + minor fix in matmul platform check Signed-off-by: raver119 * Flaky test fixes Signed-off-by: Alex Black * propagate testresources profile Signed-off-by: raver119 * Resources fix + flaky test fix Signed-off-by: Alex Black Co-authored-by: Oleg Co-authored-by: Alex Black --- .../deeplearning4j/datasets/TestDataSets.java | 5 + .../listener/TestCheckpointListener.java | 11 ++- .../regressiontest/RegressionTest100a.java | 5 + .../regressiontest/RegressionTest100b3.java | 5 + .../regressiontest/RegressionTest100b4.java | 8 +- .../regressiontest/RegressionTest100b6.java | 8 +- .../keras/e2e/KerasModelEndToEndTest.java | 7 +- .../nn/layers/recurrent/SimpleRnn.java | 2 +- .../deeplearning4j-json-server/pom.xml | 7 ++ .../remote/BinaryModelServerTest.java | 7 +- deeplearning4j/deeplearning4j-remote/pom.xml | 3 + libnd4j/include/helpers/MmulHelper.h | 2 +- libnd4j/include/helpers/impl/MmulHelper.cpp | 6 +- .../ops/declarable/generic/blas/matmul.cpp | 20 +++- .../ops/declarable/platform/mkldnn/matmul.cpp | 23 ++++- .../layers_tests/DeclarableOpsTests19.cpp | 93 ++++++++++++++++++- .../nd4j/linalg/api/blas/impl/BaseLevel2.java | 8 +- .../nd4j/linalg/api/blas/impl/BaseLevel3.java | 11 ++- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 14 ++- .../org/nd4j/linalg/api/ndarray/INDArray.java | 8 ++ .../nd4j/linalg/api/ops/impl/reduce/Mmul.java | 50 ++++++++-- .../java/org/nd4j/linalg/factory/Nd4j.java | 11 +-- .../listeners/CheckpointListenerTest.java | 9 +- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 18 ---- .../org/nd4j/linalg/api/blas/Level3Test.java | 2 +- .../api/buffer/DataTypeValidationTests.java | 2 +- .../java/org/nd4j/linalg/blas/BlasTests.java | 19 +--- 27 files changed, 279 insertions(+), 85 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java index bc892905c..44aa9a0b3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java @@ -23,6 +23,11 @@ import org.junit.Test; public class TestDataSets extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void testTinyImageNetExists() throws Exception { //Simple sanity check on extracting diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java index 91ec8c98e..721786eef 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java @@ -44,6 +44,11 @@ import static org.junit.Assert.*; public class TestCheckpointListener extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @Rule public TemporaryFolder tempDir = new TemporaryFolder(); @@ -57,7 +62,7 @@ public class TestCheckpointListener extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSetIterator iter = new IrisDataSetIterator(75,150); + DataSetIterator iter = new IrisDataSetIterator(25,50); return new Pair<>(net, iter); } @@ -178,13 +183,13 @@ public class TestCheckpointListener extends BaseDL4JTest { CheckpointListener l = new CheckpointListener.Builder(f) .keepLast(3) - .saveEvery(3, TimeUnit.SECONDS) + .saveEvery(4, TimeUnit.SECONDS) .build(); net.setListeners(l); for(int i=0; i<5; i++ ){ //10 iterations total net.fit(iter); - Thread.sleep(4000); + Thread.sleep(5000); } //Expect models saved at iterations: 2, 4, 6, 8 (iterations 0 and 1 shoud happen before first 3 seconds is up) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index a66914cd7..05bb8b5eb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -54,6 +54,11 @@ import static org.junit.Assert.*; @Slf4j public class RegressionTest100a extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index b1bef2bfc..a28c1b845 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -52,6 +52,11 @@ import static org.junit.Assert.*; public class RegressionTest100b3 extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index a4883ea07..ec8531eb2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -69,6 +69,11 @@ import org.nd4j.resources.Resources; public class RegressionTest100b4 extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Override public DataType getDataType() { return DataType.FLOAT; @@ -123,7 +128,8 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); assertEquals(dtype, net.params().dataType()); - assertEquals("Test for dtype: " + dtypeName, outExp, outAct); + boolean eq = outExp.equalsWithEps(outAct, 0.01); + assertTrue("Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct, eq); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index 637f5860f..22ac01c14 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -56,6 +56,11 @@ public class RegressionTest100b6 extends BaseDL4JTest { return DataType.FLOAT; } + @Override + public long getTimeoutMilliseconds() { + return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Test public void testCustomLayer() throws Exception { @@ -106,7 +111,8 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); assertEquals(dtype, net.params().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); - assertTrue(outExp + " vs " + outAct, eq); } + assertTrue(outExp + " vs " + outAct, eq); + } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index b17c215cb..7538d39bc 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -96,7 +96,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } }; - @Test(expected = IllegalStateException.class) + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + + @Test(expected = IllegalStateException.class) public void fileNotFoundEndToEnd() throws Exception { String modelPath = "modelimport/keras/examples/foo/bar.h5"; importEndModelTest(modelPath, null, true, true, false, false); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java index 87d88efcb..cc387446a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java @@ -72,7 +72,7 @@ public class SimpleRnn extends BaseRecurrentLayer + + testresources + + true + + + test-nd4j-native diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java index c57b0fa30..8e109689a 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java @@ -19,6 +19,7 @@ import org.nd4j.remote.clients.serde.BinarySerializer; import org.nd4j.remote.clients.serde.JsonDeserializer; import org.nd4j.remote.clients.serde.JsonSerializer; import org.nd4j.remote.clients.serde.impl.IntegerSerde; +import org.nd4j.resources.Resources; import org.nd4j.shade.jackson.databind.ObjectMapper; import javax.imageio.ImageIO; @@ -65,7 +66,7 @@ public class BinaryModelServerTest extends BaseDL4JTest { @Test public void testMlnMnist_ImageInput() throws Exception { - val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile(); + val modelFile = Resources.asFile("models/mnist/mnist-model.zip"); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile); val server = new JsonModelServer.Builder(net) @@ -129,7 +130,7 @@ public class BinaryModelServerTest extends BaseDL4JTest { @Test public void testMlnMnist_ImageInput_Async() throws Exception { - val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile(); + val modelFile = Resources.asFile("models/mnist/mnist-model.zip"); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile); val server = new JsonModelServer.Builder(net) @@ -198,7 +199,7 @@ public class BinaryModelServerTest extends BaseDL4JTest { @Test public void testBinaryIn_BinaryOut() throws Exception { - val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile(); + val modelFile = Resources.asFile("models/mnist/mnist-model.zip"); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile); val server = new JsonModelServer.Builder(net) diff --git a/deeplearning4j/deeplearning4j-remote/pom.xml b/deeplearning4j/deeplearning4j-remote/pom.xml index 4ef2e06dd..4329a1554 100644 --- a/deeplearning4j/deeplearning4j-remote/pom.xml +++ b/deeplearning4j/deeplearning4j-remote/pom.xml @@ -20,6 +20,9 @@ deeplearning4j-remote + + testresources + test-nd4j-native diff --git a/libnd4j/include/helpers/MmulHelper.h b/libnd4j/include/helpers/MmulHelper.h index 6e38be5c1..517ca9888 100644 --- a/libnd4j/include/helpers/MmulHelper.h +++ b/libnd4j/include/helpers/MmulHelper.h @@ -60,7 +60,7 @@ namespace sd { static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector>& modifA, const std::vector>& modifB); #endif - static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY); + static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha = 1.0, double beta = 0.0); }; } diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index bc525622a..f5b9bc829 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -239,7 +239,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND ////////////////////////////////////////////////////////////////////////// - void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY) { + void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha, double beta) { int xRank = x->rankOf(); int yRank = y->rankOf(); @@ -276,7 +276,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()})); } - mmul(xT, yT, zT, 1., 0.); + mmul(xT, yT, zT, alpha, beta); } else { // rest cases - batched mmul @@ -292,7 +292,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND auto xSubArr = (*xT)(i, dimsToExclude); auto ySubArr = (*yT)(i, dimsToExclude); auto zSubArr = (*zT)(i, dimsToExclude); - mmul(&xSubArr, &ySubArr, &zSubArr, 1., 0.); + mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta); } } diff --git a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp index 6209e7bbf..370aa50c6 100644 --- a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp @@ -36,10 +36,14 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - const int iSize = (int) block.getIArguments()->size(); + int iSize = (int) block.getIArguments()->size(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; + // optional use alpha nad beta + iSize = (int)block.getTArguments()->size(); + double alpha = iSize > 0 ? T_ARG(0) : 1.0; + double beta = iSize > 1 ? T_ARG(1) : 0.0; const int xRank = x->rankOf(); const int yRank = y->rankOf(); @@ -77,7 +81,7 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { } // ******* end of input validation ******* // - MmulHelper::matmul(x, y, z, transX, transY); + MmulHelper::matmul(x, y, z, transX, transY, alpha, beta); return Status::OK(); } @@ -147,11 +151,17 @@ CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) { auto dldx = OUTPUT_VARIABLE(0); auto dldy = OUTPUT_VARIABLE(1); - const int iSize = (int) block.getIArguments()->size(); + int iSize = (int) block.getIArguments()->size(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; + // optional use alpha nad beta + iSize = (int)block.getTArguments()->size(); + + double alpha = iSize > 0 ? T_ARG(0) : 1.0; + double beta = iSize > 1 ? T_ARG(1) : 0.0; + /* In: x=[a,b], y=[b,c] tX tY tZ x y z dz dLdx dLdy @@ -164,8 +174,8 @@ F F T [a,b] [b,c] [c,a] [c,a] sd::ops::matmul op; - op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {}); - op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {}); + op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {}); + op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {}); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index 91e56d801..f3ef84e2f 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -32,7 +32,7 @@ namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) { +static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) { // mkl works with following // [M,K] x [K,N] = [M,N] @@ -150,6 +150,12 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b // Create attributes (to handle alpha and beta if necessary) dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) + if (alpha != 1.f) attr.set_output_scales(0, {alpha}); + if (beta != 0.f) { + dnnl::post_ops po; + po.append_sum(beta); + attr.set_post_ops(po); + } // operation primitive description dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md); @@ -224,11 +230,16 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) { if(x->isEmpty() || y->isEmpty()) return Status::OK(); - const int iSize = (int) block.getIArguments()->size(); + int iSize = (int) block.getIArguments()->size(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; + // optional use alpha nad beta + iSize = (int)block.getTArguments()->size(); + float alpha = iSize > 0 ? T_ARG(0) : 1.0; + float beta = iSize > 1 ? T_ARG(1) : 0.0; + const int xRank = x->rankOf(); const int yRank = y->rankOf(); const int zRank = z->rankOf(); @@ -265,7 +276,7 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) { } // ******* end of input validation ******* // - matmulMKLDNN(x, y, z, transX, transY); + matmulMKLDNN(x, y, z, transX, transY, alpha, beta); return Status::OK(); } @@ -276,14 +287,16 @@ PLATFORM_CHECK(matmul, ENGINE_CPU) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); - auto z = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); const DataType xType = x->dataType(); const DataType yType = y->dataType(); const DataType zType = z->dataType(); + float alpha = block.numT() > 0 ? T_ARG(0) : 1.0; + float beta = block.numT() > 1 ? T_ARG(1) : 0.0; - return block.isUseMKLDNN() && x->rankOf() < 3 && + return !(z->ordering() == 'f' && beta != 0.f) && block.isUseMKLDNN() && x->rankOf() < 3 && ( (xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) || (xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) || diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index 8b2bd0071..f48e3d946 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -39,6 +39,97 @@ public: } }; +TEST_F(DeclarableOpsTests19, test_matmul_ccc) { + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('c', {10, 10}); + auto z = NDArrayFactory::create('c', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_matmul_fcf) { + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_matmul_cff) { + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + + +TEST_F(DeclarableOpsTests19, test_matmul_ccf) { + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_matmul_fff) { + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) { /* DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp") @@ -74,4 +165,4 @@ TEST_F(DeclarableOpsTests19, test_squeeze_1) { sd::ops::squeeze op; auto status = op.execute({&x}, {&e}, {axis}); ASSERT_EQ(Status::OK(), status); -} \ No newline at end of file +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java index 824f46c82..8736c2363 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java @@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl; import lombok.val; import org.nd4j.linalg.api.blas.Level2; import org.nd4j.linalg.api.blas.params.GemvParameters; +import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -57,6 +59,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { OpProfiler.getInstance().processBlasCall(false, A, X, Y); GemvParameters parameters = new GemvParameters(A, X, Y); + + Nd4j.exec(new Mmul(A, X, Y, alpha, beta, MMulTranspose.builder().transposeA(false).build())); + + /* if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(), parameters.getY()); @@ -86,7 +92,7 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { } else { throw new ND4JIllegalStateException("Unsupported data type " + A.dataType()); } - + */ OpExecutionerUtil.checkForAny(Y); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java index e38a0e618..8d9765aee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java @@ -19,11 +19,13 @@ package org.nd4j.linalg.api.blas.impl; import lombok.extern.slf4j.Slf4j; import org.nd4j.linalg.api.blas.Level3; import org.nd4j.linalg.api.blas.params.GemmParams; +import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.OpProfiler; @@ -59,6 +61,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { GemmParams params = new GemmParams(A, B, C); + Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(false).transposeB(false).build())); + + /* int charOder = Order; if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), params.getC()); @@ -73,6 +78,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { hgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f, params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc()); } + */ OpExecutionerUtil.checkForAny(C); } @@ -85,6 +91,9 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(true, A, B, C); + Nd4j.exec(new Mmul(A, B, C, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build())); + + /* GemmParams params = new GemmParams(A, B, C, transposeA, transposeB); if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, params.getA(), params.getB(), C); @@ -102,7 +111,7 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { (float) alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float) beta, C, params.getLdc()); } - +*/ OpExecutionerUtil.checkForAny(C); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 3edbd2682..07a2bf9b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -2866,16 +2866,22 @@ public abstract class BaseNDArray implements INDArray, Iterable { } @Override - public INDArray mmul(INDArray other) { + public INDArray mmul(INDArray other, char resultOrder) { + Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', "Order must be either 'c' or 'f', but [" + resultOrder + "] was given"); Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType()); - // FIXME: for 1D case, we probably want vector output here? - long[] shape = {rows(), other.rank() == 1 ? 1 : other.columns()}; - INDArray result = createUninitialized(this.dataType(), shape, 'f'); + // FIXME: add support for 3D+ here? + long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()}; + INDArray result = createUninitialized(this.dataType(), shape, resultOrder); if (result.isScalar()) return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1); return mmuli(other, result); } + @Override + public INDArray mmul(INDArray other) { + return mmul(other, (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c'); + } + protected INDArray create(int[] shape, char ordering) { return Nd4j.create(shape, ordering); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index de80e9413..08aa613fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -1232,6 +1232,14 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray mmul(INDArray other); + /** + * Perform a copy matrix multiplication + * @param other other the other matrix to perform matrix multiply with + * @param resultOrder either C or F order for result array + * @return the result of the matrix multiplication + */ + INDArray mmul(INDArray other, char resultOrder); + /** * Convert this ndarray to a 2d double matrix. * Note that THIS SHOULD NOT BE USED FOR SPEED. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index 30ca8ebc5..46310893d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -44,6 +44,8 @@ import java.util.*; public class Mmul extends DynamicCustomOp { protected MMulTranspose mt; + protected double alpha = 1.0; + protected double beta = 0.0; /** * @@ -59,6 +61,7 @@ public class Mmul extends DynamicCustomOp { super(null,sameDiff,new SDVariable[]{i_v1,i_v2}); this.mt = mt; addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()), ArrayUtil.fromBoolean(mt.isTransposeResult())); + addTArgument(alpha, beta); } @@ -74,6 +77,30 @@ public class Mmul extends DynamicCustomOp { this(sameDiff,i_v1,i_v2,MMulTranspose.allFalse()); } + public Mmul(INDArray x, + INDArray y, + INDArray z, + double alpha, + double beta, + MMulTranspose mt) { + addInputArgument(x, y); + + if (z != null) + addOutputArgument(z); + + if (mt != null) { + this.mt = mt; + addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), + ArrayUtil.fromBoolean(mt.isTransposeB()), + ArrayUtil.fromBoolean(mt.isTransposeResult())); + } + + this.alpha = alpha; + this.beta = beta; + + addTArgument(alpha, beta); + } + /** * * @param x @@ -84,25 +111,30 @@ public class Mmul extends DynamicCustomOp { INDArray y, INDArray z, MMulTranspose mt) { - super(null, new INDArray[]{x, y}, z == null ? null : new INDArray[]{z}); - if (mt != null) { - this.mt = mt; - addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), - ArrayUtil.fromBoolean(mt.isTransposeB()), - ArrayUtil.fromBoolean(mt.isTransposeResult())); - } + this(x, y, z, 1.0, 0.0, mt); } public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) { + this(x, y, 1.0, 0.0, transposeX, transposeY, transposeZ); + } + + public Mmul(INDArray x, INDArray y, double alpha, double beta, boolean transposeX, boolean transposeY, boolean transposeZ) { addInputArgument(x, y); addIArgument(ArrayUtil.fromBoolean(transposeX), ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeZ)); mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); + addTArgument(alpha, beta); + this.alpha = alpha; + this.beta = beta; + } + + public Mmul(INDArray x, INDArray y, double alpha, double beta) { + this(x,y,null, alpha, beta,null); } public Mmul(INDArray x, INDArray y) { - this(x,y,null,null); + this(x, y, 1.0, 0.0); } public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, @@ -111,6 +143,8 @@ public class Mmul extends DynamicCustomOp { addIArgument(ArrayUtil.fromBoolean(transposeX), ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeZ)); + + addTArgument(alpha, beta); mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 5da64dadb..43181a3b2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -791,7 +791,7 @@ public class Nd4j { boolean transposeB) { long cRows = (transposeA ? a.columns() : a.rows()); long cCols = (transposeB ? b.rows() : b.columns()); - INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, 'f'); + INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, a.ordering() == 'c' && b.ordering() == 'c' ? 'c' : 'f'); return gemm(a, b, c, transposeA, transposeB, 1.0, 0.0); } @@ -817,12 +817,9 @@ public class Nd4j { boolean transposeB, double alpha, double beta) { - //Note: some views have non-zero offset but 'default' strides (these are OK). And a 'c' order vector such as [10,1] is OK - same buffer as an 'f' order vector with same shape - Preconditions.checkState(c.length() == 1 || c.ordering() == 'f' && Shape.hasDefaultStridesForShape(c) || - c.isVectorOrScalar() && c.elementWiseStride() == 1, - "C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order " + - "and not a view. C (result) array: [%ndSInfo]", c); - getBlasWrapper().level3().gemm(a, b, c, transposeA, transposeB, alpha, beta); + Preconditions.checkArgument(c.elementWiseStride() == 1, "Nd4j.gemm() C array should NOT be a view"); + + Nd4j.exec(new Mmul(a, b, c, alpha, beta, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).build())); return c; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java index 423887b64..997bf609c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java @@ -40,6 +40,11 @@ public class CheckpointListenerTest extends BaseNd4jTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + public static SameDiff getModel(){ Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -151,7 +156,7 @@ public class CheckpointListenerTest extends BaseNd4jTest { CheckpointListener l = new CheckpointListener.Builder(dir) .keepLast(2) - .saveEvery(1, TimeUnit.SECONDS) + .saveEvery(4, TimeUnit.SECONDS) .build(); sd.setListeners(l); @@ -159,7 +164,7 @@ public class CheckpointListenerTest extends BaseNd4jTest { for(int i=0; i<5; i++ ){ //10 iterations total sd.fit(iter, 1); - Thread.sleep(1000); + Thread.sleep(5000); } //Expect models saved at iterations: 10, 20, 30, 40 diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index da91fb6cf..162e123b8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -6192,24 +6192,6 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, output); } - @Test - public void testVectorGemv() { - val vectorL = Nd4j.create(new float[]{1, 2, 3}, new long[]{3, 1}); - val vectorN = Nd4j.create(new float[]{1, 2, 3}, new long[]{3}); - val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3}); - -// log.info("vectorN: {}", vectorN); -// log.info("vectorL: {}", vectorL); - - val outN = matrix.mmul(vectorN); - val outL = matrix.mmul(vectorL); - - assertEquals(outL, outN.reshape(3,1)); - - assertEquals(1, outN.rank()); - } - - @Test public void testMatrixReshape() { val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java index 8263cc07c..8fc247683 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java @@ -60,7 +60,7 @@ public class Level3Test extends BaseNd4jTest { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); - INDArray array3 = array1.mmul(array2); + INDArray array3 = array1.mmul(array2, Nd4j.createUninitialized(new long[]{10, 10}, 'f')); //System.out.println("Array3: " + Arrays.toString(array3.data().asFloat())); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java index 7dcd2285f..b3bd78979 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java @@ -83,7 +83,7 @@ public class DataTypeValidationTests extends BaseNd4jTest { /** * Testing level2 blas */ - @Test(expected = ND4JIllegalStateException.class) + @Test(expected = RuntimeException.class) public void testBlasValidation2() { INDArray a = Nd4j.create(100, 10); INDArray x = Nd4j.create(100); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java index b81b90133..d9f307abe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java @@ -83,22 +83,7 @@ public class BlasTests extends BaseNd4jTest { try { Nd4j.gemm(a, b, view, false, false, 1.0, 0.0); fail("Expected exception"); - } catch (IllegalStateException e) { - assertTrue(e.getMessage().contains("view")); - } - } - - @Test - public void testGemmInvalid2() { - final INDArray a = Nd4j.rand(4, 3); - final INDArray b = Nd4j.rand(4, 5); - - final INDArray target = Nd4j.zeros(3, 5, 'c'); - - try { - Nd4j.gemm(a, b, target, true, false, 1.0, 0.0); - fail("Expected exception"); - } catch (IllegalStateException e) { + } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains("view")); } } @@ -114,7 +99,7 @@ public class BlasTests extends BaseNd4jTest { try { Nd4j.gemm(a, b, view, true, false, 1.0, 0.0); fail("Expected exception"); - } catch (IllegalStateException e) { + } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains("view")); } } From f1debe8c077e97d616a9e76893529f037c148473 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Fri, 10 Apr 2020 19:50:40 -0400 Subject: [PATCH 17/19] RL4J: Add ExperienceHandler (#369) * Added ExperienceHandler Signed-off-by: Alexandre Boulanger * Added getTrainingBatchSize() Signed-off-by: Alexandre Boulanger --- .../rl4j/experience/ExperienceHandler.java | 54 +++++ .../ReplayMemoryExperienceHandler.java | 111 ++++++++++ .../StateActionExperienceHandler.java | 67 ++++++ .../rl4j/experience/StateActionPair.java | 49 +++++ .../learning/async/AsyncThreadDiscrete.java | 72 +++---- .../{MiniTrans.java => UpdateAlgorithm.java} | 66 +++--- .../async/a3c/discrete/A3CThreadDiscrete.java | 61 +----- .../a3c/discrete/A3CUpdateAlgorithm.java | 113 ++++++++++ .../AsyncNStepQLearningThreadDiscrete.java | 40 +--- .../discrete/QLearningUpdateAlgorithm.java | 88 ++++++++ .../rl4j/learning/sync/ExpReplay.java | 5 +- .../rl4j/learning/sync/IExpReplay.java | 6 +- .../learning/sync/qlearning/QLearning.java | 24 --- .../qlearning/discrete/QLearningDiscrete.java | 42 ++-- .../ReplayMemoryExperienceHandlerTest.java | 107 ++++++++++ .../StateActionExperienceHandlerTest.java | 82 ++++++++ .../async/AsyncThreadDiscreteTest.java | 82 +++----- .../a3c/discrete/A3CThreadDiscreteTest.java | 197 ------------------ .../a3c/discrete/A3CUpdateAlgorithmTest.java | 160 ++++++++++++++ ...AsyncNStepQLearningThreadDiscreteTest.java | 98 --------- .../QLearningUpdateAlgorithmTest.java | 115 ++++++++++ .../discrete/QLearningDiscreteTest.java | 62 +++--- .../deeplearning4j/rl4j/support/MockDQN.java | 6 +- .../rl4j/support/MockExpReplay.java | 22 -- .../rl4j/support/MockExperienceHandler.java | 46 ++++ .../rl4j/support/MockObservationSpace.java | 14 +- .../rl4j/support/MockUpdateAlgorithm.java | 19 ++ 27 files changed, 1183 insertions(+), 625 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionPair.java rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/{MiniTrans.java => UpdateAlgorithm.java} (57%) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java new file mode 100644 index 000000000..1ec4f05c1 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.experience; + +import org.deeplearning4j.rl4j.observation.Observation; + +import java.util.List; + +/** + * A common interface to all classes capable of handling experience generated by the agents in a learning context. + * + * @param Action type + * @param Experience type + * + * @author Alexandre Boulanger + */ +public interface ExperienceHandler { + void addExperience(Observation observation, A action, double reward, boolean isTerminal); + + /** + * Called when the episode is done with the last observation + * @param observation + */ + void setFinalObservation(Observation observation); + + /** + * @return The size of the list that will be returned by generateTrainingBatch(). + */ + int getTrainingBatchSize(); + + /** + * The elements are returned in the historical order (i.e. in the order they happened) + * @return The list of experience elements + */ + List generateTrainingBatch(); + + /** + * Signal the experience handler that a new episode is starting + */ + void reset(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java new file mode 100644 index 000000000..74b7e3f05 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java @@ -0,0 +1,111 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.experience; + +import lombok.EqualsAndHashCode; +import org.deeplearning4j.rl4j.learning.sync.ExpReplay; +import org.deeplearning4j.rl4j.learning.sync.IExpReplay; +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.observation.Observation; +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; + +/** + * A experience handler that stores the experience in a replay memory. See https://arxiv.org/abs/1312.5602 + * The experience container is a {@link Transition Transition} that stores the tuple observation-action-reward-nextObservation, + * as well as whether or the not the episode ended after the Transition + * + * @param Action type + */ +@EqualsAndHashCode +public class ReplayMemoryExperienceHandler implements ExperienceHandler> { + private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000; + private static final int DEFAULT_BATCH_SIZE = 32; + + private IExpReplay expReplay; + + private Transition pendingTransition; + + public ReplayMemoryExperienceHandler(IExpReplay expReplay) { + this.expReplay = expReplay; + } + + public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) { + this(new ExpReplay(maxReplayMemorySize, batchSize, random)); + } + + public void addExperience(Observation observation, A action, double reward, boolean isTerminal) { + setNextObservationOnPending(observation); + pendingTransition = new Transition<>(observation, action, reward, isTerminal); + } + + public void setFinalObservation(Observation observation) { + setNextObservationOnPending(observation); + pendingTransition = null; + } + + @Override + public int getTrainingBatchSize() { + return expReplay.getBatchSize(); + } + + /** + * @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call. + */ + @Override + public List> generateTrainingBatch() { + return expReplay.getBatch(); + } + + @Override + public void reset() { + pendingTransition = null; + } + + private void setNextObservationOnPending(Observation observation) { + if(pendingTransition != null) { + pendingTransition.setNextObservation(observation); + expReplay.store(pendingTransition); + } + } + + public class Builder { + private int maxReplayMemorySize = DEFAULT_MAX_REPLAY_MEMORY_SIZE; + private int batchSize = DEFAULT_BATCH_SIZE; + private Random random = Nd4j.getRandom(); + + public Builder maxReplayMemorySize(int value) { + maxReplayMemorySize = value; + return this; + } + + public Builder batchSize(int value) { + batchSize = value; + return this; + } + + public Builder random(Random value) { + random = value; + return this; + } + + public ReplayMemoryExperienceHandler build() { + return new ReplayMemoryExperienceHandler(maxReplayMemorySize, batchSize, random); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java new file mode 100644 index 000000000..39338c6c0 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java @@ -0,0 +1,67 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.experience; + +import org.deeplearning4j.rl4j.observation.Observation; + +import java.util.ArrayList; +import java.util.List; + +/** + * A simple {@link ExperienceHandler experience handler} that stores the experiences. + * Note: Calling {@link StateActionExperienceHandler#generateTrainingBatch() generateTrainingBatch()} will clear the stored experiences + * + * @param Action type + * + * @author Alexandre Boulanger + */ +public class StateActionExperienceHandler implements ExperienceHandler> { + + private List> stateActionPairs; + + public void setFinalObservation(Observation observation) { + // Do nothing + } + + public void addExperience(Observation observation, A action, double reward, boolean isTerminal) { + stateActionPairs.add(new StateActionPair(observation, action, reward, isTerminal)); + } + + @Override + public int getTrainingBatchSize() { + return stateActionPairs.size(); + } + + /** + * The elements are returned in the historical order (i.e. in the order they happened) + * Note: the experience store is cleared after calling this method. + * + * @return The list of experience elements + */ + @Override + public List> generateTrainingBatch() { + List> result = stateActionPairs; + stateActionPairs = new ArrayList<>(); + + return result; + } + + @Override + public void reset() { + stateActionPairs = new ArrayList<>(); + } + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionPair.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionPair.java new file mode 100644 index 000000000..49e9ad3b5 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionPair.java @@ -0,0 +1,49 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.experience; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.deeplearning4j.rl4j.observation.Observation; + +/** + * A simple experience container. Used by {@link StateActionExperienceHandler StateActionExperienceHandler}. + * + * @param Action type + * + * @author Alexandre Boulanger + */ +@AllArgsConstructor +public class StateActionPair { + + /** + * The observation before the action is taken + */ + @Getter + private final Observation observation; + + @Getter + private final A action; + + @Getter + private final double reward; + + /** + * True if the episode ended after the action has been taken. + */ + @Getter + private final boolean terminal; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index a72abfa62..ac9853045 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -18,9 +18,12 @@ package org.deeplearning4j.rl4j.learning.async; +import lombok.AccessLevel; import lombok.Getter; +import lombok.Setter; import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; @@ -28,10 +31,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.Stack; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. @@ -45,13 +44,39 @@ public abstract class AsyncThreadDiscrete @Getter private NN current; - public AsyncThreadDiscrete(IAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { + @Setter(AccessLevel.PROTECTED) + private UpdateAlgorithm updateAlgorithm; + + // TODO: Make it configurable with a builder + @Setter(AccessLevel.PROTECTED) + private ExperienceHandler experienceHandler = new StateActionExperienceHandler(); + + public AsyncThreadDiscrete(IAsyncGlobal asyncGlobal, + MDP mdp, + TrainingListenerList listeners, + int threadNumber, + int deviceNum) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); synchronized (asyncGlobal) { current = (NN)asyncGlobal.getCurrent().clone(); } } + // TODO: Add an actor-learner class and be able to inject the update algorithm + protected abstract UpdateAlgorithm buildUpdateAlgorithm(); + + @Override + public void setHistoryProcessor(IHistoryProcessor historyProcessor) { + super.setHistoryProcessor(historyProcessor); + updateAlgorithm = buildUpdateAlgorithm(); + } + + @Override + protected void preEpoch() { + experienceHandler.reset(); + } + + /** * "Subepoch" correspond to the t_max-step iterations * that stack rewards with t_max MiniTrans @@ -65,13 +90,11 @@ public abstract class AsyncThreadDiscrete synchronized (getAsyncGlobal()) { current.copy(getAsyncGlobal().getCurrent()); } - Stack> rewards = new Stack<>(); Observation obs = sObs; IPolicy policy = getPolicy(current); - Integer action; - Integer lastAction = getMdp().getActionSpace().noOp(); + Integer action = getMdp().getActionSpace().noOp(); IHistoryProcessor hp = getHistoryProcessor(); int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1; @@ -82,21 +105,15 @@ public abstract class AsyncThreadDiscrete while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) { //if step of training, just repeat lastAction - if (obs.isSkipped()) { - action = lastAction; - } else { + if (!obs.isSkipped()) { action = policy.nextAction(obs); } StepReply stepReply = getLegacyMDPWrapper().step(action); accuReward += stepReply.getReward() * getConf().getRewardFactor(); - //if it's not a skipped frame, you can do a step of training if (!obs.isSkipped()) { - - INDArray[] output = current.outputAll(obs.getData()); - rewards.add(new MiniTrans(obs.getData(), action, output, accuReward)); - + experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone()); accuReward = 0; } @@ -104,29 +121,14 @@ public abstract class AsyncThreadDiscrete reward += stepReply.getReward(); incrementStep(); - lastAction = action; } - //a bit of a trick usable because of how the stack is treated to init R - // FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored. - - if (getMdp().isDone() && getCurrentEpochStep() < lastStep) - rewards.add(new MiniTrans(obs.getData(), null, null, 0)); - else { - INDArray[] output = null; - if (getConf().getLearnerUpdateFrequency() == -1) - output = current.outputAll(obs.getData()); - else synchronized (getAsyncGlobal()) { - output = getAsyncGlobal().getTarget().outputAll(obs.getData()); - } - double maxQ = Nd4j.max(output[0]).getDouble(0); - rewards.add(new MiniTrans(obs.getData(), null, output, maxQ)); + if (getMdp().isDone() && getCurrentEpochStep() < lastStep) { + experienceHandler.setFinalObservation(obs); } - getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep()); + getAsyncGlobal().enqueue(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), getCurrentEpochStep()); return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore()); } - - public abstract Gradient[] calcGradient(NN nn, Stack> rewards); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/MiniTrans.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java similarity index 57% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/MiniTrans.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java index 88bca6b0e..16ca1c3f8 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/MiniTrans.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java @@ -1,40 +1,26 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.async; - -import lombok.AllArgsConstructor; -import lombok.Value; -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. - * - * Its called a MiniTrans because it is similar to a Transition - * but without a next observation - * - * It is stacked and then processed by AsyncNStepQL or A3C - * following the paper implementation https://arxiv.org/abs/1602.01783 paper. - * - */ -@AllArgsConstructor -@Value -public class MiniTrans { - INDArray obs; - A action; - INDArray[] output; - double reward; -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.learning.async; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.network.NeuralNet; + +import java.util.List; + +public interface UpdateAlgorithm { + Gradient[] computeGradients(NN current, List> experience); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index c2a16d6b4..d189edca1 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -18,11 +18,7 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete; import lombok.Getter; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; -import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; -import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.async.*; import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; @@ -34,9 +30,7 @@ import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; - -import java.util.Stack; +import org.nd4j.linalg.api.rng.Random; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. @@ -67,6 +61,8 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< if(seed != null) { rnd.setSeed(seed + threadNumber); } + + setUpdateAlgorithm(buildUpdateAlgorithm()); } @Override @@ -74,52 +70,9 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< return new ACPolicy(net, rnd); } - /** - * calc the gradients based on the n-step rewards - */ @Override - public Gradient[] calcGradient(IActorCritic iac, Stack> rewards) { - MiniTrans minTrans = rewards.pop(); - - int size = rewards.size(); - - //if recurrent then train as a time serie with a batch size of 1 - boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent(); - - int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() - : getHistoryProcessor().getConf().getShape(); - int[] nshape = recurrent ? Learning.makeShape(1, shape, size) - : Learning.makeShape(size, shape); - - INDArray input = Nd4j.create(nshape); - INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1); - INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size) - : Nd4j.zeros(size, getMdp().getActionSpace().getSize()); - - double r = minTrans.getReward(); - for (int i = size - 1; i >= 0; i--) { - minTrans = rewards.pop(); - - r = minTrans.getReward() + conf.getGamma() * r; - if (recurrent) { - input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(minTrans.getObs()); - } else { - input.putRow(i, minTrans.getObs()); - } - - //the critic - targets.putScalar(i, r); - - //the actor - double expectedV = minTrans.getOutput()[0].getDouble(0); - double advantage = r - expectedV; - if (recurrent) { - logSoftmax.putScalar(0, minTrans.getAction(), i, advantage); - } else { - logSoftmax.putScalar(i, minTrans.getAction(), advantage); - } - } - - return iac.gradient(input, new INDArray[] {targets, logSoftmax}); + protected UpdateAlgorithm buildUpdateAlgorithm() { + int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); + return new A3CUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getLearnerUpdateFrequency(), conf.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java new file mode 100644 index 000000000..261cc788f --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.learning.async.a3c.discrete; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; +import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import java.util.List; + +public class A3CUpdateAlgorithm implements UpdateAlgorithm { + + private final IAsyncGlobal asyncGlobal; + private final int[] shape; + private final int actionSpaceSize; + private final int targetDqnUpdateFreq; + private final double gamma; + private final boolean recurrent; + + public A3CUpdateAlgorithm(IAsyncGlobal asyncGlobal, + int[] shape, + int actionSpaceSize, + int targetDqnUpdateFreq, + double gamma) { + + this.asyncGlobal = asyncGlobal; + + //if recurrent then train as a time serie with a batch size of 1 + recurrent = asyncGlobal.getCurrent().isRecurrent(); + this.shape = shape; + this.actionSpaceSize = actionSpaceSize; + this.targetDqnUpdateFreq = targetDqnUpdateFreq; + this.gamma = gamma; + } + + @Override + public Gradient[] computeGradients(IActorCritic current, List> experience) { + int size = experience.size(); + + int[] nshape = recurrent ? Learning.makeShape(1, shape, size) + : Learning.makeShape(size, shape); + + INDArray input = Nd4j.create(nshape); + INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1); + INDArray logSoftmax = recurrent ? Nd4j.zeros(1, actionSpaceSize, size) + : Nd4j.zeros(size, actionSpaceSize); + + StateActionPair stateActionPair = experience.get(size - 1); + double r; + if(stateActionPair.isTerminal()) { + r = 0; + } + else { + INDArray[] output = null; + if (targetDqnUpdateFreq == -1) + output = current.outputAll(stateActionPair.getObservation().getData()); + else synchronized (asyncGlobal) { + output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData()); + } + r = output[0].getDouble(0); + } + + for (int i = size - 1; i >= 0; --i) { + stateActionPair = experience.get(i); + + INDArray observationData = stateActionPair.getObservation().getData(); + + INDArray[] output = current.outputAll(observationData); + + r = stateActionPair.getReward() + gamma * r; + if (recurrent) { + input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData); + } else { + input.putRow(i, observationData); + } + + //the critic + targets.putScalar(i, r); + + //the actor + double expectedV = output[0].getDouble(0); + double advantage = r - expectedV; + if (recurrent) { + logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage); + } else { + logSoftmax.putScalar(i, stateActionPair.getAction(), advantage); + } + } + + // targets -> value, critic + // logSoftmax -> policy, actor + return current.gradient(input, new INDArray[] {targets, logSoftmax}); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 71199efaf..bd4dc16e8 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -18,11 +18,9 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import lombok.Getter; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; -import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; @@ -32,12 +30,9 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.api.rng.Random; -import java.util.Stack; - /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ @@ -65,6 +60,8 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn if(seed != null) { rnd.setSeed(seed + threadNumber); } + + setUpdateAlgorithm(buildUpdateAlgorithm()); } public Policy getPolicy(IDQN nn) { @@ -72,32 +69,9 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn rnd, conf.getMinEpsilon(), this); } - - - //calc the gradient based on the n-step rewards - public Gradient[] calcGradient(IDQN current, Stack> rewards) { - - MiniTrans minTrans = rewards.pop(); - - int size = rewards.size(); - - int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() - : getHistoryProcessor().getConf().getShape(); - int[] nshape = Learning.makeShape(size, shape); - INDArray input = Nd4j.create(nshape); - INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize()); - - double r = minTrans.getReward(); - for (int i = size - 1; i >= 0; i--) { - minTrans = rewards.pop(); - - r = minTrans.getReward() + conf.getGamma() * r; - input.putRow(i, minTrans.getObs()); - INDArray row = minTrans.getOutput()[0]; - row = row.putScalar(minTrans.getAction(), r); - targets.putRow(i, row); - } - - return current.gradient(input, targets); + @Override + protected UpdateAlgorithm buildUpdateAlgorithm() { + int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); + return new QLearningUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getTargetDqnUpdateFreq(), conf.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java new file mode 100644 index 000000000..beae271b1 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java @@ -0,0 +1,88 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.learning.async.nstep.discrete; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; +import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; + +public class QLearningUpdateAlgorithm implements UpdateAlgorithm { + + private final IAsyncGlobal asyncGlobal; + private final int[] shape; + private final int actionSpaceSize; + private final int targetDqnUpdateFreq; + private final double gamma; + + public QLearningUpdateAlgorithm(IAsyncGlobal asyncGlobal, + int[] shape, + int actionSpaceSize, + int targetDqnUpdateFreq, + double gamma) { + + this.asyncGlobal = asyncGlobal; + this.shape = shape; + this.actionSpaceSize = actionSpaceSize; + this.targetDqnUpdateFreq = targetDqnUpdateFreq; + this.gamma = gamma; + } + + @Override + public Gradient[] computeGradients(IDQN current, List> experience) { + int size = experience.size(); + + int[] nshape = Learning.makeShape(size, shape); + INDArray input = Nd4j.create(nshape); + INDArray targets = Nd4j.create(size, actionSpaceSize); + + StateActionPair stateActionPair = experience.get(size - 1); + + double r; + if(stateActionPair.isTerminal()) { + r = 0; + } + else { + INDArray[] output = null; + if (targetDqnUpdateFreq == -1) + output = current.outputAll(stateActionPair.getObservation().getData()); + else synchronized (asyncGlobal) { + output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData()); + } + r = Nd4j.max(output[0]).getDouble(0); + } + + for (int i = size - 1; i >= 0; i--) { + stateActionPair = experience.get(i); + + input.putRow(i, stateActionPair.getObservation().getData()); + + r = stateActionPair.getReward() + gamma * r; + INDArray[] output = current.outputAll(stateActionPair.getObservation().getData()); + INDArray row = output[0]; + row = row.putScalar(stateActionPair.getAction(), r); + targets.putRow(i, row); + } + + return current.gradient(input, targets); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java index 2defc1d75..93b4d1bb5 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java @@ -80,6 +80,9 @@ public class ExpReplay implements IExpReplay { //log.info("size: "+storage.size()); } - + public int getBatchSize() { + int storageSize = storage.size(); + return Math.min(storageSize, batchSize); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java index 02a4c8af5..eaef5f0f8 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java @@ -32,6 +32,11 @@ import java.util.ArrayList; */ public interface IExpReplay { + /** + * @return The size of the batch that will be returned by getBatch() + */ + int getBatchSize(); + /** * @return a batch of uniformly sampled transitions */ @@ -42,5 +47,4 @@ public interface IExpReplay { * @param transition a new transition to store */ void store(Transition transition); - } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 40704d4e9..7bef13e59 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -60,32 +60,8 @@ public abstract class QLearning implements TargetQNetworkSource, EpochStepCounter { - // FIXME Changed for refac - // @Getter - // final private IExpReplay expReplay; - @Getter - @Setter(AccessLevel.PROTECTED) - protected IExpReplay expReplay; - protected abstract LegacyMDPWrapper getLegacyMDPWrapper(); - public QLearning(QLearningConfiguration conf) { - this(conf, getSeededRandom(conf.getSeed())); - } - - public QLearning(QLearningConfiguration conf, Random random) { - expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random); - } - - private static Random getSeededRandom(Long seed) { - Random rnd = Nd4j.getRandom(); - if(seed != null) { - rnd.setSeed(seed); - } - - return rnd; - } - protected abstract EpsGreedy getEgPolicy(); public abstract MDP getMdp(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index e97415e29..1b9e667ae 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -21,6 +21,8 @@ import lombok.AccessLevel; import lombok.Getter; import lombok.Setter; import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; @@ -42,7 +44,7 @@ import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; -import java.util.ArrayList; +import java.util.List; /** @@ -71,10 +73,12 @@ public abstract class QLearningDiscrete extends QLearning> experienceHandler; + protected LegacyMDPWrapper getLegacyMDPWrapper() { return mdp; } @@ -85,7 +89,6 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) { - super(conf); this.configuration = conf; this.mdp = new LegacyMDPWrapper<>(mdp, null, this); qNetwork = dqn; @@ -98,6 +101,7 @@ public abstract class QLearningDiscrete extends QLearning getMdp() { @@ -114,7 +118,7 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning trainStep(Observation obs) { - Integer action; - boolean isHistoryProcessor = getHistoryProcessor() != null; int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1; int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1; @@ -142,37 +144,28 @@ public abstract class QLearningDiscrete extends QLearning stepReply = mdp.step(action); - + StepReply stepReply = mdp.step(lastAction); accuReward += stepReply.getReward() * configuration.getRewardFactor(); //if it's not a skipped frame, you can do a step of training if (!obs.isSkipped()) { // Add experience - if (pendingTransition != null) { - pendingTransition.setNextObservation(obs); - getExpReplay().store(pendingTransition); - } - pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone()); + experienceHandler.addExperience(obs, lastAction, accuReward, stepReply.isDone()); accuReward = 0; // Update NN // FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"? if (getStepCounter() > updateStart) { - DataSet targets = setTarget(getExpReplay().getBatch()); + DataSet targets = setTarget(experienceHandler.generateTrainingBatch()); getQNetwork().fit(targets.getFeatures(), targets.getLabels()); } } @@ -180,7 +173,7 @@ public abstract class QLearningDiscrete extends QLearning(maxQ, getQNetwork().getLatestScore(), stepReply); } - protected DataSet setTarget(ArrayList> transitions) { + protected DataSet setTarget(List> transitions) { if (transitions.size() == 0) throw new IllegalArgumentException("too few transitions"); @@ -189,9 +182,6 @@ public abstract class QLearningDiscrete extends QLearning { + + public final List> addedTransitions = new ArrayList<>(); + + @Override + public ArrayList> getBatch() { + return null; + } + + @Override + public void store(Transition transition) { + addedTransitions.add(transition); + } + + @Override + public int getBatchSize() { + return addedTransitions.size(); + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java new file mode 100644 index 000000000..7334ff87a --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java @@ -0,0 +1,82 @@ +package org.deeplearning4j.rl4j.experience; + +import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Test; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; + +import static org.junit.Assert.*; + +public class StateActionExperienceHandlerTest { + + @Test + public void when_addingExperience_expect_generateTrainingBatchReturnsIt() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(); + sut.reset(); + Observation observation = new Observation(Nd4j.zeros(1)); + sut.addExperience(observation, 123, 234.0, true); + + // Act + List> result = sut.generateTrainingBatch(); + + // Assert + assertEquals(1, result.size()); + assertSame(observation, result.get(0).getObservation()); + assertEquals(123, (int)result.get(0).getAction()); + assertEquals(234.0, result.get(0).getReward(), 0.00001); + assertTrue(result.get(0).isTerminal()); + } + + @Test + public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(); + sut.reset(); + sut.addExperience(null, 1, 1.0, false); + sut.addExperience(null, 2, 2.0, false); + sut.addExperience(null, 3, 3.0, false); + + // Act + List> result = sut.generateTrainingBatch(); + + // Assert + assertEquals(3, result.size()); + assertEquals(1, (int)result.get(0).getAction()); + assertEquals(2, (int)result.get(1).getAction()); + assertEquals(3, (int)result.get(2).getAction()); + } + + @Test + public void when_gettingExperience_expect_experienceStoreIsCleared() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(); + sut.reset(); + sut.addExperience(null, 1, 1.0, false); + + // Act + List> firstResult = sut.generateTrainingBatch(); + List> secondResult = sut.generateTrainingBatch(); + + // Assert + assertEquals(1, firstResult.size()); + assertEquals(0, secondResult.size()); + } + + @Test + public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(); + sut.reset(); + sut.addExperience(null, 1, 1.0, false); + sut.addExperience(null, 2, 2.0, false); + sut.addExperience(null, 3, 3.0, false); + + // Act + int size = sut.getTrainingBatchSize(); + + // Assert + assertEquals(3, size); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index 72f374db5..320b53a0e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -18,9 +18,12 @@ package org.deeplearning4j.rl4j.learning.async; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; +import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; @@ -31,7 +34,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; import java.util.List; -import java.util.Stack; import static org.junit.Assert.assertEquals; @@ -51,7 +53,9 @@ public class AsyncThreadDiscreteTest { TrainingListenerList listeners = new TrainingListenerList(); MockPolicy policyMock = new MockPolicy(); MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5); - TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); + MockExperienceHandler experienceHandlerMock = new MockExperienceHandler(); + MockUpdateAlgorithm updateAlgorithmMock = new MockUpdateAlgorithm(); + TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock, experienceHandlerMock, updateAlgorithmMock); sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); // Act @@ -60,8 +64,8 @@ public class AsyncThreadDiscreteTest { // Assert assertEquals(2, sut.trainSubEpochResults.size()); double[][] expectedLastObservations = new double[][] { - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, }; double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 }; for(int i = 0; i < 2; ++i) { @@ -102,62 +106,22 @@ public class AsyncThreadDiscreteTest { } } - // NeuralNetwork - assertEquals(2, nnMock.copyCallCount); - double[][] expectedNNInputs = new double[][] { + // ExperienceHandler + double[][] expectedExperienceHandlerInputs = new double[][] { new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans }; - assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size()); - for(int i = 0; i < expectedNNInputs.length; ++i) { - double[] expectedRow = expectedNNInputs[i]; - INDArray input = nnMock.outputAllInputs.get(i); + assertEquals(expectedExperienceHandlerInputs.length, experienceHandlerMock.addExperienceArgs.size()); + for(int i = 0; i < expectedExperienceHandlerInputs.length; ++i) { + double[] expectedRow = expectedExperienceHandlerInputs[i]; + INDArray input = experienceHandlerMock.addExperienceArgs.get(i).getObservation().getData(); assertEquals(expectedRow.length, input.shape()[1]); for(int j = 0; j < expectedRow.length; ++j) { assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); } } - - int arrayIdx = 0; - double[][][] expectedMinitransObs = new double[][][] { - new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation - }, - new double[][] { - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation - } - }; - double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 }; - double[] expectedRewards = new double[] { 0.0, 0.0, 3.0, 0.0, 0.0, 6.0 }; - - assertEquals(2, sut.rewards.size()); - for(int rewardIdx = 0; rewardIdx < 2; ++rewardIdx) { - Stack> miniTransStack = sut.rewards.get(rewardIdx); - - for (int i = 0; i < expectedMinitransObs[rewardIdx].length; ++i) { - MiniTrans minitrans = miniTransStack.get(i); - - // Observation - double[] expectedRow = expectedMinitransObs[rewardIdx][i]; - INDArray realRewards = minitrans.getObs(); - assertEquals(expectedRow.length, realRewards.shape()[1]); - for (int j = 0; j < expectedRow.length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * realRewards.getDouble(j), 0.00001); - } - - assertEquals(expectedOutputs[arrayIdx], minitrans.getOutput()[0].getDouble(0), 0.00001); - assertEquals(expectedRewards[arrayIdx], minitrans.getReward(), 0.00001); - ++arrayIdx; - } - } } public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete { @@ -167,22 +131,19 @@ public class AsyncThreadDiscreteTest { private final MockAsyncConfiguration config; public final List trainSubEpochResults = new ArrayList(); - public final List>> rewards = new ArrayList>>(); public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy, - MockAsyncConfiguration config, IHistoryProcessor hp) { + MockAsyncConfiguration config, IHistoryProcessor hp, + ExperienceHandler> experienceHandler, + UpdateAlgorithm updateAlgorithm) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); this.asyncGlobal = asyncGlobal; this.policy = policy; this.config = config; setHistoryProcessor(hp); - } - - @Override - public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack> rewards) { - this.rewards.add(rewards); - return new Gradient[0]; + setExperienceHandler(experienceHandler); + setUpdateAlgorithm(updateAlgorithm); } @Override @@ -200,6 +161,11 @@ public class AsyncThreadDiscreteTest { return policy; } + @Override + protected UpdateAlgorithm buildUpdateAlgorithm() { + return null; + } + @Override public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) { asyncGlobal.increaseCurrentLoop(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java deleted file mode 100644 index b812a5582..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java +++ /dev/null @@ -1,197 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.async.a3c.discrete; - -import org.deeplearning4j.nn.api.NeuralNetwork; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.async.MiniTrans; -import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; -import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.network.ac.IActorCritic; -import org.deeplearning4j.rl4j.support.*; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.List; -import java.util.Stack; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - -public class A3CThreadDiscreteTest { - - @Test - public void refac_calcGradient() { - // Arrange - double gamma = 0.9; - MockObservationSpace observationSpace = new MockObservationSpace(); - MockMDP mdpMock = new MockMDP(observationSpace); - A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build(); - MockActorCritic actorCriticMock = new MockActorCritic(); - IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock); - A3CThreadDiscrete sut = new A3CThreadDiscrete(mdpMock, asyncGlobalMock, config, 0, null, 0); - MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); - sut.setHistoryProcessor(hpMock); - - double[][] minitransObs = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - }; - double[] outputs = new double[] { 1.0, 2.0, 3.0 }; - double[] rewards = new double[] { 0.0, 0.0, 3.0 }; - - Stack> minitransList = new Stack>(); - for(int i = 0; i < 3; ++i) { - INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1); - INDArray[] output = new INDArray[] { - Nd4j.zeros(5) - }; - output[0].putScalar(i, outputs[i]); - minitransList.push(new MiniTrans<>(obs, i, output, rewards[i])); - } - minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans - - // Act - sut.calcGradient(actorCriticMock, minitransList); - - // Assert - assertEquals(1, actorCriticMock.gradientParams.size()); - INDArray input = actorCriticMock.gradientParams.get(0).getFirst(); - INDArray[] labels = actorCriticMock.gradientParams.get(0).getSecond(); - - assertEquals(minitransObs.length, input.shape()[0]); - for(int i = 0; i < minitransObs.length; ++i) { - double[] expectedRow = minitransObs[i]; - assertEquals(expectedRow.length, input.shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001); - } - } - - double latestReward = (gamma * 4.0) + 3.0; - double[] expectedLabels0 = new double[] { gamma * gamma * latestReward, gamma * latestReward, latestReward }; - for(int i = 0; i < expectedLabels0.length; ++i) { - assertEquals(expectedLabels0[i], labels[0].getDouble(i), 0.00001); - } - double[][] expectedLabels1 = new double[][] { - new double[] { 4.346, 0.0, 0.0, 0.0, 0.0 }, - new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 }, - new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 }, - }; - - assertArrayEquals(new long[] { expectedLabels0.length, 1 }, labels[0].shape()); - - for(int i = 0; i < expectedLabels1.length; ++i) { - double[] expectedRow = expectedLabels1[i]; - assertEquals(expectedRow.length, labels[1].shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], labels[1].getDouble(i, j), 0.00001); - } - } - - } - - public class MockActorCritic implements IActorCritic { - - public final List> gradientParams = new ArrayList<>(); - - @Override - public NeuralNetwork[] getNeuralNetworks() { - return new NeuralNetwork[0]; - } - - @Override - public boolean isRecurrent() { - return false; - } - - @Override - public void reset() { - - } - - @Override - public void fit(INDArray input, INDArray[] labels) { - - } - - @Override - public INDArray[] outputAll(INDArray batch) { - return new INDArray[0]; - } - - @Override - public IActorCritic clone() { - return this; - } - - @Override - public void copy(NeuralNet from) { - - } - - @Override - public void copy(IActorCritic from) { - - } - - @Override - public Gradient[] gradient(INDArray input, INDArray[] labels) { - gradientParams.add(new Pair(input, labels)); - return new Gradient[0]; - } - - @Override - public void applyGradient(Gradient[] gradient, int batchSize) { - - } - - @Override - public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException { - - } - - @Override - public void save(String pathValue, String pathPolicy) throws IOException { - - } - - @Override - public double getLatestScore() { - return 0; - } - - @Override - public void save(OutputStream os) throws IOException { - - } - - @Override - public void save(String filename) throws IOException { - - } - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java new file mode 100644 index 000000000..1434796f3 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java @@ -0,0 +1,160 @@ +package org.deeplearning4j.rl4j.learning.async.a3c.discrete; + +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.support.MockAsyncGlobal; +import org.deeplearning4j.rl4j.support.MockMDP; +import org.deeplearning4j.rl4j.support.MockObservationSpace; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class A3CUpdateAlgorithmTest { + + @Test + public void refac_calcGradient_non_terminal() { + // Arrange + double gamma = 0.9; + MockObservationSpace observationSpace = new MockObservationSpace(new int[] { 5 }); + MockMDP mdpMock = new MockMDP(observationSpace); + MockActorCritic actorCriticMock = new MockActorCritic(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock); + A3CUpdateAlgorithm sut = new A3CUpdateAlgorithm(asyncGlobalMock, observationSpace.getShape(), mdpMock.getActionSpace().getSize(), -1, gamma); + + + INDArray[] originalObservations = new INDArray[] { + Nd4j.create(new double[] { 0.0, 0.1, 0.2, 0.3, 0.4 }), + Nd4j.create(new double[] { 1.0, 1.1, 1.2, 1.3, 1.4 }), + Nd4j.create(new double[] { 2.0, 2.1, 2.2, 2.3, 2.4 }), + Nd4j.create(new double[] { 3.0, 3.1, 3.2, 3.3, 3.4 }), + }; + int[] actions = new int[] { 0, 1, 2, 1 }; + double[] rewards = new double[] { 0.1, 1.0, 10.0, 100.0 }; + + List> experience = new ArrayList>(); + for(int i = 0; i < originalObservations.length; ++i) { + experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false)); + } + + // Act + sut.computeGradients(actorCriticMock, experience); + + // Assert + assertEquals(1, actorCriticMock.gradientParams.size()); + + // Inputs + INDArray input = actorCriticMock.gradientParams.get(0).getLeft(); + for(int i = 0; i < 4; ++i) { + for(int j = 0; j < 5; ++j) { + assertEquals(i + j / 10.0, input.getDouble(i, j), 0.00001); + } + } + + INDArray targets = actorCriticMock.gradientParams.get(0).getRight()[0]; + INDArray logSoftmax = actorCriticMock.gradientParams.get(0).getRight()[1]; + + assertEquals(4, targets.shape()[0]); + assertEquals(1, targets.shape()[1]); + + // FIXME: check targets values once fixed + + assertEquals(4, logSoftmax.shape()[0]); + assertEquals(5, logSoftmax.shape()[1]); + + // FIXME: check logSoftmax values once fixed + + } + + public class MockActorCritic implements IActorCritic { + + public final List> gradientParams = new ArrayList<>(); + + @Override + public NeuralNetwork[] getNeuralNetworks() { + return new NeuralNetwork[0]; + } + + @Override + public boolean isRecurrent() { + return false; + } + + @Override + public void reset() { + + } + + @Override + public void fit(INDArray input, INDArray[] labels) { + + } + + @Override + public INDArray[] outputAll(INDArray batch) { + return new INDArray[] { batch.mul(-1.0) }; + } + + @Override + public IActorCritic clone() { + return this; + } + + @Override + public void copy(NeuralNet from) { + + } + + @Override + public void copy(IActorCritic from) { + + } + + @Override + public Gradient[] gradient(INDArray input, INDArray[] labels) { + gradientParams.add(new Pair(input, labels)); + return new Gradient[0]; + } + + @Override + public void applyGradient(Gradient[] gradient, int batchSize) { + + } + + @Override + public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException { + + } + + @Override + public void save(String pathValue, String pathPolicy) throws IOException { + + } + + @Override + public double getLatestScore() { + return 0; + } + + @Override + public void save(OutputStream os) throws IOException { + + } + + @Override + public void save(String filename) throws IOException { + + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java deleted file mode 100644 index 2a8c5b832..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java +++ /dev/null @@ -1,98 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2020 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.async.nstep.discrete; - -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.async.MiniTrans; -import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; -import org.deeplearning4j.rl4j.support.*; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.Stack; - -import static org.junit.Assert.assertEquals; - -public class AsyncNStepQLearningThreadDiscreteTest { - - @Test - public void refac_calcGradient() { - // Arrange - double gamma = 0.9; - MockObservationSpace observationSpace = new MockObservationSpace(); - MockMDP mdpMock = new MockMDP(observationSpace); - AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build(); - MockDQN dqnMock = new MockDQN(); - IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); - AsyncNStepQLearningThreadDiscrete sut = new AsyncNStepQLearningThreadDiscrete(mdpMock, asyncGlobalMock, config, null, 0, 0); - MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); - sut.setHistoryProcessor(hpMock); - - double[][] minitransObs = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - }; - double[] outputs = new double[] { 1.0, 2.0, 3.0 }; - double[] rewards = new double[] { 0.0, 0.0, 3.0 }; - - Stack> minitransList = new Stack>(); - for(int i = 0; i < 3; ++i) { - INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1); - INDArray[] output = new INDArray[] { - Nd4j.zeros(5) - }; - output[0].putScalar(i, outputs[i]); - minitransList.push(new MiniTrans<>(obs, i, output, rewards[i])); - } - minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans - - // Act - sut.calcGradient(dqnMock, minitransList); - - // Assert - assertEquals(1, dqnMock.gradientParams.size()); - INDArray input = dqnMock.gradientParams.get(0).getFirst(); - INDArray labels = dqnMock.gradientParams.get(0).getSecond(); - - assertEquals(minitransObs.length, input.shape()[0]); - for(int i = 0; i < minitransObs.length; ++i) { - double[] expectedRow = minitransObs[i]; - assertEquals(expectedRow.length, input.shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001); - } - } - - double latestReward = (gamma * 4.0) + 3.0; - double[][] expectedLabels = new double[][] { - new double[] { gamma * gamma * latestReward, 0.0, 0.0, 0.0, 0.0 }, - new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 }, - new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 }, - }; - assertEquals(minitransObs.length, labels.shape()[0]); - for(int i = 0; i < minitransObs.length; ++i) { - double[] expectedRow = expectedLabels[i]; - assertEquals(expectedRow.length, labels.shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], labels.getDouble(i, j), 0.00001); - } - } - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java new file mode 100644 index 000000000..35465d26a --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java @@ -0,0 +1,115 @@ +package org.deeplearning4j.rl4j.learning.async.nstep.discrete; + +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.support.MockAsyncGlobal; +import org.deeplearning4j.rl4j.support.MockDQN; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class QLearningUpdateAlgorithmTest { + + @Test + public void when_isTerminal_expect_initRewardIs0() { + // Arrange + MockDQN dqnMock = new MockDQN(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 1 }, 1, -1, 1.0); + List> experience = new ArrayList>() { + { + add(new StateActionPair(new Observation(Nd4j.zeros(1)), 0, 0.0, true)); + } + }; + + // Act + sut.computeGradients(dqnMock, experience); + + // Assert + assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); + } + + @Test + public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() { + // Arrange + MockDQN globalDQNMock = new MockDQN(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, -1, 1.0); + List> experience = new ArrayList>() { + { + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false)); + } + }; + MockDQN dqnMock = new MockDQN(); + + // Act + sut.computeGradients(dqnMock, experience); + + // Assert + assertEquals(2, dqnMock.outputAllParams.size()); + assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001); + assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); + } + + @Test + public void when_terminalWithTargetUpdate_expect_initRewardWithMaxQFromGlobal() { + // Arrange + MockDQN globalDQNMock = new MockDQN(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, 1.0); + List> experience = new ArrayList>() { + { + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false)); + } + }; + MockDQN dqnMock = new MockDQN(); + + // Act + sut.computeGradients(dqnMock, experience); + + // Assert + assertEquals(1, globalDQNMock.outputAllParams.size()); + assertEquals(-123.0, globalDQNMock.outputAllParams.get(0).getDouble(0, 0), 0.00001); + assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); + } + + @Test + public void when_callingWithMultipleExperiences_expect_gradientsAreValid() { + // Arrange + double gamma = 0.9; + MockDQN globalDQNMock = new MockDQN(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, gamma); + List> experience = new ArrayList>() { + { + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false)); + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true)); + } + }; + MockDQN dqnMock = new MockDQN(); + + // Act + sut.computeGradients(dqnMock, experience); + + // Assert + // input side -- should be a stack of observations + INDArray input = dqnMock.gradientParams.get(0).getLeft(); + assertEquals(-1.1, input.getDouble(0, 0), 0.00001); + assertEquals(-1.2, input.getDouble(0, 1), 0.00001); + assertEquals(-2.1, input.getDouble(1, 0), 0.00001); + assertEquals(-2.2, input.getDouble(1, 1), 0.00001); + + // target side + INDArray target = dqnMock.gradientParams.get(0).getRight(); + assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001); + assertEquals(1.2, target.getDouble(0, 1), 0.00001); + assertEquals(2.1, target.getDouble(1, 0), 0.00001); + assertEquals(2.0, target.getDouble(1, 1), 0.00001); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index fe8dd6acc..9d77084d5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -17,6 +17,8 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.IExpReplay; @@ -75,8 +77,8 @@ public class QLearningDiscreteTest { .build(); MockDataManager dataManager = new MockDataManager(false); - MockExpReplay expReplay = new MockExpReplay(); - TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random); + MockExperienceHandler experienceHandler = new MockExperienceHandler(); + TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, experienceHandler, 10, random); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); sut.setHistoryProcessor(hp); @@ -93,7 +95,6 @@ public class QLearningDiscreteTest { for (int i = 0; i < expectedRecords.length; ++i) { assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); } - assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount); @@ -133,30 +134,31 @@ public class QLearningDiscreteTest { // MDP calls assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray()); - // ExpReplay calls - double[] expectedTrRewards = new double[]{9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0}; - int[] expectedTrActions = new int[]{1, 4, 2, 4, 4, 4, 4, 4}; - double[] expectedTrNextObservation = new double[]{2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0}; - double[][] expectedTrObservations = new double[][]{ - new double[]{0.0, 2.0, 4.0, 6.0, 8.0}, - new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, - new double[]{4.0, 6.0, 8.0, 10.0, 12.0}, - new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, - new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, - new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, - new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, - new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, + // ExperienceHandler calls + double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 }; + int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; + double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; + double[][] expectedTrObservations = new double[][] { + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, + new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, + new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, + new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, }; - assertEquals(expectedTrObservations.length, expReplay.transitions.size()); - for (int i = 0; i < expectedTrRewards.length; ++i) { - Transition tr = expReplay.transitions.get(i); - assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001); - assertEquals(expectedTrActions[i], tr.getAction()); - assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001); - for (int j = 0; j < expectedTrObservations[i].length; ++j) { - assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001); + + assertEquals(expectedTrObservations.length, experienceHandler.addExperienceArgs.size()); + for(int i = 0; i < expectedTrRewards.length; ++i) { + StateActionPair stateActionPair = experienceHandler.addExperienceArgs.get(i); + assertEquals(expectedTrRewards[i], stateActionPair.getReward(), 0.0001); + assertEquals((int)expectedTrActions[i], (int)stateActionPair.getAction()); + for(int j = 0; j < expectedTrObservations[i].length; ++j) { + assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * stateActionPair.getObservation().getData().getDouble(0, j, 0), 0.0001); } } + assertEquals(expectedTrNextObservation[expectedTrNextObservation.length - 1], 255.0 * experienceHandler.finalObservation.getData().getDouble(0), 0.0001); // trainEpoch result assertEquals(initStepCount + 16, result.getStepCounter()); @@ -167,20 +169,16 @@ public class QLearningDiscreteTest { public static class TestQLearningDiscrete extends QLearningDiscrete { public TestQLearningDiscrete(MDP mdp, IDQN dqn, - QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, + QLearningConfiguration conf, IDataManager dataManager, ExperienceHandler> experienceHandler, int epsilonNbStep, Random rnd) { super(mdp, dqn, conf, epsilonNbStep, rnd); addListener(new DataManagerTrainingListener(dataManager)); - setExpReplay(expReplay); + setExperienceHandler(experienceHandler); } @Override - protected DataSet setTarget(ArrayList> transitions) { - return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[]{123.0}), Nd4j.create(new double[]{234.0})); - } - - public void setExpReplay(IExpReplay exp) { - this.expReplay = exp; + protected DataSet setTarget(List> transitions) { + return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java index 28d7f3914..6f20d82ca 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java @@ -19,6 +19,7 @@ public class MockDQN implements IDQN { public final List outputParams = new ArrayList<>(); public final List> fitParams = new ArrayList<>(); public final List> gradientParams = new ArrayList<>(); + public final List outputAllParams = new ArrayList<>(); @Override public NeuralNetwork[] getNeuralNetworks() { @@ -58,7 +59,8 @@ public class MockDQN implements IDQN { @Override public INDArray[] outputAll(INDArray batch) { - return new INDArray[0]; + outputAllParams.add(batch); + return new INDArray[] { batch.mul(-1.0) }; } @Override @@ -109,4 +111,4 @@ public class MockDQN implements IDQN { public void save(String filename) throws IOException { } -} +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java deleted file mode 100644 index d1fa84c04..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.deeplearning4j.rl4j.support; - -import org.deeplearning4j.rl4j.learning.sync.IExpReplay; -import org.deeplearning4j.rl4j.learning.sync.Transition; - -import java.util.ArrayList; -import java.util.List; - -public class MockExpReplay implements IExpReplay { - - public List> transitions = new ArrayList<>(); - - @Override - public ArrayList> getBatch() { - return null; - } - - @Override - public void store(Transition transition) { - transitions.add(transition); - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java new file mode 100644 index 000000000..13ea5d93a --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java @@ -0,0 +1,46 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.observation.Observation; + +import java.util.ArrayList; +import java.util.List; + +public class MockExperienceHandler implements ExperienceHandler> { + public List> addExperienceArgs = new ArrayList>(); + public Observation finalObservation; + public boolean isGenerateTrainingBatchCalled; + public boolean isResetCalled; + + @Override + public void addExperience(Observation observation, Integer action, double reward, boolean isTerminal) { + addExperienceArgs.add(new StateActionPair<>(observation, action, reward, isTerminal)); + } + + @Override + public void setFinalObservation(Observation observation) { + finalObservation = observation; + } + + @Override + public List> generateTrainingBatch() { + isGenerateTrainingBatchCalled = true; + return new ArrayList>() { + { + add(new Transition(null, 0, 0.0, false)); + } + }; + } + + @Override + public void reset() { + isResetCalled = true; + } + + @Override + public int getTrainingBatchSize() { + return 1; + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java index 5395242b2..ffba71b5a 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java @@ -5,6 +5,16 @@ import org.nd4j.linalg.api.ndarray.INDArray; public class MockObservationSpace implements ObservationSpace { + private final int[] shape; + + public MockObservationSpace() { + this(new int[] { 1 }); + } + + public MockObservationSpace(int[] shape) { + this.shape = shape; + } + @Override public String getName() { return null; @@ -12,7 +22,7 @@ public class MockObservationSpace implements ObservationSpace { @Override public int[] getShape() { - return new int[] { 1 }; + return shape; } @Override @@ -24,4 +34,4 @@ public class MockObservationSpace implements ObservationSpace { public INDArray getHigh() { return null; } -} +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java new file mode 100644 index 000000000..dbe2fe1fc --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java @@ -0,0 +1,19 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; + +import java.util.ArrayList; +import java.util.List; + +public class MockUpdateAlgorithm implements UpdateAlgorithm { + + public final List>> experiences = new ArrayList>>(); + + @Override + public Gradient[] computeGradients(MockNeuralNet current, List> experience) { + experiences.add(experience); + return new Gradient[0]; + } +} From 23e4aa99ad8b254c7dd32a632be4658c540bd5af Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Mon, 13 Apr 2020 13:21:51 +0300 Subject: [PATCH 18/19] Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii * - further working on bp for lstmLayer Signed-off-by: Yurii * - minor change Signed-off-by: Yurii * - further work on bp for lstmLayer 2 Signed-off-by: Yurii * - further work on bp for lstmLayer 3 Signed-off-by: Yurii * - further work on bp for lstmLayer 4 Signed-off-by: Yurii * - further work on bp for lstmLayer 5 Signed-off-by: Yurii * - further work on bp for lstmLayer 6 Signed-off-by: Yurii * - further work on bp for lstmLayer 7 Signed-off-by: Yurii * - further work on bp for lstmLayer 8 Signed-off-by: Yurii * - further work on bp for lstmLayer 9 Signed-off-by: Yurii * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii * - further work on bp for lstmLayer 10 Signed-off-by: Yurii * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii * - further work on bp for lstmLayer 10 Signed-off-by: Yurii * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii --- libnd4j/include/array/NDArray.hXX | 1 - libnd4j/include/array/cpu/NDArrayLambda.hpp | 2 +- libnd4j/include/helpers/GradCheck.h | 6 +- libnd4j/include/helpers/cpu/MmulHelper.cpp | 8 +- libnd4j/include/helpers/impl/GradCheck.cpp | 52 +- libnd4j/include/loops/legacy_ops.h | 3 +- .../generic/nn/recurrent/lstmLayer.cpp | 444 ++++- .../generic/nn/recurrent/lstmLayerCell.cpp | 339 ++++ .../ops/declarable/headers/recurrent.h | 12 + .../ops/declarable/helpers/impl/lstmLayer.cpp | 1460 ++++++++++++++++- .../ops/declarable/helpers/lstmLayer.h | 85 +- libnd4j/include/ops/ops.h | 15 +- .../layers_tests/DeclarableOpsTests13.cpp | 563 ++++++- .../layers_tests/PlaygroundTests.cpp | 73 + 14 files changed, 2896 insertions(+), 167 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 1caae85a4..7756fb7ae 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -403,7 +403,6 @@ NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::Launch ///////////////////////////////////////////////////////////////////////// // u8 string constructors -///////////////////////////////////////////////////////////////////////// NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { if (!DataTypeUtils::isS(dtype)) { diff --git a/libnd4j/include/array/cpu/NDArrayLambda.hpp b/libnd4j/include/array/cpu/NDArrayLambda.hpp index 8bced3de4..bd8742288 100644 --- a/libnd4j/include/array/cpu/NDArrayLambda.hpp +++ b/libnd4j/include/array/cpu/NDArrayLambda.hpp @@ -10,7 +10,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std:: throw std::runtime_error("NDArray::applyTriplewiseLambda method: bother four arrays (this, second, third, target) should have the same type !"); if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { - nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); + nd4j_printf("applyTriplewiseLambda requires all operands to have the same shape\n",""); throw std::runtime_error("Shapes mismach"); } diff --git a/libnd4j/include/helpers/GradCheck.h b/libnd4j/include/helpers/GradCheck.h index 0d184a5a1..f5fd1f3df 100644 --- a/libnd4j/include/helpers/GradCheck.h +++ b/libnd4j/include/helpers/GradCheck.h @@ -47,13 +47,13 @@ class ND4J_EXPORT GradCheck { * opBP - back propagation operation * argsHolderFF - argument holder for feed forward operation * argsHolderBP - argument holder for back propagation operation - * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty array which means to check all arrays + * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays * IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.} * loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM + * outArrsFFIdx - contains indexes of ff output arrays which are independent from each other, default means all are independent */ static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM); - + const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector& outArrsFFIdx = {}); }; diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 62d8153ef..73f3e54bd 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -372,16 +372,16 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con int xLenDim(0), yLenDim(0); if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); + throw std::runtime_error("MmulHelper::dot: X array must be vector !"); if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); + throw std::runtime_error("MmulHelper::dot: Y array must be vector !"); if(Z != nullptr && !Z->isScalar()) - throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); + throw std::runtime_error("MmulHelper::dot: Z array must be scalar !"); const auto length = X->lengthOf(); if(Y->lengthOf() != length) - throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); + throw std::runtime_error("MmulHelper::dot: lengths of input vectors are different !"); if(Z == nullptr) Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); diff --git a/libnd4j/include/helpers/impl/GradCheck.cpp b/libnd4j/include/helpers/impl/GradCheck.cpp index 2643a7b6d..f3daa798c 100644 --- a/libnd4j/include/helpers/impl/GradCheck.cpp +++ b/libnd4j/include/helpers/impl/GradCheck.cpp @@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector& ////////////////////////////////////////////////////////////////////////// bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss ) { + const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss, const std::vector& outArrsFFIdx) { const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP @@ -82,12 +82,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons int numOutArrs = outArrsFF.size(); double scorePlus = 0.; - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); + if(!outArrsFFIdx.empty()) { + for(const auto& k : outArrsFFIdx) { // loop through independent output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scorePlus += tmpScalar.e(0); + } + } + else { + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scorePlus += tmpScalar.e(0); + } } // subtract epsilon, feed forward @@ -95,12 +106,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons outArrsFF = opFF.execute(argsHolderFF); double scoreMinus = 0.; - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); + if(!outArrsFFIdx.empty()) { + for(const auto& k : outArrsFFIdx) { // loop through independent output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scoreMinus += tmpScalar.e(0); + } + } + else { + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scoreMinus += tmpScalar.e(0); + } } // restore initial element value @@ -120,7 +142,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons throw std::runtime_error(""); } - // printf("num = %.5f, ana = %.5f\n", numericalGrad, analyticGrad); + // printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, analyticGrad); // calculate relative error double relError; @@ -134,7 +156,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons if(math::nd4j_abs(analyticGrad - numericalGrad) < MINABSERR) continue; - printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad); + printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, analyticGrad); printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j); return false; } diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 95f83be1a..001f8806c 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -253,7 +253,8 @@ (45, ReversePow), \ (46, DivideNoNan), \ (47, IGamma), \ - (48, IGammac) + (48, IGammac), \ + (49, RELUDerivative) diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp index 3a02b8a70..8637fe990 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp @@ -24,10 +24,10 @@ #include #include + namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { @@ -43,7 +43,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't + // ct = clip(ft ◦ ct-1 + it ◦ c't) // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) // ht = ot ◦ tanh(ct) @@ -72,26 +72,26 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // 2) [2, nOut, 4*nOut] when directionMode >= 2 // ******* - // peephole weights Wp: + // peephole weights Wp, optional: // 1) [3*nOut] when directionMode < 2 // 2) [2, 3*nOut] when directionMode >= 2 // ******* - // biases b: + // biases b, optional: // 1) [4*nOut] when directionMode < 2 // 2) [2, 4*nOut] when directionMode >= 2 // ******* - // sequence length array seqLen: - // 1) [bS] always + // sequence length array seqLen, optional: + // 1) [bS] // ******* - // initial output hI: + // initial output hI, optional: // 1) [bS, nOut] when directionMode < 2 // 2) [2, bS, nOut] when directionMode >= 2 // ******* - // initial cell state cI (same shape as in hI): + // initial cell state cI (same shape as in hI), optional: // 1) [bS, nOut] when directionMode < 2 // 2) [2, bS, nOut] when directionMode >= 2 @@ -99,7 +99,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // OUTPUTS: // ******* - // output h: + // output h, optional: // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 @@ -109,19 +109,19 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 // ******* - // output at last step hL: + // output at last step hL, optional: // 1) [bS, nOut] when directionMode < 2 // 2) [2, bS, nOut] when directionMode >= 2 // ******* - // cell state at last step cL (same shape as in hL): + // cell state at last step cL (same shape as in hL), optional: // 1) [bS, nOut] when directionMode < 2 // 2) [2, bS, nOut] when directionMode >= 2 // !!! dimension 4*nOut implies order it, ft, c't, ot // !!! dimension 3*nOut implies order it, ft, ot - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX) const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus @@ -135,8 +135,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided const auto hasPH = B_ARG(4); // indicates whether peephole connections are present const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; @@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // evaluate dimensions const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); const Nd4jLong nOut = Wx->sizeAt(-1) / 4; // inputs validations @@ -323,9 +323,9 @@ DECLARE_SHAPE_FN(lstmLayer) { const auto Wr = INPUT_VARIABLE(2); // recurrent weights // evaluate dimensions - const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) ); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); const Nd4jLong nOut = Wx->sizeAt(-1) / 4; DataType type; @@ -398,6 +398,412 @@ DECLARE_SHAPE_FN(lstmLayer) { } +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // input weights Wx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // peephole weights Wp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // biases b, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // ******* + // sequence length array seqLen, optional: + // 1) [bS] + + // ******* + // initial output hI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs. output dLdh, optional: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + + // ******* + // gradient vs output at last time step dLdhL, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs cell state at last time step dLdcL(same shape as in dLdhL), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + + // OUTPUTS: + + // ******* + // gradient vs. input dLdx: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // gradient vs. input weights dLdWx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // gradient vs. recurrent weights dLdWr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // gradient vs. peephole weights dLdWp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // gradient vs. biases dLdb, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // gradient vs. sequence length array dLdsL, optional (do not calculate it!!!): + // 1) [bS] always + + // ******* + // gradient vs. initial output dLdhI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(3); // activation for cell state (c) + const auto outAct = INT_ARG(4); // activation for output (h) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether gradient vs. outputs is given for whole time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1} + const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at last time step (dLdhL) is given + const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state at last time step (dLdcL) is given + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode); + REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_BP operation: cell clipping value should be nonnegative (>=0) !"); + REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER_BP operation: please specify at least one of three input gradient arrays: dLdh, dLdhL or dLdcL !"); + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + const auto dLdh = retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output + const auto dLdhL = retLastH ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output at last time step + const auto dLdcL = retLastC ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. cell state at last time step + + count = 3; + auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input + auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights + auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights + auto dLdb = hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases + auto dLdsL = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. seqLen vector, we don't calculate it !!! + auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial output + auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial cell state + auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) : nullptr; // gradient vs. peephole weights + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if(directionMode < 2) { // no bidirectional + + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + // gradient vs. output at last time step validation + if(dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || dLdhL->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str()); + // gradient vs. cell state at last time step validation + if(dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || dLdcL->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str()); + } + else { // bidirectional + // Wx validation + if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + // gradient vs. output at last time step validation + if(dLdhL != nullptr && (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str()); + // gradient vs. cell state at last time step validation + if(dLdcL != nullptr && (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str()); + } + + // gradient vs. output validation + if(dLdh) { + int factor = directionMode <= 2 ? 1 : 2; + std::vector expdLdhShape; + if(dataFormat == 0) expdLdhShape = std::vector{sL, bS, factor*nOut}; + else if(dataFormat == 1) expdLdhShape = std::vector{bS, sL, factor*nOut}; + else if(dataFormat == 2) expdLdhShape = std::vector{bS, factor*nOut, sL}; + else expdLdhShape = std::vector{sL, 2, bS, nOut}; + REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expdLdhShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + } + + std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip), + static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), + static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; + + if(directionMode == 0) { // forward + + helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp); + } + else if(directionMode == 1) { // backward + + helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp); + } + else { // bidirectional + + NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0}); + NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0}); + NDArray dLdWxFwd = (*dLdWx)({0,1, 0,0, 0,0}); + NDArray dLdWxBwd = (*dLdWx)({1,2, 0,0, 0,0}); + + NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0}); + NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0}); + NDArray dLdWrFwd = (*dLdWr)({0,1, 0,0, 0,0}); + NDArray dLdWrBwd = (*dLdWr)({1,2, 0,0, 0,0}); + + NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), + *dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr), + *dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), *dLdbBwd(nullptr), + *dLdhIFwd(nullptr), *dLdhIBwd(nullptr), *dLdcIFwd(nullptr), *dLdcIBwd(nullptr); + + if(Wp) { + WpFwd = new NDArray((*Wp)({0,1, 0,0})); + WpBwd = new NDArray((*Wp)({1,2, 0,0})); + dLdWpFwd = new NDArray((*dLdWp)({0,1, 0,0})); + dLdWpBwd = new NDArray((*dLdWp)({1,2, 0,0})); + } + if(b) { + bFwd = new NDArray((*b)({0,1, 0,0})); + bBwd = new NDArray((*b)({1,2, 0,0})); + dLdbFwd = new NDArray((*dLdb)({0,1, 0,0})); + dLdbBwd = new NDArray((*dLdb)({1,2, 0,0})); + } + if(hI) { + hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0})); + hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0})); + dLdhIFwd = new NDArray((*dLdhI)({0,1, 0,0, 0,0})); + dLdhIBwd = new NDArray((*dLdhI)({1,2, 0,0, 0,0})); + } + if(cI) { + cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0})); + cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0})); + dLdcIFwd = new NDArray((*dLdcI)({0,1, 0,0, 0,0})); + dLdcIBwd = new NDArray((*dLdcI)({1,2, 0,0, 0,0})); + } + if(dLdhL) { + dLdhLFwd = new NDArray((*dLdhL)({0,1, 0,0, 0,0})); + dLdhLBwd = new NDArray((*dLdhL)({1,2, 0,0, 0,0})); + } + if(dLdcL) { + dLdcLFwd = new NDArray((*dLdcL)({0,1, 0,0, 0,0})); + dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0})); + } + + // FIXME looks like sum (directionMode == 2) is impossible for backprop + if(dLdh) { + if(directionMode == 2) { // sum + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !"); + // dLdhFwd = dLdh; + // dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content + } + else if(directionMode == 3) { // concat + dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0})); + dLdhBwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, nOut,2*nOut}) : (*dLdh)({0,0, nOut,2*nOut, 0,0})); + } + else { // directionMode == 4 + dLdhFwd = new NDArray((*dLdh)({0,0, 0,1, 0,0, 0,0})); + dLdhBwd = new NDArray((*dLdh)({0,0, 1,2, 0,0, 0,0})); + } + } + + + + helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd); + NDArray dLdxBwd = dLdx->ulike(); + helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd); + + *dLdx += dLdxBwd; + + delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd; + delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; + delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd; + delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd; + + if(dLdhFwd != dLdh) + delete dLdhFwd; + } + + return Status::OK(); +} + +DECLARE_TYPES(lstmLayer_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(lstmLayer_bp) { + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + + int count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + std::vector outShapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()}; + + if(b != nullptr) + outShapes.push_back(b->getShapeInfo()); + if(seqLen != nullptr) + outShapes.push_back(seqLen->getShapeInfo()); + if(hI != nullptr) + outShapes.push_back(hI->getShapeInfo()); + if(cI != nullptr) + outShapes.push_back(cI->getShapeInfo()); + if(Wp != nullptr) + outShapes.push_back(Wp->getShapeInfo()); + + return new ShapeList(outShapes); +} + } } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp new file mode 100644 index 000000000..46f32e399 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp @@ -0,0 +1,339 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#if NOT_EXCLUDED(OP_lstmLayerCell) + +#include +#include + +namespace sd { +namespace ops { + + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + // input x: [bS, nIn] or [nIn] + // input weights Wx: [nIn, 4*nOut] + // recurrent weights Wr: [nOut, 4*nOut] + // initial (previous) output hI: [bS, nOut] or [nOut] + // initial (previous) cell state cI: [bS, nOut] or [nOut] + // biases b (optional): [4*nOut] + // peephole weights Wp (optional): [3*nOut] + + // OUTPUTS: + // current output h: [bS, nOut] or [nOut] + // current cell state c: [bS, nOut] or [nOut] + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(1); // activation for cell state (c) + const auto outAct = INT_ARG(2); // activation for output (h) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasPH = B_ARG(1); // indicates whether peephole connections are present + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights + + REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !"); + + auto h = OUTPUT_VARIABLE(0); + auto c = OUTPUT_VARIABLE(1); + + // evaluate dimensions + const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // initial output/cell validation + std::vector exphIcIShape = x->rankOf() == 1 ? std::vector{nOut} : std::vector{bS, nOut}; + REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + + std::vector params = {static_cast(0)/*ignore*/, static_cast(0)/*ignore*/, static_cast(cellClip), + static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), + static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; + + helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c); + + return Status::OK(); +} + +DECLARE_TYPES(lstmLayerCell) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + + +DECLARE_SHAPE_FN(lstmLayerCell) { + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + + uint count = hasBiases ? 4 : 3; + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count); // initial cell state + + return new ShapeList({hI->getShapeInfo(), cI->getShapeInfo()}); +} + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + // input x: [bS, nIn] or [nIn] + // input weights Wx: [nIn, 4*nOut] + // recurrent weights Wr: [nOut, 4*nOut] + // initial (previous) output hI: [bS, nOut] or [nOut] + // initial (previous) cell state cI: [bS, nOut] or [nOut] + // gradient wrt output dLdh: [bS, nOut] or [nOut] + // gradient wrt cell state dLdc: [bS, nOut] or [nOut] + // peephole weights Wp (optional): [3*nOut] + // biases b (optional): [4*nOut] + + // OUTPUTS: + // gradient wrt x dLdx: [bS, nIn] or [nIn] + // gradient wrt Wx dLdWx: [nIn, 4*nOut] + // gradient wrt Wr dLdWr: [nOut, 4*nOut] + // gradient wrt hI dLdhI: [bS, nOut] or [nOut] + // gradient wrt cI dLdcI: [bS, nOut] or [nOut] + // gradient wrt b dLdb (optional): [4*nOut] + // gradient wrt Wp dLdWp (optional): [3*nOut] + + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(1); // activation for cell state (c) + const auto outAct = INT_ARG(2); // activation for output (h) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasPH = B_ARG(1); // indicates whether peephole connections are present + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output + + REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !"); + + count = 3; + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdWx = OUTPUT_VARIABLE(1); + auto dLdWr = OUTPUT_VARIABLE(2); + auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr; + auto dLdhI = OUTPUT_VARIABLE(count++); + auto dLdcI = OUTPUT_VARIABLE(count++); + auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr; + + // evaluate dimensions + const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // initial output/cell validation + std::vector exphIcIShape = x->rankOf() == 1 ? std::vector{nOut} : std::vector{bS, nOut}; + REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str()); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str()); + + + std::vector params = {static_cast(0)/*ignore*/, static_cast(0)/*ignore*/, static_cast(cellClip), + static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), + static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; + + std::vector zShape = x->rankOf() == 1 ? std::vector({4*nOut}) : std::vector({bS, 4*nOut}); + + NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext()); + NDArray a = z.ulike(); + NDArray h = cI->ulike(); + NDArray c = cI->ulike(); + + helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c); + + helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); + + return Status::OK(); +} + +DECLARE_TYPES(lstmLayerCellBp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + + +DECLARE_SHAPE_FN(lstmLayerCellBp) { + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasPH = B_ARG(1); // indicates whether peephole connections are present + + uint count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights + + std::vector shapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()}; + + if(b != nullptr) + shapes.push_back(b->getShapeInfo()); + + shapes.push_back(hI->getShapeInfo()); + shapes.push_back(cI->getShapeInfo()); + + if(Wp != nullptr) + shapes.push_back(Wp->getShapeInfo()); + + return new ShapeList(shapes); +} + +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/recurrent.h b/libnd4j/include/ops/declarable/headers/recurrent.h index 55138bb60..dd219867f 100644 --- a/libnd4j/include/ops/declarable/headers/recurrent.h +++ b/libnd4j/include/ops/declarable/headers/recurrent.h @@ -149,6 +149,13 @@ namespace ops { DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2); #endif + #if NOT_EXCLUDED(OP_lstmLayerCell) + DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3); + #endif + #if NOT_EXCLUDED(OP_lstmLayerCell) + DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3); + #endif + ////////////////////////////////////////////////////////////////////////// /** @@ -236,6 +243,11 @@ namespace ops { DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); #endif + ////////////////////////////////////////////////////////////////////////// + #if NOT_EXCLUDED(OP_lstmLayer) + DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5); + #endif + ////////////////////////////////////////////////////////////////////////// /** diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 435a3e32d..9fce17c4b 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -27,19 +28,215 @@ #include +#include +#include #include +#include // #include // #include // #include // #include // #include // #include -// #include + namespace sd { namespace ops { namespace helpers { +////////////////////////////////////////////////////////////////////////// +static void applyActivation(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { + + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::Tanh, z); + break; + case 1: + (const_cast(x)).applyScalar(scalar::RELU, 0, z); + break; + case 2: + (const_cast(x)).applyTransform(transform::Sigmoid, z); + break; + case 3: { + ExtraArguments args({ static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::Affine, z, &args); + break; + } + case 4: + (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); + break; + case 5: + thresholdRelu(x.getContext(), x, alpha, z); + break; + case 6: { + ExtraArguments args({ static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); + break; + } + case 7: + (const_cast(x)).applyTransform(transform::HardSigmoid, z); + break; + case 8: + (const_cast(x)).applyScalar(scalar::ELU, alpha, z); + break; + case 9: + (const_cast(x)).applyTransform(transform::SoftSign, z); + break; + case 10: + (const_cast(x)).applyTransform(transform::SoftPlus, z); + break; + default: + throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); + } +} + +////////////////////////////////////////////////////////////////////////// +static void activationDeriv(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { + + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::TanhDerivative, z); + break; + case 1: + (const_cast(x)).applyScalar(scalar::RELUDerivative, 0, z); + break; + case 2: + (const_cast(x)).applyTransform(transform::SigmoidDerivative, z); + break; + case 3: { + z = alpha; + break; + } + case 4: + (const_cast(x)).applyScalar(scalar::LeakyRELUDerivative, alpha, z); + break; + case 5: + (const_cast(x)).applyScalar(scalar::RELUDerivative, alpha, z); + break; + case 6: { + auto func = PRAGMA_THREADS_FOR { + for(Nd4jLong i = start; i < stop; ++i) { + auto val = beta * x.e(i); + z.p(i, alpha * beta * (1.f - sd::math::nd4j_tanh(val) * sd::math::nd4j_tanh(val))); + } + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); + break; + } + case 7: + (const_cast(x)).applyTransform(transform::HardSigmoidDerivative, z); + break; + case 8: + (const_cast(x)).applyScalar(scalar::ELUDerivative, alpha, z); + break; + case 9: + (const_cast(x)).applyTransform(transform::SoftSignDerivative, z); + break; + case 10: { + auto func = PRAGMA_THREADS_FOR { + for(Nd4jLong i = start; i < stop; ++i) { + auto val = sd::math::nd4j_exp(x.e(i)); + z.p(i, val / (1.f + val)); + } + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); + break; + } + default: + throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); + } +} + +////////////////////////////////////////////////////////////////////////// +// FIXME - derivative undefined when not-clipped c has element/elements equal to -clipVal or clipVal +static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, NDArray& z1, NDArray& z2, NDArray& z3) { + + if(clipVal == 0) + return; + + auto func = PRAGMA_THREADS_FOR { + for(Nd4jLong i = start; i < stop; ++i) { + const auto val = c.e(i); + if(val == -clipVal || val == clipVal) { + z0.p(i, 0.f); + z1.p(i, 0.f); + z2.p(i, 0.f); + z3.p(i, 0.f); + } + } + }; + samediff::Threads::parallel_for(func, 0, c.lengthOf()); +} + +////////////////////////////////////////////////////////////////////////// +static NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) { + + if(dataFormat == 0 || dataFormat == 3) + return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn] + + if(dataFormat == 1) + return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn] + + return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL] +} + +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) { + + if(dataFormat == 0 || dataFormat == 3) + return t * bS + b; // TNS: shape [sL, bS, nIn] + + return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] +} + +////////////////////////////////////////////////////////////////////////// +// x{M,K} x y{K,N} = z{M,N}, dzdy{K,N,M,N} - Jacobian derivative -> if x.rankOf() == 2 +// x{K} x y{K,N} = z{N}, dzdy{K,N,N} - Jacobian derivative -> if x.rankOf() == 1 +static NDArray mmulJacobianWeightsDeriv(const int nOut, const NDArray& x) { + + std::vector outShape = x.rankOf() == 1 ? std::vector({x.sizeAt(0), nOut, nOut}) : std::vector({x.sizeAt(1), nOut, x.sizeAt(0), nOut}); + + NDArray dzdy(x.ordering(), outShape, x.dataType(), x.getContext()); + + if(x.rankOf() == 1) { + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + if(i1 == i2) + dzdy.p(i0,i1,i2, x.e(i0)); + else + dzdy.p(i0,i1,i2, 0); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); + } + else { + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (auto i3 = 0; i3 < dzdy.sizeAt(3); ++i3) { + if(i1 == i3) + dzdy.p(i0,i1,i2,i3, x.e(i2,i0)); + else + dzdy.p(i0,i1,i2,i3, 0); + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); + } + + return dzdy; +} ////////////////////////////////////////////////////////////////////////// void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, @@ -47,25 +244,27 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const std::vector& params, NDArray* h, NDArray* c) { + // * -> means element-wise multiplication + // ^ -> means matrix multiplication /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ /** the objective is to provide math-readable code **/ // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) + // it = σ(Wxi ^ xt + Wri ^ ht-1 + bi) + // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + bf) + // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // ct = ft * ct-1 + it * c't + // ot = σ(Wxo ^ xt + Wro ^ ht-1 + bo) + // ht = ot * tanh(ct) // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) + // it = σ(Wxi ^ xt + Wri ^ ht-1 + Wpi * ct-1 + bi) + // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + Wpf * ct-1 + bf) + // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // ct = ft * ct-1 + it * c't + // ot = σ(Wxo ^ xt + Wro ^ ht-1 + Wpo * ct + bo) + // ht = ot * tanh(ct) // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus @@ -91,8 +290,8 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // Wx - input weights [nIn, 4*nOut] // Wr - recurrent weights [nOut, 4*nOut] // b - biases [4*nOut], optional, may be nullptr - // hI - previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // cI - previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // cI - (ct-1) previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr // Wp - peephole weights [3*nOut], optional, may be nullptr // OUTPUTS: @@ -109,24 +308,24 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // add biases if they are given if(b != nullptr) - z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut] + z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] - auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut] - auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut] - auto zc = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut] - auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut] + auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) // peephole connections for input and forget gates if(Wp != nullptr) { - zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut] - zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut] + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) } applyActivation(zi, params[3], params[4], params[5], zi); // inplace applyActivation(zf, params[3], params[4], params[5], zf); // inplace - applyActivation(zc, params[6], params[7], params[8], zc); // inplace + applyActivation(zg, params[6], params[7], params[8], zg); // inplace - c->assign(zf * *cI + zi * zc); // [bS, nOut] ◦ [bS, nOut] + [bS, nOut] ◦ [bS, nOut] = [bS, nOut] + c->assign(zf * *cI + zi * zg); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation if(params[2] != 0) @@ -134,15 +333,300 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // peephole connections for output gate if(Wp != nullptr) - zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [nOut] ◦ [bS, nOut] = [bS, nOut] + zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) applyActivation(zo, params[3], params[4], params[5], zo); applyActivation(*c, params[9], params[10], params[11], *h); - *h *= zo; // [bS, nOut] ◦ [bS, nOut] + *h *= zo; // [bS, nOut] * [bS, nOut](or[nOut]) } +////////////////////////////////////////////////////////////////////////// +// this auxiliary ff should be running before backprop +void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + NDArray* z, NDArray* a, NDArray* h, NDArray* c) { + + // z - zi, zf, zg, zo + // a - i, f, g, o + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + z->assign(mmul(*x, *Wx) + mmul(*hI, *Wr)); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] + //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] + // add biases if they are given + if(b != nullptr) + *z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] + + auto zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) + + auto i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) + + // peephole connections for input and forget gates + if(Wp != nullptr) { + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + } + + applyActivation(zi, params[3], params[4], params[5], i); + applyActivation(zf, params[3], params[4], params[5], f); + applyActivation(zg, params[6], params[7], params[8], g); + + c->assign(f * *cI + i * g); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) + + // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation + if(params[2] != 0) + c->applyScalar(scalar::LstmClip, params[2], *c); + + // peephole connections for output gate + if(Wp != nullptr) + zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + + applyActivation(zo, params[3], params[4], params[5], o); + + applyActivation(*c, params[9], params[10], params[11], *h); + *h *= o; // [bS, nOut] * [bS, nOut](or[nOut]) +} + + +////////////////////////////////////////////////////////////////////////// +void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdc, + const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { + + /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ + /** the objective is to provide math-readable code **/ + + // equations (no peephole connections) + // zi = x ^ Wxi + hI ^ Wri + bi + // zf = x ^ Wxf + hI ^ Wrf + bf + // zg = x ^ Wxg + hI ^ Wrg + bg + // zo = x ^ Wxo + hI ^ Wro + bo + // i = act(zi) + // f = act(zf) + // g = actC(zg) + // o = act(zo) + // c = clip(f * cI + i * g) + // h = o * actH(c) + + // equations (peephole connections are present) + // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi + // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf + // zg = x ^ Wxg + hI ^ Wrg + bg + // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo + // i = act(zi) + // f = act(zf) + // g = actC(zg) + // o = act(zo) + // c = clip(f * cI + i * g) + // h = o * actH(c) + + // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + + // params[0] - dataFormat, ignore + // params[1] - directionMode, ignore + // params[2] - cell clipping value, if it = 0 then do not apply clipping + + // params[3] - activation ID for input (i), forget (f) and output (o) gates + // params[4] - alpha value for gates activation + // params[5] - beta value for gates activation + + // params[6] - activation ID for cell state (c) + // params[7] - alpha value for cell state activation + // params[8] - beta value for cell state activation + + // params[9] - activation ID for output (h) + // params[10] - alpha value for output activation + // params[11] - beta value for output activation + + // INPUTS: + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdc - loss derivative with respect to c, [bS, nOut] or [nOut] if seqLen != nullptr + // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] + + // OUTPUTS: + // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr + // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] + // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] + // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] + // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] + + // !!! dimension 4*nOut implies order i, f, g, o + // !!! dimension 3*nOut implies order i, f, o + + // dhdc = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0 [bS, nOut] + // dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0 [bS, nOut] + + // dLdhI += dLdh; [bS, nOut] + // dLdcI += dLdhI * dhdc; [bS, nOut] + + // dLdzi = dLdcI*dcdi*didzi; [bS, nOut](or[nOut]) + // dLdzf = dLdcI*dcdf*dfdzf; [bS, nOut](or[nOut]) + // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) + // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) + + // dLdx = dLdzi^WxiT + dLdzf^WxfT + dLdzg^WxgT + dLdzo^WxoT, [bS, nIn] + // dLdhI = dLdzi^WriT + dLdzf^WrfT + dLdzg^WrgT + dLdzo^WroT, [bS, nOut] + // dLdcI = dLdcI*dcdcI, [bS, nOut] + + // dLdWxi = xT^dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxf = xT^dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxg = xT^dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxo = xT^dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] + + // dLdWri = hIT^dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrf = hIT^dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrg = hIT^dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWro = hIT^dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] + + // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbg = dLdzg.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbo = dLdzo.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + + // dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + const Nd4jLong nIn = x->sizeAt(-1); + + NDArray zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) + NDArray zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) + NDArray zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) + NDArray zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) + + NDArray i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) + NDArray f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) + NDArray g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) + NDArray o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) + + NDArray dLdz = z->ulike(); // [bS, 4*nOut](or[4*nOut]) + NDArray dLdzi = x->rankOf() == 1 ? dLdz({0, nOut}) : dLdz({0,0, 0, nOut}); + NDArray dLdzf = x->rankOf() == 1 ? dLdz({nOut, 2*nOut}) : dLdz({0,0, nOut, 2*nOut}); + NDArray dLdzg = x->rankOf() == 1 ? dLdz({2*nOut, 3*nOut}) : dLdz({0,0, 2*nOut, 3*nOut}); + NDArray dLdzo = x->rankOf() == 1 ? dLdz({3*nOut, 4*nOut}) : dLdz({0,0, 3*nOut, 4*nOut}); + + // dcdzi = dcdi*didzi, [bS, nOut](or[nOut]) + activationDeriv(zi, params[3], params[4], params[5], dLdzi); // didzi, inplace + dLdzi *= g; // dcdi = g*clipDeriv + + // dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut]) + activationDeriv(zf, params[3], params[4], params[5], dLdzf); // dfdzf, inplace + dLdzf *= *cI; // dcdf = cI*clipDeriv + + // dcdzg = dcde*dedzg, [bS, nOut](or[nOut]) + activationDeriv(zg, params[6], params[7], params[8], dLdzg); // dgdzg, inplace + dLdzg *= i; // dcdf = i*clipDeriv + + // dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut]) + activationDeriv(zo, params[3], params[4], params[5], dLdzo); + NDArray temp = dLdzo.ulike(); + applyActivation(*c, params[9], params[10], params[11], temp); // actH(c), inplace + dLdzo *= temp; + + // dcdcI + NDArray dcdcI = f.dup(); // dcdcI = f*clipDeriv [bS, nOut](or[nOut]) + + // take into account possible deposit from clipping derivative + clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI); + + // dhdc + NDArray dhdc = c->ulike(); + activationDeriv(*c, params[9], params[10], params[11], dhdc); // [bS, nOut] + dhdc *= o; + + if(Wp) { + dhdc += dLdzo*(*Wp)({2*nOut, 3*nOut}); + dcdcI += dLdzi*(*Wp)({0, nOut}) + dLdzf*(*Wp)({nOut, 2*nOut}); // broadcast [bS, nOut] * nOut + ... + } + + if(dLdh) + *dLdhI += *dLdh; + if(dLdc) + *dLdcI += *dLdc; + else + *dLdcI += *dLdhI * dhdc; + + dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzg *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzo *= *dLdhI; // [bS, nOut](or[nOut]) + + // dLdx + NDArray WxT = Wx->transpose(); + MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x [4*nOut, nIn]) = [bS, nIn] ( or[nIn] ) + + // dLdhI + NDArray WrT = Wr->transpose(); + MmulHelper::mmul(&dLdz, &WrT, dLdhI); // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x [4*nOut, nOut]) = [bS, nOut] ( or[nOut] ) + + // dLdcI + dLdcI->assign(*dLdcI*dcdcI); // [bS, nOut](or[nOut]) + + if(x->rankOf() == 1) { + + NDArray xT = x->reshape(x->ordering(),{nIn, 1}); // [nIn] -> [nIn, 1] + NDArray hIT = hI->reshape(hI->ordering(),{nOut, 1}); // [nOut] -> [nOut, 1] + NDArray dLdzR = dLdz.reshape(dLdz.ordering(), {1, 4*nOut}); // [nOut] -> [1, 4*nOut] + + // dLdWx + *dLdWx += mmul(xT, dLdzR); // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut] + + // dLdWr + *dLdWr += mmul(hIT, dLdzR); // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut] + } + else { + + // dLdWx + *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut] + + // dLdWr + *dLdWr += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut] + } + + // dLdb + if(b && x->rankOf() == 1) + *dLdb += dLdz; // [4*nOut] + else if(b) + *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 4*nOut] -> reduce -> [4*nOut]; + + // dLdWp + if(Wp && x->rankOf() == 1) { + (*dLdWp)({ 0,nOut}) += std::move(dLdzi)*(*cI); // [nOut] + (*dLdWp)({ nOut,2*nOut}) += std::move(dLdzf)*(*cI); // [nOut] + (*dLdWp)({2*nOut,3*nOut}) += std::move(dLdzo)*(*c); // [nOut] + } + else if(Wp) { + NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext()); + (std::move(dLdzi)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] + (*dLdWp)({0,nOut}) += temp; + (std::move(dLdzf)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] + (*dLdWp)({nOut,2*nOut}) += temp; + (std::move(dLdzo)*(*c)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] + (*dLdWp)({2*nOut,3*nOut}) += temp; + } +} ////////////////////////////////////////////////////////////////////////// void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, @@ -172,7 +656,7 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const int dataFormat = params[0]; const int directionMode = params[1]; - const Nd4jLong sL = x->sizeAt(dataFormat); + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); const Nd4jLong nOut = Wx->sizeAt(-1) / 4; @@ -192,7 +676,7 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, auto ct = cL; if(!cL) - cL = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + ct = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); auto ht = hL; if(!h && !hL) @@ -300,7 +784,8 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(hL) htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr - tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + if(limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } @@ -380,7 +865,8 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(hL) htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr - tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + if(limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } @@ -439,7 +925,8 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(hL) htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr - tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + if(limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } @@ -451,10 +938,915 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, delete c0Set; delete htSet; delete ctSet; + + if(!hI) + delete h0; + if(!cI) + delete c0; + if(!cL) + delete ct; + if(!h && !hL) + delete ht; +} + + +////////////////////////////////////////////////////////////////////////// +void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, + const std::vector& params, const bool forward, + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp) { + + // INPUTS: + // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // seqLen - [bS], optional, may be nullptr + // hI - initial output [bS, nOut], optional, may be nullptr + // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + // dLdh - gradient vs. output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr + // dLdhL - gradient vs. output at last time step [bS, nOut], optional, may be nullptr + // dLdcL - gradient vs. cell state at last time step [bS, nOut], optional, may be nullptr + + // OUTPUTS: + // dLdx - gradient vs. input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL] + // dLdWx - gradient vs. input weights [nIn, 4*nOut] + // dLdWr - gradient vs. recurrent weights [nOut, 4*nOut] + // dLdb - gradient vs. biases [4*nOut], optional, may be nullptr + // dLdhI - gradient vs. initial output [bS, nOut], optional, may be nullptr + // dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // dLdWp - gradient vs. peephole weights [3*nOut], optional, may be nullptr + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const int sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const int nOut = Wx->sizeAt(-1) / 4; + + auto dLdh0 = dLdhI; + if(!hI) + dLdh0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + + auto dLdc0 = dLdcI; + if(!cI) + dLdc0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + + NDArray z(x->ordering(), {sL, bS, 4*nOut}, x->dataType(), x->getContext()); + NDArray a = z.ulike(); + NDArray h(x->ordering(), {sL+1, bS, nOut}, x->dataType(), x->getContext()); + NDArray c = h.ulike(); + + // create sets of required (depends on seqLen presence) sub-arrays + std::vector dims; + ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr), + *dLdh0Set(nullptr), *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), *hISet(nullptr), *cISet(nullptr); + + if(!seqLen) { + + dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on [bS, nIn/nOut] + + xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] + dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] + hSet = new ResultSet(h.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] + cSet = new ResultSet(c.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] + zSet = new ResultSet(z.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] + aSet = new ResultSet(a.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] + if(dLdh) + dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] + } + else { + + dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis + + xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + hSet = new ResultSet(h.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] + cSet = new ResultSet(c.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] + zSet = new ResultSet(z.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] + aSet = new ResultSet(a.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] + + if(hI) + hISet = new ResultSet(hI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + if(cI) + cISet = new ResultSet(cI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + + dLdh0Set = new ResultSet(dLdh0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + dLdc0Set = new ResultSet(dLdc0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + + if(dLdh) + dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] + if(!dLdh && dLdhL) + dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + if(!dLdh && !dLdhL) + dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + } + + + // loops + if(forward) { + + if(!seqLen) { // seqLen is absent + + if(hI) + h({0,1, 0,0, 0,0}).assign(hI); + else + h({0,1, 0,0, 0,0}).nullify(); + if(cI) + c({0,1, 0,0, 0,0}).assign(cI); + else + c({0,1, 0,0, 0,0}).nullify(); + + // ff + for (int t = 0; t < sL; ++t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t+1), cSet->at(t+1)); + + // bp + for (int t = sL-1; t >= 0; --t) { + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == sL-1 ? dLdhL : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-1 ? dLdcL : nullptr); + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdcc, + zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); + } + } + else { // seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + continue; + } + + if(hI) + h({0,1, e,e+1, 0,0}).assign(hISet->at(e)); + else + h({0,1, e,e+1, 0,0}).nullify(); + if(cI) + c({0,1, e,e+1, 0,0}).assign(cISet->at(e)); + else + c({0,1, e,e+1, 0,0}).nullify(); + + // ff + for (int t = 0; t < limit; ++t) + lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, params, + zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e)); + + // bp + for (int t = limit-1; t >= 0; --t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == limit-1 && dLdhL ? dLdhLSet->at(e) : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == limit-1 ? dLdcLSet->at(e) : nullptr); + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdcc, + zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, + dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); + } + + if(limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + else { // backward or bidirectional + + if(!seqLen) { // backward or bidirectional, seqLen is absent + + if(hI) + h({sL,sL+1, 0,0, 0,0}).assign(hI); + else + h({sL,sL+1, 0,0, 0,0}).nullify(); + if(cI) + c({sL,sL+1, 0,0, 0,0}).assign(cI); + else + c({sL,sL+1, 0,0, 0,0}).nullify(); + + // ff + for (int t = sL-1; t >= 0; --t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t), cSet->at(t)); + + // bp + for (int t = 0; t < sL; ++t) { + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == 0 ? dLdhL : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcL : nullptr); + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdcc, + zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); + } + } + else if(directionMode == 1) { // backward, seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + continue; + } + + if(hI) + h({sL,sL+1, e,e+1, 0,0}).assign(hISet->at(e)); + else + h({sL,sL+1, e,e+1, 0,0}).nullify(); + if(cI) + c({sL,sL+1, e,e+1, 0,0}).assign(cISet->at(e)); + else + c({sL,sL+1, e,e+1, 0,0}).nullify(); + + // ff + for (int t = sL - 1; t >= sL-limit; --t) + lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, + zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); + + // bp + for (int t = sL-limit; t < sL; ++t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == sL-limit && dLdhL ? dLdhLSet->at(e) : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-limit ? dLdcLSet->at(e) : nullptr); + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, + dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); + } + + if(limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + else { // bidirectional mode, seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + continue; + } + + if(hI) + h({limit,limit+1, e,e+1, 0,0}).assign(hISet->at(e)); + else + h({limit,limit+1, e,e+1, 0,0}).nullify(); + if(cI) + c({limit,limit+1, e,e+1, 0,0}).assign(cISet->at(e)); + else + c({limit,limit+1, e,e+1, 0,0}).nullify(); + + // ff + for (int t = limit - 1; t >= 0; --t) + lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, + zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); + + // bp + for (int t = 0; t < limit; ++t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == 0 && dLdhL ? dLdhLSet->at(e) : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcLSet->at(e) : nullptr); + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, + dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); + } + + if(limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + + delete xSet; delete dLdxSet; delete hSet; delete cSet; delete aSet; delete zSet; + delete dLdhSet; delete dLdh0Set; delete dLdc0Set; delete dLdhLSet; delete dLdcLSet; delete hISet; delete cISet; + + if(!hI) + delete dLdh0; + if(!cI) + delete dLdc0; +} + + +} +} } -} -} -} +////////////////////////////////////////////////////////////////////////// +// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, +// const NDArray* b, NDArray* hI, NDArray* cI, const NDArray* Wp, const NDArray* dLdh, +// const std::vector& params, const bool firstIter, + +// NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, NDArray* dhIdWr, NDArray* dcIdWr, +// NDArray* dhIdb, NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp, +// NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { + +// /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ +// /** the objective is to provide math-readable code **/ + +// // equations (no peephole connections) +// // zi = x ^ Wxi + hI ^ Wri + bi +// // zf = x ^ Wxf + hI ^ Wrf + bf +// // zg = x ^ Wxg + hI ^ Wrg + bg +// // zo = x ^ Wxo + hI ^ Wro + bo +// // i = act(zi) +// // f = act(zf) +// // g = actC(zg) +// // o = act(zo) +// // c = clip(f * cI + i * g) +// // h = o * actH(c) + +// // equations (peephole connections are present) +// // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi +// // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf +// // zg = x ^ Wxg + hI ^ Wrg + bg +// // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo +// // i = act(zi) +// // f = act(zf) +// // g = actC(zg) +// // o = act(zo) +// // c = clip(f * cI + i * g) +// // h = o * actH(c) + +// // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + +// // params[0] - dataFormat, ignore +// // params[1] - directionMode, ignore +// // params[2] - cell clipping value, if it = 0 then do not apply clipping + +// // params[3] - activation ID for input (i), forget (f) and output (o) gates +// // params[4] - alpha value for gates activation +// // params[5] - beta value for gates activation + +// // params[6] - activation ID for cell state (c) +// // params[7] - alpha value for cell state activation +// // params[8] - beta value for cell state activation + +// // params[9] - activation ID for output (h) +// // params[10] - alpha value for output activation +// // params[11] - beta value for output activation + +// // INPUTS: +// // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr +// // Wx - input weights [nIn, 4*nOut] +// // Wr - recurrent weights [nOut, 4*nOut] +// // b - biases [4*nOut], optional, may be nullptr +// // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr +// // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr +// // Wp - peephole weights [3*nOut], optional, may be nullptr +// // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr +// // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if seqLen != nullptr +// // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr +// // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr +// // dhIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr +// // dcIdb - derivative from previous time step, [4*nOut], optional, may be nullptr +// // dhIdb - derivative from previous time step, [4*nOut], optional, may be nullptr + +// // OUTPUTS: +// // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr +// // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] +// // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] +// // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] +// // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr +// // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr +// // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] + +// // !!! dimension 4*nOut implies order i, f, g, o +// // !!! dimension 3*nOut implies order i, f, o + +// // dcdzi = dcdi*didzi +// // dcdzf = dcdf*dfdzf +// // dcdzg = dcdg*dgdzg +// // dhdzo = dhdo*dodzo + +// // dhdc = dhdc + Wp ? dhdzo*dzodc : 0 [bS, nOut] +// // factor = dLdh*dhdc [bS, nOut] +// // iFactor = factor*dcdzi [bS, nOut] +// // fFactor = factor*dcdzf [bS, nOut] +// // eFactor = factor*dcdzg [bS, nOut] +// // oFactor = *dLdh*dhdzo [bS, nOut] + +// // tempC = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0; +// // tempIFE = dcdzi^WriT + dcdzf^WrfT + dcdzg^WrgT +// // tempO = dhdzo^WroT + +// // dhIdcI = dhdc_from_previous_time_step + +// // dLdx = iFactor^WxiT + fFactor^WxfT + eFactor^WxgT + oFactor^WxoT, [bS, nIn] +// // dLdhI = iFactor^WriT + fFactor^WrfT + eFactor^WrgT + oFactor^WroT, [bS, nOut] +// // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut] + +// // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf, dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg, dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxo(dcIdWxo) = 0 + tempIFE*dhIdWxo + tempC*dcIdWxo; dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] + +// // dhdWxi(dhIdWxi) = 0 + dhdc*dcdWxi + tempO*dhIdWxi, dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxf(dhIdWxf) = 0 + dhdc*dcdWxf + tempO*dhIdWxf, dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxg(dhIdWxg) = 0 + dhdc*dcdWxg + tempO*dhIdWxg, dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo, dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] + +// // dhdWri(dhIdWri) = 0 + dhdc*dcdWri + tempO*dhIdWri, dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWrf(dhIdWrf) = 0 + dhdc*dcdWrf + tempO*dhIdWrf, dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWrg(dhIdWrg) = 0 + dhdc*dcdWrg + tempO*dhIdWrg, dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro, dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] + +// // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWro(dcIdWro) = 0 + tempIFE*dhIdWro + tempC*dcIdWro; dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] + +// // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + tempC*dcIdWpi).reduceALongFirstDim, dcIdWpi=dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + tempC*dcIdWpf).reduceALongFirstDim, dcIdWpf=dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcIdWpo = (0 + tempIFE*dhIdWpo + tempC*dcIdWpo).reduceALongFirstDim, dcIdWpo=dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] + +// // dhdWpi(dhIdWpi) =( 0 + dhdc*dcdWpi + tempO*dhIdWpi).reduceALongFirstDim, dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdWpf(dhIdWpf) =( 0 + dhdc*dcdWpf + tempO*dhIdWpf).reduceALongFirstDim, dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + tempO*dhIdWpo).reduceALongFirstDim, dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] + +// // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + tempC*dcIdbi).reduceALongFirstDim, dcIdbi=dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + tempC*dcIdbf).reduceALongFirstDim, dcIdbf=dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + tempC*dcIdbg).reduceALongFirstDim, dcIdbg=dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcdbo(dcIdbo) = ( 0 + tempIFE*dhIdbo + tempC*dcIdbo).reduceALongFirstDim; dcIdbo=dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] + +// // dhdbi(dhIdbi) = ( 0 + dhdc*dcdbi + tempO*dhIdbi).reduceALongFirstDim, dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdbf(dhIdbf) = ( 0 + dhdc*dcdbf + tempO*dhIdbf).reduceALongFirstDim, dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdbg(dhIdbg) = ( 0 + dhdc*dcdbg + tempO*dhIdbg).reduceALongFirstDim, dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + tempO*dhIdbo).reduceALongFirstDim, dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] + +// const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + +// NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), *dhIdWpf(nullptr), *dhIdWpo(nullptr); +// if(Wp) { +// Wpi = new NDArray((*Wp)({0, nOut})); +// Wpf = new NDArray((*Wp)({nOut, 2*nOut})); +// Wpo = new NDArray((*Wp)({2*nOut, 3*nOut})); +// dhIdWpi = new NDArray((*dhIdWp)({0, nOut})); +// dhIdWpf = new NDArray((*dhIdWp)({nOut, 2*nOut})); +// dhIdWpo = new NDArray((*dhIdWp)({2*nOut, 3*nOut})); +// dcIdWpi = new NDArray((*dcIdWp)({0, nOut})); +// dcIdWpf = new NDArray((*dcIdWp)({nOut, 2*nOut})); +// dcIdWpo = new NDArray((*dcIdWp)({2*nOut, 3*nOut})); +// } + +// NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), *dhIdbo(nullptr); +// if(b) { +// dhIdbi = new NDArray((*dhIdb)({0, nOut})); +// dhIdbf = new NDArray((*dhIdb)({nOut, 2*nOut})); +// dhIdbg = new NDArray((*dhIdb)({2*nOut, 3*nOut})); +// dhIdbo = new NDArray((*dhIdb)({3*nOut, 4*nOut})); +// dcIdbi = new NDArray((*dcIdb)({0, nOut})); +// dcIdbf = new NDArray((*dcIdb)({nOut, 2*nOut})); +// dcIdbg = new NDArray((*dcIdb)({2*nOut, 3*nOut})); +// dcIdbo = new NDArray((*dcIdb)({3*nOut, 4*nOut})); +// } + +// NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0, 0,nOut, 0,0}) : (*dhIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWxf = x->rankOf() == 1 ? (*dhIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWxg = x->rankOf() == 1 ? (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr + +// NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0, 0,nOut, 0,0}) : (*dhIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWrf = x->rankOf() == 1 ? (*dhIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWrg = x->rankOf() == 1 ? (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr + +// NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0, 0,nOut, 0,0}) : (*dcIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWxf = x->rankOf() == 1 ? (*dcIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWxg = x->rankOf() == 1 ? (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr + +// NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0, 0,nOut, 0,0}) : (*dcIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWrf = x->rankOf() == 1 ? (*dcIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWrg = x->rankOf() == 1 ? (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr + +// NDArray WxiT = (*Wx)({0,0, 0, nOut}).transpose(); // [nOut, nIn] +// NDArray WxfT = (*Wx)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nIn] +// NDArray WxgT = (*Wx)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nIn] +// NDArray WxoT = (*Wx)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nIn] + +// NDArray WriT = (*Wr)({0,0, 0, nOut}).transpose(); // [nOut, nOut] +// NDArray WrfT = (*Wr)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nOut] +// NDArray WrgT = (*Wr)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nOut] +// NDArray WroT = (*Wr)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nOut] + +// // ***** feed forward step ***** // + +// auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] +// //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] +// // add biases if they are given +// if(b) +// z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut](or[4*nOut]) + +// auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) +// auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) +// auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) +// auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) + +// // peephole connections for input and forget gates +// if(Wp) { +// zi += *cI * *Wpi; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) +// zf += *cI * *Wpf; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) +// } + +// NDArray i = zi.ulike(); // [bS, nOut] +// NDArray f = zf.ulike(); // [bS, nOut] +// NDArray g = zg.ulike(); // [bS, nOut] +// applyActivation(zi, params[3], params[4], params[5], i); +// applyActivation(zf, params[3], params[4], params[5], f); +// applyActivation(zg, params[6], params[7], params[8], g); + +// NDArray c = f * *cI + i * g; // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) + +// // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation +// if(params[2] != 0) +// c.applyScalar(scalar::LstmClip, params[2], c); + +// // peephole connections for output gate +// if(Wp) +// zo += c * *Wpo; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + +// NDArray o = zo.ulike(); // [bS, nOut](or[nOut]) +// applyActivation(zo, params[3], params[4], params[5], o); + +// // ***** back prop step ***** // + +// NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x); // [nIn, nOut, bS, nOut] (or [nIn, nOut, nOut]) +// NDArray dWrJacobian = mmulJacobianWeightsDeriv(nOut, *hI); // [nOut, nOut, bS, nOut] (or [nOut, nOut, nOut]) + +// // dodzo +// NDArray dodzo = zo.ulike(); // [bS, nOut](or[nOut]) +// activationDeriv(zo, params[3], params[4], params[5], dodzo); + +// // dhdzo = dhdo*dodzo = actH(c)*dodzo +// NDArray dhdzo = zo.ulike(); // [bS, nOut](or[nOut]) +// applyActivation(c, params[9], params[10], params[11], dhdzo); // actH(c) +// hI->assign(o*dhdzo); +// dhdzo *= dodzo; + +// // dcdzi = dcdi*didzi +// NDArray dcdzi = zi.ulike(); // [bS, nOut](or[nOut]) +// activationDeriv(zi, params[3], params[4], params[5], dcdzi); // didzi +// dcdzi *= g; // dcdi = g*clipDeriv + +// // dcdzf = dcdf*dfdzf +// NDArray dcdzf = zf.ulike(); // [bS, nOut](or[nOut]) +// activationDeriv(zf, params[3], params[4], params[5], dcdzf); // dfdzf +// dcdzf *= *cI; // dcdf = cI*clipDeriv + +// // dcdzg = dcde*dedzg +// NDArray dcdzg = zg.ulike(); // [bS, nOut](or[nOut]) +// activationDeriv(zg, params[6], params[7], params[8], dcdzg); // dedzg +// dcdzg *= i; // dcdf = i*clipDeriv + +// // dcdcI +// NDArray dcdcI = f.dup(); // [bS, nOut](or[nOut]) + +// // take into account possible deposit from clipping derivative +// clipDeriv(params[2], c, dcdzi, dcdzf, dcdzg, dcdcI); + +// // dzodc +// NDArray* dzodc = Wpo; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) + +// // dzidcI +// NDArray* dzidcI = Wpi; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) + +// // dzfdcI +// NDArray* dzfdcI = Wpf; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) + +// // dhdc +// NDArray dhdc = c.ulike(); +// activationDeriv(c, params[9], params[10], params[11], dhdc); // [bS, nOut] +// dhdc *= o; +// if(Wp) +// dhdc += dhdzo* *dzodc; + +// NDArray factor = *dLdh * dhdc; + +// NDArray iFactor = factor*dcdzi; // [bS, nOut](or[nOut]) +// NDArray fFactor = factor*dcdzf; // [bS, nOut](or[nOut]) +// NDArray eFactor = factor*dcdzg; // [bS, nOut](or[nOut]) +// NDArray oFactor = *dLdh *dhdzo; // [bS, nOut](or[nOut]) + +// NDArray tempC = dcdcI; +// if(Wp) +// tempC += dcdzi*(*dzidcI) + dcdzf*(*dzfdcI); + +// // dLdx +// dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, WxgT) + mmul(oFactor, WxoT)); // [bS, nIn](or[nOut]) +// // NDArray temp = c.ulike(); +// // applyActivation(c, params[9], params[10], params[11], temp); // actH(c) +// // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT)); // [bS, nIn](or[nOut]) + +// // dLdhI +// NDArray* dLdhII = dLdhI; +// if(dLdcI && !dLdhI) +// dLdhII = new NDArray(dLdcI->ulike()); +// dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, WrgT) + mmul(oFactor, WroT)); // [bS, nOut](or[nOut]) + +// if(firstIter) { + +// // dLdcI +// if(dLdcI) +// dLdcI->assign(factor*tempC); // [bS, nOut](or[nOut]) + +// // dcIdWxi(dcdWxi) +// dcIdWxi.assign(dcdzi*dWxJacobian); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// // dcIdWxf(dcdWxf) +// dcIdWxf.assign(dcdzf*dWxJacobian); +// // dcIdWxg(dcdWxg) +// dcIdWxg.assign(dcdzg*dWxJacobian); +// // dcIdWxo(dcdWxo) = 0 +// dcIdWxo.nullify(); + +// // dhIdWxi +// dhIdWxi.assign(dhdc*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// // dhIdWxf +// dhIdWxf.assign(dhdc*dcIdWxf); +// // dhIdWxg +// dhIdWxg.assign(dhdc*dcIdWxg); +// // dhIdWxo +// dhIdWxo.assign(dhdzo*dWxJacobian /*+ 0 */); + +// // dcIdWri(dcdWri) +// dcIdWri.assign(dcdzi*dWrJacobian); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);; +// // dcIdWrf(dcdWrf) +// dcIdWrf.assign(dcdzf*dWrJacobian); +// // dcIdWrg(dcdWrg) +// dcIdWrg.assign(dcdzg*dWrJacobian); +// // dcIdWro(dcdWro) = 0 +// dcIdWro.nullify(); + +// // dhIdWri +// dhIdWri.assign(dhdc*dcIdWri); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// // dhIdWrf +// dhIdWrf.assign(dhdc*dcIdWrf); +// // dhIdWrg +// dhIdWrg.assign(dhdc*dcIdWrg); +// // dhIdWro +// dhIdWro.assign(dhdzo*dWrJacobian /*+ 0 */); + +// if(Wp && x->rankOf() == 1) { +// // dcIdWpi +// dcIdWpi->assign(dcdzi*(*cI)); // [nOut] * [nOut] +// // dcIdWpf +// dcIdWpf->assign(dcdzf*(*cI)); // [nOut] * [nOut] +// // dcIdWpo +// dcIdWpo->nullify(); // [nOut] + +// // dhdWpi +// dhIdWpi->assign(dhdc*(*dcIdWpi)); // [nOut] * [nOut] +// // dhdWpf +// dhIdWpf->assign(dhdc*(*dcIdWpf)); // [nOut] * [nOut] +// // dhdWpo +// dhIdWpo->assign(dhdzo*c /* +0*/); // [nOut] * [nOut] +// } +// else if(Wp) { +// // dcIdWpi +// (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdWpf +// (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdWpo +// dcIdWpo->nullify(); // [nOut] + +// // dhIdWpi +// (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpf +// (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpo +// (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// } + +// if(b && x->rankOf() == 1) { +// // dcIdbi +// dcIdbi->assign(dcdzi); // [nOut] +// // dcIdbf +// dcIdbf->assign(dcdzf); // [nOut] +// // dcIdbg +// dcIdbg->assign(dcdzg); // [nOut] +// // dcIdbo +// dcIdbo->nullify(); // [nOut] + +// //dhIdbi +// dhIdbi->assign(dhdc*(*dcIdbi)); // [nOut] +// //dhIdbf +// dhIdbf->assign(dhdc*(*dcIdbf)); // [nOut] +// //dhIdbg +// dhIdbg->assign(dhdc*(*dcIdbg)); // [nOut] +// //dhIdbo +// dhIdbo->assign(dhdzo); // [nOut] + +// } +// else if(b) { +// // dcIdbi +// dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbf +// dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbg +// dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbo +// dcIdbo->nullify(); // [nOut] + +// //dhIdbi +// (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbf +// (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbg +// (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbo +// (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] + +// } +// } +// else { + +// NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, WrgT); +// NDArray tempO = mmul(dhdzo, WroT); + +// // dLdcI +// if(dLdcI) +// dLdcI->assign(factor*tempC + (*dLdhII)*(*dhIdcI)); + +// // dcIdWxi(dcdWxi) +// dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// // dcIdWxf(dcdWxf) +// dcIdWxf.assign(dcdzf*dWxJacobian + tempIFE*dhIdWxf + tempC*dcIdWxf); +// // dcIdWxg(dcdWxg) +// dcIdWxg.assign(dcdzg*dWxJacobian + tempIFE*dhIdWxg + tempC*dcIdWxg); +// // dcIdWxo(dcdWxo) +// dcIdWxo.assign(/* 0 + */tempIFE * dhIdWxo + tempC*dcIdWxo); + +// // dhIdWxi +// dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// // dhIdWxf +// dhIdWxf.assign(dhdc*dcIdWxf + tempO*dhIdWxf); +// // dhIdWxg +// dhIdWxg.assign(dhdc*dcIdWxg + tempO*dhIdWxg); +// // dhIdWxo +// dhIdWxo.assign(dhdzo*dWxJacobian + dhdc*dcIdWxo + tempO*dhIdWxo); + +// // dcIdWri(dcdWri) +// dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// // dcIdWrf(dcdWrf) +// dcIdWrf.assign(dcdzf*dWrJacobian + tempIFE*dhIdWrf + tempC*dcIdWrf); +// // dcIdWrg(dcdWrg) +// dcIdWrg.assign(dcdzg*dWrJacobian + tempIFE*dhIdWrg + tempC*dcIdWrg); +// // dcIdWro(dcdWro) +// dcIdWro.assign(/* 0 + */tempIFE * dhIdWro + tempC*dcIdWro); + +// // dhIdWri +// dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// // dhIdWrf +// dhIdWrf.assign(dhdc*dcIdWrf + tempO*dhIdWrf); +// // dhIdWrg +// dhIdWrg.assign(dhdc*dcIdWrg + tempO*dhIdWrg); +// // dhIdWro +// dhIdWro.assign(dhdzo*dWrJacobian + dhdc*dcIdWro + tempO*dhIdWro); + +// if(Wp && x->rankOf() == 1) { +// // dcIdWpi +// dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)); // [nOut] * [nOut] +// // dcIdWpf +// dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)); // [nOut] * [nOut] +// // dcIdWpo +// dcIdWpo->assign(/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)); // [nOut] * [nOut] + +// // dhdWpi +// dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)); // [nOut] * [nOut] +// // dhdWpf +// dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)); // [nOut] * [nOut] +// // dhdWpo +// dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)); // [nOut] * [nOut] +// } +// else if(Wp) { +// // dcIdWpi +// (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dcIdWpf +// (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dcIdWpo +// (/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] + +// // dhIdWpi +// (dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpf +// (dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpo +// (dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// } + +// if(b && x->rankOf() == 1) { +// // dcIdbi +// dcIdbi->assign(dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)); // [nOut] +// // dcIdbf +// dcIdbf->assign(dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)); // [nOut] +// // dcIdbg +// dcIdbg->assign(dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)); // [nOut] +// // dcIdbo +// dcIdbo->assign(/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)); // [nOut] + +// //dhIdbi +// dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi)); // [nOut] +// //dhIdbf +// dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf)); // [nOut] +// //dhIdbg +// dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg)); // [nOut] +// //dhIdbo +// dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)); // [nOut] + +// } +// else if(b) { +// // dcIdbi +// (dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbf +// (dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbg +// (dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbo +// (/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0}); // [bS, nOut]->reduce->[nOut] + +// //dhIdbi +// (dhdc*(*dcIdbi) + tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbf +// (dhdc*(*dcIdbf) + tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbg +// (dhdc*(*dcIdbg) + tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbo +// (dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] + +// } +// } + +// const std::vector dimsToExclude = x->rankOf() == 1 ? std::vector({2}) : std::vector({2, 3}); + +// // dLdWxi, dLdWxf, dLdWxg, dLdWxo +// (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, dimsToExclude); + +// // dLdWri, dLdWrf, dLdWrg, dLdWro +// (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, dimsToExclude); + +// // dLdWpi, dLdWpf, dLdWpo +// if(Wp) { +// if(x->rankOf() == 1) { +// (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi)); // [nOut] * [nOut] +// (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf)); // [nOut] * [nOut] +// (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo)); // [nOut] * [nOut] +// } +// else { +// // NDArray temp1 = (*dLdWp)({0, nOut}); +// // NDArray temp2 = (*dLdWp)({nOut, 2*nOut}); +// // NDArray temp3 = (*dLdWp)({2*nOut, 3*nOut}); +// // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdWp)({0, nOut}).assign(dhIdWpi); +// (*dLdWp)({nOut, 2*nOut}).assign(dhIdWpf); +// (*dLdWp)({2*nOut, 3*nOut}).assign(dhIdWpo); +// } +// } + +// // dLdbi, dLdbf, dLdbg, dLdbo +// if(b) { +// if(x->rankOf() == 1) { +// (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi)); // [nOut] * [nOut] +// (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf)); // [nOut] * [nOut] +// (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg)); // [nOut] * [nOut] +// (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo)); // [nOut] * [nOut] +// } +// else { +// // NDArray temp1 = (*dLdb)({0, nOut}); +// // NDArray temp2 = (*dLdb)({nOut, 2*nOut}); +// // NDArray temp3 = (*dLdb)({2*nOut, 3*nOut}); +// // NDArray temp4 = (*dLdb)({3*nOut, 4*nOut}); +// // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdb)({0, nOut}).assign(dhIdbi); +// (*dLdb)({nOut, 2*nOut}).assign(dhIdbf); +// (*dLdb)({2*nOut, 3*nOut}).assign(dhIdbg); +// (*dLdb)({3*nOut, 4*nOut}).assign(dhIdbo); +// } +// } + +// //dhIdcI +// if(dLdcI) +// dhIdcI->assign(dhdc); + +// cI->assign(c); + +// if(dLdcI && !dLdhI) +// delete dLdhII; +// if(Wp) { +// delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo; +// } +// if(b) { +// delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo; +// } +// } diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index dfa9268b4..3a2d173b5 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -22,7 +22,6 @@ #define LIBND4J_LSTMLAYER_H #include -#include namespace sd { namespace ops { @@ -34,6 +33,20 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra const std::vector& params, NDArray* h, NDArray* c); +////////////////////////////////////////////////////////////////////////// +// this auxiliary ff should be running before backprop +void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + NDArray* z, NDArray* a, NDArray* h, NDArray* c); + +////////////////////////////////////////////////////////////////////////// +void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdc, + const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); + + ////////////////////////////////////////////////////////////////////////// void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, @@ -42,71 +55,11 @@ void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const ND NDArray* h, NDArray* hL, NDArray* cL); ////////////////////////////////////////////////////////////////////////// -static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { - - switch (opId) { - case 0: - (const_cast(x)).applyTransform(transform::Tanh, z); - break; - case 1: - (const_cast(x)).applyScalar(scalar::RELU, 0, z); - break; - case 2: - (const_cast(x)).applyTransform(transform::Sigmoid, z); - break; - case 3: { - ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::Affine, z, &args); - break; - } - case 4: - (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); - break; - case 5: - helpers::thresholdRelu(x.getContext(), x, alpha, z); - break; - case 6: { - ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); - break; - } - case 7: - (const_cast(x)).applyTransform(transform::HardSigmoid, z); - break; - case 8: - (const_cast(x)).applyScalar(scalar::ELU, alpha, z); - break; - case 9: - (const_cast(x)).applyTransform(transform::SoftSign, z); - break; - case 10: - (const_cast(x)).applyTransform(transform::SoftPlus, z); - break; - default: - throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); - } -} - -////////////////////////////////////////////////////////////////////////// -static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) { - - if(dataFormat == 0 || dataFormat == 3) - return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn] - - if(dataFormat == 1) - return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn] - - return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL] -} - -////////////////////////////////////////////////////////////////////////// -static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) { - - if(dataFormat == 0 || dataFormat == 3) - return t * bS + b; // TNS: shape [sL, bS, nIn] - - return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] -} +void ND4J_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, + const std::vector& params, const bool forward, + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp); } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index e49165e78..2f02af11b 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1441,7 +1441,7 @@ namespace simdOps { } op_def static Z op(X d1) { - return d1; + return static_cast(d1); } }; @@ -2434,6 +2434,19 @@ namespace simdOps { } }; + template + class RELUDerivative { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static Z op(X d1, Y d2, Z *params) { + auto xt = static_cast(d1); + auto xf = static_cast(d2); + return xt > xf ? static_cast(1.f) : static_cast(0.f); + } + }; + template class SXELogitsSmoother { public: diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 4b5a24bb9..cee574dec 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -77,7 +77,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) { auto z = result.at(0); ASSERT_TRUE(z->isEmpty()); - + } TEST_F(DeclarableOpsTests13, test_empty_range_2) { @@ -262,7 +262,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) { ASSERT_EQ(result.status(), Status::OK()); //result.at(0)->printBuffer("Output"); ASSERT_TRUE(exp1.equalsTo(result.at(0))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) { @@ -286,7 +286,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) { ASSERT_EQ(result.status(), Status::OK()); //result.at(0)->printBuffer("Output"); ASSERT_TRUE(exp.equalsTo(result.at(0))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { @@ -312,7 +312,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { //exp.printBuffer("Expect"); //result.at(0)->printShapeInfo("Shape output"); ASSERT_TRUE(exp.equalsTo(result.at(0))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) { @@ -349,7 +349,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) { //result.at(2)->printBuffer("Symmetrized2"); // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); ASSERT_TRUE(exp.equalsTo(result.at(2))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { @@ -369,7 +369,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); //ASSERT_TRUE(exp.equalsTo(result.at(0))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { @@ -398,7 +398,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); ASSERT_TRUE(exp4.equalsTo(res)); - + } TEST_F(DeclarableOpsTests13, CellContains_test_1) { @@ -420,7 +420,7 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) { //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); //ASSERT_TRUE(exp.equalsTo(result.at(0))); - + } //////////////////////////////////////////////////////////////////// @@ -712,7 +712,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_2) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) { @@ -1109,6 +1109,7 @@ TEST_F(DeclarableOpsTests13, mergeavg_bp_1) { } } + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_1) { @@ -1200,7 +1201,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = false; // peephole connections are absent const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step + const auto retLastH = true; // return output at last time step const auto retLastC = true; // return cells state at last time step const double cellClip = 0; // do not apply clipping @@ -1398,7 +1399,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + } /////////////////////////////////////////////////////////////////// @@ -1640,7 +1641,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_7) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -1718,7 +1719,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -1805,7 +1806,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -1890,7 +1891,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -1970,7 +1971,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -2061,10 +2062,528 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - #endif } +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = false; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::SUM, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = false; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = false; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = false; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { + + const int sL = 4; + const int bS = 3; + const int nIn = 3; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = false; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {2,0,4}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = false; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = false; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 2; // bidirectional sum + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = false; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = false; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 3; // [sL, bS, nIn] + const int directionMode = 4; // bidirectional extra output dim + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = false; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test1) { @@ -2091,7 +2610,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test1) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////// @@ -2233,7 +2752,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test6) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////// @@ -2345,7 +2864,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) { ASSERT_TRUE(expected.isSameShape(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2387,7 +2906,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) { ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + } @@ -2642,7 +3161,7 @@ return; ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + } //////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 5636b2e29..166ba058f 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -844,5 +844,78 @@ TEST_F(PlaygroundTests, my) { printf("time: %i \n", time); } +/////////////////////////////////////////////////////////////////// +TEST_F(PlaygroundTests, lstmLayerCellBp_1) { + + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + // const int nIn = 8; + // const int nOut = 6; + + const float cellClip = 1.1; // clipping value + const Nd4jLong gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const Nd4jLong cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const Nd4jLong outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdc('c', {bS, nOut}, sd::DataType::DOUBLE); + + // NDArray x ('c', {nIn}, sd::DataType::DOUBLE); + // NDArray hI('c', {nOut}, sd::DataType::DOUBLE); + // NDArray cI('c', {nOut}, sd::DataType::DOUBLE); + // NDArray dLdh('c', {nOut}, sd::DataType::DOUBLE); + // NDArray dLdc('c', {nOut}, sd::DataType::DOUBLE); + + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b ('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + + x.linspace(-4,1); + hI.linspace(-2.5,0.5); + cI.linspace(-3,0.5); + Wx.linspace(0,0.1); + Wr.linspace(3,-0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + // x.assign(1.); + // hI.assign(2.); + // cI.assign(3.); + // Wx.assign(0.5); + // Wr.assign(0.5); + // Wp.assign(0.75); + // b.assign(0.7); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {gateAct, cellAct, outAct}; + + // std::vector bArgs = {false, false}; + // const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &hI, &cI}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &hI, &cI, &dLdh}, tArgs, iArgs, bArgs); + + std::vector bArgs = {true, true}; + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayerCell opFF; + sd::ops::lstmLayerCellBp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, true, true, true}); +} + + */ + + From 6fcd078c5e0be53712564091768219c820881c05 Mon Sep 17 00:00:00 2001 From: Samuel Audet Date: Tue, 14 Apr 2020 18:36:14 +0900 Subject: [PATCH 19/19] Update dependencies to just released JavaCPP and JavaCV 1.5.3 (#374) Signed-off-by: Samuel Audet --- change-cuda-versions.sh | 2 +- deeplearning4j/deeplearning4j-cuda/pom.xml | 2 +- .../templates/android-image-classification.md | 30 +++++++++---------- .../templates/android-linear-classifier.md | 30 +++++++++---------- .../templates/android-prerequisites.md | 30 +++++++++---------- docs/deeplearning4j/templates/android.md | 30 +++++++++---------- docs/deeplearning4j/templates/config-cudnn.md | 2 +- .../nd4j-cuda-platform/pom.xml | 2 +- .../nd4j-backend-impls/nd4j-cuda/pom.xml | 2 +- pom.xml | 6 ++-- pydl4j/pydl4j/pom.py | 6 ++-- 11 files changed, 71 insertions(+), 71 deletions(-) diff --git a/change-cuda-versions.sh b/change-cuda-versions.sh index 21f17bb72..7b354d68b 100755 --- a/change-cuda-versions.sh +++ b/change-cuda-versions.sh @@ -49,7 +49,7 @@ check_cuda_version "$VERSION" case $VERSION in 10.2) VERSION2="7.6" - VERSION3="1.5.2" + VERSION3="1.5.3" ;; 10.1) VERSION2="7.6" diff --git a/deeplearning4j/deeplearning4j-cuda/pom.xml b/deeplearning4j/deeplearning4j-cuda/pom.xml index dfdc76efb..30373db3a 100644 --- a/deeplearning4j/deeplearning4j-cuda/pom.xml +++ b/deeplearning4j/deeplearning4j-cuda/pom.xml @@ -28,7 +28,7 @@ 10.2 7.6 - 1.5.2 + 1.5.3 diff --git a/docs/deeplearning4j/templates/android-image-classification.md b/docs/deeplearning4j/templates/android-image-classification.md index 00931b17d..d0cc8f558 100644 --- a/docs/deeplearning4j/templates/android-image-classification.md +++ b/docs/deeplearning4j/templates/android-image-classification.md @@ -40,21 +40,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version} implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2' -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2' -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2' -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3' +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3' +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3' +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64" implementation 'com.google.code.gson:gson:2.8.2' annotationProcessor 'org.projectlombok:lombok:1.16.16' diff --git a/docs/deeplearning4j/templates/android-linear-classifier.md b/docs/deeplearning4j/templates/android-linear-classifier.md index b362279a7..b6fe5352c 100644 --- a/docs/deeplearning4j/templates/android-linear-classifier.md +++ b/docs/deeplearning4j/templates/android-linear-classifier.md @@ -35,21 +35,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version} implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2' -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2' -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2' -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3' +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3' +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3' +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64" ``` Compiling these dependencies involves a large number of files, thus it is necessary to set multiDexEnabled to true in defaultConfig. diff --git a/docs/deeplearning4j/templates/android-prerequisites.md b/docs/deeplearning4j/templates/android-prerequisites.md index d0347fc01..43b4d26bd 100644 --- a/docs/deeplearning4j/templates/android-prerequisites.md +++ b/docs/deeplearning4j/templates/android-prerequisites.md @@ -43,21 +43,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version} implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2' -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2' -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2' -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3' +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3' +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3' +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64" testimplementation 'junit:junit:4.12' ``` diff --git a/docs/deeplearning4j/templates/android.md b/docs/deeplearning4j/templates/android.md index cf705ba5e..92d302619 100644 --- a/docs/deeplearning4j/templates/android.md +++ b/docs/deeplearning4j/templates/android.md @@ -46,21 +46,21 @@ implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version} implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-arm64" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86" implementation group: 'org.nd4j', name: 'nd4j-native', version: '{{page.version}}', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2' -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2' -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2' -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86" -implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3' +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.9-1.5.3', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3' +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'opencv', version: '4.3.0-1.5.3', classifier: "android-x86_64" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3' +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-arm64" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86" +implementation group: 'org.bytedeco', name: 'leptonica', version: '1.79.0-1.5.3', classifier: "android-x86_64" ``` diff --git a/docs/deeplearning4j/templates/config-cudnn.md b/docs/deeplearning4j/templates/config-cudnn.md index 24f69da87..64d248fe3 100644 --- a/docs/deeplearning4j/templates/config-cudnn.md +++ b/docs/deeplearning4j/templates/config-cudnn.md @@ -62,7 +62,7 @@ Alternatively, in the case of CUDA 10.2, cuDNN comes bundled with the "redist" p org.bytedeco cuda-platform-redist - 10.2-7.6-1.5.2 + 10.2-7.6-1.5.3 Also note that, by default, Deeplearning4j will use the fastest algorithms available according to cuDNN, but memory usage may be excessive, causing strange launch errors. When this happens, try to reduce memory usage by using the [`NO_WORKSPACE` mode settable via the network configuration](/api/{{page.version}}/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.Builder.html#cudnnAlgoMode-org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode-), instead of the default of `ConvolutionLayer.AlgoMode.PREFER_FASTEST`, for example: diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-platform/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-platform/pom.xml index 027b49844..344e77861 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-platform/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-platform/pom.xml @@ -29,7 +29,7 @@ 10.2 7.6 - 1.5.2 + 1.5.3 nd4j-cuda-${cuda.version} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index b450e58b6..371386898 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -29,7 +29,7 @@ 10.2 7.6 - ${javacpp-presets.version} + 1.5.3 diff --git a/pom.xml b/pom.xml index 17708b222..15af4658d 100644 --- a/pom.xml +++ b/pom.xml @@ -288,9 +288,9 @@ ${javacpp.platform} - 1.5.3-SNAPSHOT - 1.5.3-SNAPSHOT - 1.5.3-SNAPSHOT + 1.5.3 + 1.5.3 + 1.5.3 3.7.7 ${python.version}-${javacpp-presets.version} diff --git a/pydl4j/pydl4j/pom.py b/pydl4j/pydl4j/pom.py index ec6e3ecb3..ad76dca97 100644 --- a/pydl4j/pydl4j/pom.py +++ b/pydl4j/pydl4j/pom.py @@ -118,9 +118,9 @@ def pom_template(): 3.0.0 - 1.5 - 1.5 - 0.3.5 + 1.5.3 + 1.5.3 + 0.3.9